diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100644 index 0000000..ea2c516 --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,352 @@ +"""Evaluate and compare pretrained vs fine-tuned ChordTransformer. + +Usage: + python scripts/evaluate.py \\ + --pretrained checkpoints/pretrained.pt \\ + --finetuned checkpoints/finetuned.pt \\ + [--user-dir data/raw_user/H1K0] \\ + [--eval-dir data/processed/user/val] \\ + [--mcgill-dir data/processed/mcgill/train] \\ + [--output-dir output/eval] \\ + [--n-examples 3] [--mcgill-sample 500] \\ + [--temperature 1.0] [--top-p 0.9] \\ + [--tempo 90] [--seed 42] [--device auto] + +Outputs: + /perplexity.txt perplexity table (pretrained vs finetuned) + /distributions.png 4-panel chord distribution comparison + /examples/ n paired .chord and .mid files per model +""" + +from __future__ import annotations + +import argparse +import logging +import random +import sys +from dataclasses import replace +from pathlib import Path + +import torch +from torch.utils.data import DataLoader + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.dataset import ChordDataset +from src.evaluate import ( + compare_distributions, + compute_perplexity, + extract_features, + extract_features_from_tokens, + plot_comparison, +) +from src.generate import generate_period +from src.midi_export import chord_file_to_midi +from src.model import ChordTransformer +from src.tokenizer import TOKEN_TO_ID, parse_chord_file, write_chord_file + +log = logging.getLogger(__name__) + +_PAD_ID: int = TOKEN_TO_ID[""] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _load_model(path: Path, device: torch.device) -> ChordTransformer: + ckpt = torch.load(path, map_location=device, weights_only=True) + model = ChordTransformer(**ckpt["model_config"]) + model.load_state_dict(ckpt["model_state"]) + model.to(device) + model.eval() + return model + + +def _resolve_device(spec: str) -> torch.device: + if spec == "auto": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(spec) + + +def _make_loader(data_dir: Path, max_seq_len: int, batch_size: int) -> DataLoader: + ds = ChordDataset(data_dir, max_length=max_seq_len) + return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) + + +def _max_seq_len(model: ChordTransformer) -> int: + return model.max_seq_len + + +def _generate_features( + model: ChordTransformer, + chord_files: list[Path], + device: torch.device, + temperature: float, + top_p: float, + seed: int, +) -> list[dict]: + """Generate one period per chord file (matching its metadata) and extract features.""" + features = [] + model.eval() + for i, path in enumerate(chord_files): + try: + ref = parse_chord_file(path) + except Exception as exc: + log.warning("skipping %s: %s", path.name, exc) + continue + mode = ref.key.split("_")[-1] + try: + period = generate_period( + model=model, + mode=mode, + time=ref.time, + subdivision=ref.subdivision, + style=ref.style, + function=ref.function, + key=ref.key, + temperature=temperature, + top_p=top_p, + seed=seed + i, + ) + features.append(extract_features(period)) + except Exception as exc: + log.warning("generation failed for %s: %s", path.name, exc) + return features + + +def _load_pt_features(pt_files: list[Path]) -> list[dict]: + """Load token sequences from .pt files and extract features.""" + features = [] + for path in pt_files: + try: + data = torch.load(path, weights_only=True) + tokens = data["tokens"].tolist() + features.append(extract_features_from_tokens(tokens)) + except Exception as exc: + log.warning("failed to load %s: %s", path.name, exc) + return features + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main() -> None: + ap = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + ap.add_argument("--pretrained", type=Path, required=True, + help="Pre-trained checkpoint (.pt).") + ap.add_argument("--finetuned", type=Path, required=True, + help="Fine-tuned checkpoint (.pt).") + ap.add_argument("--user-dir", type=Path, default=Path("data/raw_user/H1K0"), + help="Directory with user .chord files (default: data/raw_user/H1K0).") + ap.add_argument("--eval-dir", type=Path, default=None, + help="Processed .pt files for perplexity eval. " + "Defaults to data/processed/user/holdout if non-empty, " + "otherwise data/processed/user/val.") + ap.add_argument("--mcgill-dir", type=Path, + default=Path("data/processed/mcgill/train"), + help="Processed McGill .pt files for distribution stats " + "(default: data/processed/mcgill/train).") + ap.add_argument("--output-dir", type=Path, default=Path("output/eval"), + help="Output directory (default: output/eval).") + ap.add_argument("--n-examples", type=int, default=3, + help="Number of paired qualitative examples to generate (default: 3).") + ap.add_argument("--mcgill-sample", type=int, default=500, + help="Max McGill files for distribution stats (default: 500).") + ap.add_argument("--temperature", type=float, default=1.0, + help="Sampling temperature (default: 1.0).") + ap.add_argument("--top-p", type=float, default=0.9, + help="Nucleus sampling cutoff (default: 0.9).") + ap.add_argument("--tempo", type=int, default=90, + help="MIDI tempo for example files in BPM (default: 90).") + ap.add_argument("--seed", type=int, default=42, + help="Base random seed (default: 42).") + ap.add_argument("--device", default="auto", + help="Compute device: cpu, cuda, or auto (default: auto).") + args = ap.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + + # Validate inputs + for flag, path in [("--pretrained", args.pretrained), ("--finetuned", args.finetuned)]: + if not path.exists(): + print(f"ERROR: {flag} not found: {path}", file=sys.stderr) + sys.exit(1) + + device = _resolve_device(args.device) + args.output_dir.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------ + # Load models + # ------------------------------------------------------------------ + print(f"[evaluate] loading pretrained <- {args.pretrained}") + pretrained = _load_model(args.pretrained, device) + print(f"[evaluate] loading finetuned <- {args.finetuned}") + finetuned = _load_model(args.finetuned, device) + max_seq = _max_seq_len(pretrained) + + # ------------------------------------------------------------------ + # Perplexity + # ------------------------------------------------------------------ + holdout_dir = Path("data/processed/user/holdout") + val_dir = Path("data/processed/user/val") + + if args.eval_dir is not None: + eval_dir = args.eval_dir + eval_label = str(eval_dir) + elif holdout_dir.exists() and any(holdout_dir.glob("*.pt")): + eval_dir = holdout_dir + eval_label = str(holdout_dir) + elif val_dir.exists() and any(val_dir.glob("*.pt")): + eval_dir = val_dir + eval_label = f"{val_dir} [fallback — holdout is empty]" + print(f"[evaluate] WARNING: holdout is empty, using val set for perplexity") + else: + print("[evaluate] WARNING: no eval data found — skipping perplexity") + eval_dir = None + eval_label = "N/A" + + ppl_pretrained: float | None = None + ppl_finetuned: float | None = None + + if eval_dir is not None: + n_eval = len(list(eval_dir.glob("*.pt"))) + print(f"[evaluate] computing perplexity on {n_eval} files in {eval_dir} ...") + loader = _make_loader(eval_dir, max_seq, batch_size=8) + ppl_pretrained = compute_perplexity(pretrained, loader, device) + ppl_finetuned = compute_perplexity(finetuned, loader, device) + print(f"[evaluate] pretrained PPL = {ppl_pretrained:.2f}") + print(f"[evaluate] finetuned PPL = {ppl_finetuned:.2f}") + + # Save perplexity report + ppl_path = args.output_dir / "perplexity.txt" + with open(ppl_path, "w", encoding="utf-8") as fh: + fh.write("=" * 52 + "\n") + fh.write(" PERPLEXITY EVALUATION\n") + fh.write("=" * 52 + "\n") + fh.write(f" Eval set : {eval_label}\n\n") + if ppl_pretrained is not None and ppl_finetuned is not None: + improvement = (ppl_pretrained - ppl_finetuned) / ppl_pretrained * 100 + fh.write(f" pretrained PPL = {ppl_pretrained:8.2f}\n") + fh.write(f" finetuned PPL = {ppl_finetuned:8.2f}\n\n") + fh.write(f" improvement = {improvement:+.1f}%\n") + else: + fh.write(" (no eval data available)\n") + fh.write("=" * 52 + "\n") + print(f"[evaluate] saved -> {ppl_path}") + + # ------------------------------------------------------------------ + # Distribution stats + # ------------------------------------------------------------------ + + # 1. User corpus (ground truth) + user_files = sorted(args.user_dir.glob("*.chord")) if args.user_dir.exists() else [] + if not user_files: + print(f"[evaluate] WARNING: no .chord files found in {args.user_dir}") + user_feats = [extract_features(parse_chord_file(p)) for p in user_files] + print(f"[evaluate] user corpus: {len(user_feats)} periods") + + # 2. McGill sample (optional) + mcgill_feats: list[dict] = [] + if args.mcgill_dir.exists(): + mcgill_files = sorted(args.mcgill_dir.glob("*.pt")) + if len(mcgill_files) > args.mcgill_sample: + rng = random.Random(args.seed) + mcgill_files = rng.sample(mcgill_files, args.mcgill_sample) + mcgill_feats = _load_pt_features(mcgill_files) + print(f"[evaluate] McGill sample: {len(mcgill_feats)} periods") + else: + print(f"[evaluate] McGill dir not found ({args.mcgill_dir}) — skipping") + + # 3. Generated samples (one per user file, matching conditions) + print(f"[evaluate] generating {len(user_files)} samples from pretrained ...") + pre_feats = _generate_features( + pretrained, user_files, device, args.temperature, args.top_p, args.seed + ) + print(f"[evaluate] generating {len(user_files)} samples from finetuned ...") + ft_feats = _generate_features( + finetuned, user_files, device, args.temperature, args.top_p, args.seed + ) + + # Build named groups for comparison + named_groups: dict[str, list[dict]] = {} + if mcgill_feats: + named_groups["McGill (pretrain corpus)"] = mcgill_feats + named_groups["user corpus"] = user_feats + named_groups["pretrained output"] = pre_feats + named_groups["finetuned output"] = ft_feats + + distributions = compare_distributions(named_groups) + + dist_path = args.output_dir / "distributions.png" + plot_comparison( + distributions, + dist_path, + title="Chord distribution: corpus vs model outputs", + ) + print(f"[evaluate] saved -> {dist_path}") + + # ------------------------------------------------------------------ + # Qualitative examples + # ------------------------------------------------------------------ + examples_dir = args.output_dir / "examples" + examples_dir.mkdir(parents=True, exist_ok=True) + + # Use val .chord files as conditions for qualitative examples. + # Fall back to first n files from user corpus if val dir not found. + val_chord_dir = args.user_dir # user .chord files are the reference + example_files = user_files[: args.n_examples] + if not example_files: + print("[evaluate] no user files found — skipping qualitative examples") + else: + print(f"[evaluate] generating {len(example_files)} paired qualitative examples ...") + for ex_i, ref_path in enumerate(example_files): + ref = parse_chord_file(ref_path) + mode = ref.key.split("_")[-1] + stem = ref_path.stem + ex_seed = args.seed + 1000 + ex_i # separate seed range from distribution samples + + for model_name, model in [("pretrained", pretrained), ("finetuned", finetuned)]: + try: + gen = generate_period( + model=model, + mode=mode, + time=ref.time, + subdivision=ref.subdivision, + style=ref.style, + function=ref.function, + key=ref.key, + temperature=args.temperature, + top_p=args.top_p, + seed=ex_seed, + ) + gen = replace(gen, title=f"Generated — {stem} ({model_name})") + + chord_out = examples_dir / f"{model_name}_{ex_i + 1:02d}_{stem}.chord" + write_chord_file(gen, chord_out) + + midi_out = chord_out.with_suffix(".mid") + write_chord_file(gen, chord_out) # ensure file exists for midi export + chord_file_to_midi(chord_out, midi_out, tempo=args.tempo) + + print(f"[evaluate] {chord_out.name}") + print(f"[evaluate] {midi_out.name}") + print(f"[evaluate] {len(gen.bars)} bars: " + + " ".join(" ".join(b) for b in gen.bars)) + print() + except Exception as exc: + log.warning("example %d %s failed: %s", ex_i + 1, model_name, exc) + + print(f"\n[evaluate] done. Outputs in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/src/evaluate.py b/src/evaluate.py new file mode 100644 index 0000000..e30a29b --- /dev/null +++ b/src/evaluate.py @@ -0,0 +1,355 @@ +"""Evaluation utilities: perplexity, corpus stats, and distribution plots. + +Public API: + compute_perplexity(model, dataloader, device) -> float + extract_features(period: ChordPeriod) -> dict + extract_features_from_tokens(tokens: list[int]) -> dict + compare_distributions(named_groups: dict[str, list[dict]]) -> dict + plot_comparison(distributions: dict, output_path: Path, title: str) -> None +""" + +from __future__ import annotations + +import logging +import math +from collections import Counter +from pathlib import Path +from typing import Any + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from src.chord_parser import parse_chord_symbol +from src.model import ChordTransformer +from src.tokenizer import TOKEN_TO_ID, VOCAB, ChordPeriod + +log = logging.getLogger(__name__) + +_PAD_ID: int = TOKEN_TO_ID[""] +_ROOT_START: int = TOKEN_TO_ID["ROOT_C"] +_ROOT_END: int = TOKEN_TO_ID["ROOT_B"] +_QUAL_START: int = TOKEN_TO_ID["QUAL_maj"] +_QUAL_END: int = TOKEN_TO_ID["QUAL_m_add9"] +_EXT_NONE_ID: int = TOKEN_TO_ID["EXT_none"] +_BASS_ROOT_ID: int = TOKEN_TO_ID["BASS_root"] +_BASS_START: int = TOKEN_TO_ID["BASS_root"] +_BASS_END: int = TOKEN_TO_ID["BASS_B"] + +_NOTE_SEMITONES: dict[str, int] = { + "C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5, + "F#": 6, "G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11, +} +_INTERVAL_LABELS = ["P1", "m2", "M2", "m3", "M3", "P4", "TT", "P5", "m6", "M6", "m7", "M7"] + + +def _qual_from_tok(tok: str) -> str: + """'QUAL_m_add9' → 'm(add9)', 'QUAL_maj7' → 'maj7', etc.""" + suffix = tok[5:] # strip "QUAL_" + return "m(add9)" if suffix == "m_add9" else suffix + + +# --------------------------------------------------------------------------- +# Perplexity +# --------------------------------------------------------------------------- + +def compute_perplexity( + model: ChordTransformer, + dataloader: DataLoader, + device: torch.device, +) -> float: + """Compute token-level perplexity of *model* on *dataloader*. + + Args: + model: ChordTransformer (set to eval mode internally). + dataloader: Yields padded token tensors of shape [batch, seq_len]. + device: Compute device. + + Returns: + Perplexity (>= 1.0). + """ + criterion = nn.CrossEntropyLoss(ignore_index=_PAD_ID, reduction="sum") + model.eval() + total_loss = 0.0 + total_tokens = 0 + + with torch.no_grad(): + for batch in dataloader: + batch = batch.to(device) + input_ids = batch[:, :-1] + targets = batch[:, 1:] + attn_mask = (input_ids != _PAD_ID).long() + logits = model(input_ids, attention_mask=attn_mask) + loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) + total_loss += loss.item() + total_tokens += int((targets != _PAD_ID).sum()) + + if total_tokens == 0: + raise ValueError("dataloader produced no non-PAD tokens") + return math.exp(min(total_loss / total_tokens, 20.0)) + + +# --------------------------------------------------------------------------- +# Feature extraction — from ChordPeriod +# --------------------------------------------------------------------------- + +def extract_features(period: ChordPeriod) -> dict[str, Any]: + """Extract harmonic statistics from one ChordPeriod. + + Args: + period: A parsed ChordPeriod (any key). + + Returns: + Dict with keys: + qualities list[str] canonical quality string per chord event + has_extension list[bool] True when extension != 'none' + is_inverted list[bool] True when bass != root + root_intervals list[int] semitone intervals (0–11) between consecutive roots + """ + qualities: list[str] = [] + has_ext: list[bool] = [] + is_inv: list[bool] = [] + roots_semi: list[int] = [] + + for bar in period.bars: + for pos in bar: + if pos in (".", "NC", "?"): + continue + t = parse_chord_symbol(pos) + qualities.append(t.quality) + has_ext.append(t.extension != "none") + is_inv.append(t.bass not in ("root", t.root)) + if t.root in _NOTE_SEMITONES: + roots_semi.append(_NOTE_SEMITONES[t.root]) + + return { + "qualities": qualities, + "has_extension": has_ext, + "is_inverted": is_inv, + "root_intervals": [ + (roots_semi[i] - roots_semi[i - 1]) % 12 + for i in range(1, len(roots_semi)) + ], + } + + +# --------------------------------------------------------------------------- +# Feature extraction — from raw token IDs (for large processed corpora) +# --------------------------------------------------------------------------- + +def extract_features_from_tokens(tokens: list[int]) -> dict[str, Any]: + """Extract harmonic statistics from a raw token ID sequence. + + Scans for ROOT → QUAL → EXT → BASS groups; all other tokens are skipped. + + Args: + tokens: Token ID list as produced by tokenize_period. + + Returns: + Same structure as extract_features. + """ + qualities: list[str] = [] + has_ext: list[bool] = [] + is_inv: list[bool] = [] + roots_semi: list[int] = [] + + i = 0 + n = len(tokens) + while i < n: + tok = tokens[i] + if _ROOT_START <= tok <= _ROOT_END and i + 3 < n: + qual_tok = tokens[i + 1] + ext_tok = tokens[i + 2] + bass_tok = tokens[i + 3] + if _QUAL_START <= qual_tok <= _QUAL_END: + root_name = VOCAB[tok][5:] # "ROOT_F#" → "F#" + qualities.append(_qual_from_tok(VOCAB[qual_tok])) + has_ext.append(ext_tok != _EXT_NONE_ID) + is_inv.append(bass_tok != _BASS_ROOT_ID) + if root_name in _NOTE_SEMITONES: + roots_semi.append(_NOTE_SEMITONES[root_name]) + i += 4 + continue + i += 1 + + return { + "qualities": qualities, + "has_extension": has_ext, + "is_inverted": is_inv, + "root_intervals": [ + (roots_semi[i] - roots_semi[i - 1]) % 12 + for i in range(1, len(roots_semi)) + ], + } + + +# --------------------------------------------------------------------------- +# Distribution comparison +# --------------------------------------------------------------------------- + +def compare_distributions( + named_groups: dict[str, list[dict[str, Any]]], +) -> dict[str, dict[str, Any]]: + """Aggregate per-period feature lists into per-group distribution counters. + + Args: + named_groups: {label: [feature_dict, ...], ...} + Each feature_dict is the output of extract_features or + extract_features_from_tokens. + + Returns: + {label: {"quality_counter": Counter, "ext_counts": dict, + "inv_counts": dict, "interval_counter": Counter}, ...} + """ + result: dict[str, dict[str, Any]] = {} + for label, feats in named_groups.items(): + qc: Counter[str] = Counter() + ext: dict[str, int] = {"none": 0, "has_ext": 0} + inv: dict[str, int] = {"root_pos": 0, "inverted": 0} + ivc: Counter[int] = Counter() + for f in feats: + for q in f["qualities"]: + qc[q] += 1 + for h in f["has_extension"]: + ext["has_ext" if h else "none"] += 1 + for is_i in f["is_inverted"]: + inv["inverted" if is_i else "root_pos"] += 1 + for iv in f["root_intervals"]: + ivc[iv] += 1 + result[label] = { + "quality_counter": qc, + "ext_counts": ext, + "inv_counts": inv, + "interval_counter": ivc, + } + return result + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_comparison( + distributions: dict[str, dict[str, Any]], + output_path: Path, + title: str = "", +) -> None: + """Save a 2×2 distribution-comparison figure. + + Panels: + [0,0] Chord quality distribution (horizontal bars, normalised to %) + [0,1] Root motion intervals (0–11 semitones) + [1,0] Extension presence (no ext / has extension) + [1,1] Inversion frequency (root position / inverted) + + Args: + distributions: Output of compare_distributions. + output_path: Destination PNG path. + title: Optional figure suptitle. + """ + labels = list(distributions.keys()) + n_groups = len(labels) + palette = [plt.cm.tab10.colors[i % 10] for i in range(n_groups)] # type: ignore[attr-defined] + + # All quality strings that appear in at least one group, sorted by + # descending total frequency across all groups. + all_quals: list[str] = [] + _seen: set[str] = set() + for dist in distributions.values(): + for q in dist["quality_counter"]: + if q not in _seen: + all_quals.append(q) + _seen.add(q) + all_quals.sort( + key=lambda q: -sum(d["quality_counter"].get(q, 0) for d in distributions.values()) + ) + + fig, axes = plt.subplots(2, 2, figsize=(14, 11)) + if title: + fig.suptitle(title, fontsize=13) + + # ---------- [0,0] Quality distribution ---------- + ax = axes[0, 0] + n_q = len(all_quals) + bar_h = 0.8 / n_groups + y = np.arange(n_q) + for gi, (label, dist) in enumerate(distributions.items()): + total = sum(dist["quality_counter"].values()) or 1 + vals = [dist["quality_counter"].get(q, 0) / total * 100 for q in all_quals] + ax.barh( + y + (gi - n_groups / 2 + 0.5) * bar_h, vals, + height=bar_h, label=label, color=palette[gi], alpha=0.85, + ) + ax.set_yticks(y) + ax.set_yticklabels(all_quals, fontsize=8) + ax.invert_yaxis() + ax.set_xlabel("% of chord events") + ax.set_title("Chord quality distribution") + ax.legend(fontsize=8) + ax.grid(axis="x", alpha=0.3) + + # ---------- [0,1] Root motion intervals ---------- + ax = axes[0, 1] + bar_w = 0.8 / n_groups + x = np.arange(12) + for gi, (label, dist) in enumerate(distributions.items()): + total = sum(dist["interval_counter"].values()) or 1 + vals = [dist["interval_counter"].get(iv, 0) / total * 100 for iv in range(12)] + ax.bar( + x + (gi - n_groups / 2 + 0.5) * bar_w, vals, + width=bar_w, label=label, color=palette[gi], alpha=0.85, + ) + ax.set_xticks(x) + ax.set_xticklabels(_INTERVAL_LABELS, rotation=40, ha="right", fontsize=9) + ax.set_ylabel("% of root motions") + ax.set_title("Root motion intervals") + ax.legend(fontsize=8) + ax.grid(axis="y", alpha=0.3) + + # ---------- [1,0] Extension presence ---------- + ax = axes[1, 0] + ext_keys = ["none", "has_ext"] + ext_xlabels = ["no extension", "has extension"] + x = np.arange(2) + for gi, (label, dist) in enumerate(distributions.items()): + total = sum(dist["ext_counts"].values()) or 1 + vals = [dist["ext_counts"].get(k, 0) / total * 100 for k in ext_keys] + ax.bar( + x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals, + width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85, + ) + ax.set_xticks(x) + ax.set_xticklabels(ext_xlabels) + ax.set_ylabel("% of chord events") + ax.set_title("Extension presence") + ax.legend(fontsize=8) + ax.grid(axis="y", alpha=0.3) + + # ---------- [1,1] Inversion frequency ---------- + ax = axes[1, 1] + inv_keys = ["root_pos", "inverted"] + inv_xlabels = ["root position", "inverted"] + x = np.arange(2) + for gi, (label, dist) in enumerate(distributions.items()): + total = sum(dist["inv_counts"].values()) or 1 + vals = [dist["inv_counts"].get(k, 0) / total * 100 for k in inv_keys] + ax.bar( + x + (gi - n_groups / 2 + 0.5) * (0.8 / n_groups), vals, + width=0.8 / n_groups, label=label, color=palette[gi], alpha=0.85, + ) + ax.set_xticks(x) + ax.set_xticklabels(inv_xlabels) + ax.set_ylabel("% of chord events") + ax.set_title("Inversion frequency") + ax.legend(fontsize=8) + ax.grid(axis="y", alpha=0.3) + + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + log.info("saved distribution plot -> %s", output_path)