"""Fine-tune ChordTransformer on the personal (user) chord corpus. Requires a pre-trained checkpoint produced by scripts/pretrain.py. Usage: # Full run (fine-tuning + plot + report) python scripts/train.py # Skip training; re-plot and report from existing CSV python scripts/train.py --skip-training Outputs written: checkpoints/finetuned.pt best checkpoint checkpoints/finetuned.log.csv per-epoch metrics checkpoints/finetuned_curves.png train/val loss plot """ from __future__ import annotations import argparse import csv import logging import math import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import torch sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.model import ChordTransformer from src.train import TrainConfig, train from src.tokenizer import TOKEN_TO_ID # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- DATA_DIR = Path("data/processed/user") INIT_FROM = Path("checkpoints/pretrained.pt") CHECKPOINT = Path("checkpoints/finetuned.pt") LOG_CSV = Path("checkpoints/finetuned.log.csv") CURVES_PNG = Path("checkpoints/finetuned_curves.png") REPORT_TXT = Path("checkpoints/finetuned.report.txt") # --------------------------------------------------------------------------- # Training config # --------------------------------------------------------------------------- TRAIN_CFG = TrainConfig( data_dir=DATA_DIR, output=CHECKPOINT, init_from=INIT_FROM, # Small corpus (~45 train files) → ~6 batches/epoch. # 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window. epochs=50, batch_size=8, lr=1e-5, warmup_steps=10, patience=10, seed=42, device="auto", # Must match pretrained checkpoint (max_seq_len=320). max_seq_len=320, ) # --------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------- def plot_curves(log_csv: Path, out_png: Path) -> None: epochs, train_losses, val_losses, val_ppls = [], [], [], [] with open(log_csv, newline="") as fh: for row in csv.DictReader(fh): epochs.append(int(row["epoch"])) train_losses.append(float(row["train_loss"])) val_losses.append(float(row["val_loss"])) val_ppls.append(float(row["val_ppl"])) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4)) ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5) ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5) best_epoch = epochs[val_losses.index(min(val_losses))] ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8, label=f"best epoch {best_epoch}") ax1.set_xlabel("epoch") ax1.set_ylabel("cross-entropy loss") ax1.set_title("Fine-tuning loss") ax1.legend() ax1.grid(True, alpha=0.3) ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5) ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8) ax2.set_xlabel("epoch") ax2.set_ylabel("perplexity") ax2.set_title("Val perplexity") ax2.grid(True, alpha=0.3) fig.tight_layout() out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) print(f"[plot] saved -> {out_png}") # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None: rows = [] with open(log_csv, newline="") as fh: rows = list(csv.DictReader(fh)) if not rows: print("[report] log CSV is empty — nothing to report") return val_losses = [float(r["val_loss"]) for r in rows] best_idx = val_losses.index(min(val_losses)) best_row = rows[best_idx] best_loss = float(best_row["val_loss"]) conv_epoch = next( (int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01), int(best_row["epoch"]), ) n_params = None if checkpoint.exists(): ckpt = torch.load(checkpoint, weights_only=True) model = ChordTransformer(**ckpt["model_config"]) tied = model.token_emb.weight.numel() n_params = sum(p.numel() for p in model.parameters()) - tied lines = [] lines += [ "", "=" * 52, " FINE-TUNING REPORT", "=" * 52, f" Total epochs run : {len(rows)}", f" Best epoch (val loss) : {best_row['epoch']}", f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)", f" Best val loss : {best_loss:.4f}", f" Best val perplexity : {float(best_row['val_ppl']):.2f}", f" Final train loss : {float(rows[-1]['train_loss']):.4f}", ] if n_params is not None: lines.append(f" Unique parameters : {n_params:,}") lines += [ f" Checkpoint : {checkpoint}", f" Log CSV : {log_csv}", "=" * 52, "", f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}", f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}", ] for r in rows: marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" lines.append( f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" f" {float(r['lr']):>10.2e}{marker}" ) lines.append("") report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text("\n".join(lines), encoding="utf-8") print(f"[report] saved -> {report_path}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--skip-training", action="store_true", help="Skip training; re-plot and report from existing CSV.") args = ap.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) if not args.skip_training: if not INIT_FROM.exists(): print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr) print("Run python scripts/pretrain.py first.", file=sys.stderr) sys.exit(1) if not DATA_DIR.exists(): print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr) sys.exit(1) n_train = len(list((DATA_DIR / "train").glob("*.pt"))) n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size est_epoch_s = n_batches * 1.5 device_label = "GPU" if torch.cuda.is_available() else "CPU" print( f"[train] {n_train} train files, {n_batches} batches/epoch\n" f"[train] estimated time on {device_label}: " f"~{est_epoch_s/60:.0f} min/epoch, " f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n" f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n" ) train(TRAIN_CFG) else: if not LOG_CSV.exists(): print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr) sys.exit(1) print(f"[skip-training] using existing log: {LOG_CSV}") plot_curves(LOG_CSV, CURVES_PNG) write_report(LOG_CSV, CHECKPOINT, REPORT_TXT) if __name__ == "__main__": main()