"""Fine-tune ChordTransformer on the personal (user) chord corpus. Requires a pre-trained checkpoint produced by scripts/pretrain.py. Usage: # Full run (fine-tuning + plot + report) python scripts/train.py # Skip training; re-plot and report from existing CSV python scripts/train.py --skip-training Outputs written: checkpoints/finetuned.pt best checkpoint checkpoints/finetuned.log.csv per-epoch metrics checkpoints/finetuned_curves.png train/val loss plot """ from __future__ import annotations import argparse import csv import logging import math import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import torch sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.model import ChordTransformer from src.train import TrainConfig, train from src.tokenizer import TOKEN_TO_ID # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- DATA_DIR = Path("data/processed/user") INIT_FROM = Path("checkpoints/pretrained.pt") CHECKPOINT = Path("checkpoints/finetuned.pt") LOG_CSV = Path("checkpoints/finetuned.log.csv") CURVES_PNG = Path("checkpoints/finetuned_curves.png") # --------------------------------------------------------------------------- # Training config # --------------------------------------------------------------------------- TRAIN_CFG = TrainConfig( data_dir=DATA_DIR, output=CHECKPOINT, init_from=INIT_FROM, epochs=15, batch_size=8, lr=1e-5, warmup_steps=20, seed=42, device="auto", max_seq_len=256, ) # --------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------- def plot_curves(log_csv: Path, out_png: Path) -> None: epochs, train_losses, val_losses, val_ppls = [], [], [], [] with open(log_csv, newline="") as fh: for row in csv.DictReader(fh): epochs.append(int(row["epoch"])) train_losses.append(float(row["train_loss"])) val_losses.append(float(row["val_loss"])) val_ppls.append(float(row["val_ppl"])) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4)) ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5) ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5) best_epoch = epochs[val_losses.index(min(val_losses))] ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8, label=f"best epoch {best_epoch}") ax1.set_xlabel("epoch") ax1.set_ylabel("cross-entropy loss") ax1.set_title("Fine-tuning loss") ax1.legend() ax1.grid(True, alpha=0.3) ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5) ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8) ax2.set_xlabel("epoch") ax2.set_ylabel("perplexity") ax2.set_title("Val perplexity") ax2.grid(True, alpha=0.3) fig.tight_layout() out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) print(f"[plot] saved → {out_png}") # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- def print_report(log_csv: Path, checkpoint: Path) -> None: rows = [] with open(log_csv, newline="") as fh: rows = list(csv.DictReader(fh)) if not rows: print("[report] log CSV is empty — nothing to report") return val_losses = [float(r["val_loss"]) for r in rows] best_idx = val_losses.index(min(val_losses)) best_row = rows[best_idx] best_loss = float(best_row["val_loss"]) conv_epoch = next( (int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01), int(best_row["epoch"]), ) n_params = None if checkpoint.exists(): ckpt = torch.load(checkpoint, weights_only=True) model = ChordTransformer(**ckpt["model_config"]) tied = model.token_emb.weight.numel() n_params = sum(p.numel() for p in model.parameters()) - tied print() print("=" * 52) print(" FINE-TUNING REPORT") print("=" * 52) print(f" Total epochs run : {len(rows)}") print(f" Best epoch (val loss) : {best_row['epoch']}") print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)") print(f" Best val loss : {best_loss:.4f}") print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}") print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}") if n_params is not None: print(f" Unique parameters : {n_params:,}") print(f" Checkpoint : {checkpoint}") print(f" Log CSV : {log_csv}") print("=" * 52) print() print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}") print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") for r in rows: marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" print( f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" f" {float(r['lr']):>10.2e}{marker}" ) print() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--skip-training", action="store_true", help="Skip training; re-plot and report from existing CSV.") args = ap.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) if not args.skip_training: if not INIT_FROM.exists(): print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr) print("Run python scripts/pretrain.py first.", file=sys.stderr) sys.exit(1) if not DATA_DIR.exists(): print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr) sys.exit(1) n_train = len(list((DATA_DIR / "train").glob("*.pt"))) n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size est_epoch_s = n_batches * 1.5 device_label = "GPU" if torch.cuda.is_available() else "CPU" print( f"[train] {n_train} train files, {n_batches} batches/epoch\n" f"[train] estimated time on {device_label}: " f"~{est_epoch_s/60:.0f} min/epoch, " f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n" f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n" ) train(TRAIN_CFG) else: if not LOG_CSV.exists(): print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr) sys.exit(1) print(f"[skip-training] using existing log: {LOG_CSV}") plot_curves(LOG_CSV, CURVES_PNG) print_report(LOG_CSV, CHECKPOINT) if __name__ == "__main__": main()