feat: implement training loop and CLI (src/train.py, scripts/train.py)

AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch
CSV logging, best-val-loss checkpointing, early stopping (patience=5).
Same script handles both pre-training and fine-tuning via --init-from.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 11:15:39 +03:00
parent 10229be042
commit 733e1fde1f
2 changed files with 460 additions and 0 deletions
+340
View File
@@ -0,0 +1,340 @@
"""Training logic for ChordTransformer.
Public API:
TrainConfig -- dataclass collecting all hyperparameters
train -- run pre-training or fine-tuning
This module contains no argument parsing. See scripts/train.py for the CLI.
"""
from __future__ import annotations
import csv
import logging
import math
import random
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.dataset import ChordDataset
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class TrainConfig:
"""All hyperparameters and I/O paths for one training run.
Args:
data_dir: Directory produced by prepare_data.py; must contain
``train/`` and ``val/`` sub-directories with tokenized .pt files.
output: Path prefix for the saved checkpoint (.pt) and log (.log.csv).
init_from: Optional checkpoint to load weights from before training
(used for fine-tuning).
epochs: Maximum number of training epochs.
batch_size: Mini-batch size for the DataLoader.
lr: Peak learning rate for AdamW.
warmup_steps: Number of linear warm-up steps before cosine decay.
weight_decay: L2 penalty for AdamW (applied to non-bias parameters).
seed: Random seed for reproducibility.
device: 'cpu', 'cuda', or 'auto' (picks cuda if available).
patience: Early-stopping patience in epochs (on val loss).
max_seq_len: Passed to ChordDataset and ChordTransformer.
# Model architecture (ignored when init_from is set)
d_model: Transformer hidden dimension.
n_layers: Number of transformer blocks.
n_heads: Number of attention heads.
d_ff: FFN inner dimension.
dropout: Dropout probability.
"""
data_dir: Path
output: Path
init_from: Optional[Path] = None
epochs: int = 30
batch_size: int = 16
lr: float = 3e-4
warmup_steps: int = 200
weight_decay: float = 0.1
seed: int = 42
device: str = "auto"
patience: int = 5
max_seq_len: int = 512
# Architecture (ignored when init_from is set)
d_model: int = 192
n_layers: int = 3
n_heads: int = 6
d_ff: int = 768
dropout: float = 0.1
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _resolve_device(spec: str) -> torch.device:
if spec == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(spec)
def _set_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _build_scheduler(
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
"""Linear warm-up followed by cosine decay to 0."""
def _lr_lambda(step: int) -> float:
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda)
def _make_attention_mask(batch: torch.Tensor, pad_id: int) -> torch.Tensor:
"""Return 1 where token != PAD, 0 elsewhere."""
return (batch != pad_id).long()
def _run_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.CrossEntropyLoss,
device: torch.device,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
) -> float:
"""Run one epoch; return mean loss over all non-PAD tokens."""
training = optimizer is not None
model.train(training)
total_loss = 0.0
total_tokens = 0
with torch.set_grad_enabled(training):
for batch in loader:
batch = batch.to(device) # [B, T]
attention_mask = _make_attention_mask(batch, _PAD_ID).to(device)
# Causal LM: predict token t+1 from tokens 0..t
input_ids = batch[:, :-1]
targets = batch[:, 1:]
attn_mask = attention_mask[:, :-1]
logits = model(input_ids, attention_mask=attn_mask)
# logits: [B, T-1, V] targets: [B, T-1]
loss = criterion(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
)
if training:
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Accumulate token-level loss (criterion already ignores PAD)
non_pad = (targets != _PAD_ID).sum().item()
total_loss += loss.item() * non_pad
total_tokens += non_pad
if total_tokens == 0:
return float("inf")
return total_loss / total_tokens
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def train(cfg: TrainConfig) -> Path:
"""Run training according to *cfg*.
Args:
cfg: Fully populated TrainConfig.
Returns:
Path to the best checkpoint saved during training.
"""
_set_seeds(cfg.seed)
device = _resolve_device(cfg.device)
log.info("device: %s", device)
# ------------------------------------------------------------------
# Data
# ------------------------------------------------------------------
train_dir = Path(cfg.data_dir) / "train"
val_dir = Path(cfg.data_dir) / "val"
train_ds = ChordDataset(train_dir, max_length=cfg.max_seq_len)
val_ds = ChordDataset(val_dir, max_length=cfg.max_seq_len)
log.info("train: %d periods val: %d periods", len(train_ds), len(val_ds))
train_loader = DataLoader(
train_ds, batch_size=cfg.batch_size, shuffle=True,
drop_last=False, num_workers=0,
)
val_loader = DataLoader(
val_ds, batch_size=cfg.batch_size, shuffle=False,
drop_last=False, num_workers=0,
)
# ------------------------------------------------------------------
# Model
# ------------------------------------------------------------------
vocab_size = len(TOKEN_TO_ID)
if cfg.init_from is not None:
log.info("loading weights from %s", cfg.init_from)
ckpt = torch.load(cfg.init_from, map_location=device, weights_only=True)
model_cfg = ckpt["model_config"]
model = ChordTransformer(**model_cfg)
model.load_state_dict(ckpt["model_state"])
else:
model_cfg = dict(
vocab_size=vocab_size,
d_model=cfg.d_model,
n_layers=cfg.n_layers,
n_heads=cfg.n_heads,
d_ff=cfg.d_ff,
max_seq_len=cfg.max_seq_len,
dropout=cfg.dropout,
)
model = ChordTransformer(**model_cfg)
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters()) - model.token_emb.weight.numel()
log.info("model: %d unique parameters", n_params)
# ------------------------------------------------------------------
# Optimizer, scheduler, loss
# ------------------------------------------------------------------
# Apply weight decay only to weight matrices, not biases / LayerNorm params.
decay_params = [
p for n, p in model.named_parameters()
if p.requires_grad and p.dim() >= 2
]
nodecay_params = [
p for n, p in model.named_parameters()
if p.requires_grad and p.dim() < 2
]
optimizer = torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": cfg.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
],
lr=cfg.lr,
betas=(0.9, 0.95),
)
total_steps = len(train_loader) * cfg.epochs
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID)
# ------------------------------------------------------------------
# Logging setup
# ------------------------------------------------------------------
output_path = Path(cfg.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint_path = output_path.with_suffix(".pt")
log_csv_path = Path(str(output_path) + ".log.csv")
with open(log_csv_path, "w", newline="") as fh:
writer = csv.writer(fh)
writer.writerow(["epoch", "train_loss", "val_loss", "val_ppl", "lr", "elapsed_s"])
# ------------------------------------------------------------------
# Training loop
# ------------------------------------------------------------------
best_val_loss = float("inf")
epochs_without_improvement = 0
run_start = time.monotonic()
for epoch in range(1, cfg.epochs + 1):
epoch_start = time.monotonic()
train_loss = _run_epoch(
model, train_loader, criterion, device, optimizer, scheduler
)
val_loss = _run_epoch(model, val_loader, criterion, device)
val_ppl = math.exp(min(val_loss, 20)) # cap to avoid overflow
elapsed = time.monotonic() - epoch_start
current_lr = scheduler.get_last_lr()[0]
log.info(
"epoch %3d/%d train=%.4f val=%.4f ppl=%.1f lr=%.2e %.0fs",
epoch, cfg.epochs, train_loss, val_loss, val_ppl, current_lr, elapsed,
)
with open(log_csv_path, "a", newline="") as fh:
csv.writer(fh).writerow(
[epoch, f"{train_loss:.6f}", f"{val_loss:.6f}",
f"{val_ppl:.2f}", f"{current_lr:.6e}", f"{elapsed:.1f}"]
)
# Checkpoint
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
torch.save(
{
"epoch": epoch,
"model_config": model_cfg,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"val_loss": val_loss,
"train_config": {
k: str(v) if isinstance(v, Path) else v
for k, v in cfg.__dict__.items()
},
},
checkpoint_path,
)
log.info(" → saved best checkpoint (val_loss=%.4f)", val_loss)
else:
epochs_without_improvement += 1
log.info(
" no improvement for %d/%d epochs",
epochs_without_improvement, cfg.patience,
)
if epochs_without_improvement >= cfg.patience:
log.info("early stopping triggered at epoch %d", epoch)
break
total_elapsed = time.monotonic() - run_start
log.info(
"training finished: best val_loss=%.4f total=%.0fs checkpoint=%s",
best_val_loss, total_elapsed, checkpoint_path,
)
return checkpoint_path