feat: add dataset, prepare_data pipeline and fix McGill converter

- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate
- scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout
  split, logs token length stats and style/function distributions
- src/external_converters/mcgill_to_chord.py: rewrite parser for real
  McGill v2 format (2-column annotation, each bar in its own pipe group,
  interval bass notation e.g. /5 and /b3)
- .gitignore: exclude data/processed/train, val, holdout subdirectories
- tests: 37 new tests for ChordDataset and converter (260 total, all pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-19 18:09:46 +03:00
parent ea32bf43b2
commit 84ba7b4743
7 changed files with 876 additions and 314 deletions
+159 -66
View File
@@ -1,9 +1,9 @@
"""Tests for src/external_converters/mcgill_to_chord.py.
Fixture: tests/fixtures/mcgill_test/0001/salami_chords.txt
4/4 song in C major, two sections:
Section A (verse): C:maj F:maj G:7 C:maj — 4 chords × 4.0 s each
Section B (chorus): F:maj C:maj G:7 C:maj — 4 chords × 4.0 s each
4/4 song in C major, two sections in the real McGill v2 2-column format:
A, verse : | C:maj | F:maj | G:7 | C:maj | (4 bars)
B, chorus : | F:maj | C:maj | G:7 | C:maj | (4 bars)
Expected output: 2 .chord files, each with 4 bars, key=C_major, time=4/4.
"""
@@ -13,13 +13,11 @@ from pathlib import Path
import pytest
from src.external_converters.mcgill_to_chord import (
_estimate_bar_duration,
_extract_sections,
_bar_str_to_positions,
_harte_to_chord_symbol,
_infer_mode,
_parse_annotation_line,
_parse_metre,
_parse_salami_file,
_section_to_bars,
convert_song,
)
from src.tokenizer import parse_chord_file
@@ -34,17 +32,13 @@ TEST_SONG = FIXTURES / "0001"
class TestHarteConversion:
"""Unit tests for individual Harte → .chord symbol conversion."""
def test_simple_major(self):
assert _harte_to_chord_symbol("C:maj") == "Cmaj"
def test_flat_minor_seventh(self):
# Bb normalises to A#
assert _harte_to_chord_symbol("Bb:min7") == "A#m7"
def test_half_diminished(self):
# hdim7 = half-diminished 7th = our m7b5
assert _harte_to_chord_symbol("E:hdim7") == "Em7b5"
def test_dominant_seventh(self):
@@ -62,13 +56,24 @@ class TestHarteConversion:
def test_augmented(self):
assert _harte_to_chord_symbol("C:aug") == "Caug"
def test_slash_chord(self):
def test_slash_chord_absolute_bass(self):
assert _harte_to_chord_symbol("C:maj/E") == "Cmaj/E"
def test_slash_chord_flat_bass(self):
# Flat bass note also normalised to sharp
def test_slash_chord_flat_bass_normalised(self):
assert _harte_to_chord_symbol("G:maj/Bb") == "Gmaj/A#"
def test_slash_chord_interval_fifth(self):
# '/5' = perfect 5th (7 semitones) above root C → G
assert _harte_to_chord_symbol("C:maj/5") == "Cmaj/G"
def test_slash_chord_interval_b3(self):
# '/b3' = minor 3rd (3 semitones) above root F → Ab = G#
assert _harte_to_chord_symbol("F:min/b3") == "Fm/G#"
def test_slash_chord_interval_3(self):
# '/3' = major 3rd (4 semitones) above root C → E
assert _harte_to_chord_symbol("C:7/3") == "C7/E"
def test_no_chord_returns_none(self):
assert _harte_to_chord_symbol("N") is None
@@ -79,7 +84,6 @@ class TestHarteConversion:
assert _harte_to_chord_symbol("") is None
def test_extended_dominant_ninth(self):
# G:9 → dominant 7 + extension 9
assert _harte_to_chord_symbol("G:9") == "G79"
def test_major_ninth(self):
@@ -96,14 +100,15 @@ class TestHarteConversion:
def test_output_is_parseable(self):
from src.chord_parser import parse_chord_symbol
for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7", "C:maj/E"):
for harte in ("C:maj", "Bb:min7", "E:hdim7", "G:7", "D:maj7",
"C:maj/E", "C:maj/5", "F:min/b3"):
sym = _harte_to_chord_symbol(harte)
assert sym is not None
parse_chord_symbol(sym) # must not raise
parse_chord_symbol(sym)
# ---------------------------------------------------------------------------
# Helper units
# Salami file parsing (2-column format)
# ---------------------------------------------------------------------------
@@ -115,60 +120,150 @@ class TestParseSalamiFile:
assert header["metre"] == "4/4"
assert header["tonic"] == "C"
def test_events_count(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
# 10 data lines total (including Z lines)
assert len(events) == 10
def test_data_line_count(self):
_, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
# 4 lines: silence, A/verse, B/chorus, silence
assert len(lines) == 4
def test_first_event_is_silence(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
ts, label, chord = events[0]
def test_first_line_is_silence(self):
_, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
ts, annotation = lines[0]
assert ts == 0.0
assert label == "Z"
assert annotation == "silence"
def test_returns_two_tuples(self):
_, lines = _parse_salami_file(TEST_SONG / "salami_chords.txt")
for item in lines:
assert len(item) == 2
class TestExtractSections:
def test_two_sections(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
assert len(sections) == 2
def test_section_functions(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
assert sections[0].function == "verse"
assert sections[1].function == "chorus"
def test_events_per_section(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
assert len(sections[0].events) == 4
assert len(sections[1].events) == 4
def test_chord_values(self):
_, events = _parse_salami_file(TEST_SONG / "salami_chords.txt")
sections = _extract_sections(events)
hartes = [e.harte for e in sections[0].events]
assert hartes == ["C:maj", "F:maj", "G:7", "C:maj"]
# ---------------------------------------------------------------------------
# Annotation line parsing
# ---------------------------------------------------------------------------
class TestEstimateBarDuration:
def test_uniform_durations(self):
assert _estimate_bar_duration([2.0, 2.0, 2.0, 2.0]) == 2.0
class TestParseAnnotationLine:
def test_silence_returns_empty(self):
letter, func, bars = _parse_annotation_line("silence")
assert letter is None and func is None and bars == []
def test_mixed_durations(self):
# Median of [2, 2, 2, 4, 4] = 2 → bar_dur = 2
assert _estimate_bar_duration([2.0, 2.0, 2.0, 4.0, 4.0]) == 2.0
def test_end_returns_empty(self):
letter, func, bars = _parse_annotation_line("end")
assert letter is None and func is None and bars == []
def test_too_few_samples_returns_default(self):
assert _estimate_bar_duration([]) == 2.0
assert _estimate_bar_duration([3.0]) == 2.0
def test_continuation_arrow_returns_empty(self):
letter, func, bars = _parse_annotation_line("->")
assert bars == []
def test_clamp_upper(self):
assert _estimate_bar_duration([10.0, 10.0, 10.0]) == 5.0
def test_section_letter_extracted(self):
letter, _, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |")
assert letter == "A"
def test_clamp_lower(self):
assert _estimate_bar_duration([0.3, 0.3, 0.3]) == 2.0 # all < 0.5, falls back
def test_function_extracted(self):
_, func, _ = _parse_annotation_line("A, verse, | C:maj | F:maj |")
assert func == "verse"
def test_chorus_function(self):
_, func, _ = _parse_annotation_line("B, chorus, | F:maj | C:maj |")
assert func == "chorus"
def test_bar_count(self):
_, _, bars = _parse_annotation_line(
"A, verse, | C:maj | F:maj | G:7 | C:maj |"
)
assert len(bars) == 4
def test_bar_contents(self):
_, _, bars = _parse_annotation_line(
"A, verse, | C:maj | F:maj | G:7 | C:maj |"
)
assert bars == ["C:maj", "F:maj", "G:7", "C:maj"]
def test_continuation_line_no_letter(self):
letter, func, bars = _parse_annotation_line("| C:maj | F:maj |")
assert letter is None
assert func is None
assert bars == ["C:maj", "F:maj"]
def test_repeat_xN(self):
_, _, bars = _parse_annotation_line("| C:maj | x4")
assert bars == ["C:maj"] * 4
def test_trailing_annotation_ignored(self):
_, _, bars = _parse_annotation_line(
"A, intro, | Ab:maj | Db:maj | Ab:maj | G:7 |, (synth)"
)
assert len(bars) == 4
assert bars[0] == "Ab:maj"
def test_multi_chord_bar_preserved(self):
_, _, bars = _parse_annotation_line("| G:hdim7 C:7 | F:min |")
assert bars[0] == "G:hdim7 C:7"
assert bars[1] == "F:min"
# ---------------------------------------------------------------------------
# Bar string to positions
# ---------------------------------------------------------------------------
class TestBarStrToPositions:
def test_single_chord_fills_position_zero(self):
pos = _bar_str_to_positions("C:maj", 4)
assert pos[0] == "Cmaj"
def test_single_chord_rest_are_holds(self):
pos = _bar_str_to_positions("C:maj", 4)
assert pos[1:] == [".", ".", "."]
def test_two_chords_distributed(self):
pos = _bar_str_to_positions("C:maj D:min", 4)
assert pos[0] == "Cmaj"
assert pos[2] == "Dm"
assert pos[1] == "."
assert pos[3] == "."
def test_four_chords_direct_map(self):
# Harte notation: 4 elements → 4 positions, direct 1-to-1 mapping
pos = _bar_str_to_positions("C:maj A:min F:maj G:7", 4)
assert pos == ["Cmaj", "Am", "Fmaj", "G7"]
def test_explicit_hold_tokens(self):
pos = _bar_str_to_positions("C:maj . F:maj .", 4)
assert pos == ["Cmaj", ".", "Fmaj", "."]
def test_nc_mapped(self):
pos = _bar_str_to_positions("N", 4)
assert pos[0] == "NC"
def test_unknown_mapped(self):
pos = _bar_str_to_positions("X", 4)
assert pos[0] == "?"
def test_unrecognized_returns_none(self):
# Starts with a note letter so passes filter, but quality is unknown
assert _bar_str_to_positions("C:xyz", 4) is None
def test_performance_annotation_filtered(self):
# "(voice" is not a chord — should be ignored
pos = _bar_str_to_positions("C:maj (voice", 4)
assert pos is not None
assert pos[0] == "Cmaj"
def test_result_length(self):
for n in (3, 4, 6):
pos = _bar_str_to_positions("C:maj", n)
assert len(pos) == n
def test_interval_bass_resolved(self):
# C:maj/5 → Cmaj/G
pos = _bar_str_to_positions("C:maj/5", 4)
assert pos[0] == "Cmaj/G"
# ---------------------------------------------------------------------------
# Metre parsing
# ---------------------------------------------------------------------------
class TestParseMetre:
@@ -196,8 +291,6 @@ class TestParseMetre:
class TestFullConversion:
"""Integration tests: convert_song with fixture produces valid .chord files."""
def test_returns_two_periods(self, tmp_path):
assert convert_song(TEST_SONG, tmp_path) == 2
@@ -208,7 +301,7 @@ class TestFullConversion:
def test_output_files_are_parseable(self, tmp_path):
convert_song(TEST_SONG, tmp_path)
for f in tmp_path.glob("*.chord"):
assert parse_chord_file(f) is not None # must not raise
assert parse_chord_file(f) is not None
def test_verse_has_four_bars(self, tmp_path):
convert_song(TEST_SONG, tmp_path)
@@ -257,7 +350,7 @@ class TestFullConversion:
for bar in p.bars:
first = bar[0]
if first not in (".", "NC", "?"):
parse_chord_symbol(first) # must not raise
parse_chord_symbol(first)
def test_missing_salami_returns_zero(self, tmp_path):
empty_song = tmp_path / "empty"