diff --git a/.gitignore b/.gitignore index 83016e0..bb3bc3a 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,9 @@ checkpoints/*.ckpt # Processed data (reproducible from source) data/processed/*.pt data/processed/*.pkl +data/processed/train/ +data/processed/val/ +data/processed/holdout/ # External corpora (download separately; too large for git) data/raw_external/ diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100644 index 0000000..87136f9 --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,222 @@ +"""Tokenize .chord files into .pt tensors for model training. + +Usage: + python scripts/prepare_data.py --input-dir data/raw_user \\ + --output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42] + +Arguments: + --input-dir Root directory to search recursively for .chord files. + --output-dir Output directory. Subdirs train/, val/, holdout/ are created. + --split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1. + --seed Random seed for reproducible shuffling. Default: 42. + --log-level Logging verbosity. Default: INFO. + +Files found under any "holdout" directory within --input-dir are written to +/holdout/ and never participate in the train/val split. +""" + +from __future__ import annotations + +import argparse +import logging +import random +import sys +from collections import Counter +from pathlib import Path + +import torch + +# Allow running as a script from the project root without installing the package. +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402 + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _is_holdout(path: Path, input_dir: Path) -> bool: + """True when the path lives under a 'holdout' sub-directory of input_dir.""" + try: + rel = path.relative_to(input_dir) + except ValueError: + return False + return "holdout" in rel.parts + + +def _parse_ratios(s: str) -> tuple[float, float]: + parts = s.split("/") + if len(parts) != 2: + raise argparse.ArgumentTypeError( + f"split-ratios must be TRAIN/VAL format, got {s!r}" + ) + try: + train_r, val_r = float(parts[0]), float(parts[1]) + except ValueError: + raise argparse.ArgumentTypeError( + f"split-ratios values must be floats, got {s!r}" + ) + total = train_r + val_r + if abs(total - 1.0) > 1e-6: + raise argparse.ArgumentTypeError( + f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}" + ) + return train_r, val_r + + +def _process_file(path: Path) -> dict | None: + """Parse and tokenize one .chord file. Returns None on any error.""" + try: + period = parse_chord_file(path) + ids = tokenize_period(period) + tokens = torch.tensor(ids, dtype=torch.long) + meta = { + "title": period.title, + "key": period.key, + "style": period.style, + "function": period.function, + "time": period.time, + "source_file": str(path), + "n_tokens": len(ids), + } + return {"tokens": tokens, "meta": meta} + except Exception as exc: + log.warning("Skipping %s: %s", path, exc) + return None + + +def _save(data: dict, out_dir: Path, stem: str) -> None: + out_path = out_dir / f"{stem}.pt" + if out_path.exists(): + log.warning("Overwriting existing output file: %s", out_path) + torch.save(data, out_path) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Tokenize .chord files into .pt tensors for model training.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--input-dir", required=True, type=Path, + help="Root directory containing .chord files (searched recursively).", + ) + parser.add_argument( + "--output-dir", required=True, type=Path, + help="Output directory; train/, val/, holdout/ subdirs are created.", + ) + parser.add_argument( + "--split-ratios", default="0.9/0.1", + help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.", + ) + parser.add_argument( + "--seed", type=int, default=42, + help="Random seed for reproducible shuffling. Default: 42.", + ) + parser.add_argument( + "--log-level", default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging verbosity. Default: INFO.", + ) + args = parser.parse_args(argv) + + logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") + + train_ratio, _val_ratio = _parse_ratios(args.split_ratios) + + input_dir: Path = args.input_dir.resolve() + output_dir: Path = args.output_dir.resolve() + + if not input_dir.exists(): + log.error("Input directory does not exist: %s", input_dir) + sys.exit(1) + + for subdir in ("train", "val", "holdout"): + (output_dir / subdir).mkdir(parents=True, exist_ok=True) + + all_files = sorted(input_dir.rglob("*.chord")) + if not all_files: + log.warning("No .chord files found in %s", input_dir) + return + + holdout_files = [f for f in all_files if _is_holdout(f, input_dir)] + regular_files = [f for f in all_files if not _is_holdout(f, input_dir)] + + log.info( + "Found %d .chord files total (%d holdout, %d regular)", + len(all_files), len(holdout_files), len(regular_files), + ) + + # --- Holdout --- + holdout_records: list[dict] = [] + for path in holdout_files: + data = _process_file(path) + if data is not None: + holdout_records.append(data) + _save(data, output_dir / "holdout", path.stem) + + # --- Train / val split --- + random.seed(args.seed) + shuffled = list(regular_files) + random.shuffle(shuffled) + + n_train = round(len(shuffled) * train_ratio) + train_paths = shuffled[:n_train] + val_paths = shuffled[n_train:] + + train_records: list[dict] = [] + for path in train_paths: + data = _process_file(path) + if data is not None: + train_records.append(data) + _save(data, output_dir / "train", path.stem) + + val_records: list[dict] = [] + for path in val_paths: + data = _process_file(path) + if data is not None: + val_records.append(data) + _save(data, output_dir / "val", path.stem) + + # --- Stats --- + all_records = train_records + val_records + holdout_records + if not all_records: + log.warning("No files were successfully processed.") + return + + token_lengths = [r["meta"]["n_tokens"] for r in all_records] + style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records) + function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records) + + log.info("--- Processing summary ---") + log.info("Total processed: %d (train=%d, val=%d, holdout=%d)", + len(all_records), len(train_records), len(val_records), len(holdout_records)) + + skipped = len(all_files) - len(all_records) + if skipped: + log.warning("Skipped due to errors: %d", skipped) + + log.info("Token lengths: mean=%.1f, max=%d", + sum(token_lengths) / len(token_lengths), max(token_lengths)) + + log.info("Style distribution:") + for style, count in sorted(style_counts.items()): + log.info(" %-16s %d", style, count) + + log.info("Function distribution:") + for func, count in sorted(function_counts.items()): + log.info(" %-16s %d", func, count) + + +if __name__ == "__main__": + main() diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..182e833 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,52 @@ +"""PyTorch Dataset for tokenized .chord period files. + +Public API: + ChordDataset — Dataset that loads pre-tokenized .pt files from a directory. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import torch +from torch.utils.data import Dataset + +from src.tokenizer import TOKEN_TO_ID + +log = logging.getLogger(__name__) + +_PAD_ID: int = TOKEN_TO_ID[""] + + +class ChordDataset(Dataset): + """Dataset over a directory of tokenized .pt period files. + + Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``. + ``__getitem__`` returns a fixed-length LongTensor: the token sequence is + truncated to *max_length* if too long, or right-padded with if short. + + Args: + data_dir: Directory containing .pt files (non-recursive). + max_length: Fixed output sequence length. Default 512. + """ + + def __init__(self, data_dir: Path, max_length: int = 512) -> None: + self._max_length = max_length + self._files: list[Path] = sorted(Path(data_dir).glob("*.pt")) + if not self._files: + log.warning("ChordDataset: no .pt files found in %s", data_dir) + + def __len__(self) -> int: + return len(self._files) + + def __getitem__(self, idx: int) -> torch.Tensor: + data = torch.load(self._files[idx], weights_only=True) + tokens: torch.Tensor = data["tokens"] + + length = tokens.shape[0] + if length >= self._max_length: + return tokens[: self._max_length] + + pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype) + return torch.cat([tokens, pad]) diff --git a/src/external_converters/mcgill_to_chord.py b/src/external_converters/mcgill_to_chord.py index e2913b7..4be362c 100644 --- a/src/external_converters/mcgill_to_chord.py +++ b/src/external_converters/mcgill_to_chord.py @@ -1,19 +1,24 @@ """Convert McGill Billboard dataset (salami_chords.txt) to .chord files. -McGill Billboard format: +McGill Billboard v2 format: Each song is a subdirectory (e.g. 0003/, 0004/) containing salami_chords.txt. - The file has a header (# key: value) followed by tab-separated data lines: - \\t\\t + Header: # key: value lines (artist, title, metre, tonic). + Data: tab-separated pairs \\t where annotation is: + - "silence" / "end" — structural boundary (no chord data) + - "[Letter[, function,]] | bar1 | bar2 | ... |" + Each | ... | group is ONE BAR; space-separated tokens inside are + beat-level chord changes within that bar. + - "| ... | xN" — the bar(s) repeated N times - Section labels: 'Z' (silence/boundary), a letter (e.g. 'A', 'B,verse'), or '.' (continuation). - Chords: Harte notation (e.g. C:maj, Bb:min7, N for no chord, X for unknown). + Bass notes in Harte may be absolute (e.g. '/E') or scale-degree intervals + (e.g. '/5' = perfect fifth, '/b3' = minor third above root). Public API: - convert_dataset(dataset_dir, output_dir) -- convert entire dataset directory + convert_dataset(dataset_dir, output_dir) -- convert entire dataset convert_song(song_dir, output_dir) -- convert one song directory CLI: - python -m src.external_converters.mcgill_to_chord [--out ] + python -m src.external_converters.mcgill_to_chord [--out …] Example: python -m src.external_converters.mcgill_to_chord data/raw_external/mcgill/ \\ @@ -25,14 +30,35 @@ from __future__ import annotations import argparse import logging import re -import statistics from collections import Counter -from dataclasses import dataclass, field from pathlib import Path from typing import Optional log = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Note tables +# --------------------------------------------------------------------------- + +_CHROMATIC: list[str] = [ + "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B" +] +_NOTE_INDEX: dict[str, int] = {n: i for i, n in enumerate(_CHROMATIC)} + +_FLAT_TO_SHARP: dict[str, str] = { + "Cb": "B", "Db": "C#", "Eb": "D#", "Fb": "E", + "Gb": "F#", "Ab": "G#", "Bb": "A#", +} + +_VALID_NOTES: frozenset[str] = frozenset(_CHROMATIC) + +# Harte scale-degree intervals: semitones above root +_HARTE_INTERVAL: dict[str, int] = { + "1": 0, "b2": 1, "2": 2, "b3": 3, "3": 4, "4": 5, + "#4": 6, "b5": 6, "5": 7, "#5": 8, "b6": 8, "6": 9, + "b7": 10, "7": 11, +} + # --------------------------------------------------------------------------- # Harte quality → (our_quality, our_extension) # --------------------------------------------------------------------------- @@ -63,12 +89,11 @@ _HARTE_QUALITY: dict[str, tuple[str, str]] = { "13": ("7", "13"), "maj13": ("maj7", "13"), "min13": ("m7", "13"), - "1": ("maj", "none"), # root only → major - "5": ("maj", "none"), # power chord → major (no 3rd) - "": ("maj", "none"), # bare root + "1": ("maj", "none"), + "5": ("maj", "none"), + "": ("maj", "none"), } -# Parenthetical alterations in Harte (e.g. '7(b9)') → our extension token _HARTE_PAREN_EXT: dict[str, str] = { "b9": "b9", "#9": "#9", @@ -79,7 +104,6 @@ _HARTE_PAREN_EXT: dict[str, str] = { "9": "9", } -# McGill Billboard section function strings → our function tokens _FUNCTION_MAP: dict[str, str] = { "intro": "intro", "verse": "verse", @@ -92,7 +116,7 @@ _FUNCTION_MAP: dict[str, str] = { "bridge": "bridge", "outro": "outro", "coda": "outro", - "end": "outro", + "ending": "outro", "interlude": "interlude", "instrumental": "interlude", "solo": "interlude", @@ -101,18 +125,8 @@ _FUNCTION_MAP: dict[str, str] = { "other": "other", } -_VALID_NOTES: frozenset[str] = frozenset( - {"C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"} -) - -_FLAT_TO_SHARP: dict[str, str] = { - "Cb": "B", "Db": "C#", "Eb": "D#", "Fb": "E", - "Gb": "F#", "Ab": "G#", "Bb": "A#", -} - _VALID_TIMES: frozenset[str] = frozenset({"4/4", "3/4", "6/8", "2/4", "12/8"}) -# Quality families used for mode inference _MAJOR_QUALITIES: frozenset[str] = frozenset( {"maj", "maj7", "6", "add9", "aug", "sus2", "sus4", "7sus4", "aug7"} ) @@ -120,25 +134,6 @@ _MINOR_QUALITIES: frozenset[str] = frozenset( {"m", "m7", "mM7", "m6", "m7b5", "dim", "dim7"} ) -# --------------------------------------------------------------------------- -# Internal data structures -# --------------------------------------------------------------------------- - - -@dataclass -class _ChordEvent: - start: float - duration: float # seconds - harte: str # Harte chord string: 'N', 'X', 'C:maj', etc. - - -@dataclass -class _Section: - letter: str # section letter, e.g. 'A', 'B' - function: str # our function token, e.g. 'verse', 'chorus' - events: list[_ChordEvent] = field(default_factory=list) - - # --------------------------------------------------------------------------- # Note / chord helpers # --------------------------------------------------------------------------- @@ -150,35 +145,49 @@ def _normalize_note(raw: str) -> Optional[str]: return note if note in _VALID_NOTES else None +def _resolve_harte_bass(root: str, bass_str: str) -> Optional[str]: + """Convert Harte bass notation to an absolute sharp note name. + + Supports absolute notes ('E', 'Bb') and scale-degree intervals ('5', 'b3'). + """ + bass_str = bass_str.strip() + if not bass_str: + return None + # Absolute note: starts with A–G + if bass_str[0] in "ABCDEFG": + if len(bass_str) >= 2 and bass_str[1] in "#b": + raw, tail = bass_str[:2], bass_str[2:] + else: + raw, tail = bass_str[:1], bass_str[1:] + if tail: + return None + return _normalize_note(raw) + # Scale-degree interval + interval = _HARTE_INTERVAL.get(bass_str) + if interval is None: + return None + root_idx = _NOTE_INDEX[root] + return _CHROMATIC[(root_idx + interval) % 12] + + def _harte_to_chord_symbol(harte: str) -> Optional[str]: - """Convert a Harte chord string to our .chord format symbol. + """Convert a Harte chord string to our .chord symbol. Args: - harte: Harte notation string, e.g. 'C:maj', 'Bb:min7', 'E:hdim7/G#'. + harte: Harte notation, e.g. 'C:maj', 'Bb:min7', 'F:maj/5', 'G:7(b9)'. Returns: - Our chord symbol (e.g. 'Cmaj', 'A#m7', 'Em7b5/G#'), or None for + Our chord symbol (e.g. 'Cmaj', 'A#m7', 'Fmaj/C'), or None for N (no chord), X (unknown), or any unparseable input. """ harte = harte.strip() if harte in ("N", "X", ""): return None - # Extract slash bass note (rightmost '/') - bass_note = "root" + # Extract slash bass (rightmost '/') + bass_raw: Optional[str] = None if "/" in harte: - main, bass_raw = harte.rsplit("/", 1) - if len(bass_raw) >= 2 and bass_raw[1] in "#b": - raw_b, tail = bass_raw[:2], bass_raw[2:] - else: - raw_b, tail = bass_raw[:1], bass_raw[1:] - if tail or not raw_b: - return None - bn = _normalize_note(raw_b) - if bn is None: - return None - bass_note = bn - harte = main + harte, bass_raw = harte.rsplit("/", 1) # Split root from quality on first ':' if ":" in harte: @@ -202,6 +211,14 @@ def _harte_to_chord_symbol(harte: str) -> Optional[str]: if root is None: return None + # Resolve bass now that root is known + bass_note = "root" + if bass_raw is not None: + resolved = _resolve_harte_bass(root, bass_raw) + if resolved is None: + return None + bass_note = resolved + # Parse quality — handle parenthetical alterations like '7(b9)' m = re.match(r'^([^(]*)\(([^)]+)\)$', quality_str) if m: @@ -231,17 +248,15 @@ def _harte_to_chord_symbol(harte: str) -> Optional[str]: def _parse_salami_file( path: Path, -) -> tuple[dict[str, str], list[tuple[float, str, str]]]: +) -> tuple[dict[str, str], list[tuple[float, str]]]: """Parse a salami_chords.txt file. Returns: - (header, events) where header maps lowercase field names to values, - and events is a list of (timestamp, label, chord) triples. - label may be 'Z', a section letter (possibly with ',function'), or '.'. - chord is in Harte notation or '' when the column is absent. + (header, data_lines) where header maps lowercase field names to values + and data_lines is a list of (timestamp, annotation_string) pairs. """ header: dict[str, str] = {} - events: list[tuple[float, str, str]] = [] + data_lines: list[tuple[float, str]] = [] for raw in path.read_text(encoding="utf-8").splitlines(): line = raw.strip() @@ -253,126 +268,118 @@ def _parse_salami_file( k, v = content.split(":", 1) header[k.strip().lower()] = v.strip() continue - parts = line.split("\t") + parts = line.split("\t", 1) if len(parts) < 2: continue try: ts = float(parts[0]) except ValueError: continue - label = parts[1].strip() - chord = parts[2].strip() if len(parts) > 2 else "" - events.append((ts, label, chord)) + data_lines.append((ts, parts[1].strip())) - return header, events + return header, data_lines # --------------------------------------------------------------------------- -# Section extraction +# Annotation line parsing # --------------------------------------------------------------------------- -def _parse_section_label(label: str) -> tuple[str, str]: - """Parse 'A,verse' → (letter='A', function='verse').""" - if "," in label: - letter, func_raw = label.split(",", 1) - func = _FUNCTION_MAP.get(func_raw.strip().lower(), "other") +def _parse_annotation_line( + annotation: str, +) -> tuple[Optional[str], Optional[str], list[str]]: + """Parse one annotation string into (section_letter, function, bar_strings). + + bar_strings is a list of bar content strings, one per bar. + Returns (None, None, []) for silence/end/empty/continuation-only lines. + """ + annotation = annotation.strip() + if not annotation or annotation.lower() in ("silence", "end"): + return None, None, [] + if annotation.startswith("->"): + return None, None, [] + + section_letter: Optional[str] = None + function: Optional[str] = None + + first_pipe = annotation.find("|") + if first_pipe == -1: + prefix = annotation + bar_section = "" else: - letter = label - func = "other" - return letter.strip(), func + prefix = annotation[:first_pipe] + bar_section = annotation[first_pipe:] + + # Parse optional section header before first '|' + if prefix.strip(): + parts = [p.strip() for p in prefix.rstrip(",").split(",")] + if parts and len(parts[0]) == 1 and parts[0].isupper(): + section_letter = parts[0] + if len(parts) > 1 and parts[1]: + function = _FUNCTION_MAP.get(parts[1].lower(), "other") + + if not bar_section: + return section_letter, function, [] + + # Split on '|': odd-indexed parts are bar contents, last part is trailing + raw_parts = bar_section.split("|") + # raw_parts[0] is before first '|' (empty or whitespace) + # raw_parts[-1] is after last '|' (trailing annotation / xN) + trailing = raw_parts[-1].strip() if raw_parts else "" + intermediate = raw_parts[1:-1] # bar contents between pipes + + bar_strings = [p.strip() for p in intermediate if p.strip()] + + # Handle xN repeat: "x4" in trailing → repeat all bars N times + xN = re.match(r"x(\d+)\b", trailing) + if xN and bar_strings: + bar_strings = bar_strings * int(xN.group(1)) + + return section_letter, function, bar_strings -def _extract_sections( - events: list[tuple[float, str, str]], -) -> list[_Section]: - """Group raw event triples into _Section objects with _ChordEvent lists.""" - sections: list[_Section] = [] - current: Optional[_Section] = None - timestamps = [e[0] for e in events] +def _bar_str_to_positions(bar_content: str, n_positions: int) -> Optional[list[str]]: + """Convert bar content string to a fixed-length position list. - for i, (ts, label, chord) in enumerate(events): - dur = timestamps[i + 1] - ts if i + 1 < len(timestamps) else 0.0 - - if label in ("Z", ""): - current = None - continue - - if label == ".": - if current is not None and chord and dur > 0: - current.events.append(_ChordEvent(ts, dur, chord)) - continue - - # New section starts here - letter, func = _parse_section_label(label) - current = _Section(letter=letter, function=func) - sections.append(current) - if chord and dur > 0: - current.events.append(_ChordEvent(ts, dur, chord)) - - return sections - - -# --------------------------------------------------------------------------- -# Bar quantization -# --------------------------------------------------------------------------- - - -def _estimate_bar_duration(durations: list[float]) -> float: - """Estimate duration of one bar in seconds. - - Uses the median of non-trivial chord durations as a proxy for one bar. - Clamped to [1.0, 5.0] s (covers ~48–240 BPM in 4/4). - Falls back to 2.0 s when fewer than 3 samples. + Distributes space-separated chord elements across n_positions slots. + Returns None if any element is an unrecognized chord symbol. """ - valid = [d for d in durations if d > 0.5] - if len(valid) < 3: - return 2.0 - return max(1.0, min(5.0, statistics.median(valid))) + # Filter out performance annotations: keep only chord-like tokens + raw_elements = bar_content.split() + elements = [e for e in raw_elements if _is_chord_element(e)] + positions: list[str] = ["."] * n_positions + n = len(elements) + if n == 0: + return positions -def _expected_positions(time: str, subdivision: int) -> int: - """Number of positions per bar for the given time signature and subdivision.""" - num, denom = (int(x) for x in time.split("/")) - return (num * subdivision) // denom - - -def _section_to_bars( - section: _Section, - bar_duration: float, - time: str, - subdivision: int, -) -> Optional[list[list[str]]]: - """Convert a section's chord events to a list of bars. - - Returns None if any event contains an unrecognized Harte chord symbol; - the caller will skip the section and log a reason. - """ - positions_per_bar = _expected_positions(time, subdivision) - bars: list[list[str]] = [] - - for event in section.events: - if event.harte == "N": - first_pos = "NC" - elif event.harte == "X": - first_pos = "?" + for i, elem in enumerate(elements): + pos_idx = i * n_positions // n + if elem == ".": + continue # explicit hold — leave slot as "." + elif elem == "N": + if positions[pos_idx] == ".": + positions[pos_idx] = "NC" + elif elem == "X": + if positions[pos_idx] == ".": + positions[pos_idx] = "?" else: - sym = _harte_to_chord_symbol(event.harte) + sym = _harte_to_chord_symbol(elem) if sym is None: - log.debug( - "unrecognized Harte chord %r in section %s", - event.harte, section.letter, - ) + log.debug("unrecognized Harte chord %r in bar %r", elem, bar_content) return None - first_pos = sym + if positions[pos_idx] == ".": + positions[pos_idx] = sym - n_bars = max(1, round(event.duration / bar_duration)) - bars.append([first_pos] + ["."] * (positions_per_bar - 1)) - for _ in range(n_bars - 1): - # Hold chord across additional bars - bars.append(["."] * positions_per_bar) + return positions - return bars + +def _is_chord_element(elem: str) -> bool: + """True if elem is a chord token, hold marker, or NC/unknown.""" + if elem in (".", "N", "X"): + return True + # Chord: starts with a note letter + return bool(elem) and elem[0] in "ABCDEFG" # --------------------------------------------------------------------------- @@ -380,45 +387,42 @@ def _section_to_bars( # --------------------------------------------------------------------------- -def _infer_mode(tonic: str, sections: list[_Section]) -> str: +def _infer_mode(tonic: str, harte_chords: list[str]) -> str: """Determine 'major' or 'minor' from tonic chord quality distribution. - Counts occurrences of the tonic root in major-family vs minor-family - qualities across all sections. Returns 'major' on a tie or no data. + Returns 'major' on a tie or when no data is available. """ major_count = 0 minor_count = 0 - for section in sections: - for event in section.events: - if not event.harte or event.harte in ("N", "X"): - continue - # Extract root without a full Harte parse - colon = event.harte.find(":") - root_part = event.harte[:colon] if colon != -1 else event.harte - root_str = root_part.split("/")[0] - if len(root_str) >= 2 and root_str[1] in "#b": - raw_root = root_str[:2] - else: - raw_root = root_str[:1] - if not raw_root: - continue - root = _normalize_note(raw_root) - if root != tonic: - continue - # Extract quality - quality_str = event.harte[colon + 1:] if colon != -1 else "" - if "/" in quality_str: - quality_str = quality_str[: quality_str.index("/")] - base = re.sub(r'\([^)]*\)', "", quality_str).strip() - result = _HARTE_QUALITY.get(base) - if result is None: - continue - our_quality = result[0] - if our_quality in _MAJOR_QUALITIES: - major_count += 1 - elif our_quality in _MINOR_QUALITIES: - minor_count += 1 + for harte in harte_chords: + if not harte or harte in ("N", "X", "."): + continue + colon = harte.find(":") + root_part = harte[:colon] if colon != -1 else harte + root_str = root_part.split("/")[0] + if len(root_str) >= 2 and root_str[1] in "#b": + raw_root = root_str[:2] + else: + raw_root = root_str[:1] + if not raw_root: + continue + root = _normalize_note(raw_root) + if root != tonic: + continue + quality_str = harte[colon + 1:] if colon != -1 else "" + slash_pos = quality_str.find("/") + if slash_pos != -1: + quality_str = quality_str[:slash_pos] + base = re.sub(r"\([^)]*\)", "", quality_str).strip() + result = _HARTE_QUALITY.get(base) + if result is None: + continue + our_quality = result[0] + if our_quality in _MAJOR_QUALITIES: + major_count += 1 + elif our_quality in _MINOR_QUALITIES: + minor_count += 1 return "minor" if minor_count > major_count else "major" @@ -441,6 +445,11 @@ def _parse_metre(metre: str) -> tuple[Optional[str], int]: return None, 0 +def _expected_positions(time: str, subdivision: int) -> int: + num, denom = (int(x) for x in time.split("/")) + return (num * subdivision) // denom + + # --------------------------------------------------------------------------- # File writing # --------------------------------------------------------------------------- @@ -455,7 +464,6 @@ def _write_chord_file( function: Optional[str], bars: list[list[str]], ) -> None: - """Write a harmonic period to a .chord file.""" lines = [ f"# title: {title}", f"# key: {key}", @@ -463,12 +471,12 @@ def _write_chord_file( f"# subdivision: {subdivision}", "# style: other", ] - if function: + if function and function != "unspecified": lines.append(f"# function: {function}") - lines.append("") # blank line before body + lines.append("") for i in range(0, len(bars), 4): - chunk = bars[i : i + 4] + chunk = bars[i: i + 4] line = " ".join(f"| {' '.join(b)}" for b in chunk) + " |" lines.append(line) @@ -484,8 +492,8 @@ def convert_song(song_dir: Path, output_dir: Path) -> int: """Convert one McGill Billboard song directory to .chord files. Args: - song_dir: Directory containing salami_chords.txt (e.g. 0003/). - output_dir: Destination directory for .chord files (created if absent). + song_dir: Directory containing salami_chords.txt. + output_dir: Destination directory for .chord files. Returns: Number of .chord files successfully written. @@ -496,13 +504,12 @@ def convert_song(song_dir: Path, output_dir: Path) -> int: return 0 try: - header, raw_events = _parse_salami_file(salami) + header, data_lines = _parse_salami_file(salami) except Exception as exc: log.error("failed to parse %s: %s", salami, exc) return 0 song_id = song_dir.name - time_sig, subdivision = _parse_metre(header.get("metre", "4/4")) if time_sig is None: log.warning( @@ -513,57 +520,75 @@ def convert_song(song_dir: Path, output_dir: Path) -> int: tonic_raw = header.get("tonic", "C").strip() tonic = _normalize_note(tonic_raw) or "C" - sections = _extract_sections(raw_events) - if not sections: - log.warning("no sections found in %s", salami) - return 0 + # Collect all Harte tokens for mode inference + all_harte: list[str] = [] + for _, annotation in data_lines: + _, _, bar_groups = _parse_annotation_line(annotation) + for bg in bar_groups: + all_harte.extend(bg.split()) - all_durations = [ - e.duration - for s in sections - for e in s.events - if e.harte not in ("N", "X", "") and e.duration > 0.5 - ] - bar_duration = _estimate_bar_duration(all_durations) - mode = _infer_mode(tonic, sections) + mode = _infer_mode(tonic, all_harte) key = f"{tonic}_{mode}" - artist = header.get("artist", "unknown") song_title = header.get("title", "unknown") + n_positions = _expected_positions(time_sig, subdivision) + + # Group annotation lines into sections + sections: list[tuple[str, list[list[str]]]] = [] + current_function = "unspecified" + current_bars: list[list[str]] = [] + current_valid = True + + for _, annotation in data_lines: + letter, func, bar_groups = _parse_annotation_line(annotation) + + if letter is not None: + # New section boundary — save current section if non-empty + if current_bars and current_valid: + sections.append((current_function, current_bars)) + current_bars = [] + current_valid = True + current_function = func if func is not None else "unspecified" + + if not current_valid: + continue + + for bg in bar_groups: + positions = _bar_str_to_positions(bg, n_positions) + if positions is None: + current_valid = False + break + current_bars.append(positions) + + # Save the final section + if current_bars and current_valid: + sections.append((current_function, current_bars)) + output_dir.mkdir(parents=True, exist_ok=True) n_saved = 0 skip_reasons: Counter[str] = Counter() - for idx, section in enumerate(sections): - bars = _section_to_bars(section, bar_duration, time_sig, subdivision) - if bars is None: - skip_reasons["unrecognized_chord"] += 1 - continue - + for idx, (func, bars) in enumerate(sections): n = len(bars) if n < 4: log.debug( - "section %s in %s: %d bar(s) < 4, skipping", - section.letter, song_id, n, + "section %d in %s: %d bar(s) < 4, skipping", idx, song_id, n ) skip_reasons["too_short"] += 1 continue if n > 16: log.debug( - "section %s in %s: %d bars > 16, skipping", - section.letter, song_id, n, + "section %d in %s: %d bars > 16, skipping", idx, song_id, n ) skip_reasons["too_long"] += 1 continue - func = section.function filename = f"mcgill_{song_id}_{idx:02d}_{func}.chord" out_path = output_dir / filename - period_title = f"{artist} - {song_title} ({section.letter},{func})" + period_title = f"{artist} - {song_title} ({func})" _write_chord_file( - out_path, period_title, key, time_sig, subdivision, - func if func != "unspecified" else None, bars, + out_path, period_title, key, time_sig, subdivision, func, bars ) n_saved += 1 log.debug("wrote %s", out_path.name) @@ -581,10 +606,6 @@ def convert_song(song_dir: Path, output_dir: Path) -> int: def convert_dataset(dataset_dir: Path, output_dir: Path) -> tuple[int, int]: """Convert all song directories in a McGill Billboard dataset. - Args: - dataset_dir: Root directory containing per-song subdirectories. - output_dir: Destination directory for .chord files. - Returns: (n_saved, n_empty) where n_empty counts songs that produced no output. """ @@ -606,7 +627,7 @@ def convert_dataset(dataset_dir: Path, output_dir: Path) -> tuple[int, int]: # --------------------------------------------------------------------------- -# CLI entry point +# CLI # --------------------------------------------------------------------------- if __name__ == "__main__": @@ -615,7 +636,8 @@ if __name__ == "__main__": epilog=( "Example:\n" " python -m src.external_converters.mcgill_to_chord " - "data/raw_external/mcgill/ --out data/raw_external/mcgill_converted/" + "data/raw_external/mcgill/billboard-2.0-salami_chords/ " + "--out data/raw_external/mcgill_chord/" ), formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -625,9 +647,9 @@ if __name__ == "__main__": ) parser.add_argument( "--out", type=Path, - default=Path("data/raw_external/mcgill_converted"), + default=Path("data/raw_external/mcgill_chord"), metavar="output_dir", - help="destination for .chord files (default: data/raw_external/mcgill_converted/)", + help="destination for .chord files (default: data/raw_external/mcgill_chord/)", ) parser.add_argument( "--log-level", default="INFO", diff --git a/tests/fixtures/mcgill_test/0001/salami_chords.txt b/tests/fixtures/mcgill_test/0001/salami_chords.txt index 1925b8b..6018066 100644 --- a/tests/fixtures/mcgill_test/0001/salami_chords.txt +++ b/tests/fixtures/mcgill_test/0001/salami_chords.txt @@ -3,13 +3,7 @@ # metre: 4/4 # tonic: C -0.000000 Z -4.000000 A,verse C:maj -8.000000 . F:maj -12.000000 . G:7 -16.000000 . C:maj -20.000000 B,chorus F:maj -24.000000 . C:maj -28.000000 . G:7 -32.000000 . C:maj -36.000000 Z +0.000000 silence +4.000000 A, verse, | C:maj | F:maj | G:7 | C:maj | +20.000000 B, chorus, | F:maj | C:maj | G:7 | C:maj | +36.000000 silence diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..e6da403 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,176 @@ +"""Tests for ChordDataset in src/dataset.py.""" + +from pathlib import Path + +import torch +import pytest + +from src.dataset import ChordDataset +from src.tokenizer import TOKEN_TO_ID, parse_chord_file, tokenize_period + +FIXTURES = Path(__file__).parent / "fixtures" +_PAD_ID = TOKEN_TO_ID[""] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_pt(tmp_path: Path, stem: str, n_tokens: int) -> Path: + """Write a dummy .pt file with sequential token IDs.""" + tokens = torch.arange(n_tokens, dtype=torch.long) + path = tmp_path / f"{stem}.pt" + torch.save({"tokens": tokens, "meta": {"style": "user", "function": "verse"}}, path) + return path + + +def _write_real_pt(tmp_path: Path, fixture_name: str) -> tuple[Path, int]: + """Tokenize a real fixture and write its .pt file. Returns (path, n_tokens).""" + period = parse_chord_file(FIXTURES / fixture_name) + ids = tokenize_period(period) + tokens = torch.tensor(ids, dtype=torch.long) + out = tmp_path / f"{fixture_name}.pt" + torch.save({"tokens": tokens, "meta": {"style": period.style}}, out) + return out, len(ids) + + +# --------------------------------------------------------------------------- +# Length and file discovery +# --------------------------------------------------------------------------- + + +class TestChordDatasetLength: + def test_empty_directory(self, tmp_path): + ds = ChordDataset(tmp_path) + assert len(ds) == 0 + + def test_single_file(self, tmp_path): + _write_pt(tmp_path, "a", 10) + assert len(ChordDataset(tmp_path)) == 1 + + def test_multiple_files(self, tmp_path): + for name in ("a", "b", "c"): + _write_pt(tmp_path, name, 10) + assert len(ChordDataset(tmp_path)) == 3 + + def test_non_pt_files_ignored(self, tmp_path): + _write_pt(tmp_path, "a", 10) + (tmp_path / "notes.txt").write_text("ignored") + (tmp_path / "model.pth").write_text("ignored") + assert len(ChordDataset(tmp_path)) == 1 + + +# --------------------------------------------------------------------------- +# Output shape +# --------------------------------------------------------------------------- + + +class TestChordDatasetShape: + def test_returns_tensor(self, tmp_path): + _write_pt(tmp_path, "a", 50) + item = ChordDataset(tmp_path)[0] + assert isinstance(item, torch.Tensor) + + def test_dtype_is_long(self, tmp_path): + _write_pt(tmp_path, "a", 50) + item = ChordDataset(tmp_path)[0] + assert item.dtype == torch.long + + def test_shape_equals_max_length_when_shorter(self, tmp_path): + _write_pt(tmp_path, "a", 50) + assert ChordDataset(tmp_path, max_length=100)[0].shape[0] == 100 + + def test_shape_equals_max_length_when_longer(self, tmp_path): + _write_pt(tmp_path, "a", 600) + assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512 + + def test_shape_equals_max_length_exact(self, tmp_path): + _write_pt(tmp_path, "a", 512) + assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512 + + def test_custom_max_length(self, tmp_path): + _write_pt(tmp_path, "a", 30) + assert ChordDataset(tmp_path, max_length=64)[0].shape[0] == 64 + + +# --------------------------------------------------------------------------- +# Padding +# --------------------------------------------------------------------------- + + +class TestChordDatasetPadding: + def test_trailing_tokens_are_pad_id(self, tmp_path): + n = 50 + _write_pt(tmp_path, "a", n) + item = ChordDataset(tmp_path, max_length=100)[0] + assert (item[n:] == _PAD_ID).all() + + def test_prefix_matches_original_tokens(self, tmp_path): + n = 50 + _write_pt(tmp_path, "a", n) + item = ChordDataset(tmp_path, max_length=100)[0] + expected = torch.arange(n, dtype=torch.long) + assert (item[:n] == expected).all() + + def test_no_padding_when_exact_length(self, tmp_path): + n = 100 + _write_pt(tmp_path, "a", n) + item = ChordDataset(tmp_path, max_length=n)[0] + expected = torch.arange(n, dtype=torch.long) + assert (item == expected).all() + + +# --------------------------------------------------------------------------- +# Truncation +# --------------------------------------------------------------------------- + + +class TestChordDatasetTruncation: + def test_truncated_length(self, tmp_path): + _write_pt(tmp_path, "a", 600) + item = ChordDataset(tmp_path, max_length=512)[0] + assert item.shape[0] == 512 + + def test_truncated_prefix_matches_original(self, tmp_path): + _write_pt(tmp_path, "a", 600) + item = ChordDataset(tmp_path, max_length=512)[0] + expected = torch.arange(512, dtype=torch.long) + assert (item == expected).all() + + +# --------------------------------------------------------------------------- +# Real fixture round-trip +# --------------------------------------------------------------------------- + + +class TestChordDatasetRealFixture: + def test_bos_at_position_zero(self, tmp_path): + _write_real_pt(tmp_path, "valid_c_major.chord") + item = ChordDataset(tmp_path, max_length=512)[0] + assert item[0] == TOKEN_TO_ID[""] + + def test_eos_at_correct_position(self, tmp_path): + _, n = _write_real_pt(tmp_path, "valid_c_major.chord") + item = ChordDataset(tmp_path, max_length=512)[0] + assert item[n - 1] == TOKEN_TO_ID[""] + + def test_tokens_after_eos_are_pad(self, tmp_path): + _, n = _write_real_pt(tmp_path, "valid_c_major.chord") + item = ChordDataset(tmp_path, max_length=512)[0] + assert (item[n:] == _PAD_ID).all() + + def test_all_valid_fixture_files_loadable(self, tmp_path): + for name in ( + "valid_c_major.chord", + "valid_fsharp_major.chord", + "valid_b_minor.chord", + "valid_gsharp_minor.chord", + ): + _write_real_pt(tmp_path, name) + ds = ChordDataset(tmp_path, max_length=512) + assert len(ds) == 4 + for i in range(4): + item = ds[i] + assert item.shape[0] == 512 + assert item[0] == TOKEN_TO_ID[""] diff --git a/tests/test_mcgill_converter.py b/tests/test_mcgill_converter.py index 33bdcea..e5e661e 100644 --- a/tests/test_mcgill_converter.py +++ b/tests/test_mcgill_converter.py @@ -1,9 +1,9 @@ """Tests for src/external_converters/mcgill_to_chord.py. Fixture: tests/fixtures/mcgill_test/0001/salami_chords.txt - 4/4 song in C major, two sections: - Section A (verse): C:maj F:maj G:7 C:maj — 4 chords × 4.0 s each - Section B (chorus): F:maj C:maj G:7 C:maj — 4 chords × 4.0 s each + 4/4 song in C major, two sections in the real McGill v2 2-column format: + A, verse : | C:maj | F:maj | G:7 | C:maj | (4 bars) + B, chorus : | F:maj | C:maj | G:7 | C:maj | (4 bars) Expected output: 2 .chord files, each with 4 bars, key=C_major, time=4/4. """ @@ -13,13 +13,11 @@ from pathlib import Path import pytest from src.external_converters.mcgill_to_chord import ( - _estimate_bar_duration, - _extract_sections, + _bar_str_to_positions, _harte_to_chord_symbol, - _infer_mode, + _parse_annotation_line, _parse_metre, _parse_salami_file, - _section_to_bars, convert_song, ) from src.tokenizer import parse_chord_file @@ -34,17 +32,13 @@ TEST_SONG = FIXTURES / "0001" class TestHarteConversion: - """Unit tests for individual Harte → .chord symbol conversion.""" - def test_simple_major(self): assert _harte_to_chord_symbol("C:maj") == "Cmaj" def test_flat_minor_seventh(self): - # Bb normalises to A# assert _harte_to_chord_symbol("Bb:min7") == "A#m7" def test_half_diminished(self): - # hdim7 = half-diminished 7th = our m7b5 assert _harte_to_chord_symbol("E:hdim7") == "Em7b5" def test_dominant_seventh(self): @@ -62,13 +56,24 @@ class TestHarteConversion: def test_augmented(self): assert _harte_to_chord_symbol("C:aug") == "Caug" - def test_slash_chord(self): + def test_slash_chord_absolute_bass(self): assert _harte_to_chord_symbol("C:maj/E") == "Cmaj/E" - def test_slash_chord_flat_bass(self): - # Flat bass note also normalised to sharp + def test_slash_chord_flat_bass_normalised(self): assert _harte_to_chord_symbol("G:maj/Bb") == "Gmaj/A#" + def test_slash_chord_interval_fifth(self): + # '/5' = perfect 5th (7 semitones) above root C → G + assert _harte_to_chord_symbol("C:maj/5") == "Cmaj/G" + + def test_slash_chord_interval_b3(self): + # '/b3' = minor 3rd (3 semitones) above root F → Ab = G# + assert _harte_to_chord_symbol("F:min/b3") == "Fm/G#" + + def test_slash_chord_interval_3(self): + # '/3' = major 3rd (4 semitones) above root C → E + assert _harte_to_chord_symbol("C:7/3") == "C7/E" + def test_no_chord_returns_none(self): assert _harte_to_chord_symbol("N") is None @@ -79,7 +84,6 @@ class TestHarteConversion: assert _harte_to_chord_symbol("") is None def test_extended_dominant_ninth(self): - # G:9 → dominant 7 + extension 9 assert _harte_to_chord_symbol("G:9") == "G79" def test_major_ninth(self): @@ -96,14 +100,15 @@ class TestHarteConversion: def test_output_is_parseable(self): from src.chord_parser import parse_chord_symbol - for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7", "C:maj/E"): + for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7", + "C:maj/E", "C:maj/5", "F:min/b3"): sym = _harte_to_chord_symbol(harte) assert sym is not None - parse_chord_symbol(sym) # must not raise + parse_chord_symbol(sym) # --------------------------------------------------------------------------- -# Helper units +# Salami file parsing (2-column format) # --------------------------------------------------------------------------- @@ -115,60 +120,150 @@ class TestParseSalamiFile: assert header["metre"] == "4/4" assert header["tonic"] == "C" - def test_events_count(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - # 10 data lines total (including Z lines) - assert len(events) == 10 + def test_data_line_count(self): + _, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt") + # 4 lines: silence, A/verse, B/chorus, silence + assert len(lines) == 4 - def test_first_event_is_silence(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - ts, label, chord = events[0] + def test_first_line_is_silence(self): + _, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt") + ts, annotation = lines[0] assert ts == 0.0 - assert label == "Z" + assert annotation == "silence" + + def test_returns_two_tuples(self): + _, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt") + for item in lines: + assert len(item) == 2 -class TestExtractSections: - def test_two_sections(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - sections = _extract_sections(events) - assert len(sections) == 2 - - def test_section_functions(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - sections = _extract_sections(events) - assert sections[0].function == "verse" - assert sections[1].function == "chorus" - - def test_events_per_section(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - sections = _extract_sections(events) - assert len(sections[0].events) == 4 - assert len(sections[1].events) == 4 - - def test_chord_values(self): - _, events = _parse_salami_file(TEST_SONG / "salami_chords.txt") - sections = _extract_sections(events) - hartes = [e.harte for e in sections[0].events] - assert hartes == ["C:maj", "F:maj", "G:7", "C:maj"] +# --------------------------------------------------------------------------- +# Annotation line parsing +# --------------------------------------------------------------------------- -class TestEstimateBarDuration: - def test_uniform_durations(self): - assert _estimate_bar_duration([2.0, 2.0, 2.0, 2.0]) == 2.0 +class TestParseAnnotationLine: + def test_silence_returns_empty(self): + letter, func, bars = _parse_annotation_line("silence") + assert letter is None and func is None and bars == [] - def test_mixed_durations(self): - # Median of [2, 2, 2, 4, 4] = 2 → bar_dur = 2 - assert _estimate_bar_duration([2.0, 2.0, 2.0, 4.0, 4.0]) == 2.0 + def test_end_returns_empty(self): + letter, func, bars = _parse_annotation_line("end") + assert letter is None and func is None and bars == [] - def test_too_few_samples_returns_default(self): - assert _estimate_bar_duration([]) == 2.0 - assert _estimate_bar_duration([3.0]) == 2.0 + def test_continuation_arrow_returns_empty(self): + letter, func, bars = _parse_annotation_line("->") + assert bars == [] - def test_clamp_upper(self): - assert _estimate_bar_duration([10.0, 10.0, 10.0]) == 5.0 + def test_section_letter_extracted(self): + letter, _, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |") + assert letter == "A" - def test_clamp_lower(self): - assert _estimate_bar_duration([0.3, 0.3, 0.3]) == 2.0 # all < 0.5, falls back + def test_function_extracted(self): + _, func, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |") + assert func == "verse" + + def test_chorus_function(self): + _, func, _ = _parse_annotation_line("B, chorus, | F:maj | C:maj |") + assert func == "chorus" + + def test_bar_count(self): + _, _, bars = _parse_annotation_line( + "A, verse, | C:maj | F:maj | G:7 | C:maj |" + ) + assert len(bars) == 4 + + def test_bar_contents(self): + _, _, bars = _parse_annotation_line( + "A, verse, | C:maj | F:maj | G:7 | C:maj |" + ) + assert bars == ["C:maj", "F:maj", "G:7", "C:maj"] + + def test_continuation_line_no_letter(self): + letter, func, bars = _parse_annotation_line("| C:maj | F:maj |") + assert letter is None + assert func is None + assert bars == ["C:maj", "F:maj"] + + def test_repeat_xN(self): + _, _, bars = _parse_annotation_line("| C:maj | x4") + assert bars == ["C:maj"] * 4 + + def test_trailing_annotation_ignored(self): + _, _, bars = _parse_annotation_line( + "A, intro, | Ab:maj | Db:maj | Ab:maj | G:7 |, (synth)" + ) + assert len(bars) == 4 + assert bars[0] == "Ab:maj" + + def test_multi_chord_bar_preserved(self): + _, _, bars = _parse_annotation_line("| G:hdim7 C:7 | F:min |") + assert bars[0] == "G:hdim7 C:7" + assert bars[1] == "F:min" + + +# --------------------------------------------------------------------------- +# Bar string to positions +# --------------------------------------------------------------------------- + + +class TestBarStrToPositions: + def test_single_chord_fills_position_zero(self): + pos = _bar_str_to_positions("C:maj", 4) + assert pos[0] == "Cmaj" + + def test_single_chord_rest_are_holds(self): + pos = _bar_str_to_positions("C:maj", 4) + assert pos[1:] == [".", ".", "."] + + def test_two_chords_distributed(self): + pos = _bar_str_to_positions("C:maj D:min", 4) + assert pos[0] == "Cmaj" + assert pos[2] == "Dm" + assert pos[1] == "." + assert pos[3] == "." + + def test_four_chords_direct_map(self): + # Harte notation: 4 elements → 4 positions, direct 1-to-1 mapping + pos = _bar_str_to_positions("C:maj A:min F:maj G:7", 4) + assert pos == ["Cmaj", "Am", "Fmaj", "G7"] + + def test_explicit_hold_tokens(self): + pos = _bar_str_to_positions("C:maj . F:maj .", 4) + assert pos == ["Cmaj", ".", "Fmaj", "."] + + def test_nc_mapped(self): + pos = _bar_str_to_positions("N", 4) + assert pos[0] == "NC" + + def test_unknown_mapped(self): + pos = _bar_str_to_positions("X", 4) + assert pos[0] == "?" + + def test_unrecognized_returns_none(self): + # Starts with a note letter so passes filter, but quality is unknown + assert _bar_str_to_positions("C:xyz", 4) is None + + def test_performance_annotation_filtered(self): + # "(voice" is not a chord — should be ignored + pos = _bar_str_to_positions("C:maj (voice", 4) + assert pos is not None + assert pos[0] == "Cmaj" + + def test_result_length(self): + for n in (3, 4, 6): + pos = _bar_str_to_positions("C:maj", n) + assert len(pos) == n + + def test_interval_bass_resolved(self): + # C:maj/5 → Cmaj/G + pos = _bar_str_to_positions("C:maj/5", 4) + assert pos[0] == "Cmaj/G" + + +# --------------------------------------------------------------------------- +# Metre parsing +# --------------------------------------------------------------------------- class TestParseMetre: @@ -196,8 +291,6 @@ class TestParseMetre: class TestFullConversion: - """Integration tests: convert_song with fixture produces valid .chord files.""" - def test_returns_two_periods(self, tmp_path): assert convert_song(TEST_SONG, tmp_path) == 2 @@ -208,7 +301,7 @@ class TestFullConversion: def test_output_files_are_parseable(self, tmp_path): convert_song(TEST_SONG, tmp_path) for f in tmp_path.glob("*.chord"): - assert parse_chord_file(f) is not None # must not raise + assert parse_chord_file(f) is not None def test_verse_has_four_bars(self, tmp_path): convert_song(TEST_SONG, tmp_path) @@ -257,7 +350,7 @@ class TestFullConversion: for bar in p.bars: first = bar[0] if first not in (".", "NC", "?"): - parse_chord_symbol(first) # must not raise + parse_chord_symbol(first) def test_missing_salami_returns_zero(self, tmp_path): empty_song = tmp_path / "empty"