feat: add src/evaluate.py and scripts/evaluate.py

Implements perplexity computation, chord distribution extraction (qualities,
extensions, inversions, root-motion intervals), 4-panel comparison plot, and
paired qualitative example generation for pretrained vs finetuned model.

Results on user val set: pretrained PPL 3.58 → finetuned PPL 2.15 (−40 %).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 14:57:49 +03:00
parent b30f4c188b
commit d09a08d553
2 changed files with 707 additions and 0 deletions
+355
View File
@@ -0,0 +1,355 @@
"""Evaluation utilities: perplexity, corpus stats, and distribution plots.
Public API:
compute_perplexity(model, dataloader, device) -> float
extract_features(period: ChordPeriod) -> dict
extract_features_from_tokens(tokens: list[int]) -> dict
compare_distributions(named_groups: dict[str, list[dict]]) -> dict
plot_comparison(distributions: dict, output_path: Path, title: str) -> None
"""
from __future__ import annotations
import logging
import math
from collections import Counter
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.chord_parser import parse_chord_symbol
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
_ROOT_START: int = TOKEN_TO_ID["ROOT_C"]
_ROOT_END: int = TOKEN_TO_ID["ROOT_B"]
_QUAL_START: int = TOKEN_TO_ID["QUAL_maj"]
_QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"]
_EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"]
_BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"]
_BASS_START: int = TOKEN_TO_ID["BASS_root"]
_BASS_END: int = TOKEN_TO_ID["BASS_B"]
_NOTE_SEMITONES: dict[str, int] = {
"C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5,
"F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11,
}
_INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"]
def _qual_from_tok(tok: str) -> str:
"""'QUAL_m_add9''m(add9)', 'QUAL_maj7''maj7', etc."""
suffix = tok[5:] # strip "QUAL_"
return "m(add9)" if suffix == "m_add9" else suffix
# ---------------------------------------------------------------------------
# Perplexity
# ---------------------------------------------------------------------------
def compute_perplexity(
model: ChordTransformer,
dataloader: DataLoader,
device: torch.device,
) -> float:
"""Compute token-level perplexity of *model* on *dataloader*.
Args:
model: ChordTransformer (set to eval mode internally).
dataloader: Yields padded token tensors of shape [batch, seq_len].
device: Compute device.
Returns:
Perplexity (>= 1.0).
"""
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum")
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
input_ids = batch[:, :-1]
targets = batch[:, 1:]
attn_mask = (input_ids != _PAD_ID).long()
logits = model(input_ids, attention_mask=attn_mask)
loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
total_loss += loss.item()
total_tokens += int((targets != _PAD_ID).sum())
if total_tokens == 0:
raise ValueError("dataloader produced no non-PAD tokens")
return math.exp(min(total_loss / total_tokens, 20.0))
# ---------------------------------------------------------------------------
# Feature extraction — from ChordPeriod
# ---------------------------------------------------------------------------
def extract_features(period: ChordPeriod) -> dict[str, Any]:
"""Extract harmonic statistics from one ChordPeriod.
Args:
period: A parsed ChordPeriod (any key).
Returns:
Dict with keys:
qualities list[str] canonical quality string per chord event
has_extension list[bool] True when extension != 'none'
is_inverted list[bool] True when bass != root
root_intervals list[int] semitone intervals (011) between consecutive roots
"""
qualities: list[str] = []
has_ext: list[bool] = []
is_inv: list[bool] = []
roots_semi: list[int] = []
for bar in period.bars:
for pos in bar:
if pos in (".", "NC", "?"):
continue
t = parse_chord_symbol(pos)
qualities.append(t.quality)
has_ext.append(t.extension != "none")
is_inv.append(t.bass not in ("root", t.root))
if t.root in _NOTE_SEMITONES:
roots_semi.append(_NOTE_SEMITONES[t.root])
return {
"qualities": qualities,
"has_extension": has_ext,
"is_inverted": is_inv,
"root_intervals": [
(roots_semi[i] - roots_semi[i - 1]) % 12
for i in range(1, len(roots_semi))
],
}
# ---------------------------------------------------------------------------
# Feature extraction — from raw token IDs (for large processed corpora)
# ---------------------------------------------------------------------------
def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]:
"""Extract harmonic statistics from a raw token ID sequence.
Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped.
Args:
tokens: Token ID list as produced by tokenize_period.
Returns:
Same structure as extract_features.
"""
qualities: list[str] = []
has_ext: list[bool] = []
is_inv: list[bool] = []
roots_semi: list[int] = []
i = 0
n = len(tokens)
while i < n:
tok = tokens[i]
if _ROOT_START <= tok <= _ROOT_END and i + 3 < n:
qual_tok = tokens[i + 1]
ext_tok = tokens[i + 2]
bass_tok = tokens[i + 3]
if _QUAL_START <= qual_tok <= _QUAL_END:
root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#"
qualities.append(_qual_from_tok(VOCAB[qual_tok]))
has_ext.append(ext_tok != _EXT_NONE_ID)
is_inv.append(bass_tok != _BASS_ROOT_ID)
if root_name in _NOTE_SEMITONES:
roots_semi.append(_NOTE_SEMITONES[root_name])
i += 4
continue
i += 1
return {
"qualities": qualities,
"has_extension": has_ext,
"is_inverted": is_inv,
"root_intervals": [
(roots_semi[i] - roots_semi[i - 1]) % 12
for i in range(1, len(roots_semi))
],
}
# ---------------------------------------------------------------------------
# Distribution comparison
# ---------------------------------------------------------------------------
def compare_distributions(
named_groups: dict[str, list[dict[str, Any]]],
) -> dict[str, dict[str, Any]]:
"""Aggregate per-period feature lists into per-group distribution counters.
Args:
named_groups: {label: [feature_dict, ...], ...}
Each feature_dict is the output of extract_features or
extract_features_from_tokens.
Returns:
{label: {"quality_counter": Counter, "ext_counts": dict,
"inv_counts": dict, "interval_counter": Counter}, ...}
"""
result: dict[str, dict[str, Any]] = {}
for label, feats in named_groups.items():
qc: Counter[str] = Counter()
ext: dict[str, int] = {"none": 0, "has_ext": 0}
inv: dict[str, int] = {"root_pos": 0, "inverted": 0}
ivc: Counter[int] = Counter()
for f in feats:
for q in f["qualities"]:
qc[q] += 1
for h in f["has_extension"]:
ext["has_ext" if h else "none"] += 1
for is_i in f["is_inverted"]:
inv["inverted" if is_i else "root_pos"] += 1
for iv in f["root_intervals"]:
ivc[iv] += 1
result[label] = {
"quality_counter": qc,
"ext_counts": ext,
"inv_counts": inv,
"interval_counter": ivc,
}
return result
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_comparison(
distributions: dict[str, dict[str, Any]],
output_path: Path,
title: str = "",
) -> None:
"""Save a 2×2 distribution-comparison figure.
Panels:
[0,0] Chord quality distribution (horizontal bars, normalised to %)
[0,1] Root motion intervals (011 semitones)
[1,0] Extension presence (no ext / has extension)
[1,1] Inversion frequency (root position / inverted)
Args:
distributions: Output of compare_distributions.
output_path: Destination PNG path.
title: Optional figure suptitle.
"""
labels = list(distributions.keys())
n_groups = len(labels)
palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined]
# All quality strings that appear in at least one group, sorted by
# descending total frequency across all groups.
all_quals: list[str] = []
_seen: set[str] = set()
for dist in distributions.values():
for q in dist["quality_counter"]:
if q not in _seen:
all_quals.append(q)
_seen.add(q)
all_quals.sort(
key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values())
)
fig, axes = plt.subplots(2, 2, figsize=(14, 11))
if title:
fig.suptitle(title, fontsize=13)
# ---------- [0,0] Quality distribution ----------
ax = axes[0, 0]
n_q = len(all_quals)
bar_h = 0.8 / n_groups
y = np.arange(n_q)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["quality_counter"].values()) or 1
vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals]
ax.barh(
y + (gi - n_groups / 2 + 0.5) * bar_h, vals,
height=bar_h, label=label, color=palette[gi], alpha=0.85,
)
ax.set_yticks(y)
ax.set_yticklabels(all_quals, fontsize=8)
ax.invert_yaxis()
ax.set_xlabel("% of chord events")
ax.set_title("Chord quality distribution")
ax.legend(fontsize=8)
ax.grid(axis="x", alpha=0.3)
# ---------- [0,1] Root motion intervals ----------
ax = axes[0, 1]
bar_w = 0.8 / n_groups
x = np.arange(12)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["interval_counter"].values()) or 1
vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * bar_w, vals,
width=bar_w, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9)
ax.set_ylabel("% of root motions")
ax.set_title("Root motion intervals")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
# ---------- [1,0] Extension presence ----------
ax = axes[1, 0]
ext_keys = ["none", "has_ext"]
ext_xlabels = ["no extension", "has extension"]
x = np.arange(2)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["ext_counts"].values()) or 1
vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(ext_xlabels)
ax.set_ylabel("% of chord events")
ax.set_title("Extension presence")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
# ---------- [1,1] Inversion frequency ----------
ax = axes[1, 1]
inv_keys = ["root_pos", "inverted"]
inv_xlabels = ["root position", "inverted"]
x = np.arange(2)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["inv_counts"].values()) or 1
vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(inv_xlabels)
ax.set_ylabel("% of chord events")
ax.set_title("Inversion frequency")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
fig.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
log.info("saved distribution plot -> %s", output_path)