feat: implement training loop and CLI (src/train.py, scripts/train.py)
AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch CSV logging, best-val-loss checkpointing, early stopping (patience=5). Same script handles both pre-training and fine-tuning via --init-from. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+340
@@ -0,0 +1,340 @@
|
||||
"""Training logic for ChordTransformer.
|
||||
|
||||
Public API:
|
||||
TrainConfig -- dataclass collecting all hyperparameters
|
||||
train -- run pre-training or fine-tuning
|
||||
|
||||
This module contains no argument parsing. See scripts/train.py for the CLI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset import ChordDataset
|
||||
from src.model import ChordTransformer
|
||||
from src.tokenizer import TOKEN_TO_ID
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
"""All hyperparameters and I/O paths for one training run.
|
||||
|
||||
Args:
|
||||
data_dir: Directory produced by prepare_data.py; must contain
|
||||
``train/`` and ``val/`` sub-directories with tokenized .pt files.
|
||||
output: Path prefix for the saved checkpoint (.pt) and log (.log.csv).
|
||||
init_from: Optional checkpoint to load weights from before training
|
||||
(used for fine-tuning).
|
||||
epochs: Maximum number of training epochs.
|
||||
batch_size: Mini-batch size for the DataLoader.
|
||||
lr: Peak learning rate for AdamW.
|
||||
warmup_steps: Number of linear warm-up steps before cosine decay.
|
||||
weight_decay: L2 penalty for AdamW (applied to non-bias parameters).
|
||||
seed: Random seed for reproducibility.
|
||||
device: 'cpu', 'cuda', or 'auto' (picks cuda if available).
|
||||
patience: Early-stopping patience in epochs (on val loss).
|
||||
max_seq_len: Passed to ChordDataset and ChordTransformer.
|
||||
# Model architecture (ignored when init_from is set)
|
||||
d_model: Transformer hidden dimension.
|
||||
n_layers: Number of transformer blocks.
|
||||
n_heads: Number of attention heads.
|
||||
d_ff: FFN inner dimension.
|
||||
dropout: Dropout probability.
|
||||
"""
|
||||
|
||||
data_dir: Path
|
||||
output: Path
|
||||
init_from: Optional[Path] = None
|
||||
epochs: int = 30
|
||||
batch_size: int = 16
|
||||
lr: float = 3e-4
|
||||
warmup_steps: int = 200
|
||||
weight_decay: float = 0.1
|
||||
seed: int = 42
|
||||
device: str = "auto"
|
||||
patience: int = 5
|
||||
max_seq_len: int = 512
|
||||
# Architecture (ignored when init_from is set)
|
||||
d_model: int = 192
|
||||
n_layers: int = 3
|
||||
n_heads: int = 6
|
||||
d_ff: int = 768
|
||||
dropout: float = 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_device(spec: str) -> torch.device:
|
||||
if spec == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(spec)
|
||||
|
||||
|
||||
def _set_seeds(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _build_scheduler(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: int,
|
||||
total_steps: int,
|
||||
) -> torch.optim.lr_scheduler.LambdaLR:
|
||||
"""Linear warm-up followed by cosine decay to 0."""
|
||||
|
||||
def _lr_lambda(step: int) -> float:
|
||||
if step < warmup_steps:
|
||||
return step / max(1, warmup_steps)
|
||||
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
||||
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda)
|
||||
|
||||
|
||||
def _make_attention_mask(batch: torch.Tensor, pad_id: int) -> torch.Tensor:
|
||||
"""Return 1 where token != PAD, 0 elsewhere."""
|
||||
return (batch != pad_id).long()
|
||||
|
||||
|
||||
def _run_epoch(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
criterion: nn.CrossEntropyLoss,
|
||||
device: torch.device,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
) -> float:
|
||||
"""Run one epoch; return mean loss over all non-PAD tokens."""
|
||||
training = optimizer is not None
|
||||
model.train(training)
|
||||
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
with torch.set_grad_enabled(training):
|
||||
for batch in loader:
|
||||
batch = batch.to(device) # [B, T]
|
||||
attention_mask = _make_attention_mask(batch, _PAD_ID).to(device)
|
||||
|
||||
# Causal LM: predict token t+1 from tokens 0..t
|
||||
input_ids = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
attn_mask = attention_mask[:, :-1]
|
||||
|
||||
logits = model(input_ids, attention_mask=attn_mask)
|
||||
# logits: [B, T-1, V] targets: [B, T-1]
|
||||
loss = criterion(
|
||||
logits.reshape(-1, logits.size(-1)),
|
||||
targets.reshape(-1),
|
||||
)
|
||||
|
||||
if training:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# Accumulate token-level loss (criterion already ignores PAD)
|
||||
non_pad = (targets != _PAD_ID).sum().item()
|
||||
total_loss += loss.item() * non_pad
|
||||
total_tokens += non_pad
|
||||
|
||||
if total_tokens == 0:
|
||||
return float("inf")
|
||||
return total_loss / total_tokens
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def train(cfg: TrainConfig) -> Path:
|
||||
"""Run training according to *cfg*.
|
||||
|
||||
Args:
|
||||
cfg: Fully populated TrainConfig.
|
||||
|
||||
Returns:
|
||||
Path to the best checkpoint saved during training.
|
||||
"""
|
||||
_set_seeds(cfg.seed)
|
||||
device = _resolve_device(cfg.device)
|
||||
log.info("device: %s", device)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data
|
||||
# ------------------------------------------------------------------
|
||||
train_dir = Path(cfg.data_dir) / "train"
|
||||
val_dir = Path(cfg.data_dir) / "val"
|
||||
train_ds = ChordDataset(train_dir, max_length=cfg.max_seq_len)
|
||||
val_ds = ChordDataset(val_dir, max_length=cfg.max_seq_len)
|
||||
log.info("train: %d periods val: %d periods", len(train_ds), len(val_ds))
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
||||
drop_last=False, num_workers=0,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds, batch_size=cfg.batch_size, shuffle=False,
|
||||
drop_last=False, num_workers=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model
|
||||
# ------------------------------------------------------------------
|
||||
vocab_size = len(TOKEN_TO_ID)
|
||||
|
||||
if cfg.init_from is not None:
|
||||
log.info("loading weights from %s", cfg.init_from)
|
||||
ckpt = torch.load(cfg.init_from, map_location=device, weights_only=True)
|
||||
model_cfg = ckpt["model_config"]
|
||||
model = ChordTransformer(**model_cfg)
|
||||
model.load_state_dict(ckpt["model_state"])
|
||||
else:
|
||||
model_cfg = dict(
|
||||
vocab_size=vocab_size,
|
||||
d_model=cfg.d_model,
|
||||
n_layers=cfg.n_layers,
|
||||
n_heads=cfg.n_heads,
|
||||
d_ff=cfg.d_ff,
|
||||
max_seq_len=cfg.max_seq_len,
|
||||
dropout=cfg.dropout,
|
||||
)
|
||||
model = ChordTransformer(**model_cfg)
|
||||
|
||||
model = model.to(device)
|
||||
n_params = sum(p.numel() for p in model.parameters()) - model.token_emb.weight.numel()
|
||||
log.info("model: %d unique parameters", n_params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optimizer, scheduler, loss
|
||||
# ------------------------------------------------------------------
|
||||
# Apply weight decay only to weight matrices, not biases / LayerNorm params.
|
||||
decay_params = [
|
||||
p for n, p in model.named_parameters()
|
||||
if p.requires_grad and p.dim() >= 2
|
||||
]
|
||||
nodecay_params = [
|
||||
p for n, p in model.named_parameters()
|
||||
if p.requires_grad and p.dim() < 2
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
[
|
||||
{"params": decay_params, "weight_decay": cfg.weight_decay},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
],
|
||||
lr=cfg.lr,
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
|
||||
total_steps = len(train_loader) * cfg.epochs
|
||||
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Logging setup
|
||||
# ------------------------------------------------------------------
|
||||
output_path = Path(cfg.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_path = output_path.with_suffix(".pt")
|
||||
log_csv_path = Path(str(output_path) + ".log.csv")
|
||||
|
||||
with open(log_csv_path, "w", newline="") as fh:
|
||||
writer = csv.writer(fh)
|
||||
writer.writerow(["epoch", "train_loss", "val_loss", "val_ppl", "lr", "elapsed_s"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop
|
||||
# ------------------------------------------------------------------
|
||||
best_val_loss = float("inf")
|
||||
epochs_without_improvement = 0
|
||||
run_start = time.monotonic()
|
||||
|
||||
for epoch in range(1, cfg.epochs + 1):
|
||||
epoch_start = time.monotonic()
|
||||
|
||||
train_loss = _run_epoch(
|
||||
model, train_loader, criterion, device, optimizer, scheduler
|
||||
)
|
||||
val_loss = _run_epoch(model, val_loader, criterion, device)
|
||||
val_ppl = math.exp(min(val_loss, 20)) # cap to avoid overflow
|
||||
elapsed = time.monotonic() - epoch_start
|
||||
current_lr = scheduler.get_last_lr()[0]
|
||||
|
||||
log.info(
|
||||
"epoch %3d/%d train=%.4f val=%.4f ppl=%.1f lr=%.2e %.0fs",
|
||||
epoch, cfg.epochs, train_loss, val_loss, val_ppl, current_lr, elapsed,
|
||||
)
|
||||
|
||||
with open(log_csv_path, "a", newline="") as fh:
|
||||
csv.writer(fh).writerow(
|
||||
[epoch, f"{train_loss:.6f}", f"{val_loss:.6f}",
|
||||
f"{val_ppl:.2f}", f"{current_lr:.6e}", f"{elapsed:.1f}"]
|
||||
)
|
||||
|
||||
# Checkpoint
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_without_improvement = 0
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_config": model_cfg,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
"train_config": {
|
||||
k: str(v) if isinstance(v, Path) else v
|
||||
for k, v in cfg.__dict__.items()
|
||||
},
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
log.info(" → saved best checkpoint (val_loss=%.4f)", val_loss)
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
log.info(
|
||||
" no improvement for %d/%d epochs",
|
||||
epochs_without_improvement, cfg.patience,
|
||||
)
|
||||
if epochs_without_improvement >= cfg.patience:
|
||||
log.info("early stopping triggered at epoch %d", epoch)
|
||||
break
|
||||
|
||||
total_elapsed = time.monotonic() - run_start
|
||||
log.info(
|
||||
"training finished: best val_loss=%.4f total=%.0fs checkpoint=%s",
|
||||
best_val_loss, total_elapsed, checkpoint_path,
|
||||
)
|
||||
return checkpoint_path
|
||||
Reference in New Issue
Block a user