feat: add src/evaluate.py and scripts/evaluate.py

Implements perplexity computation, chord distribution extraction (qualities,
extensions, inversions, root-motion intervals), 4-panel comparison plot, and
paired qualitative example generation for pretrained vs finetuned model.

Results on user val set: pretrained PPL 3.58 → finetuned PPL 2.15 (−40 %).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 14:57:49 +03:00
parent b30f4c188b
commit d09a08d553
2 changed files with 707 additions and 0 deletions
+352
View File
@@ -0,0 +1,352 @@
"""Evaluate and compare pretrained vs fine-tuned ChordTransformer.
Usage:
python scripts/evaluate.py \\
--pretrained checkpoints/pretrained.pt \\
--finetuned checkpoints/finetuned.pt \\
[--user-dir data/raw_user/H1K0] \\
[--eval-dir data/processed/user/val] \\
[--mcgill-dir data/processed/mcgill/train] \\
[--output-dir output/eval] \\
[--n-examples 3] [--mcgill-sample 500] \\
[--temperature 1.0] [--top-p 0.9] \\
[--tempo 90] [--seed 42] [--device auto]
Outputs:
<output-dir>/perplexity.txt perplexity table (pretrained vs finetuned)
<output-dir>/distributions.png 4-panel chord distribution comparison
<output-dir>/examples/ n paired .chord and .mid files per model
"""
from __future__ import annotations
import argparse
import logging
import random
import sys
from dataclasses import replace
from pathlib import Path
import torch
from torch.utils.data import DataLoader
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.dataset import ChordDataset
from src.evaluate import (
compare_distributions,
compute_perplexity,
extract_features,
extract_features_from_tokens,
plot_comparison,
)
from src.generate import generate_period
from src.midi_export import chord_file_to_midi
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, write_chord_file
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_model(path: Path, device: torch.device) -> ChordTransformer:
ckpt = torch.load(path, map_location=device, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()
return model
def _resolve_device(spec: str) -> torch.device:
if spec == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(spec)
def _make_loader(data_dir: Path, max_seq_len: int, batch_size: int) -> DataLoader:
ds = ChordDataset(data_dir, max_length=max_seq_len)
return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
def _max_seq_len(model: ChordTransformer) -> int:
return model.max_seq_len
def _generate_features(
model: ChordTransformer,
chord_files: list[Path],
device: torch.device,
temperature: float,
top_p: float,
seed: int,
) -> list[dict]:
"""Generate one period per chord file (matching its metadata) and extract features."""
features = []
model.eval()
for i, path in enumerate(chord_files):
try:
ref = parse_chord_file(path)
except Exception as exc:
log.warning("skipping %s: %s", path.name, exc)
continue
mode = ref.key.split("_")[-1]
try:
period = generate_period(
model=model,
mode=mode,
time=ref.time,
subdivision=ref.subdivision,
style=ref.style,
function=ref.function,
key=ref.key,
temperature=temperature,
top_p=top_p,
seed=seed + i,
)
features.append(extract_features(period))
except Exception as exc:
log.warning("generation failed for %s: %s", path.name, exc)
return features
def _load_pt_features(pt_files: list[Path]) -> list[dict]:
"""Load token sequences from .pt files and extract features."""
features = []
for path in pt_files:
try:
data = torch.load(path, weights_only=True)
tokens = data["tokens"].tolist()
features.append(extract_features_from_tokens(tokens))
except Exception as exc:
log.warning("failed to load %s: %s", path.name, exc)
return features
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
ap.add_argument("--pretrained", type=Path, required=True,
help="Pre-trained checkpoint (.pt).")
ap.add_argument("--finetuned", type=Path, required=True,
help="Fine-tuned checkpoint (.pt).")
ap.add_argument("--user-dir", type=Path, default=Path("data/raw_user/H1K0"),
help="Directory with user .chord files (default: data/raw_user/H1K0).")
ap.add_argument("--eval-dir", type=Path, default=None,
help="Processed .pt files for perplexity eval. "
"Defaults to data/processed/user/holdout if non-empty, "
"otherwise data/processed/user/val.")
ap.add_argument("--mcgill-dir", type=Path,
default=Path("data/processed/mcgill/train"),
help="Processed McGill .pt files for distribution stats "
"(default: data/processed/mcgill/train).")
ap.add_argument("--output-dir", type=Path, default=Path("output/eval"),
help="Output directory (default: output/eval).")
ap.add_argument("--n-examples", type=int, default=3,
help="Number of paired qualitative examples to generate (default: 3).")
ap.add_argument("--mcgill-sample", type=int, default=500,
help="Max McGill files for distribution stats (default: 500).")
ap.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature (default: 1.0).")
ap.add_argument("--top-p", type=float, default=0.9,
help="Nucleus sampling cutoff (default: 0.9).")
ap.add_argument("--tempo", type=int, default=90,
help="MIDI tempo for example files in BPM (default: 90).")
ap.add_argument("--seed", type=int, default=42,
help="Base random seed (default: 42).")
ap.add_argument("--device", default="auto",
help="Compute device: cpu, cuda, or auto (default: auto).")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
# Validate inputs
for flag, path in [("--pretrained", args.pretrained), ("--finetuned", args.finetuned)]:
if not path.exists():
print(f"ERROR: {flag} not found: {path}", file=sys.stderr)
sys.exit(1)
device = _resolve_device(args.device)
args.output_dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# Load models
# ------------------------------------------------------------------
print(f"[evaluate] loading pretrained <- {args.pretrained}")
pretrained = _load_model(args.pretrained, device)
print(f"[evaluate] loading finetuned <- {args.finetuned}")
finetuned = _load_model(args.finetuned, device)
max_seq = _max_seq_len(pretrained)
# ------------------------------------------------------------------
# Perplexity
# ------------------------------------------------------------------
holdout_dir = Path("data/processed/user/holdout")
val_dir = Path("data/processed/user/val")
if args.eval_dir is not None:
eval_dir = args.eval_dir
eval_label = str(eval_dir)
elif holdout_dir.exists() and any(holdout_dir.glob("*.pt")):
eval_dir = holdout_dir
eval_label = str(holdout_dir)
elif val_dir.exists() and any(val_dir.glob("*.pt")):
eval_dir = val_dir
eval_label = f"{val_dir} [fallback — holdout is empty]"
print(f"[evaluate] WARNING: holdout is empty, using val set for perplexity")
else:
print("[evaluate] WARNING: no eval data found — skipping perplexity")
eval_dir = None
eval_label = "N/A"
ppl_pretrained: float | None = None
ppl_finetuned: float | None = None
if eval_dir is not None:
n_eval = len(list(eval_dir.glob("*.pt")))
print(f"[evaluate] computing perplexity on {n_eval} files in {eval_dir} ...")
loader = _make_loader(eval_dir, max_seq, batch_size=8)
ppl_pretrained = compute_perplexity(pretrained, loader, device)
ppl_finetuned = compute_perplexity(finetuned, loader, device)
print(f"[evaluate] pretrained PPL = {ppl_pretrained:.2f}")
print(f"[evaluate] finetuned PPL = {ppl_finetuned:.2f}")
# Save perplexity report
ppl_path = args.output_dir / "perplexity.txt"
with open(ppl_path, "w", encoding="utf-8") as fh:
fh.write("=" * 52 + "\n")
fh.write(" PERPLEXITY EVALUATION\n")
fh.write("=" * 52 + "\n")
fh.write(f" Eval set : {eval_label}\n\n")
if ppl_pretrained is not None and ppl_finetuned is not None:
improvement = (ppl_pretrained - ppl_finetuned) / ppl_pretrained * 100
fh.write(f" pretrained PPL = {ppl_pretrained:8.2f}\n")
fh.write(f" finetuned PPL = {ppl_finetuned:8.2f}\n\n")
fh.write(f" improvement = {improvement:+.1f}%\n")
else:
fh.write(" (no eval data available)\n")
fh.write("=" * 52 + "\n")
print(f"[evaluate] saved -> {ppl_path}")
# ------------------------------------------------------------------
# Distribution stats
# ------------------------------------------------------------------
# 1. User corpus (ground truth)
user_files = sorted(args.user_dir.glob("*.chord")) if args.user_dir.exists() else []
if not user_files:
print(f"[evaluate] WARNING: no .chord files found in {args.user_dir}")
user_feats = [extract_features(parse_chord_file(p)) for p in user_files]
print(f"[evaluate] user corpus: {len(user_feats)} periods")
# 2. McGill sample (optional)
mcgill_feats: list[dict] = []
if args.mcgill_dir.exists():
mcgill_files = sorted(args.mcgill_dir.glob("*.pt"))
if len(mcgill_files) > args.mcgill_sample:
rng = random.Random(args.seed)
mcgill_files = rng.sample(mcgill_files, args.mcgill_sample)
mcgill_feats = _load_pt_features(mcgill_files)
print(f"[evaluate] McGill sample: {len(mcgill_feats)} periods")
else:
print(f"[evaluate] McGill dir not found ({args.mcgill_dir}) — skipping")
# 3. Generated samples (one per user file, matching conditions)
print(f"[evaluate] generating {len(user_files)} samples from pretrained ...")
pre_feats = _generate_features(
pretrained, user_files, device, args.temperature, args.top_p, args.seed
)
print(f"[evaluate] generating {len(user_files)} samples from finetuned ...")
ft_feats = _generate_features(
finetuned, user_files, device, args.temperature, args.top_p, args.seed
)
# Build named groups for comparison
named_groups: dict[str, list[dict]] = {}
if mcgill_feats:
named_groups["McGill (pretrain corpus)"] = mcgill_feats
named_groups["user corpus"] = user_feats
named_groups["pretrained output"] = pre_feats
named_groups["finetuned output"] = ft_feats
distributions = compare_distributions(named_groups)
dist_path = args.output_dir / "distributions.png"
plot_comparison(
distributions,
dist_path,
title="Chord distribution: corpus vs model outputs",
)
print(f"[evaluate] saved -> {dist_path}")
# ------------------------------------------------------------------
# Qualitative examples
# ------------------------------------------------------------------
examples_dir = args.output_dir / "examples"
examples_dir.mkdir(parents=True, exist_ok=True)
# Use val .chord files as conditions for qualitative examples.
# Fall back to first n files from user corpus if val dir not found.
val_chord_dir = args.user_dir # user .chord files are the reference
example_files = user_files[: args.n_examples]
if not example_files:
print("[evaluate] no user files found — skipping qualitative examples")
else:
print(f"[evaluate] generating {len(example_files)} paired qualitative examples ...")
for ex_i, ref_path in enumerate(example_files):
ref = parse_chord_file(ref_path)
mode = ref.key.split("_")[-1]
stem = ref_path.stem
ex_seed = args.seed + 1000 + ex_i # separate seed range from distribution samples
for model_name, model in [("pretrained", pretrained), ("finetuned", finetuned)]:
try:
gen = generate_period(
model=model,
mode=mode,
time=ref.time,
subdivision=ref.subdivision,
style=ref.style,
function=ref.function,
key=ref.key,
temperature=args.temperature,
top_p=args.top_p,
seed=ex_seed,
)
gen = replace(gen, title=f"Generated — {stem} ({model_name})")
chord_out = examples_dir / f"{model_name}_{ex_i + 1:02d}_{stem}.chord"
write_chord_file(gen, chord_out)
midi_out = chord_out.with_suffix(".mid")
write_chord_file(gen, chord_out) # ensure file exists for midi export
chord_file_to_midi(chord_out, midi_out, tempo=args.tempo)
print(f"[evaluate] {chord_out.name}")
print(f"[evaluate] {midi_out.name}")
print(f"[evaluate] {len(gen.bars)} bars: "
+ " ".join(" ".join(b) for b in gen.bars))
print()
except Exception as exc:
log.warning("example %d %s failed: %s", ex_i + 1, model_name, exc)
print(f"\n[evaluate] done. Outputs in {args.output_dir}/")
if __name__ == "__main__":
main()
+355
View File
@@ -0,0 +1,355 @@
"""Evaluation utilities: perplexity, corpus stats, and distribution plots.
Public API:
compute_perplexity(model, dataloader, device) -> float
extract_features(period: ChordPeriod) -> dict
extract_features_from_tokens(tokens: list[int]) -> dict
compare_distributions(named_groups: dict[str, list[dict]]) -> dict
plot_comparison(distributions: dict, output_path: Path, title: str) -> None
"""
from __future__ import annotations
import logging
import math
from collections import Counter
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.chord_parser import parse_chord_symbol
from src.model import ChordTransformer
from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
_ROOT_START: int = TOKEN_TO_ID["ROOT_C"]
_ROOT_END: int = TOKEN_TO_ID["ROOT_B"]
_QUAL_START: int = TOKEN_TO_ID["QUAL_maj"]
_QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"]
_EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"]
_BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"]
_BASS_START: int = TOKEN_TO_ID["BASS_root"]
_BASS_END: int = TOKEN_TO_ID["BASS_B"]
_NOTE_SEMITONES: dict[str, int] = {
"C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5,
"F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11,
}
_INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"]
def _qual_from_tok(tok: str) -> str:
"""'QUAL_m_add9''m(add9)', 'QUAL_maj7''maj7', etc."""
suffix = tok[5:] # strip "QUAL_"
return "m(add9)" if suffix == "m_add9" else suffix
# ---------------------------------------------------------------------------
# Perplexity
# ---------------------------------------------------------------------------
def compute_perplexity(
model: ChordTransformer,
dataloader: DataLoader,
device: torch.device,
) -> float:
"""Compute token-level perplexity of *model* on *dataloader*.
Args:
model: ChordTransformer (set to eval mode internally).
dataloader: Yields padded token tensors of shape [batch, seq_len].
device: Compute device.
Returns:
Perplexity (>= 1.0).
"""
criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum")
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
input_ids = batch[:, :-1]
targets = batch[:, 1:]
attn_mask = (input_ids != _PAD_ID).long()
logits = model(input_ids, attention_mask=attn_mask)
loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
total_loss += loss.item()
total_tokens += int((targets != _PAD_ID).sum())
if total_tokens == 0:
raise ValueError("dataloader produced no non-PAD tokens")
return math.exp(min(total_loss / total_tokens, 20.0))
# ---------------------------------------------------------------------------
# Feature extraction — from ChordPeriod
# ---------------------------------------------------------------------------
def extract_features(period: ChordPeriod) -> dict[str, Any]:
"""Extract harmonic statistics from one ChordPeriod.
Args:
period: A parsed ChordPeriod (any key).
Returns:
Dict with keys:
qualities list[str] canonical quality string per chord event
has_extension list[bool] True when extension != 'none'
is_inverted list[bool] True when bass != root
root_intervals list[int] semitone intervals (011) between consecutive roots
"""
qualities: list[str] = []
has_ext: list[bool] = []
is_inv: list[bool] = []
roots_semi: list[int] = []
for bar in period.bars:
for pos in bar:
if pos in (".", "NC", "?"):
continue
t = parse_chord_symbol(pos)
qualities.append(t.quality)
has_ext.append(t.extension != "none")
is_inv.append(t.bass not in ("root", t.root))
if t.root in _NOTE_SEMITONES:
roots_semi.append(_NOTE_SEMITONES[t.root])
return {
"qualities": qualities,
"has_extension": has_ext,
"is_inverted": is_inv,
"root_intervals": [
(roots_semi[i] - roots_semi[i - 1]) % 12
for i in range(1, len(roots_semi))
],
}
# ---------------------------------------------------------------------------
# Feature extraction — from raw token IDs (for large processed corpora)
# ---------------------------------------------------------------------------
def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]:
"""Extract harmonic statistics from a raw token ID sequence.
Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped.
Args:
tokens: Token ID list as produced by tokenize_period.
Returns:
Same structure as extract_features.
"""
qualities: list[str] = []
has_ext: list[bool] = []
is_inv: list[bool] = []
roots_semi: list[int] = []
i = 0
n = len(tokens)
while i < n:
tok = tokens[i]
if _ROOT_START <= tok <= _ROOT_END and i + 3 < n:
qual_tok = tokens[i + 1]
ext_tok = tokens[i + 2]
bass_tok = tokens[i + 3]
if _QUAL_START <= qual_tok <= _QUAL_END:
root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#"
qualities.append(_qual_from_tok(VOCAB[qual_tok]))
has_ext.append(ext_tok != _EXT_NONE_ID)
is_inv.append(bass_tok != _BASS_ROOT_ID)
if root_name in _NOTE_SEMITONES:
roots_semi.append(_NOTE_SEMITONES[root_name])
i += 4
continue
i += 1
return {
"qualities": qualities,
"has_extension": has_ext,
"is_inverted": is_inv,
"root_intervals": [
(roots_semi[i] - roots_semi[i - 1]) % 12
for i in range(1, len(roots_semi))
],
}
# ---------------------------------------------------------------------------
# Distribution comparison
# ---------------------------------------------------------------------------
def compare_distributions(
named_groups: dict[str, list[dict[str, Any]]],
) -> dict[str, dict[str, Any]]:
"""Aggregate per-period feature lists into per-group distribution counters.
Args:
named_groups: {label: [feature_dict, ...], ...}
Each feature_dict is the output of extract_features or
extract_features_from_tokens.
Returns:
{label: {"quality_counter": Counter, "ext_counts": dict,
"inv_counts": dict, "interval_counter": Counter}, ...}
"""
result: dict[str, dict[str, Any]] = {}
for label, feats in named_groups.items():
qc: Counter[str] = Counter()
ext: dict[str, int] = {"none": 0, "has_ext": 0}
inv: dict[str, int] = {"root_pos": 0, "inverted": 0}
ivc: Counter[int] = Counter()
for f in feats:
for q in f["qualities"]:
qc[q] += 1
for h in f["has_extension"]:
ext["has_ext" if h else "none"] += 1
for is_i in f["is_inverted"]:
inv["inverted" if is_i else "root_pos"] += 1
for iv in f["root_intervals"]:
ivc[iv] += 1
result[label] = {
"quality_counter": qc,
"ext_counts": ext,
"inv_counts": inv,
"interval_counter": ivc,
}
return result
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_comparison(
distributions: dict[str, dict[str, Any]],
output_path: Path,
title: str = "",
) -> None:
"""Save a 2×2 distribution-comparison figure.
Panels:
[0,0] Chord quality distribution (horizontal bars, normalised to %)
[0,1] Root motion intervals (011 semitones)
[1,0] Extension presence (no ext / has extension)
[1,1] Inversion frequency (root position / inverted)
Args:
distributions: Output of compare_distributions.
output_path: Destination PNG path.
title: Optional figure suptitle.
"""
labels = list(distributions.keys())
n_groups = len(labels)
palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined]
# All quality strings that appear in at least one group, sorted by
# descending total frequency across all groups.
all_quals: list[str] = []
_seen: set[str] = set()
for dist in distributions.values():
for q in dist["quality_counter"]:
if q not in _seen:
all_quals.append(q)
_seen.add(q)
all_quals.sort(
key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values())
)
fig, axes = plt.subplots(2, 2, figsize=(14, 11))
if title:
fig.suptitle(title, fontsize=13)
# ---------- [0,0] Quality distribution ----------
ax = axes[0, 0]
n_q = len(all_quals)
bar_h = 0.8 / n_groups
y = np.arange(n_q)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["quality_counter"].values()) or 1
vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals]
ax.barh(
y + (gi - n_groups / 2 + 0.5) * bar_h, vals,
height=bar_h, label=label, color=palette[gi], alpha=0.85,
)
ax.set_yticks(y)
ax.set_yticklabels(all_quals, fontsize=8)
ax.invert_yaxis()
ax.set_xlabel("% of chord events")
ax.set_title("Chord quality distribution")
ax.legend(fontsize=8)
ax.grid(axis="x", alpha=0.3)
# ---------- [0,1] Root motion intervals ----------
ax = axes[0, 1]
bar_w = 0.8 / n_groups
x = np.arange(12)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["interval_counter"].values()) or 1
vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * bar_w, vals,
width=bar_w, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9)
ax.set_ylabel("% of root motions")
ax.set_title("Root motion intervals")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
# ---------- [1,0] Extension presence ----------
ax = axes[1, 0]
ext_keys = ["none", "has_ext"]
ext_xlabels = ["no extension", "has extension"]
x = np.arange(2)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["ext_counts"].values()) or 1
vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(ext_xlabels)
ax.set_ylabel("% of chord events")
ax.set_title("Extension presence")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
# ---------- [1,1] Inversion frequency ----------
ax = axes[1, 1]
inv_keys = ["root_pos", "inverted"]
inv_xlabels = ["root position", "inverted"]
x = np.arange(2)
for gi, (label, dist) in enumerate(distributions.items()):
total = sum(dist["inv_counts"].values()) or 1
vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys]
ax.bar(
x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals,
width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85,
)
ax.set_xticks(x)
ax.set_xticklabels(inv_xlabels)
ax.set_ylabel("% of chord events")
ax.set_title("Inversion frequency")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
fig.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
log.info("saved distribution plot -> %s", output_path)