"""Evaluation utilities: perplexity, corpus stats, and distribution plots. Public API: compute_perplexity(model, dataloader, device) -> float extract_features(period: ChordPeriod) -> dict extract_features_from_tokens(tokens: list[int]) -> dict compare_distributions(named_groups: dict[str, list[dict]]) -> dict plot_comparison(distributions: dict, output_path: Path, title: str) -> None """ from __future__ import annotations import logging import math from collections import Counter from pathlib import Path from typing import Any import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from src.chord_parser import parse_chord_symbol from src.model import ChordTransformer from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod log = logging.getLogger(__name__) _PAD_ID: int = TOKEN_TO_ID[""] _ROOT_START: int = TOKEN_TO_ID["ROOT_C"] _ROOT_END: int = TOKEN_TO_ID["ROOT_B"] _QUAL_START: int = TOKEN_TO_ID["QUAL_maj"] _QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"] _EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"] _BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"] _BASS_START: int = TOKEN_TO_ID["BASS_root"] _BASS_END: int = TOKEN_TO_ID["BASS_B"] _NOTE_SEMITONES: dict[str, int] = { "C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5, "F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11, } _INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"] def _qual_from_tok(tok: str) -> str: """'QUAL_m_add9' → 'm(add9)', 'QUAL_maj7' → 'maj7', etc.""" suffix = tok[5:] # strip "QUAL_" return "m(add9)" if suffix == "m_add9" else suffix # --------------------------------------------------------------------------- # Perplexity # --------------------------------------------------------------------------- def compute_perplexity( model: ChordTransformer, dataloader: DataLoader, device: torch.device, ) -> float: """Compute token-level perplexity of *model* on *dataloader*. Args: model: ChordTransformer (set to eval mode internally). dataloader: Yields padded token tensors of shape [batch, seq_len]. device: Compute device. Returns: Perplexity (>= 1.0). """ criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum") model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for batch in dataloader: batch = batch.to(device) input_ids = batch[:, :-1] targets = batch[:, 1:] attn_mask = (input_ids != _PAD_ID).long() logits = model(input_ids, attention_mask=attn_mask) loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) total_loss += loss.item() total_tokens += int((targets != _PAD_ID).sum()) if total_tokens == 0: raise ValueError("dataloader produced no non-PAD tokens") return math.exp(min(total_loss / total_tokens, 20.0)) # --------------------------------------------------------------------------- # Feature extraction — from ChordPeriod # --------------------------------------------------------------------------- def extract_features(period: ChordPeriod) -> dict[str, Any]: """Extract harmonic statistics from one ChordPeriod. Args: period: A parsed ChordPeriod (any key). Returns: Dict with keys: qualities list[str] canonical quality string per chord event has_extension list[bool] True when extension != 'none' is_inverted list[bool] True when bass != root root_intervals list[int] semitone intervals (0–11) between consecutive roots """ qualities: list[str] = [] has_ext: list[bool] = [] is_inv: list[bool] = [] roots_semi: list[int] = [] for bar in period.bars: for pos in bar: if pos in (".", "NC", "?"): continue t = parse_chord_symbol(pos) qualities.append(t.quality) has_ext.append(t.extension != "none") is_inv.append(t.bass not in ("root", t.root)) if t.root in _NOTE_SEMITONES: roots_semi.append(_NOTE_SEMITONES[t.root]) return { "qualities": qualities, "has_extension": has_ext, "is_inverted": is_inv, "root_intervals": [ (roots_semi[i] - roots_semi[i - 1]) % 12 for i in range(1, len(roots_semi)) ], } # --------------------------------------------------------------------------- # Feature extraction — from raw token IDs (for large processed corpora) # --------------------------------------------------------------------------- def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]: """Extract harmonic statistics from a raw token ID sequence. Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped. Args: tokens: Token ID list as produced by tokenize_period. Returns: Same structure as extract_features. """ qualities: list[str] = [] has_ext: list[bool] = [] is_inv: list[bool] = [] roots_semi: list[int] = [] i = 0 n = len(tokens) while i < n: tok = tokens[i] if _ROOT_START <= tok <= _ROOT_END and i + 3 < n: qual_tok = tokens[i + 1] ext_tok = tokens[i + 2] bass_tok = tokens[i + 3] if _QUAL_START <= qual_tok <= _QUAL_END: root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#" qualities.append(_qual_from_tok(VOCAB[qual_tok])) has_ext.append(ext_tok != _EXT_NONE_ID) is_inv.append(bass_tok != _BASS_ROOT_ID) if root_name in _NOTE_SEMITONES: roots_semi.append(_NOTE_SEMITONES[root_name]) i += 4 continue i += 1 return { "qualities": qualities, "has_extension": has_ext, "is_inverted": is_inv, "root_intervals": [ (roots_semi[i] - roots_semi[i - 1]) % 12 for i in range(1, len(roots_semi)) ], } # --------------------------------------------------------------------------- # Distribution comparison # --------------------------------------------------------------------------- def compare_distributions( named_groups: dict[str, list[dict[str, Any]]], ) -> dict[str, dict[str, Any]]: """Aggregate per-period feature lists into per-group distribution counters. Args: named_groups: {label: [feature_dict, ...], ...} Each feature_dict is the output of extract_features or extract_features_from_tokens. Returns: {label: {"quality_counter": Counter, "ext_counts": dict, "inv_counts": dict, "interval_counter": Counter}, ...} """ result: dict[str, dict[str, Any]] = {} for label, feats in named_groups.items(): qc: Counter[str] = Counter() ext: dict[str, int] = {"none": 0, "has_ext": 0} inv: dict[str, int] = {"root_pos": 0, "inverted": 0} ivc: Counter[int] = Counter() for f in feats: for q in f["qualities"]: qc[q] += 1 for h in f["has_extension"]: ext["has_ext" if h else "none"] += 1 for is_i in f["is_inverted"]: inv["inverted" if is_i else "root_pos"] += 1 for iv in f["root_intervals"]: ivc[iv] += 1 result[label] = { "quality_counter": qc, "ext_counts": ext, "inv_counts": inv, "interval_counter": ivc, } return result # --------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------- def plot_comparison( distributions: dict[str, dict[str, Any]], output_path: Path, title: str = "", ) -> None: """Save a 2×2 distribution-comparison figure. Panels: [0,0] Chord quality distribution (horizontal bars, normalised to %) [0,1] Root motion intervals (0–11 semitones) [1,0] Extension presence (no ext / has extension) [1,1] Inversion frequency (root position / inverted) Args: distributions: Output of compare_distributions. output_path: Destination PNG path. title: Optional figure suptitle. """ labels = list(distributions.keys()) n_groups = len(labels) palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined] # All quality strings that appear in at least one group, sorted by # descending total frequency across all groups. all_quals: list[str] = [] _seen: set[str] = set() for dist in distributions.values(): for q in dist["quality_counter"]: if q not in _seen: all_quals.append(q) _seen.add(q) all_quals.sort( key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values()) ) fig, axes = plt.subplots(2, 2, figsize=(14, 11)) if title: fig.suptitle(title, fontsize=13) # ---------- [0,0] Quality distribution ---------- ax = axes[0, 0] n_q = len(all_quals) bar_h = 0.8 / n_groups y = np.arange(n_q) for gi, (label, dist) in enumerate(distributions.items()): total = sum(dist["quality_counter"].values()) or 1 vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals] ax.barh( y + (gi - n_groups / 2 + 0.5) * bar_h, vals, height=bar_h, label=label, color=palette[gi], alpha=0.85, ) ax.set_yticks(y) ax.set_yticklabels(all_quals, fontsize=8) ax.invert_yaxis() ax.set_xlabel("% of chord events") ax.set_title("Chord quality distribution") ax.legend(fontsize=8) ax.grid(axis="x", alpha=0.3) # ---------- [0,1] Root motion intervals ---------- ax = axes[0, 1] bar_w = 0.8 / n_groups x = np.arange(12) for gi, (label, dist) in enumerate(distributions.items()): total = sum(dist["interval_counter"].values()) or 1 vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)] ax.bar( x + (gi - n_groups / 2 + 0.5) * bar_w, vals, width=bar_w, label=label, color=palette[gi], alpha=0.85, ) ax.set_xticks(x) ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9) ax.set_ylabel("% of root motions") ax.set_title("Root motion intervals") ax.legend(fontsize=8) ax.grid(axis="y", alpha=0.3) # ---------- [1,0] Extension presence ---------- ax = axes[1, 0] ext_keys = ["none", "has_ext"] ext_xlabels = ["no extension", "has extension"] x = np.arange(2) for gi, (label, dist) in enumerate(distributions.items()): total = sum(dist["ext_counts"].values()) or 1 vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys] ax.bar( x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals, width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85, ) ax.set_xticks(x) ax.set_xticklabels(ext_xlabels) ax.set_ylabel("% of chord events") ax.set_title("Extension presence") ax.legend(fontsize=8) ax.grid(axis="y", alpha=0.3) # ---------- [1,1] Inversion frequency ---------- ax = axes[1, 1] inv_keys = ["root_pos", "inverted"] inv_xlabels = ["root position", "inverted"] x = np.arange(2) for gi, (label, dist) in enumerate(distributions.items()): total = sum(dist["inv_counts"].values()) or 1 vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys] ax.bar( x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals, width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85, ) ax.set_xticks(x) ax.set_xticklabels(inv_xlabels) ax.set_ylabel("% of chord events") ax.set_title("Inversion frequency") ax.legend(fontsize=8) ax.grid(axis="y", alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) log.info("saved distribution plot -> %s", output_path)