feat: add src/evaluate.py and scripts/evaluate.py

Implements perplexity computation, chord distribution extraction (qualities,
extensions, inversions, root-motion intervals), 4-panel comparison plot, and
paired qualitative example generation for pretrained vs finetuned model.

Results on user val set: pretrained PPL 3.58 → finetuned PPL 2.15 (−40 %).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 14:57:49 +03:00
parent b30f4c188b
commit d09a08d553
2 changed files with 707 additions and 0 deletions
+352
View File
@@ -0,0 +1,352 @@
"""Evaluate and compare pretrained vs fine-tuned ChordTransformer.
Usage:
python scripts/evaluate.py \\
--pretrained checkpoints/pretrained.pt \\
--finetuned checkpoints/finetuned.pt \\
[--user-dir data/raw_user/H1K0] \\
[--eval-dir data/processed/user/val] \\
[--mcgill-dir data/processed/mcgill/train] \\
[--output-dir output/eval] \\
[--n-examples 3] [--mcgill-sample 500] \\
[--temperature 1.0] [--top-p 0.9] \\
[--tempo 90] [--seed 42] [--device auto]
Outputs:
<output-dir>/perplexity.txt perplexity table (pretrained vs finetuned)
<output-dir>/distributions.png 4-panel chord distribution comparison
<output-dir>/examples/ n paired .chord and .mid files per model
"""
from __future__ import annotations
import argparse
import logging
import random
import sys
from dataclasses import replace
from pathlib import Path
import torch
from torch.utils.data import DataLoader
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.dataset import ChordDataset
from src.evaluate import (
compare_distributions,
compute_perplexity,
extract_features,
extract_features_from_tokens,
plot_comparison,
)
from src.generate import generate_period
from src.midi_export import chord_file_to_midi
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, write_chord_file
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_model(path: Path, device: torch.device) -> ChordTransformer:
ckpt = torch.load(path, map_location=device, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()
return model
def _resolve_device(spec: str) -> torch.device:
if spec == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(spec)
def _make_loader(data_dir: Path, max_seq_len: int, batch_size: int) -> DataLoader:
ds = ChordDataset(data_dir, max_length=max_seq_len)
return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
def _max_seq_len(model: ChordTransformer) -> int:
return model.max_seq_len
def _generate_features(
model: ChordTransformer,
chord_files: list[Path],
device: torch.device,
temperature: float,
top_p: float,
seed: int,
) -> list[dict]:
"""Generate one period per chord file (matching its metadata) and extract features."""
features = []
model.eval()
for i, path in enumerate(chord_files):
try:
ref = parse_chord_file(path)
except Exception as exc:
log.warning("skipping %s: %s", path.name, exc)
continue
mode = ref.key.split("_")[-1]
try:
period = generate_period(
model=model,
mode=mode,
time=ref.time,
subdivision=ref.subdivision,
style=ref.style,
function=ref.function,
key=ref.key,
temperature=temperature,
top_p=top_p,
seed=seed + i,
)
features.append(extract_features(period))
except Exception as exc:
log.warning("generation failed for %s: %s", path.name, exc)
return features
def _load_pt_features(pt_files: list[Path]) -> list[dict]:
"""Load token sequences from .pt files and extract features."""
features = []
for path in pt_files:
try:
data = torch.load(path, weights_only=True)
tokens = data["tokens"].tolist()
features.append(extract_features_from_tokens(tokens))
except Exception as exc:
log.warning("failed to load %s: %s", path.name, exc)
return features
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
ap.add_argument("--pretrained", type=Path, required=True,
help="Pre-trained checkpoint (.pt).")
ap.add_argument("--finetuned", type=Path, required=True,
help="Fine-tuned checkpoint (.pt).")
ap.add_argument("--user-dir", type=Path, default=Path("data/raw_user/H1K0"),
help="Directory with user .chord files (default: data/raw_user/H1K0).")
ap.add_argument("--eval-dir", type=Path, default=None,
help="Processed .pt files for perplexity eval. "
"Defaults to data/processed/user/holdout if non-empty, "
"otherwise data/processed/user/val.")
ap.add_argument("--mcgill-dir", type=Path,
default=Path("data/processed/mcgill/train"),
help="Processed McGill .pt files for distribution stats "
"(default: data/processed/mcgill/train).")
ap.add_argument("--output-dir", type=Path, default=Path("output/eval"),
help="Output directory (default: output/eval).")
ap.add_argument("--n-examples", type=int, default=3,
help="Number of paired qualitative examples to generate (default: 3).")
ap.add_argument("--mcgill-sample", type=int, default=500,
help="Max McGill files for distribution stats (default: 500).")
ap.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature (default: 1.0).")
ap.add_argument("--top-p", type=float, default=0.9,
help="Nucleus sampling cutoff (default: 0.9).")
ap.add_argument("--tempo", type=int, default=90,
help="MIDI tempo for example files in BPM (default: 90).")
ap.add_argument("--seed", type=int, default=42,
help="Base random seed (default: 42).")
ap.add_argument("--device", default="auto",
help="Compute device: cpu, cuda, or auto (default: auto).")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
# Validate inputs
for flag, path in [("--pretrained", args.pretrained), ("--finetuned", args.finetuned)]:
if not path.exists():
print(f"ERROR: {flag} not found: {path}", file=sys.stderr)
sys.exit(1)
device = _resolve_device(args.device)
args.output_dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# Load models
# ------------------------------------------------------------------
print(f"[evaluate] loading pretrained <- {args.pretrained}")
pretrained = _load_model(args.pretrained, device)
print(f"[evaluate] loading finetuned <- {args.finetuned}")
finetuned = _load_model(args.finetuned, device)
max_seq = _max_seq_len(pretrained)
# ------------------------------------------------------------------
# Perplexity
# ------------------------------------------------------------------
holdout_dir = Path("data/processed/user/holdout")
val_dir = Path("data/processed/user/val")
if args.eval_dir is not None:
eval_dir = args.eval_dir
eval_label = str(eval_dir)
elif holdout_dir.exists() and any(holdout_dir.glob("*.pt")):
eval_dir = holdout_dir
eval_label = str(holdout_dir)
elif val_dir.exists() and any(val_dir.glob("*.pt")):
eval_dir = val_dir
eval_label = f"{val_dir} [fallback — holdout is empty]"
print(f"[evaluate] WARNING: holdout is empty, using val set for perplexity")
else:
print("[evaluate] WARNING: no eval data found — skipping perplexity")
eval_dir = None
eval_label = "N/A"
ppl_pretrained: float | None = None
ppl_finetuned: float | None = None
if eval_dir is not None:
n_eval = len(list(eval_dir.glob("*.pt")))
print(f"[evaluate] computing perplexity on {n_eval} files in {eval_dir} ...")
loader = _make_loader(eval_dir, max_seq, batch_size=8)
ppl_pretrained = compute_perplexity(pretrained, loader, device)
ppl_finetuned = compute_perplexity(finetuned, loader, device)
print(f"[evaluate] pretrained PPL = {ppl_pretrained:.2f}")
print(f"[evaluate] finetuned PPL = {ppl_finetuned:.2f}")
# Save perplexity report
ppl_path = args.output_dir / "perplexity.txt"
with open(ppl_path, "w", encoding="utf-8") as fh:
fh.write("=" * 52 + "\n")
fh.write(" PERPLEXITY EVALUATION\n")
fh.write("=" * 52 + "\n")
fh.write(f" Eval set : {eval_label}\n\n")
if ppl_pretrained is not None and ppl_finetuned is not None:
improvement = (ppl_pretrained - ppl_finetuned) / ppl_pretrained * 100
fh.write(f" pretrained PPL = {ppl_pretrained:8.2f}\n")
fh.write(f" finetuned PPL = {ppl_finetuned:8.2f}\n\n")
fh.write(f" improvement = {improvement:+.1f}%\n")
else:
fh.write(" (no eval data available)\n")
fh.write("=" * 52 + "\n")
print(f"[evaluate] saved -> {ppl_path}")
# ------------------------------------------------------------------
# Distribution stats
# ------------------------------------------------------------------
# 1. User corpus (ground truth)
user_files = sorted(args.user_dir.glob("*.chord")) if args.user_dir.exists() else []
if not user_files:
print(f"[evaluate] WARNING: no .chord files found in {args.user_dir}")
user_feats = [extract_features(parse_chord_file(p)) for p in user_files]
print(f"[evaluate] user corpus: {len(user_feats)} periods")
# 2. McGill sample (optional)
mcgill_feats: list[dict] = []
if args.mcgill_dir.exists():
mcgill_files = sorted(args.mcgill_dir.glob("*.pt"))
if len(mcgill_files) > args.mcgill_sample:
rng = random.Random(args.seed)
mcgill_files = rng.sample(mcgill_files, args.mcgill_sample)
mcgill_feats = _load_pt_features(mcgill_files)
print(f"[evaluate] McGill sample: {len(mcgill_feats)} periods")
else:
print(f"[evaluate] McGill dir not found ({args.mcgill_dir}) — skipping")
# 3. Generated samples (one per user file, matching conditions)
print(f"[evaluate] generating {len(user_files)} samples from pretrained ...")
pre_feats = _generate_features(
pretrained, user_files, device, args.temperature, args.top_p, args.seed
)
print(f"[evaluate] generating {len(user_files)} samples from finetuned ...")
ft_feats = _generate_features(
finetuned, user_files, device, args.temperature, args.top_p, args.seed
)
# Build named groups for comparison
named_groups: dict[str, list[dict]] = {}
if mcgill_feats:
named_groups["McGill (pretrain corpus)"] = mcgill_feats
named_groups["user corpus"] = user_feats
named_groups["pretrained output"] = pre_feats
named_groups["finetuned output"] = ft_feats
distributions = compare_distributions(named_groups)
dist_path = args.output_dir / "distributions.png"
plot_comparison(
distributions,
dist_path,
title="Chord distribution: corpus vs model outputs",
)
print(f"[evaluate] saved -> {dist_path}")
# ------------------------------------------------------------------
# Qualitative examples
# ------------------------------------------------------------------
examples_dir = args.output_dir / "examples"
examples_dir.mkdir(parents=True, exist_ok=True)
# Use val .chord files as conditions for qualitative examples.
# Fall back to first n files from user corpus if val dir not found.
val_chord_dir = args.user_dir # user .chord files are the reference
example_files = user_files[: args.n_examples]
if not example_files:
print("[evaluate] no user files found — skipping qualitative examples")
else:
print(f"[evaluate] generating {len(example_files)} paired qualitative examples ...")
for ex_i, ref_path in enumerate(example_files):
ref = parse_chord_file(ref_path)
mode = ref.key.split("_")[-1]
stem = ref_path.stem
ex_seed = args.seed + 1000 + ex_i # separate seed range from distribution samples
for model_name, model in [("pretrained", pretrained), ("finetuned", finetuned)]:
try:
gen = generate_period(
model=model,
mode=mode,
time=ref.time,
subdivision=ref.subdivision,
style=ref.style,
function=ref.function,
key=ref.key,
temperature=args.temperature,
top_p=args.top_p,
seed=ex_seed,
)
gen = replace(gen, title=f"Generated — {stem} ({model_name})")
chord_out = examples_dir / f"{model_name}_{ex_i + 1:02d}_{stem}.chord"
write_chord_file(gen, chord_out)
midi_out = chord_out.with_suffix(".mid")
write_chord_file(gen, chord_out) # ensure file exists for midi export
chord_file_to_midi(chord_out, midi_out, tempo=args.tempo)
print(f"[evaluate] {chord_out.name}")
print(f"[evaluate] {midi_out.name}")
print(f"[evaluate] {len(gen.bars)} bars: "
+ " ".join(" ".join(b) for b in gen.bars))
print()
except Exception as exc:
log.warning("example %d %s failed: %s", ex_i + 1, model_name, exc)
print(f"\n[evaluate] done. Outputs in {args.output_dir}/")
if __name__ == "__main__":
main()