feat: implement training loop and CLI (src/train.py, scripts/train.py)

AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch
CSV logging, best-val-loss checkpointing, early stopping (patience=5).
Same script handles both pre-training and fine-tuning via --init-from.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 11:15:39 +03:00
parent 10229be042
commit 733e1fde1f
2 changed files with 460 additions and 0 deletions
+120
View File
@@ -0,0 +1,120 @@
"""CLI entry point for pre-training and fine-tuning ChordTransformer.
Usage (pre-training):
python scripts/train.py \\
--data-dir data/processed/pretrain \\
--output checkpoints/pretrained \\
--epochs 50 --batch-size 32 --lr 3e-4
Usage (fine-tuning):
python scripts/train.py \\
--data-dir data/processed/finetune \\
--init-from checkpoints/pretrained.pt \\
--output checkpoints/finetuned \\
--epochs 15 --lr 1e-5
The script saves:
<output>.pt best checkpoint (lowest val loss)
<output>.log.csv per-epoch metrics
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.train import TrainConfig, train # noqa: E402
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Train or fine-tune ChordTransformer on tokenized chord data.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# I/O
io = p.add_argument_group("I/O")
io.add_argument(
"--data-dir", required=True, type=Path,
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).",
)
io.add_argument(
"--output", required=True, type=Path,
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
)
io.add_argument(
"--init-from", type=Path, default=None,
help="Checkpoint to load weights from before training (fine-tuning mode).",
)
# Training
tr = p.add_argument_group("Training")
tr.add_argument("--epochs", type=int, default=30)
tr.add_argument("--batch-size", type=int, default=16)
tr.add_argument("--lr", type=float, default=3e-4)
tr.add_argument("--warmup-steps", type=int, default=200)
tr.add_argument("--weight-decay", type=float, default=0.1)
tr.add_argument("--patience", type=int, default=5,
help="Early-stopping patience (epochs without val improvement).")
tr.add_argument("--seed", type=int, default=42)
tr.add_argument(
"--device", default="auto", choices=["auto", "cpu", "cuda"],
help="Compute device. 'auto' selects cuda when available.",
)
# Architecture (ignored when --init-from is given)
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
arch.add_argument("--d-model", type=int, default=192)
arch.add_argument("--n-layers", type=int, default=3)
arch.add_argument("--n-heads", type=int, default=6)
arch.add_argument("--d-ff", type=int, default=768)
arch.add_argument("--dropout", type=float, default=0.1)
arch.add_argument("--max-seq-len", type=int, default=512)
# Logging
p.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
)
return p.parse_args()
def main() -> None:
args = _parse_args()
logging.basicConfig(
level=args.log_level,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
cfg = TrainConfig(
data_dir=args.data_dir,
output=args.output,
init_from=args.init_from,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
seed=args.seed,
device=args.device,
patience=args.patience,
max_seq_len=args.max_seq_len,
d_model=args.d_model,
n_layers=args.n_layers,
n_heads=args.n_heads,
d_ff=args.d_ff,
dropout=args.dropout,
)
checkpoint = train(cfg)
print(f"best checkpoint: {checkpoint}")
if __name__ == "__main__":
main()
+340
View File
@@ -0,0 +1,340 @@
"""Training logic for ChordTransformer.
Public API:
TrainConfig -- dataclass collecting all hyperparameters
train -- run pre-training or fine-tuning
This module contains no argument parsing. See scripts/train.py for the CLI.
"""
from __future__ import annotations
import csv
import logging
import math
import random
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.dataset import ChordDataset
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class TrainConfig:
"""All hyperparameters and I/O paths for one training run.
Args:
data_dir: Directory produced by prepare_data.py; must contain
``train/`` and ``val/`` sub-directories with tokenized .pt files.
output: Path prefix for the saved checkpoint (.pt) and log (.log.csv).
init_from: Optional checkpoint to load weights from before training
(used for fine-tuning).
epochs: Maximum number of training epochs.
batch_size: Mini-batch size for the DataLoader.
lr: Peak learning rate for AdamW.
warmup_steps: Number of linear warm-up steps before cosine decay.
weight_decay: L2 penalty for AdamW (applied to non-bias parameters).
seed: Random seed for reproducibility.
device: 'cpu', 'cuda', or 'auto' (picks cuda if available).
patience: Early-stopping patience in epochs (on val loss).
max_seq_len: Passed to ChordDataset and ChordTransformer.
# Model architecture (ignored when init_from is set)
d_model: Transformer hidden dimension.
n_layers: Number of transformer blocks.
n_heads: Number of attention heads.
d_ff: FFN inner dimension.
dropout: Dropout probability.
"""
data_dir: Path
output: Path
init_from: Optional[Path] = None
epochs: int = 30
batch_size: int = 16
lr: float = 3e-4
warmup_steps: int = 200
weight_decay: float = 0.1
seed: int = 42
device: str = "auto"
patience: int = 5
max_seq_len: int = 512
# Architecture (ignored when init_from is set)
d_model: int = 192
n_layers: int = 3
n_heads: int = 6
d_ff: int = 768
dropout: float = 0.1
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _resolve_device(spec: str) -> torch.device:
if spec == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(spec)
def _set_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _build_scheduler(
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
"""Linear warm-up followed by cosine decay to 0."""
def _lr_lambda(step: int) -> float:
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda)
def _make_attention_mask(batch: torch.Tensor, pad_id: int) -> torch.Tensor:
"""Return 1 where token != PAD, 0 elsewhere."""
return (batch != pad_id).long()
def _run_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.CrossEntropyLoss,
device: torch.device,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
) -> float:
"""Run one epoch; return mean loss over all non-PAD tokens."""
training = optimizer is not None
model.train(training)
total_loss = 0.0
total_tokens = 0
with torch.set_grad_enabled(training):
for batch in loader:
batch = batch.to(device) # [B, T]
attention_mask = _make_attention_mask(batch, _PAD_ID).to(device)
# Causal LM: predict token t+1 from tokens 0..t
input_ids = batch[:, :-1]
targets = batch[:, 1:]
attn_mask = attention_mask[:, :-1]
logits = model(input_ids, attention_mask=attn_mask)
# logits: [B, T-1, V] targets: [B, T-1]
loss = criterion(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
)
if training:
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Accumulate token-level loss (criterion already ignores PAD)
non_pad = (targets != _PAD_ID).sum().item()
total_loss += loss.item() * non_pad
total_tokens += non_pad
if total_tokens == 0:
return float("inf")
return total_loss / total_tokens
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def train(cfg: TrainConfig) -> Path:
"""Run training according to *cfg*.
Args:
cfg: Fully populated TrainConfig.
Returns:
Path to the best checkpoint saved during training.
"""
_set_seeds(cfg.seed)
device = _resolve_device(cfg.device)
log.info("device: %s", device)
# ------------------------------------------------------------------
# Data
# ------------------------------------------------------------------
train_dir = Path(cfg.data_dir) / "train"
val_dir = Path(cfg.data_dir) / "val"
train_ds = ChordDataset(train_dir, max_length=cfg.max_seq_len)
val_ds = ChordDataset(val_dir, max_length=cfg.max_seq_len)
log.info("train: %d periods val: %d periods", len(train_ds), len(val_ds))
train_loader = DataLoader(
train_ds, batch_size=cfg.batch_size, shuffle=True,
drop_last=False, num_workers=0,
)
val_loader = DataLoader(
val_ds, batch_size=cfg.batch_size, shuffle=False,
drop_last=False, num_workers=0,
)
# ------------------------------------------------------------------
# Model
# ------------------------------------------------------------------
vocab_size = len(TOKEN_TO_ID)
if cfg.init_from is not None:
log.info("loading weights from %s", cfg.init_from)
ckpt = torch.load(cfg.init_from, map_location=device, weights_only=True)
model_cfg = ckpt["model_config"]
model = ChordTransformer(**model_cfg)
model.load_state_dict(ckpt["model_state"])
else:
model_cfg = dict(
vocab_size=vocab_size,
d_model=cfg.d_model,
n_layers=cfg.n_layers,
n_heads=cfg.n_heads,
d_ff=cfg.d_ff,
max_seq_len=cfg.max_seq_len,
dropout=cfg.dropout,
)
model = ChordTransformer(**model_cfg)
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters()) - model.token_emb.weight.numel()
log.info("model: %d unique parameters", n_params)
# ------------------------------------------------------------------
# Optimizer, scheduler, loss
# ------------------------------------------------------------------
# Apply weight decay only to weight matrices, not biases / LayerNorm params.
decay_params = [
p for n, p in model.named_parameters()
if p.requires_grad and p.dim() >= 2
]
nodecay_params = [
p for n, p in model.named_parameters()
if p.requires_grad and p.dim() < 2
]
optimizer = torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": cfg.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
],
lr=cfg.lr,
betas=(0.9, 0.95),
)
total_steps = len(train_loader) * cfg.epochs
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID)
# ------------------------------------------------------------------
# Logging setup
# ------------------------------------------------------------------
output_path = Path(cfg.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint_path = output_path.with_suffix(".pt")
log_csv_path = Path(str(output_path) + ".log.csv")
with open(log_csv_path, "w", newline="") as fh:
writer = csv.writer(fh)
writer.writerow(["epoch", "train_loss", "val_loss", "val_ppl", "lr", "elapsed_s"])
# ------------------------------------------------------------------
# Training loop
# ------------------------------------------------------------------
best_val_loss = float("inf")
epochs_without_improvement = 0
run_start = time.monotonic()
for epoch in range(1, cfg.epochs + 1):
epoch_start = time.monotonic()
train_loss = _run_epoch(
model, train_loader, criterion, device, optimizer, scheduler
)
val_loss = _run_epoch(model, val_loader, criterion, device)
val_ppl = math.exp(min(val_loss, 20)) # cap to avoid overflow
elapsed = time.monotonic() - epoch_start
current_lr = scheduler.get_last_lr()[0]
log.info(
"epoch %3d/%d train=%.4f val=%.4f ppl=%.1f lr=%.2e %.0fs",
epoch, cfg.epochs, train_loss, val_loss, val_ppl, current_lr, elapsed,
)
with open(log_csv_path, "a", newline="") as fh:
csv.writer(fh).writerow(
[epoch, f"{train_loss:.6f}", f"{val_loss:.6f}",
f"{val_ppl:.2f}", f"{current_lr:.6e}", f"{elapsed:.1f}"]
)
# Checkpoint
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
torch.save(
{
"epoch": epoch,
"model_config": model_cfg,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"val_loss": val_loss,
"train_config": {
k: str(v) if isinstance(v, Path) else v
for k, v in cfg.__dict__.items()
},
},
checkpoint_path,
)
log.info(" → saved best checkpoint (val_loss=%.4f)", val_loss)
else:
epochs_without_improvement += 1
log.info(
" no improvement for %d/%d epochs",
epochs_without_improvement, cfg.patience,
)
if epochs_without_improvement >= cfg.patience:
log.info("early stopping triggered at epoch %d", epoch)
break
total_elapsed = time.monotonic() - run_start
log.info(
"training finished: best val_loss=%.4f total=%.0fs checkpoint=%s",
best_val_loss, total_elapsed, checkpoint_path,
)
return checkpoint_path