feat: add dataset, prepare_data pipeline and fix McGill converter

- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate
- scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout
  split, logs token length stats and style/function distributions
- src/external_converters/mcgill_to_chord.py: rewrite parser for real
  McGill v2 format (2-column annotation, each bar in its own pipe group,
  interval bass notation e.g. /5 and /b3)
- .gitignore: exclude data/processed/train, val, holdout subdirectories
- tests: 37 new tests for ChordDataset and converter (260 total, all pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-19 18:09:46 +03:00
parent ea32bf43b2
commit 84ba7b4743
7 changed files with 876 additions and 314 deletions
+3
View File
@@ -35,6 +35,9 @@ checkpoints/*.ckpt
# Processed data (reproducible from source) # Processed data (reproducible from source)
data/processed/*.pt data/processed/*.pt
data/processed/*.pkl data/processed/*.pkl
data/processed/train/
data/processed/val/
data/processed/holdout/
# External corpora (download separately; too large for git) # External corpora (download separately; too large for git)
data/raw_external/ data/raw_external/
+222
View File
@@ -0,0 +1,222 @@
"""Tokenize .chord files into .pt tensors for model training.
Usage:
python scripts/prepare_data.py --input-dir data/raw_user \\
--output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42]
Arguments:
--input-dir Root directory to search recursively for .chord files.
--output-dir Output directory. Subdirs train/, val/, holdout/ are created.
--split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1.
--seed Random seed for reproducible shuffling. Default: 42.
--log-level Logging verbosity. Default: INFO.
Files found under any "holdout" directory within --input-dir are written to
<output-dir>/holdout/ and never participate in the train/val split.
"""
from __future__ import annotations
import argparse
import logging
import random
import sys
from collections import Counter
from pathlib import Path
import torch
# Allow running as a script from the project root without installing the package.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_holdout(path: Path, input_dir: Path) -> bool:
"""True when the path lives under a 'holdout' sub-directory of input_dir."""
try:
rel = path.relative_to(input_dir)
except ValueError:
return False
return "holdout" in rel.parts
def _parse_ratios(s: str) -> tuple[float, float]:
parts = s.split("/")
if len(parts) != 2:
raise argparse.ArgumentTypeError(
f"split-ratios must be TRAIN/VAL format, got {s!r}"
)
try:
train_r, val_r = float(parts[0]), float(parts[1])
except ValueError:
raise argparse.ArgumentTypeError(
f"split-ratios values must be floats, got {s!r}"
)
total = train_r + val_r
if abs(total - 1.0) > 1e-6:
raise argparse.ArgumentTypeError(
f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}"
)
return train_r, val_r
def _process_file(path: Path) -> dict | None:
"""Parse and tokenize one .chord file. Returns None on any error."""
try:
period = parse_chord_file(path)
ids = tokenize_period(period)
tokens = torch.tensor(ids, dtype=torch.long)
meta = {
"title": period.title,
"key": period.key,
"style": period.style,
"function": period.function,
"time": period.time,
"source_file": str(path),
"n_tokens": len(ids),
}
return {"tokens": tokens, "meta": meta}
except Exception as exc:
log.warning("Skipping %s: %s", path, exc)
return None
def _save(data: dict, out_dir: Path, stem: str) -> None:
out_path = out_dir / f"{stem}.pt"
if out_path.exists():
log.warning("Overwriting existing output file: %s", out_path)
torch.save(data, out_path)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(
description="Tokenize .chord files into .pt tensors for model training.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--input-dir", required=True, type=Path,
help="Root directory containing .chord files (searched recursively).",
)
parser.add_argument(
"--output-dir", required=True, type=Path,
help="Output directory; train/, val/, holdout/ subdirs are created.",
)
parser.add_argument(
"--split-ratios", default="0.9/0.1",
help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.",
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed for reproducible shuffling. Default: 42.",
)
parser.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity. Default: INFO.",
)
args = parser.parse_args(argv)
logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s")
train_ratio, _val_ratio = _parse_ratios(args.split_ratios)
input_dir: Path = args.input_dir.resolve()
output_dir: Path = args.output_dir.resolve()
if not input_dir.exists():
log.error("Input directory does not exist: %s", input_dir)
sys.exit(1)
for subdir in ("train", "val", "holdout"):
(output_dir / subdir).mkdir(parents=True, exist_ok=True)
all_files = sorted(input_dir.rglob("*.chord"))
if not all_files:
log.warning("No .chord files found in %s", input_dir)
return
holdout_files = [f for f in all_files if _is_holdout(f, input_dir)]
regular_files = [f for f in all_files if not _is_holdout(f, input_dir)]
log.info(
"Found %d .chord files total (%d holdout, %d regular)",
len(all_files), len(holdout_files), len(regular_files),
)
# --- Holdout ---
holdout_records: list[dict] = []
for path in holdout_files:
data = _process_file(path)
if data is not None:
holdout_records.append(data)
_save(data, output_dir / "holdout", path.stem)
# --- Train / val split ---
random.seed(args.seed)
shuffled = list(regular_files)
random.shuffle(shuffled)
n_train = round(len(shuffled) * train_ratio)
train_paths = shuffled[:n_train]
val_paths = shuffled[n_train:]
train_records: list[dict] = []
for path in train_paths:
data = _process_file(path)
if data is not None:
train_records.append(data)
_save(data, output_dir / "train", path.stem)
val_records: list[dict] = []
for path in val_paths:
data = _process_file(path)
if data is not None:
val_records.append(data)
_save(data, output_dir / "val", path.stem)
# --- Stats ---
all_records = train_records + val_records + holdout_records
if not all_records:
log.warning("No files were successfully processed.")
return
token_lengths = [r["meta"]["n_tokens"] for r in all_records]
style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records)
function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records)
log.info("--- Processing summary ---")
log.info("Total processed: %d (train=%d, val=%d, holdout=%d)",
len(all_records), len(train_records), len(val_records), len(holdout_records))
skipped = len(all_files) - len(all_records)
if skipped:
log.warning("Skipped due to errors: %d", skipped)
log.info("Token lengths: mean=%.1f, max=%d",
sum(token_lengths) / len(token_lengths), max(token_lengths))
log.info("Style distribution:")
for style, count in sorted(style_counts.items()):
log.info(" %-16s %d", style, count)
log.info("Function distribution:")
for func, count in sorted(function_counts.items()):
log.info(" %-16s %d", func, count)
if __name__ == "__main__":
main()
+52
View File
@@ -0,0 +1,52 @@
"""PyTorch Dataset for tokenized .chord period files.
Public API:
ChordDataset — Dataset that loads pre-tokenized .pt files from a directory.
"""
from __future__ import annotations
import logging
from pathlib import Path
import torch
from torch.utils.data import Dataset
from src.tokenizer import TOKEN_TO_ID
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
class ChordDataset(Dataset):
"""Dataset over a directory of tokenized .pt period files.
Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``.
``__getitem__`` returns a fixed-length LongTensor: the token sequence is
truncated to *max_length* if too long, or right-padded with <PAD> if short.
Args:
data_dir: Directory containing .pt files (non-recursive).
max_length: Fixed output sequence length. Default 512.
"""
def __init__(self, data_dir: Path, max_length: int = 512) -> None:
self._max_length = max_length
self._files: list[Path] = sorted(Path(data_dir).glob("*.pt"))
if not self._files:
log.warning("ChordDataset: no .pt files found in %s", data_dir)
def __len__(self) -> int:
return len(self._files)
def __getitem__(self, idx: int) -> torch.Tensor:
data = torch.load(self._files[idx], weights_only=True)
tokens: torch.Tensor = data["tokens"]
length = tokens.shape[0]
if length >= self._max_length:
return tokens[: self._max_length]
pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype)
return torch.cat([tokens, pad])
+260 -238
View File
@@ -1,19 +1,24 @@
"""Convert McGill Billboard dataset (salami_chords.txt) to .chord files. """Convert McGill Billboard dataset (salami_chords.txt) to .chord files.
McGill Billboard format: McGill Billboard v2 format:
Each song is a subdirectory (e.g. 0003/, 0004/) containing salami_chords.txt. Each song is a subdirectory (e.g. 0003/, 0004/) containing salami_chords.txt.
The file has a header (# key: value) followed by tab-separated data lines: Header: # key: value lines (artist, title, metre, tonic).
<timestamp>\\t<section_label>\\t<chord> Data: tab-separated pairs <timestamp>\\t<annotation> where annotation is:
- "silence" / "end" — structural boundary (no chord data)
- "[Letter[, function,]] | bar1 | bar2 | ... |"
Each | ... | group is ONE BAR; space-separated tokens inside are
beat-level chord changes within that bar.
- "| ... | xN" — the bar(s) repeated N times
Section labels: 'Z' (silence/boundary), a letter (e.g. 'A', 'B,verse'), or '.' (continuation). Bass notes in Harte may be absolute (e.g. '/E') or scale-degree intervals
Chords: Harte notation (e.g. C:maj, Bb:min7, N for no chord, X for unknown). (e.g. '/5' = perfect fifth, '/b3' = minor third above root).
Public API: Public API:
convert_dataset(dataset_dir, output_dir) -- convert entire dataset directory convert_dataset(dataset_dir, output_dir) -- convert entire dataset
convert_song(song_dir, output_dir) -- convert one song directory convert_song(song_dir, output_dir) -- convert one song directory
CLI: CLI:
python -m src.external_converters.mcgill_to_chord <dataset_dir> [--out <output_dir>] python -m src.external_converters.mcgill_to_chord <dataset_dir> [--out ]
Example: Example:
python -m src.external_converters.mcgill_to_chord data/raw_external/mcgill/ \\ python -m src.external_converters.mcgill_to_chord data/raw_external/mcgill/ \\
@@ -25,14 +30,35 @@ from __future__ import annotations
import argparse import argparse
import logging import logging
import re import re
import statistics
from collections import Counter from collections import Counter
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Note tables
# ---------------------------------------------------------------------------
_CHROMATIC: list[str] = [
"C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"
]
_NOTE_INDEX: dict[str, int] = {n: i for i, n in enumerate(_CHROMATIC)}
_FLAT_TO_SHARP: dict[str, str] = {
"Cb": "B", "Db": "C#", "Eb": "D#", "Fb": "E",
"Gb": "F#", "Ab": "G#", "Bb": "A#",
}
_VALID_NOTES: frozenset[str] = frozenset(_CHROMATIC)
# Harte scale-degree intervals: semitones above root
_HARTE_INTERVAL: dict[str, int] = {
"1": 0, "b2": 1, "2": 2, "b3": 3, "3": 4, "4": 5,
"#4": 6, "b5": 6, "5": 7, "#5": 8, "b6": 8, "6": 9,
"b7": 10, "7": 11,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Harte quality → (our_quality, our_extension) # Harte quality → (our_quality, our_extension)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -63,12 +89,11 @@ _HARTE_QUALITY: dict[str, tuple[str, str]] = {
"13": ("7", "13"), "13": ("7", "13"),
"maj13": ("maj7", "13"), "maj13": ("maj7", "13"),
"min13": ("m7", "13"), "min13": ("m7", "13"),
"1": ("maj", "none"), # root only → major "1": ("maj", "none"),
"5": ("maj", "none"), # power chord → major (no 3rd) "5": ("maj", "none"),
"": ("maj", "none"), # bare root "": ("maj", "none"),
} }
# Parenthetical alterations in Harte (e.g. '7(b9)') → our extension token
_HARTE_PAREN_EXT: dict[str, str] = { _HARTE_PAREN_EXT: dict[str, str] = {
"b9": "b9", "b9": "b9",
"#9": "#9", "#9": "#9",
@@ -79,7 +104,6 @@ _HARTE_PAREN_EXT: dict[str, str] = {
"9": "9", "9": "9",
} }
# McGill Billboard section function strings → our function tokens
_FUNCTION_MAP: dict[str, str] = { _FUNCTION_MAP: dict[str, str] = {
"intro": "intro", "intro": "intro",
"verse": "verse", "verse": "verse",
@@ -92,7 +116,7 @@ _FUNCTION_MAP: dict[str, str] = {
"bridge": "bridge", "bridge": "bridge",
"outro": "outro", "outro": "outro",
"coda": "outro", "coda": "outro",
"end": "outro", "ending": "outro",
"interlude": "interlude", "interlude": "interlude",
"instrumental": "interlude", "instrumental": "interlude",
"solo": "interlude", "solo": "interlude",
@@ -101,18 +125,8 @@ _FUNCTION_MAP: dict[str, str] = {
"other": "other", "other": "other",
} }
_VALID_NOTES: frozenset[str] = frozenset(
{"C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"}
)
_FLAT_TO_SHARP: dict[str, str] = {
"Cb": "B", "Db": "C#", "Eb": "D#", "Fb": "E",
"Gb": "F#", "Ab": "G#", "Bb": "A#",
}
_VALID_TIMES: frozenset[str] = frozenset({"4/4", "3/4", "6/8", "2/4", "12/8"}) _VALID_TIMES: frozenset[str] = frozenset({"4/4", "3/4", "6/8", "2/4", "12/8"})
# Quality families used for mode inference
_MAJOR_QUALITIES: frozenset[str] = frozenset( _MAJOR_QUALITIES: frozenset[str] = frozenset(
{"maj", "maj7", "6", "add9", "aug", "sus2", "sus4", "7sus4", "aug7"} {"maj", "maj7", "6", "add9", "aug", "sus2", "sus4", "7sus4", "aug7"}
) )
@@ -120,25 +134,6 @@ _MINOR_QUALITIES: frozenset[str] = frozenset(
{"m", "m7", "mM7", "m6", "m7b5", "dim", "dim7"} {"m", "m7", "mM7", "m6", "m7b5", "dim", "dim7"}
) )
# ---------------------------------------------------------------------------
# Internal data structures
# ---------------------------------------------------------------------------
@dataclass
class _ChordEvent:
start: float
duration: float # seconds
harte: str # Harte chord string: 'N', 'X', 'C:maj', etc.
@dataclass
class _Section:
letter: str # section letter, e.g. 'A', 'B'
function: str # our function token, e.g. 'verse', 'chorus'
events: list[_ChordEvent] = field(default_factory=list)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Note / chord helpers # Note / chord helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -150,35 +145,49 @@ def _normalize_note(raw: str) -> Optional[str]:
return note if note in _VALID_NOTES else None return note if note in _VALID_NOTES else None
def _resolve_harte_bass(root: str, bass_str: str) -> Optional[str]:
"""Convert Harte bass notation to an absolute sharp note name.
Supports absolute notes ('E', 'Bb') and scale-degree intervals ('5', 'b3').
"""
bass_str = bass_str.strip()
if not bass_str:
return None
# Absolute note: starts with AG
if bass_str[0] in "ABCDEFG":
if len(bass_str) >= 2 and bass_str[1] in "#b":
raw, tail = bass_str[:2], bass_str[2:]
else:
raw, tail = bass_str[:1], bass_str[1:]
if tail:
return None
return _normalize_note(raw)
# Scale-degree interval
interval = _HARTE_INTERVAL.get(bass_str)
if interval is None:
return None
root_idx = _NOTE_INDEX[root]
return _CHROMATIC[(root_idx + interval) % 12]
def _harte_to_chord_symbol(harte: str) -> Optional[str]: def _harte_to_chord_symbol(harte: str) -> Optional[str]:
"""Convert a Harte chord string to our .chord format symbol. """Convert a Harte chord string to our .chord symbol.
Args: Args:
harte: Harte notation string, e.g. 'C:maj', 'Bb:min7', 'E:hdim7/G#'. harte: Harte notation, e.g. 'C:maj', 'Bb:min7', 'F:maj/5', 'G:7(b9)'.
Returns: Returns:
Our chord symbol (e.g. 'Cmaj', 'A#m7', 'Em7b5/G#'), or None for Our chord symbol (e.g. 'Cmaj', 'A#m7', 'Fmaj/C'), or None for
N (no chord), X (unknown), or any unparseable input. N (no chord), X (unknown), or any unparseable input.
""" """
harte = harte.strip() harte = harte.strip()
if harte in ("N", "X", ""): if harte in ("N", "X", ""):
return None return None
# Extract slash bass note (rightmost '/') # Extract slash bass (rightmost '/')
bass_note = "root" bass_raw: Optional[str] = None
if "/" in harte: if "/" in harte:
main, bass_raw = harte.rsplit("/", 1) harte, bass_raw = harte.rsplit("/", 1)
if len(bass_raw) >= 2 and bass_raw[1] in "#b":
raw_b, tail = bass_raw[:2], bass_raw[2:]
else:
raw_b, tail = bass_raw[:1], bass_raw[1:]
if tail or not raw_b:
return None
bn = _normalize_note(raw_b)
if bn is None:
return None
bass_note = bn
harte = main
# Split root from quality on first ':' # Split root from quality on first ':'
if ":" in harte: if ":" in harte:
@@ -202,6 +211,14 @@ def _harte_to_chord_symbol(harte: str) -> Optional[str]:
if root is None: if root is None:
return None return None
# Resolve bass now that root is known
bass_note = "root"
if bass_raw is not None:
resolved = _resolve_harte_bass(root, bass_raw)
if resolved is None:
return None
bass_note = resolved
# Parse quality — handle parenthetical alterations like '7(b9)' # Parse quality — handle parenthetical alterations like '7(b9)'
m = re.match(r'^([^(]*)\(([^)]+)\)$', quality_str) m = re.match(r'^([^(]*)\(([^)]+)\)$', quality_str)
if m: if m:
@@ -231,17 +248,15 @@ def _harte_to_chord_symbol(harte: str) -> Optional[str]:
def _parse_salami_file( def _parse_salami_file(
path: Path, path: Path,
) -> tuple[dict[str, str], list[tuple[float, str, str]]]: ) -> tuple[dict[str, str], list[tuple[float, str]]]:
"""Parse a salami_chords.txt file. """Parse a salami_chords.txt file.
Returns: Returns:
(header, events) where header maps lowercase field names to values, (header, data_lines) where header maps lowercase field names to values
and events is a list of (timestamp, label, chord) triples. and data_lines is a list of (timestamp, annotation_string) pairs.
label may be 'Z', a section letter (possibly with ',function'), or '.'.
chord is in Harte notation or '' when the column is absent.
""" """
header: dict[str, str] = {} header: dict[str, str] = {}
events: list[tuple[float, str, str]] = [] data_lines: list[tuple[float, str]] = []
for raw in path.read_text(encoding="utf-8").splitlines(): for raw in path.read_text(encoding="utf-8").splitlines():
line = raw.strip() line = raw.strip()
@@ -253,126 +268,118 @@ def _parse_salami_file(
k, v = content.split(":", 1) k, v = content.split(":", 1)
header[k.strip().lower()] = v.strip() header[k.strip().lower()] = v.strip()
continue continue
parts = line.split("\t") parts = line.split("\t", 1)
if len(parts) < 2: if len(parts) < 2:
continue continue
try: try:
ts = float(parts[0]) ts = float(parts[0])
except ValueError: except ValueError:
continue continue
label = parts[1].strip() data_lines.append((ts, parts[1].strip()))
chord = parts[2].strip() if len(parts) > 2 else ""
events.append((ts, label, chord))
return header, events return header, data_lines
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Section extraction # Annotation line parsing
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _parse_section_label(label: str) -> tuple[str, str]: def _parse_annotation_line(
"""Parse 'A,verse' → (letter='A', function='verse').""" annotation: str,
if "," in label: ) -> tuple[Optional[str], Optional[str], list[str]]:
letter, func_raw = label.split(",", 1) """Parse one annotation string into (section_letter, function, bar_strings).
func = _FUNCTION_MAP.get(func_raw.strip().lower(), "other")
bar_strings is a list of bar content strings, one per bar.
Returns (None, None, []) for silence/end/empty/continuation-only lines.
"""
annotation = annotation.strip()
if not annotation or annotation.lower() in ("silence", "end"):
return None, None, []
if annotation.startswith("->"):
return None, None, []
section_letter: Optional[str] = None
function: Optional[str] = None
first_pipe = annotation.find("|")
if first_pipe == -1:
prefix = annotation
bar_section = ""
else: else:
letter = label prefix = annotation[:first_pipe]
func = "other" bar_section = annotation[first_pipe:]
return letter.strip(), func
# Parse optional section header before first '|'
if prefix.strip():
parts = [p.strip() for p in prefix.rstrip(",").split(",")]
if parts and len(parts[0]) == 1 and parts[0].isupper():
section_letter = parts[0]
if len(parts) > 1 and parts[1]:
function = _FUNCTION_MAP.get(parts[1].lower(), "other")
if not bar_section:
return section_letter, function, []
# Split on '|': odd-indexed parts are bar contents, last part is trailing
raw_parts = bar_section.split("|")
# raw_parts[0] is before first '|' (empty or whitespace)
# raw_parts[-1] is after last '|' (trailing annotation / xN)
trailing = raw_parts[-1].strip() if raw_parts else ""
intermediate = raw_parts[1:-1] # bar contents between pipes
bar_strings = [p.strip() for p in intermediate if p.strip()]
# Handle xN repeat: "x4" in trailing → repeat all bars N times
xN = re.match(r"x(\d+)\b", trailing)
if xN and bar_strings:
bar_strings = bar_strings * int(xN.group(1))
return section_letter, function, bar_strings
def _extract_sections( def _bar_str_to_positions(bar_content: str, n_positions: int) -> Optional[list[str]]:
events: list[tuple[float, str, str]], """Convert bar content string to a fixed-length position list.
) -> list[_Section]:
"""Group raw event triples into _Section objects with _ChordEvent lists."""
sections: list[_Section] = []
current: Optional[_Section] = None
timestamps = [e[0] for e in events]
for i, (ts, label, chord) in enumerate(events): Distributes space-separated chord elements across n_positions slots.
dur = timestamps[i + 1] - ts if i + 1 < len(timestamps) else 0.0 Returns None if any element is an unrecognized chord symbol.
if label in ("Z", ""):
current = None
continue
if label == ".":
if current is not None and chord and dur > 0:
current.events.append(_ChordEvent(ts, dur, chord))
continue
# New section starts here
letter, func = _parse_section_label(label)
current = _Section(letter=letter, function=func)
sections.append(current)
if chord and dur > 0:
current.events.append(_ChordEvent(ts, dur, chord))
return sections
# ---------------------------------------------------------------------------
# Bar quantization
# ---------------------------------------------------------------------------
def _estimate_bar_duration(durations: list[float]) -> float:
"""Estimate duration of one bar in seconds.
Uses the median of non-trivial chord durations as a proxy for one bar.
Clamped to [1.0, 5.0] s (covers ~48240 BPM in 4/4).
Falls back to 2.0 s when fewer than 3 samples.
""" """
valid = [d for d in durations if d > 0.5] # Filter out performance annotations: keep only chord-like tokens
if len(valid) < 3: raw_elements = bar_content.split()
return 2.0 elements = [e for e in raw_elements if _is_chord_element(e)]
return max(1.0, min(5.0, statistics.median(valid)))
positions: list[str] = ["."] * n_positions
n = len(elements)
if n == 0:
return positions
def _expected_positions(time: str, subdivision: int) -> int: for i, elem in enumerate(elements):
"""Number of positions per bar for the given time signature and subdivision.""" pos_idx = i * n_positions // n
num, denom = (int(x) for x in time.split("/")) if elem == ".":
return (num * subdivision) // denom continue # explicit hold — leave slot as "."
elif elem == "N":
if positions[pos_idx] == ".":
def _section_to_bars( positions[pos_idx] = "NC"
section: _Section, elif elem == "X":
bar_duration: float, if positions[pos_idx] == ".":
time: str, positions[pos_idx] = "?"
subdivision: int,
) -> Optional[list[list[str]]]:
"""Convert a section's chord events to a list of bars.
Returns None if any event contains an unrecognized Harte chord symbol;
the caller will skip the section and log a reason.
"""
positions_per_bar = _expected_positions(time, subdivision)
bars: list[list[str]] = []
for event in section.events:
if event.harte == "N":
first_pos = "NC"
elif event.harte == "X":
first_pos = "?"
else: else:
sym = _harte_to_chord_symbol(event.harte) sym = _harte_to_chord_symbol(elem)
if sym is None: if sym is None:
log.debug( log.debug("unrecognized Harte chord %r in bar %r", elem, bar_content)
"unrecognized Harte chord %r in section %s",
event.harte, section.letter,
)
return None return None
first_pos = sym if positions[pos_idx] == ".":
positions[pos_idx] = sym
n_bars = max(1, round(event.duration / bar_duration)) return positions
bars.append([first_pos] + ["."] * (positions_per_bar - 1))
for _ in range(n_bars - 1):
# Hold chord across additional bars
bars.append(["."] * positions_per_bar)
return bars
def _is_chord_element(elem: str) -> bool:
"""True if elem is a chord token, hold marker, or NC/unknown."""
if elem in (".", "N", "X"):
return True
# Chord: starts with a note letter
return bool(elem) and elem[0] in "ABCDEFG"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -380,45 +387,42 @@ def _section_to_bars(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _infer_mode(tonic: str, sections: list[_Section]) -> str: def _infer_mode(tonic: str, harte_chords: list[str]) -> str:
"""Determine 'major' or 'minor' from tonic chord quality distribution. """Determine 'major' or 'minor' from tonic chord quality distribution.
Counts occurrences of the tonic root in major-family vs minor-family Returns 'major' on a tie or when no data is available.
qualities across all sections. Returns 'major' on a tie or no data.
""" """
major_count = 0 major_count = 0
minor_count = 0 minor_count = 0
for section in sections: for harte in harte_chords:
for event in section.events: if not harte or harte in ("N", "X", "."):
if not event.harte or event.harte in ("N", "X"): continue
continue colon = harte.find(":")
# Extract root without a full Harte parse root_part = harte[:colon] if colon != -1 else harte
colon = event.harte.find(":") root_str = root_part.split("/")[0]
root_part = event.harte[:colon] if colon != -1 else event.harte if len(root_str) >= 2 and root_str[1] in "#b":
root_str = root_part.split("/")[0] raw_root = root_str[:2]
if len(root_str) >= 2 and root_str[1] in "#b": else:
raw_root = root_str[:2] raw_root = root_str[:1]
else: if not raw_root:
raw_root = root_str[:1] continue
if not raw_root: root = _normalize_note(raw_root)
continue if root != tonic:
root = _normalize_note(raw_root) continue
if root != tonic: quality_str = harte[colon + 1:] if colon != -1 else ""
continue slash_pos = quality_str.find("/")
# Extract quality if slash_pos != -1:
quality_str = event.harte[colon + 1:] if colon != -1 else "" quality_str = quality_str[:slash_pos]
if "/" in quality_str: base = re.sub(r"\([^)]*\)", "", quality_str).strip()
quality_str = quality_str[: quality_str.index("/")] result = _HARTE_QUALITY.get(base)
base = re.sub(r'\([^)]*\)', "", quality_str).strip() if result is None:
result = _HARTE_QUALITY.get(base) continue
if result is None: our_quality = result[0]
continue if our_quality in _MAJOR_QUALITIES:
our_quality = result[0] major_count += 1
if our_quality in _MAJOR_QUALITIES: elif our_quality in _MINOR_QUALITIES:
major_count += 1 minor_count += 1
elif our_quality in _MINOR_QUALITIES:
minor_count += 1
return "minor" if minor_count > major_count else "major" return "minor" if minor_count > major_count else "major"
@@ -441,6 +445,11 @@ def _parse_metre(metre: str) -> tuple[Optional[str], int]:
return None, 0 return None, 0
def _expected_positions(time: str, subdivision: int) -> int:
num, denom = (int(x) for x in time.split("/"))
return (num * subdivision) // denom
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# File writing # File writing
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -455,7 +464,6 @@ def _write_chord_file(
function: Optional[str], function: Optional[str],
bars: list[list[str]], bars: list[list[str]],
) -> None: ) -> None:
"""Write a harmonic period to a .chord file."""
lines = [ lines = [
f"# title: {title}", f"# title: {title}",
f"# key: {key}", f"# key: {key}",
@@ -463,12 +471,12 @@ def _write_chord_file(
f"# subdivision: {subdivision}", f"# subdivision: {subdivision}",
"# style: other", "# style: other",
] ]
if function: if function and function != "unspecified":
lines.append(f"# function: {function}") lines.append(f"# function: {function}")
lines.append("") # blank line before body lines.append("")
for i in range(0, len(bars), 4): for i in range(0, len(bars), 4):
chunk = bars[i : i + 4] chunk = bars[i: i + 4]
line = " ".join(f"| {' '.join(b)}" for b in chunk) + " |" line = " ".join(f"| {' '.join(b)}" for b in chunk) + " |"
lines.append(line) lines.append(line)
@@ -484,8 +492,8 @@ def convert_song(song_dir: Path, output_dir: Path) -> int:
"""Convert one McGill Billboard song directory to .chord files. """Convert one McGill Billboard song directory to .chord files.
Args: Args:
song_dir: Directory containing salami_chords.txt (e.g. 0003/). song_dir: Directory containing salami_chords.txt.
output_dir: Destination directory for .chord files (created if absent). output_dir: Destination directory for .chord files.
Returns: Returns:
Number of .chord files successfully written. Number of .chord files successfully written.
@@ -496,13 +504,12 @@ def convert_song(song_dir: Path, output_dir: Path) -> int:
return 0 return 0
try: try:
header, raw_events = _parse_salami_file(salami) header, data_lines = _parse_salami_file(salami)
except Exception as exc: except Exception as exc:
log.error("failed to parse %s: %s", salami, exc) log.error("failed to parse %s: %s", salami, exc)
return 0 return 0
song_id = song_dir.name song_id = song_dir.name
time_sig, subdivision = _parse_metre(header.get("metre", "4/4")) time_sig, subdivision = _parse_metre(header.get("metre", "4/4"))
if time_sig is None: if time_sig is None:
log.warning( log.warning(
@@ -513,57 +520,75 @@ def convert_song(song_dir: Path, output_dir: Path) -> int:
tonic_raw = header.get("tonic", "C").strip() tonic_raw = header.get("tonic", "C").strip()
tonic = _normalize_note(tonic_raw) or "C" tonic = _normalize_note(tonic_raw) or "C"
sections = _extract_sections(raw_events) # Collect all Harte tokens for mode inference
if not sections: all_harte: list[str] = []
log.warning("no sections found in %s", salami) for _, annotation in data_lines:
return 0 _, _, bar_groups = _parse_annotation_line(annotation)
for bg in bar_groups:
all_harte.extend(bg.split())
all_durations = [ mode = _infer_mode(tonic, all_harte)
e.duration
for s in sections
for e in s.events
if e.harte not in ("N", "X", "") and e.duration > 0.5
]
bar_duration = _estimate_bar_duration(all_durations)
mode = _infer_mode(tonic, sections)
key = f"{tonic}_{mode}" key = f"{tonic}_{mode}"
artist = header.get("artist", "unknown") artist = header.get("artist", "unknown")
song_title = header.get("title", "unknown") song_title = header.get("title", "unknown")
n_positions = _expected_positions(time_sig, subdivision)
# Group annotation lines into sections
sections: list[tuple[str, list[list[str]]]] = []
current_function = "unspecified"
current_bars: list[list[str]] = []
current_valid = True
for _, annotation in data_lines:
letter, func, bar_groups = _parse_annotation_line(annotation)
if letter is not None:
# New section boundary — save current section if non-empty
if current_bars and current_valid:
sections.append((current_function, current_bars))
current_bars = []
current_valid = True
current_function = func if func is not None else "unspecified"
if not current_valid:
continue
for bg in bar_groups:
positions = _bar_str_to_positions(bg, n_positions)
if positions is None:
current_valid = False
break
current_bars.append(positions)
# Save the final section
if current_bars and current_valid:
sections.append((current_function, current_bars))
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
n_saved = 0 n_saved = 0
skip_reasons: Counter[str] = Counter() skip_reasons: Counter[str] = Counter()
for idx, section in enumerate(sections): for idx, (func, bars) in enumerate(sections):
bars = _section_to_bars(section, bar_duration, time_sig, subdivision)
if bars is None:
skip_reasons["unrecognized_chord"] += 1
continue
n = len(bars) n = len(bars)
if n < 4: if n < 4:
log.debug( log.debug(
"section %s in %s: %d bar(s) < 4, skipping", "section %d in %s: %d bar(s) < 4, skipping", idx, song_id, n
section.letter, song_id, n,
) )
skip_reasons["too_short"] += 1 skip_reasons["too_short"] += 1
continue continue
if n > 16: if n > 16:
log.debug( log.debug(
"section %s in %s: %d bars > 16, skipping", "section %d in %s: %d bars > 16, skipping", idx, song_id, n
section.letter, song_id, n,
) )
skip_reasons["too_long"] += 1 skip_reasons["too_long"] += 1
continue continue
func = section.function
filename = f"mcgill_{song_id}_{idx:02d}_{func}.chord" filename = f"mcgill_{song_id}_{idx:02d}_{func}.chord"
out_path = output_dir / filename out_path = output_dir / filename
period_title = f"{artist} - {song_title} ({section.letter},{func})" period_title = f"{artist} - {song_title} ({func})"
_write_chord_file( _write_chord_file(
out_path, period_title, key, time_sig, subdivision, out_path, period_title, key, time_sig, subdivision, func, bars
func if func != "unspecified" else None, bars,
) )
n_saved += 1 n_saved += 1
log.debug("wrote %s", out_path.name) log.debug("wrote %s", out_path.name)
@@ -581,10 +606,6 @@ def convert_song(song_dir: Path, output_dir: Path) -> int:
def convert_dataset(dataset_dir: Path, output_dir: Path) -> tuple[int, int]: def convert_dataset(dataset_dir: Path, output_dir: Path) -> tuple[int, int]:
"""Convert all song directories in a McGill Billboard dataset. """Convert all song directories in a McGill Billboard dataset.
Args:
dataset_dir: Root directory containing per-song subdirectories.
output_dir: Destination directory for .chord files.
Returns: Returns:
(n_saved, n_empty) where n_empty counts songs that produced no output. (n_saved, n_empty) where n_empty counts songs that produced no output.
""" """
@@ -606,7 +627,7 @@ def convert_dataset(dataset_dir: Path, output_dir: Path) -> tuple[int, int]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# CLI entry point # CLI
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
@@ -615,7 +636,8 @@ if __name__ == "__main__":
epilog=( epilog=(
"Example:\n" "Example:\n"
" python -m src.external_converters.mcgill_to_chord " " python -m src.external_converters.mcgill_to_chord "
"data/raw_external/mcgill/ --out data/raw_external/mcgill_converted/" "data/raw_external/mcgill/billboard-2.0-salami_chords/ "
"--out data/raw_external/mcgill_chord/"
), ),
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
@@ -625,9 +647,9 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--out", type=Path, "--out", type=Path,
default=Path("data/raw_external/mcgill_converted"), default=Path("data/raw_external/mcgill_chord"),
metavar="output_dir", metavar="output_dir",
help="destination for .chord files (default: data/raw_external/mcgill_converted/)", help="destination for .chord files (default: data/raw_external/mcgill_chord/)",
) )
parser.add_argument( parser.add_argument(
"--log-level", default="INFO", "--log-level", default="INFO",
+4 -10
View File
@@ -3,13 +3,7 @@
# metre: 4/4 # metre: 4/4
# tonic: C # tonic: C
0.000000 Z 0.000000 silence
4.000000 A,verse C:maj 4.000000 A, verse, | C:maj | F:maj | G:7 | C:maj |
8.000000 . F:maj 20.000000 B, chorus, | F:maj | C:maj | G:7 | C:maj |
12.000000 . G:7 36.000000 silence
16.000000 . C:maj
20.000000 B,chorus F:maj
24.000000 . C:maj
28.000000 . G:7
32.000000 . C:maj
36.000000 Z
+176
View File
@@ -0,0 +1,176 @@
"""Tests for ChordDataset in src/dataset.py."""
from pathlib import Path
import torch
import pytest
from src.dataset import ChordDataset
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, tokenize_period
FIXTURES = Path(__file__).parent / "fixtures"
_PAD_ID = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _write_pt(tmp_path: Path, stem: str, n_tokens: int) -> Path:
"""Write a dummy .pt file with sequential token IDs."""
tokens = torch.arange(n_tokens, dtype=torch.long)
path = tmp_path / f"{stem}.pt"
torch.save({"tokens": tokens, "meta": {"style": "user", "function": "verse"}}, path)
return path
def _write_real_pt(tmp_path: Path, fixture_name: str) -> tuple[Path, int]:
"""Tokenize a real fixture and write its .pt file. Returns (path, n_tokens)."""
period = parse_chord_file(FIXTURES / fixture_name)
ids = tokenize_period(period)
tokens = torch.tensor(ids, dtype=torch.long)
out = tmp_path / f"{fixture_name}.pt"
torch.save({"tokens": tokens, "meta": {"style": period.style}}, out)
return out, len(ids)
# ---------------------------------------------------------------------------
# Length and file discovery
# ---------------------------------------------------------------------------
class TestChordDatasetLength:
def test_empty_directory(self, tmp_path):
ds = ChordDataset(tmp_path)
assert len(ds) == 0
def test_single_file(self, tmp_path):
_write_pt(tmp_path, "a", 10)
assert len(ChordDataset(tmp_path)) == 1
def test_multiple_files(self, tmp_path):
for name in ("a", "b", "c"):
_write_pt(tmp_path, name, 10)
assert len(ChordDataset(tmp_path)) == 3
def test_non_pt_files_ignored(self, tmp_path):
_write_pt(tmp_path, "a", 10)
(tmp_path / "notes.txt").write_text("ignored")
(tmp_path / "model.pth").write_text("ignored")
assert len(ChordDataset(tmp_path)) == 1
# ---------------------------------------------------------------------------
# Output shape
# ---------------------------------------------------------------------------
class TestChordDatasetShape:
def test_returns_tensor(self, tmp_path):
_write_pt(tmp_path, "a", 50)
item = ChordDataset(tmp_path)[0]
assert isinstance(item, torch.Tensor)
def test_dtype_is_long(self, tmp_path):
_write_pt(tmp_path, "a", 50)
item = ChordDataset(tmp_path)[0]
assert item.dtype == torch.long
def test_shape_equals_max_length_when_shorter(self, tmp_path):
_write_pt(tmp_path, "a", 50)
assert ChordDataset(tmp_path, max_length=100)[0].shape[0] == 100
def test_shape_equals_max_length_when_longer(self, tmp_path):
_write_pt(tmp_path, "a", 600)
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
def test_shape_equals_max_length_exact(self, tmp_path):
_write_pt(tmp_path, "a", 512)
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
def test_custom_max_length(self, tmp_path):
_write_pt(tmp_path, "a", 30)
assert ChordDataset(tmp_path, max_length=64)[0].shape[0] == 64
# ---------------------------------------------------------------------------
# Padding
# ---------------------------------------------------------------------------
class TestChordDatasetPadding:
def test_trailing_tokens_are_pad_id(self, tmp_path):
n = 50
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=100)[0]
assert (item[n:] == _PAD_ID).all()
def test_prefix_matches_original_tokens(self, tmp_path):
n = 50
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=100)[0]
expected = torch.arange(n, dtype=torch.long)
assert (item[:n] == expected).all()
def test_no_padding_when_exact_length(self, tmp_path):
n = 100
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=n)[0]
expected = torch.arange(n, dtype=torch.long)
assert (item == expected).all()
# ---------------------------------------------------------------------------
# Truncation
# ---------------------------------------------------------------------------
class TestChordDatasetTruncation:
def test_truncated_length(self, tmp_path):
_write_pt(tmp_path, "a", 600)
item = ChordDataset(tmp_path, max_length=512)[0]
assert item.shape[0] == 512
def test_truncated_prefix_matches_original(self, tmp_path):
_write_pt(tmp_path, "a", 600)
item = ChordDataset(tmp_path, max_length=512)[0]
expected = torch.arange(512, dtype=torch.long)
assert (item == expected).all()
# ---------------------------------------------------------------------------
# Real fixture round-trip
# ---------------------------------------------------------------------------
class TestChordDatasetRealFixture:
def test_bos_at_position_zero(self, tmp_path):
_write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert item[0] == TOKEN_TO_ID["<BOS>"]
def test_eos_at_correct_position(self, tmp_path):
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert item[n - 1] == TOKEN_TO_ID["<EOS>"]
def test_tokens_after_eos_are_pad(self, tmp_path):
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert (item[n:] == _PAD_ID).all()
def test_all_valid_fixture_files_loadable(self, tmp_path):
for name in (
"valid_c_major.chord",
"valid_fsharp_major.chord",
"valid_b_minor.chord",
"valid_gsharp_minor.chord",
):
_write_real_pt(tmp_path, name)
ds = ChordDataset(tmp_path, max_length=512)
assert len(ds) == 4
for i in range(4):
item = ds[i]
assert item.shape[0] == 512
assert item[0] == TOKEN_TO_ID["<BOS>"]
+159 -66
View File
@@ -1,9 +1,9 @@
"""Tests for src/external_converters/mcgill_to_chord.py. """Tests for src/external_converters/mcgill_to_chord.py.
Fixture: tests/fixtures/mcgill_test/0001/salami_chords.txt Fixture: tests/fixtures/mcgill_test/0001/salami_chords.txt
4/4 song in C major, two sections: 4/4 song in C major, two sections in the real McGill v2 2-column format:
Section A (verse): C:maj F:maj G:7 C:maj — 4 chords × 4.0 s each A, verse : | C:maj | F:maj | G:7 | C:maj | (4 bars)
Section B (chorus): F:maj C:maj G:7 C:maj — 4 chords × 4.0 s each B, chorus : | F:maj | C:maj | G:7 | C:maj | (4 bars)
Expected output: 2 .chord files, each with 4 bars, key=C_major, time=4/4. Expected output: 2 .chord files, each with 4 bars, key=C_major, time=4/4.
""" """
@@ -13,13 +13,11 @@ from pathlib import Path
import pytest import pytest
from src.external_converters.mcgill_to_chord import ( from src.external_converters.mcgill_to_chord import (
_estimate_bar_duration, _bar_str_to_positions,
_extract_sections,
_harte_to_chord_symbol, _harte_to_chord_symbol,
_infer_mode, _parse_annotation_line,
_parse_metre, _parse_metre,
_parse_salami_file, _parse_salami_file,
_section_to_bars,
convert_song, convert_song,
) )
from src.tokenizer import parse_chord_file from src.tokenizer import parse_chord_file
@@ -34,17 +32,13 @@ TEST_SONG = FIXTURES / "0001"
class TestHarteConversion: class TestHarteConversion:
"""Unit tests for individual Harte → .chord symbol conversion."""
def test_simple_major(self): def test_simple_major(self):
assert _harte_to_chord_symbol("C:maj") == "Cmaj" assert _harte_to_chord_symbol("C:maj") == "Cmaj"
def test_flat_minor_seventh(self): def test_flat_minor_seventh(self):
# Bb normalises to A#
assert _harte_to_chord_symbol("Bb:min7") == "A#m7" assert _harte_to_chord_symbol("Bb:min7") == "A#m7"
def test_half_diminished(self): def test_half_diminished(self):
# hdim7 = half-diminished 7th = our m7b5
assert _harte_to_chord_symbol("E:hdim7") == "Em7b5" assert _harte_to_chord_symbol("E:hdim7") == "Em7b5"
def test_dominant_seventh(self): def test_dominant_seventh(self):
@@ -62,13 +56,24 @@ class TestHarteConversion:
def test_augmented(self): def test_augmented(self):
assert _harte_to_chord_symbol("C:aug") == "Caug" assert _harte_to_chord_symbol("C:aug") == "Caug"
def test_slash_chord(self): def test_slash_chord_absolute_bass(self):
assert _harte_to_chord_symbol("C:maj/E") == "Cmaj/E" assert _harte_to_chord_symbol("C:maj/E") == "Cmaj/E"
def test_slash_chord_flat_bass(self): def test_slash_chord_flat_bass_normalised(self):
# Flat bass note also normalised to sharp
assert _harte_to_chord_symbol("G:maj/Bb") == "Gmaj/A#" assert _harte_to_chord_symbol("G:maj/Bb") == "Gmaj/A#"
def test_slash_chord_interval_fifth(self):
# '/5' = perfect 5th (7 semitones) above root C → G
assert _harte_to_chord_symbol("C:maj/5") == "Cmaj/G"
def test_slash_chord_interval_b3(self):
# '/b3' = minor 3rd (3 semitones) above root F → Ab = G#
assert _harte_to_chord_symbol("F:min/b3") == "Fm/G#"
def test_slash_chord_interval_3(self):
# '/3' = major 3rd (4 semitones) above root C → E
assert _harte_to_chord_symbol("C:7/3") == "C7/E"
def test_no_chord_returns_none(self): def test_no_chord_returns_none(self):
assert _harte_to_chord_symbol("N") is None assert _harte_to_chord_symbol("N") is None
@@ -79,7 +84,6 @@ class TestHarteConversion:
assert _harte_to_chord_symbol("") is None assert _harte_to_chord_symbol("") is None
def test_extended_dominant_ninth(self): def test_extended_dominant_ninth(self):
# G:9 → dominant 7 + extension 9
assert _harte_to_chord_symbol("G:9") == "G79" assert _harte_to_chord_symbol("G:9") == "G79"
def test_major_ninth(self): def test_major_ninth(self):
@@ -96,14 +100,15 @@ class TestHarteConversion:
def test_output_is_parseable(self): def test_output_is_parseable(self):
from src.chord_parser import parse_chord_symbol from src.chord_parser import parse_chord_symbol
for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7", "C:maj/E"): for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7",
"C:maj/E", "C:maj/5", "F:min/b3"):
sym = _harte_to_chord_symbol(harte) sym = _harte_to_chord_symbol(harte)
assert sym is not None assert sym is not None
parse_chord_symbol(sym) # must not raise parse_chord_symbol(sym)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helper units # Salami file parsing (2-column format)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -115,60 +120,150 @@ class TestParseSalamiFile:
assert header["metre"] == "4/4" assert header["metre"] == "4/4"
assert header["tonic"] == "C" assert header["tonic"] == "C"
def test_events_count(self): def test_data_line_count(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") _, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
# 10 data lines total (including Z lines) # 4 lines: silence, A/verse, B/chorus, silence
assert len(events) == 10 assert len(lines) == 4
def test_first_event_is_silence(self): def test_first_line_is_silence(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") _, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
ts, label, chord = events[0] ts, annotation = lines[0]
assert ts == 0.0 assert ts == 0.0
assert label == "Z" assert annotation == "silence"
def test_returns_two_tuples(self):
_, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
for item in lines:
assert len(item) == 2
class TestExtractSections: # ---------------------------------------------------------------------------
def test_two_sections(self): # Annotation line parsing
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") # ---------------------------------------------------------------------------
sections = _extract_sections(events)
assert len(sections) == 2
def test_section_functions(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
assert sections[0].function == "verse"
assert sections[1].function == "chorus"
def test_events_per_section(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
assert len(sections[0].events) == 4
assert len(sections[1].events) == 4
def test_chord_values(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
hartes = [e.harte for e in sections[0].events]
assert hartes == ["C:maj", "F:maj", "G:7", "C:maj"]
class TestEstimateBarDuration: class TestParseAnnotationLine:
def test_uniform_durations(self): def test_silence_returns_empty(self):
assert _estimate_bar_duration([2.0, 2.0, 2.0, 2.0]) == 2.0 letter, func, bars = _parse_annotation_line("silence")
assert letter is None and func is None and bars == []
def test_mixed_durations(self): def test_end_returns_empty(self):
# Median of [2, 2, 2, 4, 4] = 2 → bar_dur = 2 letter, func, bars = _parse_annotation_line("end")
assert _estimate_bar_duration([2.0, 2.0, 2.0, 4.0, 4.0]) == 2.0 assert letter is None and func is None and bars == []
def test_too_few_samples_returns_default(self): def test_continuation_arrow_returns_empty(self):
assert _estimate_bar_duration([]) == 2.0 letter, func, bars = _parse_annotation_line("->")
assert _estimate_bar_duration([3.0]) == 2.0 assert bars == []
def test_clamp_upper(self): def test_section_letter_extracted(self):
assert _estimate_bar_duration([10.0, 10.0, 10.0]) == 5.0 letter, _, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |")
assert letter == "A"
def test_clamp_lower(self): def test_function_extracted(self):
assert _estimate_bar_duration([0.3, 0.3, 0.3]) == 2.0 # all < 0.5, falls back _, func, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |")
assert func == "verse"
def test_chorus_function(self):
_, func, _ = _parse_annotation_line("B, chorus, | F:maj | C:maj |")
assert func == "chorus"
def test_bar_count(self):
_, _, bars = _parse_annotation_line(
"A, verse, | C:maj | F:maj | G:7 | C:maj |"
)
assert len(bars) == 4
def test_bar_contents(self):
_, _, bars = _parse_annotation_line(
"A, verse, | C:maj | F:maj | G:7 | C:maj |"
)
assert bars == ["C:maj", "F:maj", "G:7", "C:maj"]
def test_continuation_line_no_letter(self):
letter, func, bars = _parse_annotation_line("| C:maj | F:maj |")
assert letter is None
assert func is None
assert bars == ["C:maj", "F:maj"]
def test_repeat_xN(self):
_, _, bars = _parse_annotation_line("| C:maj | x4")
assert bars == ["C:maj"] * 4
def test_trailing_annotation_ignored(self):
_, _, bars = _parse_annotation_line(
"A, intro, | Ab:maj | Db:maj | Ab:maj | G:7 |, (synth)"
)
assert len(bars) == 4
assert bars[0] == "Ab:maj"
def test_multi_chord_bar_preserved(self):
_, _, bars = _parse_annotation_line("| G:hdim7 C:7 | F:min |")
assert bars[0] == "G:hdim7 C:7"
assert bars[1] == "F:min"
# ---------------------------------------------------------------------------
# Bar string to positions
# ---------------------------------------------------------------------------
class TestBarStrToPositions:
def test_single_chord_fills_position_zero(self):
pos = _bar_str_to_positions("C:maj", 4)
assert pos[0] == "Cmaj"
def test_single_chord_rest_are_holds(self):
pos = _bar_str_to_positions("C:maj", 4)
assert pos[1:] == [".", ".", "."]
def test_two_chords_distributed(self):
pos = _bar_str_to_positions("C:maj D:min", 4)
assert pos[0] == "Cmaj"
assert pos[2] == "Dm"
assert pos[1] == "."
assert pos[3] == "."
def test_four_chords_direct_map(self):
# Harte notation: 4 elements → 4 positions, direct 1-to-1 mapping
pos = _bar_str_to_positions("C:maj A:min F:maj G:7", 4)
assert pos == ["Cmaj", "Am", "Fmaj", "G7"]
def test_explicit_hold_tokens(self):
pos = _bar_str_to_positions("C:maj . F:maj .", 4)
assert pos == ["Cmaj", ".", "Fmaj", "."]
def test_nc_mapped(self):
pos = _bar_str_to_positions("N", 4)
assert pos[0] == "NC"
def test_unknown_mapped(self):
pos = _bar_str_to_positions("X", 4)
assert pos[0] == "?"
def test_unrecognized_returns_none(self):
# Starts with a note letter so passes filter, but quality is unknown
assert _bar_str_to_positions("C:xyz", 4) is None
def test_performance_annotation_filtered(self):
# "(voice" is not a chord — should be ignored
pos = _bar_str_to_positions("C:maj (voice", 4)
assert pos is not None
assert pos[0] == "Cmaj"
def test_result_length(self):
for n in (3, 4, 6):
pos = _bar_str_to_positions("C:maj", n)
assert len(pos) == n
def test_interval_bass_resolved(self):
# C:maj/5 → Cmaj/G
pos = _bar_str_to_positions("C:maj/5", 4)
assert pos[0] == "Cmaj/G"
# ---------------------------------------------------------------------------
# Metre parsing
# ---------------------------------------------------------------------------
class TestParseMetre: class TestParseMetre:
@@ -196,8 +291,6 @@ class TestParseMetre:
class TestFullConversion: class TestFullConversion:
"""Integration tests: convert_song with fixture produces valid .chord files."""
def test_returns_two_periods(self, tmp_path): def test_returns_two_periods(self, tmp_path):
assert convert_song(TEST_SONG, tmp_path) == 2 assert convert_song(TEST_SONG, tmp_path) == 2
@@ -208,7 +301,7 @@ class TestFullConversion:
def test_output_files_are_parseable(self, tmp_path): def test_output_files_are_parseable(self, tmp_path):
convert_song(TEST_SONG, tmp_path) convert_song(TEST_SONG, tmp_path)
for f in tmp_path.glob("*.chord"): for f in tmp_path.glob("*.chord"):
assert parse_chord_file(f) is not None # must not raise assert parse_chord_file(f) is not None
def test_verse_has_four_bars(self, tmp_path): def test_verse_has_four_bars(self, tmp_path):
convert_song(TEST_SONG, tmp_path) convert_song(TEST_SONG, tmp_path)
@@ -257,7 +350,7 @@ class TestFullConversion:
for bar in p.bars: for bar in p.bars:
first = bar[0] first = bar[0]
if first not in (".", "NC", "?"): if first not in (".", "NC", "?"):
parse_chord_symbol(first) # must not raise parse_chord_symbol(first)
def test_missing_salami_returns_zero(self, tmp_path): def test_missing_salami_returns_zero(self, tmp_path):
empty_song = tmp_path / "empty" empty_song = tmp_path / "empty"