feat: implement training loop and CLI (src/train.py, scripts/train.py)
AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch CSV logging, best-val-loss checkpointing, early stopping (patience=5). Same script handles both pre-training and fine-tuning via --init-from. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,120 @@
|
||||
"""CLI entry point for pre-training and fine-tuning ChordTransformer.
|
||||
|
||||
Usage (pre-training):
|
||||
python scripts/train.py \\
|
||||
--data-dir data/processed/pretrain \\
|
||||
--output checkpoints/pretrained \\
|
||||
--epochs 50 --batch-size 32 --lr 3e-4
|
||||
|
||||
Usage (fine-tuning):
|
||||
python scripts/train.py \\
|
||||
--data-dir data/processed/finetune \\
|
||||
--init-from checkpoints/pretrained.pt \\
|
||||
--output checkpoints/finetuned \\
|
||||
--epochs 15 --lr 1e-5
|
||||
|
||||
The script saves:
|
||||
<output>.pt best checkpoint (lowest val loss)
|
||||
<output>.log.csv per-epoch metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from src.train import TrainConfig, train # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Train or fine-tune ChordTransformer on tokenized chord data.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# I/O
|
||||
io = p.add_argument_group("I/O")
|
||||
io.add_argument(
|
||||
"--data-dir", required=True, type=Path,
|
||||
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).",
|
||||
)
|
||||
io.add_argument(
|
||||
"--output", required=True, type=Path,
|
||||
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
|
||||
)
|
||||
io.add_argument(
|
||||
"--init-from", type=Path, default=None,
|
||||
help="Checkpoint to load weights from before training (fine-tuning mode).",
|
||||
)
|
||||
|
||||
# Training
|
||||
tr = p.add_argument_group("Training")
|
||||
tr.add_argument("--epochs", type=int, default=30)
|
||||
tr.add_argument("--batch-size", type=int, default=16)
|
||||
tr.add_argument("--lr", type=float, default=3e-4)
|
||||
tr.add_argument("--warmup-steps", type=int, default=200)
|
||||
tr.add_argument("--weight-decay", type=float, default=0.1)
|
||||
tr.add_argument("--patience", type=int, default=5,
|
||||
help="Early-stopping patience (epochs without val improvement).")
|
||||
tr.add_argument("--seed", type=int, default=42)
|
||||
tr.add_argument(
|
||||
"--device", default="auto", choices=["auto", "cpu", "cuda"],
|
||||
help="Compute device. 'auto' selects cuda when available.",
|
||||
)
|
||||
|
||||
# Architecture (ignored when --init-from is given)
|
||||
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
|
||||
arch.add_argument("--d-model", type=int, default=192)
|
||||
arch.add_argument("--n-layers", type=int, default=3)
|
||||
arch.add_argument("--n-heads", type=int, default=6)
|
||||
arch.add_argument("--d-ff", type=int, default=768)
|
||||
arch.add_argument("--dropout", type=float, default=0.1)
|
||||
arch.add_argument("--max-seq-len", type=int, default=512)
|
||||
|
||||
# Logging
|
||||
p.add_argument(
|
||||
"--log-level", default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
logging.basicConfig(
|
||||
level=args.log_level,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
cfg = TrainConfig(
|
||||
data_dir=args.data_dir,
|
||||
output=args.output,
|
||||
init_from=args.init_from,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
warmup_steps=args.warmup_steps,
|
||||
weight_decay=args.weight_decay,
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
patience=args.patience,
|
||||
max_seq_len=args.max_seq_len,
|
||||
d_model=args.d_model,
|
||||
n_layers=args.n_layers,
|
||||
n_heads=args.n_heads,
|
||||
d_ff=args.d_ff,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
checkpoint = train(cfg)
|
||||
print(f"best checkpoint: {checkpoint}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+340
@@ -0,0 +1,340 @@
|
||||
"""Training logic for ChordTransformer.
|
||||
|
||||
Public API:
|
||||
TrainConfig -- dataclass collecting all hyperparameters
|
||||
train -- run pre-training or fine-tuning
|
||||
|
||||
This module contains no argument parsing. See scripts/train.py for the CLI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset import ChordDataset
|
||||
from src.model import ChordTransformer
|
||||
from src.tokenizer import TOKEN_TO_ID
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
"""All hyperparameters and I/O paths for one training run.
|
||||
|
||||
Args:
|
||||
data_dir: Directory produced by prepare_data.py; must contain
|
||||
``train/`` and ``val/`` sub-directories with tokenized .pt files.
|
||||
output: Path prefix for the saved checkpoint (.pt) and log (.log.csv).
|
||||
init_from: Optional checkpoint to load weights from before training
|
||||
(used for fine-tuning).
|
||||
epochs: Maximum number of training epochs.
|
||||
batch_size: Mini-batch size for the DataLoader.
|
||||
lr: Peak learning rate for AdamW.
|
||||
warmup_steps: Number of linear warm-up steps before cosine decay.
|
||||
weight_decay: L2 penalty for AdamW (applied to non-bias parameters).
|
||||
seed: Random seed for reproducibility.
|
||||
device: 'cpu', 'cuda', or 'auto' (picks cuda if available).
|
||||
patience: Early-stopping patience in epochs (on val loss).
|
||||
max_seq_len: Passed to ChordDataset and ChordTransformer.
|
||||
# Model architecture (ignored when init_from is set)
|
||||
d_model: Transformer hidden dimension.
|
||||
n_layers: Number of transformer blocks.
|
||||
n_heads: Number of attention heads.
|
||||
d_ff: FFN inner dimension.
|
||||
dropout: Dropout probability.
|
||||
"""
|
||||
|
||||
data_dir: Path
|
||||
output: Path
|
||||
init_from: Optional[Path] = None
|
||||
epochs: int = 30
|
||||
batch_size: int = 16
|
||||
lr: float = 3e-4
|
||||
warmup_steps: int = 200
|
||||
weight_decay: float = 0.1
|
||||
seed: int = 42
|
||||
device: str = "auto"
|
||||
patience: int = 5
|
||||
max_seq_len: int = 512
|
||||
# Architecture (ignored when init_from is set)
|
||||
d_model: int = 192
|
||||
n_layers: int = 3
|
||||
n_heads: int = 6
|
||||
d_ff: int = 768
|
||||
dropout: float = 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_device(spec: str) -> torch.device:
|
||||
if spec == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(spec)
|
||||
|
||||
|
||||
def _set_seeds(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _build_scheduler(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: int,
|
||||
total_steps: int,
|
||||
) -> torch.optim.lr_scheduler.LambdaLR:
|
||||
"""Linear warm-up followed by cosine decay to 0."""
|
||||
|
||||
def _lr_lambda(step: int) -> float:
|
||||
if step < warmup_steps:
|
||||
return step / max(1, warmup_steps)
|
||||
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
||||
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda)
|
||||
|
||||
|
||||
def _make_attention_mask(batch: torch.Tensor, pad_id: int) -> torch.Tensor:
|
||||
"""Return 1 where token != PAD, 0 elsewhere."""
|
||||
return (batch != pad_id).long()
|
||||
|
||||
|
||||
def _run_epoch(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
criterion: nn.CrossEntropyLoss,
|
||||
device: torch.device,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
) -> float:
|
||||
"""Run one epoch; return mean loss over all non-PAD tokens."""
|
||||
training = optimizer is not None
|
||||
model.train(training)
|
||||
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
with torch.set_grad_enabled(training):
|
||||
for batch in loader:
|
||||
batch = batch.to(device) # [B, T]
|
||||
attention_mask = _make_attention_mask(batch, _PAD_ID).to(device)
|
||||
|
||||
# Causal LM: predict token t+1 from tokens 0..t
|
||||
input_ids = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
attn_mask = attention_mask[:, :-1]
|
||||
|
||||
logits = model(input_ids, attention_mask=attn_mask)
|
||||
# logits: [B, T-1, V] targets: [B, T-1]
|
||||
loss = criterion(
|
||||
logits.reshape(-1, logits.size(-1)),
|
||||
targets.reshape(-1),
|
||||
)
|
||||
|
||||
if training:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# Accumulate token-level loss (criterion already ignores PAD)
|
||||
non_pad = (targets != _PAD_ID).sum().item()
|
||||
total_loss += loss.item() * non_pad
|
||||
total_tokens += non_pad
|
||||
|
||||
if total_tokens == 0:
|
||||
return float("inf")
|
||||
return total_loss / total_tokens
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def train(cfg: TrainConfig) -> Path:
|
||||
"""Run training according to *cfg*.
|
||||
|
||||
Args:
|
||||
cfg: Fully populated TrainConfig.
|
||||
|
||||
Returns:
|
||||
Path to the best checkpoint saved during training.
|
||||
"""
|
||||
_set_seeds(cfg.seed)
|
||||
device = _resolve_device(cfg.device)
|
||||
log.info("device: %s", device)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data
|
||||
# ------------------------------------------------------------------
|
||||
train_dir = Path(cfg.data_dir) / "train"
|
||||
val_dir = Path(cfg.data_dir) / "val"
|
||||
train_ds = ChordDataset(train_dir, max_length=cfg.max_seq_len)
|
||||
val_ds = ChordDataset(val_dir, max_length=cfg.max_seq_len)
|
||||
log.info("train: %d periods val: %d periods", len(train_ds), len(val_ds))
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
||||
drop_last=False, num_workers=0,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds, batch_size=cfg.batch_size, shuffle=False,
|
||||
drop_last=False, num_workers=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model
|
||||
# ------------------------------------------------------------------
|
||||
vocab_size = len(TOKEN_TO_ID)
|
||||
|
||||
if cfg.init_from is not None:
|
||||
log.info("loading weights from %s", cfg.init_from)
|
||||
ckpt = torch.load(cfg.init_from, map_location=device, weights_only=True)
|
||||
model_cfg = ckpt["model_config"]
|
||||
model = ChordTransformer(**model_cfg)
|
||||
model.load_state_dict(ckpt["model_state"])
|
||||
else:
|
||||
model_cfg = dict(
|
||||
vocab_size=vocab_size,
|
||||
d_model=cfg.d_model,
|
||||
n_layers=cfg.n_layers,
|
||||
n_heads=cfg.n_heads,
|
||||
d_ff=cfg.d_ff,
|
||||
max_seq_len=cfg.max_seq_len,
|
||||
dropout=cfg.dropout,
|
||||
)
|
||||
model = ChordTransformer(**model_cfg)
|
||||
|
||||
model = model.to(device)
|
||||
n_params = sum(p.numel() for p in model.parameters()) - model.token_emb.weight.numel()
|
||||
log.info("model: %d unique parameters", n_params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optimizer, scheduler, loss
|
||||
# ------------------------------------------------------------------
|
||||
# Apply weight decay only to weight matrices, not biases / LayerNorm params.
|
||||
decay_params = [
|
||||
p for n, p in model.named_parameters()
|
||||
if p.requires_grad and p.dim() >= 2
|
||||
]
|
||||
nodecay_params = [
|
||||
p for n, p in model.named_parameters()
|
||||
if p.requires_grad and p.dim() < 2
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
[
|
||||
{"params": decay_params, "weight_decay": cfg.weight_decay},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
],
|
||||
lr=cfg.lr,
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
|
||||
total_steps = len(train_loader) * cfg.epochs
|
||||
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Logging setup
|
||||
# ------------------------------------------------------------------
|
||||
output_path = Path(cfg.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_path = output_path.with_suffix(".pt")
|
||||
log_csv_path = Path(str(output_path) + ".log.csv")
|
||||
|
||||
with open(log_csv_path, "w", newline="") as fh:
|
||||
writer = csv.writer(fh)
|
||||
writer.writerow(["epoch", "train_loss", "val_loss", "val_ppl", "lr", "elapsed_s"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop
|
||||
# ------------------------------------------------------------------
|
||||
best_val_loss = float("inf")
|
||||
epochs_without_improvement = 0
|
||||
run_start = time.monotonic()
|
||||
|
||||
for epoch in range(1, cfg.epochs + 1):
|
||||
epoch_start = time.monotonic()
|
||||
|
||||
train_loss = _run_epoch(
|
||||
model, train_loader, criterion, device, optimizer, scheduler
|
||||
)
|
||||
val_loss = _run_epoch(model, val_loader, criterion, device)
|
||||
val_ppl = math.exp(min(val_loss, 20)) # cap to avoid overflow
|
||||
elapsed = time.monotonic() - epoch_start
|
||||
current_lr = scheduler.get_last_lr()[0]
|
||||
|
||||
log.info(
|
||||
"epoch %3d/%d train=%.4f val=%.4f ppl=%.1f lr=%.2e %.0fs",
|
||||
epoch, cfg.epochs, train_loss, val_loss, val_ppl, current_lr, elapsed,
|
||||
)
|
||||
|
||||
with open(log_csv_path, "a", newline="") as fh:
|
||||
csv.writer(fh).writerow(
|
||||
[epoch, f"{train_loss:.6f}", f"{val_loss:.6f}",
|
||||
f"{val_ppl:.2f}", f"{current_lr:.6e}", f"{elapsed:.1f}"]
|
||||
)
|
||||
|
||||
# Checkpoint
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_without_improvement = 0
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_config": model_cfg,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
"train_config": {
|
||||
k: str(v) if isinstance(v, Path) else v
|
||||
for k, v in cfg.__dict__.items()
|
||||
},
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
log.info(" → saved best checkpoint (val_loss=%.4f)", val_loss)
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
log.info(
|
||||
" no improvement for %d/%d epochs",
|
||||
epochs_without_improvement, cfg.patience,
|
||||
)
|
||||
if epochs_without_improvement >= cfg.patience:
|
||||
log.info("early stopping triggered at epoch %d", epoch)
|
||||
break
|
||||
|
||||
total_elapsed = time.monotonic() - run_start
|
||||
log.info(
|
||||
"training finished: best val_loss=%.4f total=%.0fs checkpoint=%s",
|
||||
best_val_loss, total_elapsed, checkpoint_path,
|
||||
)
|
||||
return checkpoint_path
|
||||
Reference in New Issue
Block a user