"""Evaluate and compare pretrained vs fine-tuned ChordTransformer. Usage: python scripts/evaluate.py \\ --pretrained checkpoints/pretrained.pt \\ --finetuned checkpoints/finetuned.pt \\ [--user-dir data/raw_user/H1K0] \\ [--eval-dir data/processed/user/val] \\ [--mcgill-dir data/processed/mcgill/train] \\ [--output-dir output/eval] \\ [--n-examples 3] [--mcgill-sample 500] \\ [--temperature 1.0] [--top-p 0.9] \\ [--tempo 90] [--seed 42] [--device auto] Outputs: /perplexity.txt perplexity table (pretrained vs finetuned) /distributions.png 4-panel chord distribution comparison /examples/ n paired .chord and .mid files per model """ from __future__ import annotations import argparse import logging import random import sys from dataclasses import replace from pathlib import Path import torch from torch.utils.data import DataLoader sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.dataset import ChordDataset from src.evaluate import ( compare_distributions, compute_perplexity, extract_features, extract_features_from_tokens, plot_comparison, ) from src.generate import generate_period from src.midi_export import chord_file_to_midi from src.model import ChordTransformer from src.tokenizer import TOKEN_TO_ID, parse_chord_file, write_chord_file log = logging.getLogger(__name__) _PAD_ID: int = TOKEN_TO_ID[""] # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _load_model(path: Path, device: torch.device) -> ChordTransformer: ckpt = torch.load(path, map_location=device, weights_only=True) model = ChordTransformer(**ckpt["model_config"]) model.load_state_dict(ckpt["model_state"]) model.to(device) model.eval() return model def _resolve_device(spec: str) -> torch.device: if spec == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(spec) def _make_loader(data_dir: Path, max_seq_len: int, batch_size: int) -> DataLoader: ds = ChordDataset(data_dir, max_length=max_seq_len) return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) def _max_seq_len(model: ChordTransformer) -> int: return model.max_seq_len def _generate_features( model: ChordTransformer, chord_files: list[Path], device: torch.device, temperature: float, top_p: float, seed: int, ) -> list[dict]: """Generate one period per chord file (matching its metadata) and extract features.""" features = [] model.eval() for i, path in enumerate(chord_files): try: ref = parse_chord_file(path) except Exception as exc: log.warning("skipping %s: %s", path.name, exc) continue mode = ref.key.split("_")[-1] try: period = generate_period( model=model, mode=mode, time=ref.time, subdivision=ref.subdivision, style=ref.style, function=ref.function, key=ref.key, temperature=temperature, top_p=top_p, seed=seed + i, ) features.append(extract_features(period)) except Exception as exc: log.warning("generation failed for %s: %s", path.name, exc) return features def _load_pt_features(pt_files: list[Path]) -> list[dict]: """Load token sequences from .pt files and extract features.""" features = [] for path in pt_files: try: data = torch.load(path, weights_only=True) tokens = data["tokens"].tolist() features.append(extract_features_from_tokens(tokens)) except Exception as exc: log.warning("failed to load %s: %s", path.name, exc) return features # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) ap.add_argument("--pretrained", type=Path, required=True, help="Pre-trained checkpoint (.pt).") ap.add_argument("--finetuned", type=Path, required=True, help="Fine-tuned checkpoint (.pt).") ap.add_argument("--user-dir", type=Path, default=Path("data/raw_user/H1K0"), help="Directory with user .chord files (default: data/raw_user/H1K0).") ap.add_argument("--eval-dir", type=Path, default=None, help="Processed .pt files for perplexity eval. " "Defaults to data/processed/user/holdout if non-empty, " "otherwise data/processed/user/val.") ap.add_argument("--mcgill-dir", type=Path, default=Path("data/processed/mcgill/train"), help="Processed McGill .pt files for distribution stats " "(default: data/processed/mcgill/train).") ap.add_argument("--output-dir", type=Path, default=Path("output/eval"), help="Output directory (default: output/eval).") ap.add_argument("--n-examples", type=int, default=3, help="Number of paired qualitative examples to generate (default: 3).") ap.add_argument("--mcgill-sample", type=int, default=500, help="Max McGill files for distribution stats (default: 500).") ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (default: 1.0).") ap.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling cutoff (default: 0.9).") ap.add_argument("--tempo", type=int, default=90, help="MIDI tempo for example files in BPM (default: 90).") ap.add_argument("--seed", type=int, default=42, help="Base random seed (default: 42).") ap.add_argument("--device", default="auto", help="Compute device: cpu, cuda, or auto (default: auto).") args = ap.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) # Validate inputs for flag, path in [("--pretrained", args.pretrained), ("--finetuned", args.finetuned)]: if not path.exists(): print(f"ERROR: {flag} not found: {path}", file=sys.stderr) sys.exit(1) device = _resolve_device(args.device) args.output_dir.mkdir(parents=True, exist_ok=True) # ------------------------------------------------------------------ # Load models # ------------------------------------------------------------------ print(f"[evaluate] loading pretrained <- {args.pretrained}") pretrained = _load_model(args.pretrained, device) print(f"[evaluate] loading finetuned <- {args.finetuned}") finetuned = _load_model(args.finetuned, device) max_seq = _max_seq_len(pretrained) # ------------------------------------------------------------------ # Perplexity # ------------------------------------------------------------------ holdout_dir = Path("data/processed/user/holdout") val_dir = Path("data/processed/user/val") if args.eval_dir is not None: eval_dir = args.eval_dir eval_label = str(eval_dir) elif holdout_dir.exists() and any(holdout_dir.glob("*.pt")): eval_dir = holdout_dir eval_label = str(holdout_dir) elif val_dir.exists() and any(val_dir.glob("*.pt")): eval_dir = val_dir eval_label = f"{val_dir} [fallback — holdout is empty]" print(f"[evaluate] WARNING: holdout is empty, using val set for perplexity") else: print("[evaluate] WARNING: no eval data found — skipping perplexity") eval_dir = None eval_label = "N/A" ppl_pretrained: float | None = None ppl_finetuned: float | None = None if eval_dir is not None: n_eval = len(list(eval_dir.glob("*.pt"))) print(f"[evaluate] computing perplexity on {n_eval} files in {eval_dir} ...") loader = _make_loader(eval_dir, max_seq, batch_size=8) ppl_pretrained = compute_perplexity(pretrained, loader, device) ppl_finetuned = compute_perplexity(finetuned, loader, device) print(f"[evaluate] pretrained PPL = {ppl_pretrained:.2f}") print(f"[evaluate] finetuned PPL = {ppl_finetuned:.2f}") # Save perplexity report ppl_path = args.output_dir / "perplexity.txt" with open(ppl_path, "w", encoding="utf-8") as fh: fh.write("=" * 52 + "\n") fh.write(" PERPLEXITY EVALUATION\n") fh.write("=" * 52 + "\n") fh.write(f" Eval set : {eval_label}\n\n") if ppl_pretrained is not None and ppl_finetuned is not None: improvement = (ppl_pretrained - ppl_finetuned) / ppl_pretrained * 100 fh.write(f" pretrained PPL = {ppl_pretrained:8.2f}\n") fh.write(f" finetuned PPL = {ppl_finetuned:8.2f}\n\n") fh.write(f" improvement = {improvement:+.1f}%\n") else: fh.write(" (no eval data available)\n") fh.write("=" * 52 + "\n") print(f"[evaluate] saved -> {ppl_path}") # ------------------------------------------------------------------ # Distribution stats # ------------------------------------------------------------------ # 1. User corpus (ground truth) user_files = sorted(args.user_dir.glob("*.chord")) if args.user_dir.exists() else [] if not user_files: print(f"[evaluate] WARNING: no .chord files found in {args.user_dir}") user_feats = [extract_features(parse_chord_file(p)) for p in user_files] print(f"[evaluate] user corpus: {len(user_feats)} periods") # 2. McGill sample (optional) mcgill_feats: list[dict] = [] if args.mcgill_dir.exists(): mcgill_files = sorted(args.mcgill_dir.glob("*.pt")) if len(mcgill_files) > args.mcgill_sample: rng = random.Random(args.seed) mcgill_files = rng.sample(mcgill_files, args.mcgill_sample) mcgill_feats = _load_pt_features(mcgill_files) print(f"[evaluate] McGill sample: {len(mcgill_feats)} periods") else: print(f"[evaluate] McGill dir not found ({args.mcgill_dir}) — skipping") # 3. Generated samples (one per user file, matching conditions) print(f"[evaluate] generating {len(user_files)} samples from pretrained ...") pre_feats = _generate_features( pretrained, user_files, device, args.temperature, args.top_p, args.seed ) print(f"[evaluate] generating {len(user_files)} samples from finetuned ...") ft_feats = _generate_features( finetuned, user_files, device, args.temperature, args.top_p, args.seed ) # Build named groups for comparison named_groups: dict[str, list[dict]] = {} if mcgill_feats: named_groups["McGill (pretrain corpus)"] = mcgill_feats named_groups["user corpus"] = user_feats named_groups["pretrained output"] = pre_feats named_groups["finetuned output"] = ft_feats distributions = compare_distributions(named_groups) dist_path = args.output_dir / "distributions.png" plot_comparison( distributions, dist_path, title="Chord distribution: corpus vs model outputs", ) print(f"[evaluate] saved -> {dist_path}") # ------------------------------------------------------------------ # Qualitative examples # ------------------------------------------------------------------ examples_dir = args.output_dir / "examples" examples_dir.mkdir(parents=True, exist_ok=True) # Use val .chord files as conditions for qualitative examples. # Fall back to first n files from user corpus if val dir not found. val_chord_dir = args.user_dir # user .chord files are the reference example_files = user_files[: args.n_examples] if not example_files: print("[evaluate] no user files found — skipping qualitative examples") else: print(f"[evaluate] generating {len(example_files)} paired qualitative examples ...") for ex_i, ref_path in enumerate(example_files): ref = parse_chord_file(ref_path) mode = ref.key.split("_")[-1] stem = ref_path.stem ex_seed = args.seed + 1000 + ex_i # separate seed range from distribution samples for model_name, model in [("pretrained", pretrained), ("finetuned", finetuned)]: try: gen = generate_period( model=model, mode=mode, time=ref.time, subdivision=ref.subdivision, style=ref.style, function=ref.function, key=ref.key, temperature=args.temperature, top_p=args.top_p, seed=ex_seed, ) gen = replace(gen, title=f"Generated — {stem} ({model_name})") chord_out = examples_dir / f"{model_name}_{ex_i + 1:02d}_{stem}.chord" write_chord_file(gen, chord_out) midi_out = chord_out.with_suffix(".mid") write_chord_file(gen, chord_out) # ensure file exists for midi export chord_file_to_midi(chord_out, midi_out, tempo=args.tempo) print(f"[evaluate] {chord_out.name}") print(f"[evaluate] {midi_out.name}") print(f"[evaluate] {len(gen.bars)} bars: " + " ".join(" ".join(b) for b in gen.bars)) print() except Exception as exc: log.warning("example %d %s failed: %s", ex_i + 1, model_name, exc) print(f"\n[evaluate] done. Outputs in {args.output_dir}/") if __name__ == "__main__": main()