From 733e1fde1fc74703abde78ca0c0a298799dcddff Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Wed, 20 May 2026 11:15:39 +0300 Subject: [PATCH] feat: implement training loop and CLI (src/train.py, scripts/train.py) AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch CSV logging, best-val-loss checkpointing, early stopping (patience=5). Same script handles both pre-training and fine-tuning via --init-from. Co-Authored-By: Claude Sonnet 4.6 --- scripts/train.py | 120 +++++++++++++++++ src/train.py | 340 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 460 insertions(+) create mode 100644 scripts/train.py create mode 100644 src/train.py diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..1b9f76d --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,120 @@ +"""CLI entry point for pre-training and fine-tuning ChordTransformer. + +Usage (pre-training): + python scripts/train.py \\ + --data-dir data/processed/pretrain \\ + --output checkpoints/pretrained \\ + --epochs 50 --batch-size 32 --lr 3e-4 + +Usage (fine-tuning): + python scripts/train.py \\ + --data-dir data/processed/finetune \\ + --init-from checkpoints/pretrained.pt \\ + --output checkpoints/finetuned \\ + --epochs 15 --lr 1e-5 + +The script saves: + .pt best checkpoint (lowest val loss) + .log.csv per-epoch metrics +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.train import TrainConfig, train # noqa: E402 + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Train or fine-tune ChordTransformer on tokenized chord data.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # I/O + io = p.add_argument_group("I/O") + io.add_argument( + "--data-dir", required=True, type=Path, + help="Directory with train/ and val/ sub-directories (output of prepare_data.py).", + ) + io.add_argument( + "--output", required=True, type=Path, + help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.", + ) + io.add_argument( + "--init-from", type=Path, default=None, + help="Checkpoint to load weights from before training (fine-tuning mode).", + ) + + # Training + tr = p.add_argument_group("Training") + tr.add_argument("--epochs", type=int, default=30) + tr.add_argument("--batch-size", type=int, default=16) + tr.add_argument("--lr", type=float, default=3e-4) + tr.add_argument("--warmup-steps", type=int, default=200) + tr.add_argument("--weight-decay", type=float, default=0.1) + tr.add_argument("--patience", type=int, default=5, + help="Early-stopping patience (epochs without val improvement).") + tr.add_argument("--seed", type=int, default=42) + tr.add_argument( + "--device", default="auto", choices=["auto", "cpu", "cuda"], + help="Compute device. 'auto' selects cuda when available.", + ) + + # Architecture (ignored when --init-from is given) + arch = p.add_argument_group("Architecture (ignored when --init-from is set)") + arch.add_argument("--d-model", type=int, default=192) + arch.add_argument("--n-layers", type=int, default=3) + arch.add_argument("--n-heads", type=int, default=6) + arch.add_argument("--d-ff", type=int, default=768) + arch.add_argument("--dropout", type=float, default=0.1) + arch.add_argument("--max-seq-len", type=int, default=512) + + # Logging + p.add_argument( + "--log-level", default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + + return p.parse_args() + + +def main() -> None: + args = _parse_args() + logging.basicConfig( + level=args.log_level, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + + cfg = TrainConfig( + data_dir=args.data_dir, + output=args.output, + init_from=args.init_from, + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + warmup_steps=args.warmup_steps, + weight_decay=args.weight_decay, + seed=args.seed, + device=args.device, + patience=args.patience, + max_seq_len=args.max_seq_len, + d_model=args.d_model, + n_layers=args.n_layers, + n_heads=args.n_heads, + d_ff=args.d_ff, + dropout=args.dropout, + ) + + checkpoint = train(cfg) + print(f"best checkpoint: {checkpoint}") + + +if __name__ == "__main__": + main() diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..3f48c21 --- /dev/null +++ b/src/train.py @@ -0,0 +1,340 @@ +"""Training logic for ChordTransformer. + +Public API: + TrainConfig -- dataclass collecting all hyperparameters + train -- run pre-training or fine-tuning + +This module contains no argument parsing. See scripts/train.py for the CLI. +""" + +from __future__ import annotations + +import csv +import logging +import math +import random +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from src.dataset import ChordDataset +from src.model import ChordTransformer +from src.tokenizer import TOKEN_TO_ID + +log = logging.getLogger(__name__) + +_PAD_ID: int = TOKEN_TO_ID[""] + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class TrainConfig: + """All hyperparameters and I/O paths for one training run. + + Args: + data_dir: Directory produced by prepare_data.py; must contain + ``train/`` and ``val/`` sub-directories with tokenized .pt files. + output: Path prefix for the saved checkpoint (.pt) and log (.log.csv). + init_from: Optional checkpoint to load weights from before training + (used for fine-tuning). + epochs: Maximum number of training epochs. + batch_size: Mini-batch size for the DataLoader. + lr: Peak learning rate for AdamW. + warmup_steps: Number of linear warm-up steps before cosine decay. + weight_decay: L2 penalty for AdamW (applied to non-bias parameters). + seed: Random seed for reproducibility. + device: 'cpu', 'cuda', or 'auto' (picks cuda if available). + patience: Early-stopping patience in epochs (on val loss). + max_seq_len: Passed to ChordDataset and ChordTransformer. + # Model architecture (ignored when init_from is set) + d_model: Transformer hidden dimension. + n_layers: Number of transformer blocks. + n_heads: Number of attention heads. + d_ff: FFN inner dimension. + dropout: Dropout probability. + """ + + data_dir: Path + output: Path + init_from: Optional[Path] = None + epochs: int = 30 + batch_size: int = 16 + lr: float = 3e-4 + warmup_steps: int = 200 + weight_decay: float = 0.1 + seed: int = 42 + device: str = "auto" + patience: int = 5 + max_seq_len: int = 512 + # Architecture (ignored when init_from is set) + d_model: int = 192 + n_layers: int = 3 + n_heads: int = 6 + d_ff: int = 768 + dropout: float = 0.1 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_device(spec: str) -> torch.device: + if spec == "auto": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(spec) + + +def _set_seeds(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _build_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, +) -> torch.optim.lr_scheduler.LambdaLR: + """Linear warm-up followed by cosine decay to 0.""" + + def _lr_lambda(step: int) -> float: + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda) + + +def _make_attention_mask(batch: torch.Tensor, pad_id: int) -> torch.Tensor: + """Return 1 where token != PAD, 0 elsewhere.""" + return (batch != pad_id).long() + + +def _run_epoch( + model: nn.Module, + loader: DataLoader, + criterion: nn.CrossEntropyLoss, + device: torch.device, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, +) -> float: + """Run one epoch; return mean loss over all non-PAD tokens.""" + training = optimizer is not None + model.train(training) + + total_loss = 0.0 + total_tokens = 0 + + with torch.set_grad_enabled(training): + for batch in loader: + batch = batch.to(device) # [B, T] + attention_mask = _make_attention_mask(batch, _PAD_ID).to(device) + + # Causal LM: predict token t+1 from tokens 0..t + input_ids = batch[:, :-1] + targets = batch[:, 1:] + attn_mask = attention_mask[:, :-1] + + logits = model(input_ids, attention_mask=attn_mask) + # logits: [B, T-1, V] targets: [B, T-1] + loss = criterion( + logits.reshape(-1, logits.size(-1)), + targets.reshape(-1), + ) + + if training: + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + # Accumulate token-level loss (criterion already ignores PAD) + non_pad = (targets != _PAD_ID).sum().item() + total_loss += loss.item() * non_pad + total_tokens += non_pad + + if total_tokens == 0: + return float("inf") + return total_loss / total_tokens + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def train(cfg: TrainConfig) -> Path: + """Run training according to *cfg*. + + Args: + cfg: Fully populated TrainConfig. + + Returns: + Path to the best checkpoint saved during training. + """ + _set_seeds(cfg.seed) + device = _resolve_device(cfg.device) + log.info("device: %s", device) + + # ------------------------------------------------------------------ + # Data + # ------------------------------------------------------------------ + train_dir = Path(cfg.data_dir) / "train" + val_dir = Path(cfg.data_dir) / "val" + train_ds = ChordDataset(train_dir, max_length=cfg.max_seq_len) + val_ds = ChordDataset(val_dir, max_length=cfg.max_seq_len) + log.info("train: %d periods val: %d periods", len(train_ds), len(val_ds)) + + train_loader = DataLoader( + train_ds, batch_size=cfg.batch_size, shuffle=True, + drop_last=False, num_workers=0, + ) + val_loader = DataLoader( + val_ds, batch_size=cfg.batch_size, shuffle=False, + drop_last=False, num_workers=0, + ) + + # ------------------------------------------------------------------ + # Model + # ------------------------------------------------------------------ + vocab_size = len(TOKEN_TO_ID) + + if cfg.init_from is not None: + log.info("loading weights from %s", cfg.init_from) + ckpt = torch.load(cfg.init_from, map_location=device, weights_only=True) + model_cfg = ckpt["model_config"] + model = ChordTransformer(**model_cfg) + model.load_state_dict(ckpt["model_state"]) + else: + model_cfg = dict( + vocab_size=vocab_size, + d_model=cfg.d_model, + n_layers=cfg.n_layers, + n_heads=cfg.n_heads, + d_ff=cfg.d_ff, + max_seq_len=cfg.max_seq_len, + dropout=cfg.dropout, + ) + model = ChordTransformer(**model_cfg) + + model = model.to(device) + n_params = sum(p.numel() for p in model.parameters()) - model.token_emb.weight.numel() + log.info("model: %d unique parameters", n_params) + + # ------------------------------------------------------------------ + # Optimizer, scheduler, loss + # ------------------------------------------------------------------ + # Apply weight decay only to weight matrices, not biases / LayerNorm params. + decay_params = [ + p for n, p in model.named_parameters() + if p.requires_grad and p.dim() >= 2 + ] + nodecay_params = [ + p for n, p in model.named_parameters() + if p.requires_grad and p.dim() < 2 + ] + optimizer = torch.optim.AdamW( + [ + {"params": decay_params, "weight_decay": cfg.weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ], + lr=cfg.lr, + betas=(0.9, 0.95), + ) + + total_steps = len(train_loader) * cfg.epochs + scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps) + + criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID) + + # ------------------------------------------------------------------ + # Logging setup + # ------------------------------------------------------------------ + output_path = Path(cfg.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + checkpoint_path = output_path.with_suffix(".pt") + log_csv_path = Path(str(output_path) + ".log.csv") + + with open(log_csv_path, "w", newline="") as fh: + writer = csv.writer(fh) + writer.writerow(["epoch", "train_loss", "val_loss", "val_ppl", "lr", "elapsed_s"]) + + # ------------------------------------------------------------------ + # Training loop + # ------------------------------------------------------------------ + best_val_loss = float("inf") + epochs_without_improvement = 0 + run_start = time.monotonic() + + for epoch in range(1, cfg.epochs + 1): + epoch_start = time.monotonic() + + train_loss = _run_epoch( + model, train_loader, criterion, device, optimizer, scheduler + ) + val_loss = _run_epoch(model, val_loader, criterion, device) + val_ppl = math.exp(min(val_loss, 20)) # cap to avoid overflow + elapsed = time.monotonic() - epoch_start + current_lr = scheduler.get_last_lr()[0] + + log.info( + "epoch %3d/%d train=%.4f val=%.4f ppl=%.1f lr=%.2e %.0fs", + epoch, cfg.epochs, train_loss, val_loss, val_ppl, current_lr, elapsed, + ) + + with open(log_csv_path, "a", newline="") as fh: + csv.writer(fh).writerow( + [epoch, f"{train_loss:.6f}", f"{val_loss:.6f}", + f"{val_ppl:.2f}", f"{current_lr:.6e}", f"{elapsed:.1f}"] + ) + + # Checkpoint + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_without_improvement = 0 + torch.save( + { + "epoch": epoch, + "model_config": model_cfg, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "val_loss": val_loss, + "train_config": { + k: str(v) if isinstance(v, Path) else v + for k, v in cfg.__dict__.items() + }, + }, + checkpoint_path, + ) + log.info(" → saved best checkpoint (val_loss=%.4f)", val_loss) + else: + epochs_without_improvement += 1 + log.info( + " no improvement for %d/%d epochs", + epochs_without_improvement, cfg.patience, + ) + if epochs_without_improvement >= cfg.patience: + log.info("early stopping triggered at epoch %d", epoch) + break + + total_elapsed = time.monotonic() - run_start + log.info( + "training finished: best val_loss=%.4f total=%.0fs checkpoint=%s", + best_val_loss, total_elapsed, checkpoint_path, + ) + return checkpoint_path