feat: add src/evaluate.py and scripts/evaluate.py
Implements perplexity computation, chord distribution extraction (qualities, extensions, inversions, root-motion intervals), 4-panel comparison plot, and paired qualitative example generation for pretrained vs finetuned model. Results on user val set: pretrained PPL 3.58 → finetuned PPL 2.15 (−40 %). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,352 @@
|
||||
"""Evaluate and compare pretrained vs fine-tuned ChordTransformer.
|
||||
|
||||
Usage:
|
||||
python scripts/evaluate.py \\
|
||||
--pretrained checkpoints/pretrained.pt \\
|
||||
--finetuned checkpoints/finetuned.pt \\
|
||||
[--user-dir data/raw_user/H1K0] \\
|
||||
[--eval-dir data/processed/user/val] \\
|
||||
[--mcgill-dir data/processed/mcgill/train] \\
|
||||
[--output-dir output/eval] \\
|
||||
[--n-examples 3] [--mcgill-sample 500] \\
|
||||
[--temperature 1.0] [--top-p 0.9] \\
|
||||
[--tempo 90] [--seed 42] [--device auto]
|
||||
|
||||
Outputs:
|
||||
<output-dir>/perplexity.txt perplexity table (pretrained vs finetuned)
|
||||
<output-dir>/distributions.png 4-panel chord distribution comparison
|
||||
<output-dir>/examples/ n paired .chord and .mid files per model
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from src.dataset import ChordDataset
|
||||
from src.evaluate import (
|
||||
compare_distributions,
|
||||
compute_perplexity,
|
||||
extract_features,
|
||||
extract_features_from_tokens,
|
||||
plot_comparison,
|
||||
)
|
||||
from src.generate import generate_period
|
||||
from src.midi_export import chord_file_to_midi
|
||||
from src.model import ChordTransformer
|
||||
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, write_chord_file
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_model(path: Path, device: torch.device) -> ChordTransformer:
|
||||
ckpt = torch.load(path, map_location=device, weights_only=True)
|
||||
model = ChordTransformer(**ckpt["model_config"])
|
||||
model.load_state_dict(ckpt["model_state"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def _resolve_device(spec: str) -> torch.device:
|
||||
if spec == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(spec)
|
||||
|
||||
|
||||
def _make_loader(data_dir: Path, max_seq_len: int, batch_size: int) -> DataLoader:
|
||||
ds = ChordDataset(data_dir, max_length=max_seq_len)
|
||||
return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
||||
|
||||
|
||||
def _max_seq_len(model: ChordTransformer) -> int:
|
||||
return model.max_seq_len
|
||||
|
||||
|
||||
def _generate_features(
|
||||
model: ChordTransformer,
|
||||
chord_files: list[Path],
|
||||
device: torch.device,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
seed: int,
|
||||
) -> list[dict]:
|
||||
"""Generate one period per chord file (matching its metadata) and extract features."""
|
||||
features = []
|
||||
model.eval()
|
||||
for i, path in enumerate(chord_files):
|
||||
try:
|
||||
ref = parse_chord_file(path)
|
||||
except Exception as exc:
|
||||
log.warning("skipping %s: %s", path.name, exc)
|
||||
continue
|
||||
mode = ref.key.split("_")[-1]
|
||||
try:
|
||||
period = generate_period(
|
||||
model=model,
|
||||
mode=mode,
|
||||
time=ref.time,
|
||||
subdivision=ref.subdivision,
|
||||
style=ref.style,
|
||||
function=ref.function,
|
||||
key=ref.key,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
seed=seed + i,
|
||||
)
|
||||
features.append(extract_features(period))
|
||||
except Exception as exc:
|
||||
log.warning("generation failed for %s: %s", path.name, exc)
|
||||
return features
|
||||
|
||||
|
||||
def _load_pt_features(pt_files: list[Path]) -> list[dict]:
|
||||
"""Load token sequences from .pt files and extract features."""
|
||||
features = []
|
||||
for path in pt_files:
|
||||
try:
|
||||
data = torch.load(path, weights_only=True)
|
||||
tokens = data["tokens"].tolist()
|
||||
features.append(extract_features_from_tokens(tokens))
|
||||
except Exception as exc:
|
||||
log.warning("failed to load %s: %s", path.name, exc)
|
||||
return features
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
ap.add_argument("--pretrained", type=Path, required=True,
|
||||
help="Pre-trained checkpoint (.pt).")
|
||||
ap.add_argument("--finetuned", type=Path, required=True,
|
||||
help="Fine-tuned checkpoint (.pt).")
|
||||
ap.add_argument("--user-dir", type=Path, default=Path("data/raw_user/H1K0"),
|
||||
help="Directory with user .chord files (default: data/raw_user/H1K0).")
|
||||
ap.add_argument("--eval-dir", type=Path, default=None,
|
||||
help="Processed .pt files for perplexity eval. "
|
||||
"Defaults to data/processed/user/holdout if non-empty, "
|
||||
"otherwise data/processed/user/val.")
|
||||
ap.add_argument("--mcgill-dir", type=Path,
|
||||
default=Path("data/processed/mcgill/train"),
|
||||
help="Processed McGill .pt files for distribution stats "
|
||||
"(default: data/processed/mcgill/train).")
|
||||
ap.add_argument("--output-dir", type=Path, default=Path("output/eval"),
|
||||
help="Output directory (default: output/eval).")
|
||||
ap.add_argument("--n-examples", type=int, default=3,
|
||||
help="Number of paired qualitative examples to generate (default: 3).")
|
||||
ap.add_argument("--mcgill-sample", type=int, default=500,
|
||||
help="Max McGill files for distribution stats (default: 500).")
|
||||
ap.add_argument("--temperature", type=float, default=1.0,
|
||||
help="Sampling temperature (default: 1.0).")
|
||||
ap.add_argument("--top-p", type=float, default=0.9,
|
||||
help="Nucleus sampling cutoff (default: 0.9).")
|
||||
ap.add_argument("--tempo", type=int, default=90,
|
||||
help="MIDI tempo for example files in BPM (default: 90).")
|
||||
ap.add_argument("--seed", type=int, default=42,
|
||||
help="Base random seed (default: 42).")
|
||||
ap.add_argument("--device", default="auto",
|
||||
help="Compute device: cpu, cuda, or auto (default: auto).")
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
# Validate inputs
|
||||
for flag, path in [("--pretrained", args.pretrained), ("--finetuned", args.finetuned)]:
|
||||
if not path.exists():
|
||||
print(f"ERROR: {flag} not found: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
device = _resolve_device(args.device)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Load models
|
||||
# ------------------------------------------------------------------
|
||||
print(f"[evaluate] loading pretrained <- {args.pretrained}")
|
||||
pretrained = _load_model(args.pretrained, device)
|
||||
print(f"[evaluate] loading finetuned <- {args.finetuned}")
|
||||
finetuned = _load_model(args.finetuned, device)
|
||||
max_seq = _max_seq_len(pretrained)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Perplexity
|
||||
# ------------------------------------------------------------------
|
||||
holdout_dir = Path("data/processed/user/holdout")
|
||||
val_dir = Path("data/processed/user/val")
|
||||
|
||||
if args.eval_dir is not None:
|
||||
eval_dir = args.eval_dir
|
||||
eval_label = str(eval_dir)
|
||||
elif holdout_dir.exists() and any(holdout_dir.glob("*.pt")):
|
||||
eval_dir = holdout_dir
|
||||
eval_label = str(holdout_dir)
|
||||
elif val_dir.exists() and any(val_dir.glob("*.pt")):
|
||||
eval_dir = val_dir
|
||||
eval_label = f"{val_dir} [fallback — holdout is empty]"
|
||||
print(f"[evaluate] WARNING: holdout is empty, using val set for perplexity")
|
||||
else:
|
||||
print("[evaluate] WARNING: no eval data found — skipping perplexity")
|
||||
eval_dir = None
|
||||
eval_label = "N/A"
|
||||
|
||||
ppl_pretrained: float | None = None
|
||||
ppl_finetuned: float | None = None
|
||||
|
||||
if eval_dir is not None:
|
||||
n_eval = len(list(eval_dir.glob("*.pt")))
|
||||
print(f"[evaluate] computing perplexity on {n_eval} files in {eval_dir} ...")
|
||||
loader = _make_loader(eval_dir, max_seq, batch_size=8)
|
||||
ppl_pretrained = compute_perplexity(pretrained, loader, device)
|
||||
ppl_finetuned = compute_perplexity(finetuned, loader, device)
|
||||
print(f"[evaluate] pretrained PPL = {ppl_pretrained:.2f}")
|
||||
print(f"[evaluate] finetuned PPL = {ppl_finetuned:.2f}")
|
||||
|
||||
# Save perplexity report
|
||||
ppl_path = args.output_dir / "perplexity.txt"
|
||||
with open(ppl_path, "w", encoding="utf-8") as fh:
|
||||
fh.write("=" * 52 + "\n")
|
||||
fh.write(" PERPLEXITY EVALUATION\n")
|
||||
fh.write("=" * 52 + "\n")
|
||||
fh.write(f" Eval set : {eval_label}\n\n")
|
||||
if ppl_pretrained is not None and ppl_finetuned is not None:
|
||||
improvement = (ppl_pretrained - ppl_finetuned) / ppl_pretrained * 100
|
||||
fh.write(f" pretrained PPL = {ppl_pretrained:8.2f}\n")
|
||||
fh.write(f" finetuned PPL = {ppl_finetuned:8.2f}\n\n")
|
||||
fh.write(f" improvement = {improvement:+.1f}%\n")
|
||||
else:
|
||||
fh.write(" (no eval data available)\n")
|
||||
fh.write("=" * 52 + "\n")
|
||||
print(f"[evaluate] saved -> {ppl_path}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Distribution stats
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1. User corpus (ground truth)
|
||||
user_files = sorted(args.user_dir.glob("*.chord")) if args.user_dir.exists() else []
|
||||
if not user_files:
|
||||
print(f"[evaluate] WARNING: no .chord files found in {args.user_dir}")
|
||||
user_feats = [extract_features(parse_chord_file(p)) for p in user_files]
|
||||
print(f"[evaluate] user corpus: {len(user_feats)} periods")
|
||||
|
||||
# 2. McGill sample (optional)
|
||||
mcgill_feats: list[dict] = []
|
||||
if args.mcgill_dir.exists():
|
||||
mcgill_files = sorted(args.mcgill_dir.glob("*.pt"))
|
||||
if len(mcgill_files) > args.mcgill_sample:
|
||||
rng = random.Random(args.seed)
|
||||
mcgill_files = rng.sample(mcgill_files, args.mcgill_sample)
|
||||
mcgill_feats = _load_pt_features(mcgill_files)
|
||||
print(f"[evaluate] McGill sample: {len(mcgill_feats)} periods")
|
||||
else:
|
||||
print(f"[evaluate] McGill dir not found ({args.mcgill_dir}) — skipping")
|
||||
|
||||
# 3. Generated samples (one per user file, matching conditions)
|
||||
print(f"[evaluate] generating {len(user_files)} samples from pretrained ...")
|
||||
pre_feats = _generate_features(
|
||||
pretrained, user_files, device, args.temperature, args.top_p, args.seed
|
||||
)
|
||||
print(f"[evaluate] generating {len(user_files)} samples from finetuned ...")
|
||||
ft_feats = _generate_features(
|
||||
finetuned, user_files, device, args.temperature, args.top_p, args.seed
|
||||
)
|
||||
|
||||
# Build named groups for comparison
|
||||
named_groups: dict[str, list[dict]] = {}
|
||||
if mcgill_feats:
|
||||
named_groups["McGill (pretrain corpus)"] = mcgill_feats
|
||||
named_groups["user corpus"] = user_feats
|
||||
named_groups["pretrained output"] = pre_feats
|
||||
named_groups["finetuned output"] = ft_feats
|
||||
|
||||
distributions = compare_distributions(named_groups)
|
||||
|
||||
dist_path = args.output_dir / "distributions.png"
|
||||
plot_comparison(
|
||||
distributions,
|
||||
dist_path,
|
||||
title="Chord distribution: corpus vs model outputs",
|
||||
)
|
||||
print(f"[evaluate] saved -> {dist_path}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qualitative examples
|
||||
# ------------------------------------------------------------------
|
||||
examples_dir = args.output_dir / "examples"
|
||||
examples_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use val .chord files as conditions for qualitative examples.
|
||||
# Fall back to first n files from user corpus if val dir not found.
|
||||
val_chord_dir = args.user_dir # user .chord files are the reference
|
||||
example_files = user_files[: args.n_examples]
|
||||
if not example_files:
|
||||
print("[evaluate] no user files found — skipping qualitative examples")
|
||||
else:
|
||||
print(f"[evaluate] generating {len(example_files)} paired qualitative examples ...")
|
||||
for ex_i, ref_path in enumerate(example_files):
|
||||
ref = parse_chord_file(ref_path)
|
||||
mode = ref.key.split("_")[-1]
|
||||
stem = ref_path.stem
|
||||
ex_seed = args.seed + 1000 + ex_i # separate seed range from distribution samples
|
||||
|
||||
for model_name, model in [("pretrained", pretrained), ("finetuned", finetuned)]:
|
||||
try:
|
||||
gen = generate_period(
|
||||
model=model,
|
||||
mode=mode,
|
||||
time=ref.time,
|
||||
subdivision=ref.subdivision,
|
||||
style=ref.style,
|
||||
function=ref.function,
|
||||
key=ref.key,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
seed=ex_seed,
|
||||
)
|
||||
gen = replace(gen, title=f"Generated — {stem} ({model_name})")
|
||||
|
||||
chord_out = examples_dir / f"{model_name}_{ex_i + 1:02d}_{stem}.chord"
|
||||
write_chord_file(gen, chord_out)
|
||||
|
||||
midi_out = chord_out.with_suffix(".mid")
|
||||
write_chord_file(gen, chord_out) # ensure file exists for midi export
|
||||
chord_file_to_midi(chord_out, midi_out, tempo=args.tempo)
|
||||
|
||||
print(f"[evaluate] {chord_out.name}")
|
||||
print(f"[evaluate] {midi_out.name}")
|
||||
print(f"[evaluate] {len(gen.bars)} bars: "
|
||||
+ " ".join(" ".join(b) for b in gen.bars))
|
||||
print()
|
||||
except Exception as exc:
|
||||
log.warning("example %d %s failed: %s", ex_i + 1, model_name, exc)
|
||||
|
||||
print(f"\n[evaluate] done. Outputs in {args.output_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+355
@@ -0,0 +1,355 @@
|
||||
"""Evaluation utilities: perplexity, corpus stats, and distribution plots.
|
||||
|
||||
Public API:
|
||||
compute_perplexity(model, dataloader, device) -> float
|
||||
extract_features(period: ChordPeriod) -> dict
|
||||
extract_features_from_tokens(tokens: list[int]) -> dict
|
||||
compare_distributions(named_groups: dict[str, list[dict]]) -> dict
|
||||
plot_comparison(distributions: dict, output_path: Path, title: str) -> None
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.chord_parser import parse_chord_symbol
|
||||
from src.model import ChordTransformer
|
||||
from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||||
_ROOT_START: int = TOKEN_TO_ID["ROOT_C"]
|
||||
_ROOT_END: int = TOKEN_TO_ID["ROOT_B"]
|
||||
_QUAL_START: int = TOKEN_TO_ID["QUAL_maj"]
|
||||
_QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"]
|
||||
_EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"]
|
||||
_BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"]
|
||||
_BASS_START: int = TOKEN_TO_ID["BASS_root"]
|
||||
_BASS_END: int = TOKEN_TO_ID["BASS_B"]
|
||||
|
||||
_NOTE_SEMITONES: dict[str, int] = {
|
||||
"C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5,
|
||||
"F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11,
|
||||
}
|
||||
_INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"]
|
||||
|
||||
|
||||
def _qual_from_tok(tok: str) -> str:
|
||||
"""'QUAL_m_add9' → 'm(add9)', 'QUAL_maj7' → 'maj7', etc."""
|
||||
suffix = tok[5:] # strip "QUAL_"
|
||||
return "m(add9)" if suffix == "m_add9" else suffix
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Perplexity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_perplexity(
|
||||
model: ChordTransformer,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> float:
|
||||
"""Compute token-level perplexity of *model* on *dataloader*.
|
||||
|
||||
Args:
|
||||
model: ChordTransformer (set to eval mode internally).
|
||||
dataloader: Yields padded token tensors of shape [batch, seq_len].
|
||||
device: Compute device.
|
||||
|
||||
Returns:
|
||||
Perplexity (>= 1.0).
|
||||
"""
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum")
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
batch = batch.to(device)
|
||||
input_ids = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
attn_mask = (input_ids != _PAD_ID).long()
|
||||
logits = model(input_ids, attention_mask=attn_mask)
|
||||
loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
||||
total_loss += loss.item()
|
||||
total_tokens += int((targets != _PAD_ID).sum())
|
||||
|
||||
if total_tokens == 0:
|
||||
raise ValueError("dataloader produced no non-PAD tokens")
|
||||
return math.exp(min(total_loss / total_tokens, 20.0))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature extraction — from ChordPeriod
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def extract_features(period: ChordPeriod) -> dict[str, Any]:
|
||||
"""Extract harmonic statistics from one ChordPeriod.
|
||||
|
||||
Args:
|
||||
period: A parsed ChordPeriod (any key).
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
qualities list[str] canonical quality string per chord event
|
||||
has_extension list[bool] True when extension != 'none'
|
||||
is_inverted list[bool] True when bass != root
|
||||
root_intervals list[int] semitone intervals (0–11) between consecutive roots
|
||||
"""
|
||||
qualities: list[str] = []
|
||||
has_ext: list[bool] = []
|
||||
is_inv: list[bool] = []
|
||||
roots_semi: list[int] = []
|
||||
|
||||
for bar in period.bars:
|
||||
for pos in bar:
|
||||
if pos in (".", "NC", "?"):
|
||||
continue
|
||||
t = parse_chord_symbol(pos)
|
||||
qualities.append(t.quality)
|
||||
has_ext.append(t.extension != "none")
|
||||
is_inv.append(t.bass not in ("root", t.root))
|
||||
if t.root in _NOTE_SEMITONES:
|
||||
roots_semi.append(_NOTE_SEMITONES[t.root])
|
||||
|
||||
return {
|
||||
"qualities": qualities,
|
||||
"has_extension": has_ext,
|
||||
"is_inverted": is_inv,
|
||||
"root_intervals": [
|
||||
(roots_semi[i] - roots_semi[i - 1]) % 12
|
||||
for i in range(1, len(roots_semi))
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature extraction — from raw token IDs (for large processed corpora)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]:
|
||||
"""Extract harmonic statistics from a raw token ID sequence.
|
||||
|
||||
Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped.
|
||||
|
||||
Args:
|
||||
tokens: Token ID list as produced by tokenize_period.
|
||||
|
||||
Returns:
|
||||
Same structure as extract_features.
|
||||
"""
|
||||
qualities: list[str] = []
|
||||
has_ext: list[bool] = []
|
||||
is_inv: list[bool] = []
|
||||
roots_semi: list[int] = []
|
||||
|
||||
i = 0
|
||||
n = len(tokens)
|
||||
while i < n:
|
||||
tok = tokens[i]
|
||||
if _ROOT_START <= tok <= _ROOT_END and i + 3 < n:
|
||||
qual_tok = tokens[i + 1]
|
||||
ext_tok = tokens[i + 2]
|
||||
bass_tok = tokens[i + 3]
|
||||
if _QUAL_START <= qual_tok <= _QUAL_END:
|
||||
root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#"
|
||||
qualities.append(_qual_from_tok(VOCAB[qual_tok]))
|
||||
has_ext.append(ext_tok != _EXT_NONE_ID)
|
||||
is_inv.append(bass_tok != _BASS_ROOT_ID)
|
||||
if root_name in _NOTE_SEMITONES:
|
||||
roots_semi.append(_NOTE_SEMITONES[root_name])
|
||||
i += 4
|
||||
continue
|
||||
i += 1
|
||||
|
||||
return {
|
||||
"qualities": qualities,
|
||||
"has_extension": has_ext,
|
||||
"is_inverted": is_inv,
|
||||
"root_intervals": [
|
||||
(roots_semi[i] - roots_semi[i - 1]) % 12
|
||||
for i in range(1, len(roots_semi))
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Distribution comparison
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compare_distributions(
|
||||
named_groups: dict[str, list[dict[str, Any]]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Aggregate per-period feature lists into per-group distribution counters.
|
||||
|
||||
Args:
|
||||
named_groups: {label: [feature_dict, ...], ...}
|
||||
Each feature_dict is the output of extract_features or
|
||||
extract_features_from_tokens.
|
||||
|
||||
Returns:
|
||||
{label: {"quality_counter": Counter, "ext_counts": dict,
|
||||
"inv_counts": dict, "interval_counter": Counter}, ...}
|
||||
"""
|
||||
result: dict[str, dict[str, Any]] = {}
|
||||
for label, feats in named_groups.items():
|
||||
qc: Counter[str] = Counter()
|
||||
ext: dict[str, int] = {"none": 0, "has_ext": 0}
|
||||
inv: dict[str, int] = {"root_pos": 0, "inverted": 0}
|
||||
ivc: Counter[int] = Counter()
|
||||
for f in feats:
|
||||
for q in f["qualities"]:
|
||||
qc[q] += 1
|
||||
for h in f["has_extension"]:
|
||||
ext["has_ext" if h else "none"] += 1
|
||||
for is_i in f["is_inverted"]:
|
||||
inv["inverted" if is_i else "root_pos"] += 1
|
||||
for iv in f["root_intervals"]:
|
||||
ivc[iv] += 1
|
||||
result[label] = {
|
||||
"quality_counter": qc,
|
||||
"ext_counts": ext,
|
||||
"inv_counts": inv,
|
||||
"interval_counter": ivc,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plotting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def plot_comparison(
|
||||
distributions: dict[str, dict[str, Any]],
|
||||
output_path: Path,
|
||||
title: str = "",
|
||||
) -> None:
|
||||
"""Save a 2×2 distribution-comparison figure.
|
||||
|
||||
Panels:
|
||||
[0,0] Chord quality distribution (horizontal bars, normalised to %)
|
||||
[0,1] Root motion intervals (0–11 semitones)
|
||||
[1,0] Extension presence (no ext / has extension)
|
||||
[1,1] Inversion frequency (root position / inverted)
|
||||
|
||||
Args:
|
||||
distributions: Output of compare_distributions.
|
||||
output_path: Destination PNG path.
|
||||
title: Optional figure suptitle.
|
||||
"""
|
||||
labels = list(distributions.keys())
|
||||
n_groups = len(labels)
|
||||
palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined]
|
||||
|
||||
# All quality strings that appear in at least one group, sorted by
|
||||
# descending total frequency across all groups.
|
||||
all_quals: list[str] = []
|
||||
_seen: set[str] = set()
|
||||
for dist in distributions.values():
|
||||
for q in dist["quality_counter"]:
|
||||
if q not in _seen:
|
||||
all_quals.append(q)
|
||||
_seen.add(q)
|
||||
all_quals.sort(
|
||||
key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values())
|
||||
)
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 11))
|
||||
if title:
|
||||
fig.suptitle(title, fontsize=13)
|
||||
|
||||
# ---------- [0,0] Quality distribution ----------
|
||||
ax = axes[0, 0]
|
||||
n_q = len(all_quals)
|
||||
bar_h = 0.8 / n_groups
|
||||
y = np.arange(n_q)
|
||||
for gi, (label, dist) in enumerate(distributions.items()):
|
||||
total = sum(dist["quality_counter"].values()) or 1
|
||||
vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals]
|
||||
ax.barh(
|
||||
y + (gi - n_groups / 2 + 0.5) * bar_h, vals,
|
||||
height=bar_h, label=label, color=palette[gi], alpha=0.85,
|
||||
)
|
||||
ax.set_yticks(y)
|
||||
ax.set_yticklabels(all_quals, fontsize=8)
|
||||
ax.invert_yaxis()
|
||||
ax.set_xlabel("% of chord events")
|
||||
ax.set_title("Chord quality distribution")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="x", alpha=0.3)
|
||||
|
||||
# ---------- [0,1] Root motion intervals ----------
|
||||
ax = axes[0, 1]
|
||||
bar_w = 0.8 / n_groups
|
||||
x = np.arange(12)
|
||||
for gi, (label, dist) in enumerate(distributions.items()):
|
||||
total = sum(dist["interval_counter"].values()) or 1
|
||||
vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)]
|
||||
ax.bar(
|
||||
x + (gi - n_groups / 2 + 0.5) * bar_w, vals,
|
||||
width=bar_w, label=label, color=palette[gi], alpha=0.85,
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9)
|
||||
ax.set_ylabel("% of root motions")
|
||||
ax.set_title("Root motion intervals")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
|
||||
# ---------- [1,0] Extension presence ----------
|
||||
ax = axes[1, 0]
|
||||
ext_keys = ["none", "has_ext"]
|
||||
ext_xlabels = ["no extension", "has extension"]
|
||||
x = np.arange(2)
|
||||
for gi, (label, dist) in enumerate(distributions.items()):
|
||||
total = sum(dist["ext_counts"].values()) or 1
|
||||
vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys]
|
||||
ax.bar(
|
||||
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
|
||||
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(ext_xlabels)
|
||||
ax.set_ylabel("% of chord events")
|
||||
ax.set_title("Extension presence")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
|
||||
# ---------- [1,1] Inversion frequency ----------
|
||||
ax = axes[1, 1]
|
||||
inv_keys = ["root_pos", "inverted"]
|
||||
inv_xlabels = ["root position", "inverted"]
|
||||
x = np.arange(2)
|
||||
for gi, (label, dist) in enumerate(distributions.items()):
|
||||
total = sum(dist["inv_counts"].values()) or 1
|
||||
vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys]
|
||||
ax.bar(
|
||||
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
|
||||
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(inv_xlabels)
|
||||
ax.set_ylabel("% of chord events")
|
||||
ax.set_title("Inversion frequency")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
log.info("saved distribution plot -> %s", output_path)
|
||||
Reference in New Issue
Block a user