d09a08d553
Implements perplexity computation, chord distribution extraction (qualities, extensions, inversions, root-motion intervals), 4-panel comparison plot, and paired qualitative example generation for pretrained vs finetuned model. Results on user val set: pretrained PPL 3.58 → finetuned PPL 2.15 (−40 %). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""Evaluation utilities: perplexity, corpus stats, and distribution plots.
|
||
|
||
Public API:
|
||
compute_perplexity(model, dataloader, device) -> float
|
||
extract_features(period: ChordPeriod) -> dict
|
||
extract_features_from_tokens(tokens: list[int]) -> dict
|
||
compare_distributions(named_groups: dict[str, list[dict]]) -> dict
|
||
plot_comparison(distributions: dict, output_path: Path, title: str) -> None
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import math
|
||
from collections import Counter
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader
|
||
|
||
from src.chord_parser import parse_chord_symbol
|
||
from src.model import ChordTransformer
|
||
from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||
_ROOT_START: int = TOKEN_TO_ID["ROOT_C"]
|
||
_ROOT_END: int = TOKEN_TO_ID["ROOT_B"]
|
||
_QUAL_START: int = TOKEN_TO_ID["QUAL_maj"]
|
||
_QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"]
|
||
_EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"]
|
||
_BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"]
|
||
_BASS_START: int = TOKEN_TO_ID["BASS_root"]
|
||
_BASS_END: int = TOKEN_TO_ID["BASS_B"]
|
||
|
||
_NOTE_SEMITONES: dict[str, int] = {
|
||
"C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5,
|
||
"F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11,
|
||
}
|
||
_INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"]
|
||
|
||
|
||
def _qual_from_tok(tok: str) -> str:
|
||
"""'QUAL_m_add9' → 'm(add9)', 'QUAL_maj7' → 'maj7', etc."""
|
||
suffix = tok[5:] # strip "QUAL_"
|
||
return "m(add9)" if suffix == "m_add9" else suffix
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Perplexity
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def compute_perplexity(
|
||
model: ChordTransformer,
|
||
dataloader: DataLoader,
|
||
device: torch.device,
|
||
) -> float:
|
||
"""Compute token-level perplexity of *model* on *dataloader*.
|
||
|
||
Args:
|
||
model: ChordTransformer (set to eval mode internally).
|
||
dataloader: Yields padded token tensors of shape [batch, seq_len].
|
||
device: Compute device.
|
||
|
||
Returns:
|
||
Perplexity (>= 1.0).
|
||
"""
|
||
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum")
|
||
model.eval()
|
||
total_loss = 0.0
|
||
total_tokens = 0
|
||
|
||
with torch.no_grad():
|
||
for batch in dataloader:
|
||
batch = batch.to(device)
|
||
input_ids = batch[:, :-1]
|
||
targets = batch[:, 1:]
|
||
attn_mask = (input_ids != _PAD_ID).long()
|
||
logits = model(input_ids, attention_mask=attn_mask)
|
||
loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
||
total_loss += loss.item()
|
||
total_tokens += int((targets != _PAD_ID).sum())
|
||
|
||
if total_tokens == 0:
|
||
raise ValueError("dataloader produced no non-PAD tokens")
|
||
return math.exp(min(total_loss / total_tokens, 20.0))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Feature extraction — from ChordPeriod
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def extract_features(period: ChordPeriod) -> dict[str, Any]:
|
||
"""Extract harmonic statistics from one ChordPeriod.
|
||
|
||
Args:
|
||
period: A parsed ChordPeriod (any key).
|
||
|
||
Returns:
|
||
Dict with keys:
|
||
qualities list[str] canonical quality string per chord event
|
||
has_extension list[bool] True when extension != 'none'
|
||
is_inverted list[bool] True when bass != root
|
||
root_intervals list[int] semitone intervals (0–11) between consecutive roots
|
||
"""
|
||
qualities: list[str] = []
|
||
has_ext: list[bool] = []
|
||
is_inv: list[bool] = []
|
||
roots_semi: list[int] = []
|
||
|
||
for bar in period.bars:
|
||
for pos in bar:
|
||
if pos in (".", "NC", "?"):
|
||
continue
|
||
t = parse_chord_symbol(pos)
|
||
qualities.append(t.quality)
|
||
has_ext.append(t.extension != "none")
|
||
is_inv.append(t.bass not in ("root", t.root))
|
||
if t.root in _NOTE_SEMITONES:
|
||
roots_semi.append(_NOTE_SEMITONES[t.root])
|
||
|
||
return {
|
||
"qualities": qualities,
|
||
"has_extension": has_ext,
|
||
"is_inverted": is_inv,
|
||
"root_intervals": [
|
||
(roots_semi[i] - roots_semi[i - 1]) % 12
|
||
for i in range(1, len(roots_semi))
|
||
],
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Feature extraction — from raw token IDs (for large processed corpora)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]:
|
||
"""Extract harmonic statistics from a raw token ID sequence.
|
||
|
||
Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped.
|
||
|
||
Args:
|
||
tokens: Token ID list as produced by tokenize_period.
|
||
|
||
Returns:
|
||
Same structure as extract_features.
|
||
"""
|
||
qualities: list[str] = []
|
||
has_ext: list[bool] = []
|
||
is_inv: list[bool] = []
|
||
roots_semi: list[int] = []
|
||
|
||
i = 0
|
||
n = len(tokens)
|
||
while i < n:
|
||
tok = tokens[i]
|
||
if _ROOT_START <= tok <= _ROOT_END and i + 3 < n:
|
||
qual_tok = tokens[i + 1]
|
||
ext_tok = tokens[i + 2]
|
||
bass_tok = tokens[i + 3]
|
||
if _QUAL_START <= qual_tok <= _QUAL_END:
|
||
root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#"
|
||
qualities.append(_qual_from_tok(VOCAB[qual_tok]))
|
||
has_ext.append(ext_tok != _EXT_NONE_ID)
|
||
is_inv.append(bass_tok != _BASS_ROOT_ID)
|
||
if root_name in _NOTE_SEMITONES:
|
||
roots_semi.append(_NOTE_SEMITONES[root_name])
|
||
i += 4
|
||
continue
|
||
i += 1
|
||
|
||
return {
|
||
"qualities": qualities,
|
||
"has_extension": has_ext,
|
||
"is_inverted": is_inv,
|
||
"root_intervals": [
|
||
(roots_semi[i] - roots_semi[i - 1]) % 12
|
||
for i in range(1, len(roots_semi))
|
||
],
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Distribution comparison
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def compare_distributions(
|
||
named_groups: dict[str, list[dict[str, Any]]],
|
||
) -> dict[str, dict[str, Any]]:
|
||
"""Aggregate per-period feature lists into per-group distribution counters.
|
||
|
||
Args:
|
||
named_groups: {label: [feature_dict, ...], ...}
|
||
Each feature_dict is the output of extract_features or
|
||
extract_features_from_tokens.
|
||
|
||
Returns:
|
||
{label: {"quality_counter": Counter, "ext_counts": dict,
|
||
"inv_counts": dict, "interval_counter": Counter}, ...}
|
||
"""
|
||
result: dict[str, dict[str, Any]] = {}
|
||
for label, feats in named_groups.items():
|
||
qc: Counter[str] = Counter()
|
||
ext: dict[str, int] = {"none": 0, "has_ext": 0}
|
||
inv: dict[str, int] = {"root_pos": 0, "inverted": 0}
|
||
ivc: Counter[int] = Counter()
|
||
for f in feats:
|
||
for q in f["qualities"]:
|
||
qc[q] += 1
|
||
for h in f["has_extension"]:
|
||
ext["has_ext" if h else "none"] += 1
|
||
for is_i in f["is_inverted"]:
|
||
inv["inverted" if is_i else "root_pos"] += 1
|
||
for iv in f["root_intervals"]:
|
||
ivc[iv] += 1
|
||
result[label] = {
|
||
"quality_counter": qc,
|
||
"ext_counts": ext,
|
||
"inv_counts": inv,
|
||
"interval_counter": ivc,
|
||
}
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Plotting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def plot_comparison(
|
||
distributions: dict[str, dict[str, Any]],
|
||
output_path: Path,
|
||
title: str = "",
|
||
) -> None:
|
||
"""Save a 2×2 distribution-comparison figure.
|
||
|
||
Panels:
|
||
[0,0] Chord quality distribution (horizontal bars, normalised to %)
|
||
[0,1] Root motion intervals (0–11 semitones)
|
||
[1,0] Extension presence (no ext / has extension)
|
||
[1,1] Inversion frequency (root position / inverted)
|
||
|
||
Args:
|
||
distributions: Output of compare_distributions.
|
||
output_path: Destination PNG path.
|
||
title: Optional figure suptitle.
|
||
"""
|
||
labels = list(distributions.keys())
|
||
n_groups = len(labels)
|
||
palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined]
|
||
|
||
# All quality strings that appear in at least one group, sorted by
|
||
# descending total frequency across all groups.
|
||
all_quals: list[str] = []
|
||
_seen: set[str] = set()
|
||
for dist in distributions.values():
|
||
for q in dist["quality_counter"]:
|
||
if q not in _seen:
|
||
all_quals.append(q)
|
||
_seen.add(q)
|
||
all_quals.sort(
|
||
key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values())
|
||
)
|
||
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 11))
|
||
if title:
|
||
fig.suptitle(title, fontsize=13)
|
||
|
||
# ---------- [0,0] Quality distribution ----------
|
||
ax = axes[0, 0]
|
||
n_q = len(all_quals)
|
||
bar_h = 0.8 / n_groups
|
||
y = np.arange(n_q)
|
||
for gi, (label, dist) in enumerate(distributions.items()):
|
||
total = sum(dist["quality_counter"].values()) or 1
|
||
vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals]
|
||
ax.barh(
|
||
y + (gi - n_groups / 2 + 0.5) * bar_h, vals,
|
||
height=bar_h, label=label, color=palette[gi], alpha=0.85,
|
||
)
|
||
ax.set_yticks(y)
|
||
ax.set_yticklabels(all_quals, fontsize=8)
|
||
ax.invert_yaxis()
|
||
ax.set_xlabel("% of chord events")
|
||
ax.set_title("Chord quality distribution")
|
||
ax.legend(fontsize=8)
|
||
ax.grid(axis="x", alpha=0.3)
|
||
|
||
# ---------- [0,1] Root motion intervals ----------
|
||
ax = axes[0, 1]
|
||
bar_w = 0.8 / n_groups
|
||
x = np.arange(12)
|
||
for gi, (label, dist) in enumerate(distributions.items()):
|
||
total = sum(dist["interval_counter"].values()) or 1
|
||
vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)]
|
||
ax.bar(
|
||
x + (gi - n_groups / 2 + 0.5) * bar_w, vals,
|
||
width=bar_w, label=label, color=palette[gi], alpha=0.85,
|
||
)
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9)
|
||
ax.set_ylabel("% of root motions")
|
||
ax.set_title("Root motion intervals")
|
||
ax.legend(fontsize=8)
|
||
ax.grid(axis="y", alpha=0.3)
|
||
|
||
# ---------- [1,0] Extension presence ----------
|
||
ax = axes[1, 0]
|
||
ext_keys = ["none", "has_ext"]
|
||
ext_xlabels = ["no extension", "has extension"]
|
||
x = np.arange(2)
|
||
for gi, (label, dist) in enumerate(distributions.items()):
|
||
total = sum(dist["ext_counts"].values()) or 1
|
||
vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys]
|
||
ax.bar(
|
||
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
|
||
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
|
||
)
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(ext_xlabels)
|
||
ax.set_ylabel("% of chord events")
|
||
ax.set_title("Extension presence")
|
||
ax.legend(fontsize=8)
|
||
ax.grid(axis="y", alpha=0.3)
|
||
|
||
# ---------- [1,1] Inversion frequency ----------
|
||
ax = axes[1, 1]
|
||
inv_keys = ["root_pos", "inverted"]
|
||
inv_xlabels = ["root position", "inverted"]
|
||
x = np.arange(2)
|
||
for gi, (label, dist) in enumerate(distributions.items()):
|
||
total = sum(dist["inv_counts"].values()) or 1
|
||
vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys]
|
||
ax.bar(
|
||
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
|
||
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
|
||
)
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(inv_xlabels)
|
||
ax.set_ylabel("% of chord events")
|
||
ax.set_title("Inversion frequency")
|
||
ax.legend(fontsize=8)
|
||
ax.grid(axis="y", alpha=0.3)
|
||
|
||
fig.tight_layout()
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
log.info("saved distribution plot -> %s", output_path)
|