diff --git a/requirements.txt b/requirements.txt index 0f72d83..c4db453 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,11 +4,10 @@ # Core ML torch==2.12.0 numpy==2.4.6 -pandas==3.0.3 # Music processing -music21==10.1.0 pretty_midi==0.2.11 +mido==1.3.3 # Visualization matplotlib==3.10.9 diff --git a/scripts/run_pretrain.py b/scripts/run_pretrain.py new file mode 100644 index 0000000..44b1777 --- /dev/null +++ b/scripts/run_pretrain.py @@ -0,0 +1,215 @@ +"""Run full pre-training on the McGill corpus, then plot loss curves and +print a short diagnostic report. + +Usage: + # Full run (training + plot + report) + python scripts/run_pretrain.py + + # Skip training if a checkpoint already exists; only re-plot and report + python scripts/run_pretrain.py --skip-training + +Outputs written: + checkpoints/pretrained.pt best checkpoint + checkpoints/pretrained.log.csv per-epoch metrics + checkpoints/pretrained_curves.png train/val loss plot +""" + +from __future__ import annotations + +import argparse +import csv +import logging +import math +import sys +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") # headless — no display required +import matplotlib.pyplot as plt +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.model import ChordTransformer +from src.train import TrainConfig, train +from src.tokenizer import TOKEN_TO_ID + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +DATA_DIR = Path("data/processed") +CHECKPOINT = Path("checkpoints/pretrained.pt") +LOG_CSV = Path("checkpoints/pretrained.log.csv") +CURVES_PNG = Path("checkpoints/pretrained_curves.png") + +# --------------------------------------------------------------------------- +# Training config (mirrors the requested CLI invocation) +# --------------------------------------------------------------------------- + +TRAIN_CFG = TrainConfig( + data_dir=DATA_DIR, + output=CHECKPOINT, + epochs=50, + batch_size=32, + lr=3e-4, + warmup_steps=200, + seed=42, + device="auto", + # Real McGill sequences are ≤ 195 tokens (p95 = 146, mean = 92). + # Using 256 instead of the 512 default cuts attention cost ~4x. + max_seq_len=256, +) + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_curves(log_csv: Path, out_png: Path) -> None: + epochs, train_losses, val_losses, val_ppls = [], [], [], [] + with open(log_csv, newline="") as fh: + for row in csv.DictReader(fh): + epochs.append(int(row["epoch"])) + train_losses.append(float(row["train_loss"])) + val_losses.append(float(row["val_loss"])) + val_ppls.append(float(row["val_ppl"])) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4)) + + ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5) + ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5) + best_epoch = epochs[val_losses.index(min(val_losses))] + ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8, + label=f"best epoch {best_epoch}") + ax1.set_xlabel("epoch") + ax1.set_ylabel("cross-entropy loss") + ax1.set_title("Pre-training loss") + ax1.legend() + ax1.grid(True, alpha=0.3) + + ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5) + ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8) + ax2.set_xlabel("epoch") + ax2.set_ylabel("perplexity") + ax2.set_title("Val perplexity") + ax2.grid(True, alpha=0.3) + + fig.tight_layout() + out_png.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_png, dpi=150) + plt.close(fig) + print(f"[plot] saved → {out_png}") + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + +def print_report(log_csv: Path, checkpoint: Path) -> None: + rows = [] + with open(log_csv, newline="") as fh: + rows = list(csv.DictReader(fh)) + + if not rows: + print("[report] log CSV is empty — nothing to report") + return + + val_losses = [float(r["val_loss"]) for r in rows] + best_idx = val_losses.index(min(val_losses)) + best_row = rows[best_idx] + + # Convergence heuristic: first epoch where val loss is within 1 % of best + best_loss = float(best_row["val_loss"]) + conv_epoch = next( + (int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01), + int(best_row["epoch"]), + ) + + # Parameter count from checkpoint + n_params = None + if checkpoint.exists(): + ckpt = torch.load(checkpoint, weights_only=True) + mcfg = ckpt["model_config"] + model = ChordTransformer(**mcfg) + tied = model.token_emb.weight.numel() + n_params = sum(p.numel() for p in model.parameters()) - tied + + print() + print("=" * 52) + print(" PRE-TRAINING REPORT") + print("=" * 52) + print(f" Total epochs run : {len(rows)}") + print(f" Best epoch (val loss) : {best_row['epoch']}") + print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)") + print(f" Best val loss : {best_loss:.4f}") + print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}") + print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}") + if n_params is not None: + print(f" Unique parameters : {n_params:,}") + print(f" Checkpoint : {checkpoint}") + print(f" Log CSV : {log_csv}") + print("=" * 52) + print() + + # Full epoch table for copy-paste into the report + print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}") + print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") + for r in rows: + marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" + print( + f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" + f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" + f" {float(r['lr']):>10.2e}{marker}" + ) + print() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--skip-training", action="store_true", + help="Skip training; re-plot and report from existing CSV.") + args = ap.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + + if not args.skip_training: + if not DATA_DIR.exists(): + print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) + print("Run prepare_data.py first.", file=sys.stderr) + sys.exit(1) + import pathlib + n_train = len(list((DATA_DIR / "train").glob("*.pt"))) + n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size + # Rough estimate: ~1.5 s/batch on CPU with seq_len≈196, faster on GPU. + est_epoch_s = n_batches * 1.5 + device_label = "GPU" if __import__("torch").cuda.is_available() else "CPU" + print( + f"[run_pretrain] {n_train} train files, {n_batches} batches/epoch\n" + f"[run_pretrain] estimated time on {device_label}: " + f"~{est_epoch_s/60:.0f} min/epoch, " + f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n" + f"[run_pretrain] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n" + ) + train(TRAIN_CFG) + else: + if not LOG_CSV.exists(): + print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr) + sys.exit(1) + print(f"[skip-training] using existing log: {LOG_CSV}") + + plot_curves(LOG_CSV, CURVES_PNG) + print_report(LOG_CSV, CHECKPOINT) + + +if __name__ == "__main__": + main() diff --git a/src/train.py b/src/train.py index 3f48c21..d8cf408 100644 --- a/src/train.py +++ b/src/train.py @@ -267,8 +267,10 @@ def train(cfg: TrainConfig) -> Path: # ------------------------------------------------------------------ output_path = Path(cfg.output) output_path.parent.mkdir(parents=True, exist_ok=True) - checkpoint_path = output_path.with_suffix(".pt") - log_csv_path = Path(str(output_path) + ".log.csv") + # Normalise: --output may be given with or without the .pt extension. + stem = output_path.with_suffix("") if output_path.suffix == ".pt" else output_path + checkpoint_path = stem.with_suffix(".pt") + log_csv_path = Path(str(stem) + ".log.csv") with open(log_csv_path, "w", newline="") as fh: writer = csv.writer(fh)