feat: add dataset, prepare_data pipeline and fix McGill converter

- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate
- scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout
  split, logs token length stats and style/function distributions
- src/external_converters/mcgill_to_chord.py: rewrite parser for real
  McGill v2 format (2-column annotation, each bar in its own pipe group,
  interval bass notation e.g. /5 and /b3)
- .gitignore: exclude data/processed/train, val, holdout subdirectories
- tests: 37 new tests for ChordDataset and converter (260 total, all pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-19 18:09:46 +03:00
parent ea32bf43b2
commit 84ba7b4743
7 changed files with 876 additions and 314 deletions
+176
View File
@@ -0,0 +1,176 @@
"""Tests for ChordDataset in src/dataset.py."""
from pathlib import Path
import torch
import pytest
from src.dataset import ChordDataset
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, tokenize_period
FIXTURES = Path(__file__).parent / "fixtures"
_PAD_ID = TOKEN_TO_ID["<PAD>"]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _write_pt(tmp_path: Path, stem: str, n_tokens: int) -> Path:
"""Write a dummy .pt file with sequential token IDs."""
tokens = torch.arange(n_tokens, dtype=torch.long)
path = tmp_path / f"{stem}.pt"
torch.save({"tokens": tokens, "meta": {"style": "user", "function": "verse"}}, path)
return path
def _write_real_pt(tmp_path: Path, fixture_name: str) -> tuple[Path, int]:
"""Tokenize a real fixture and write its .pt file. Returns (path, n_tokens)."""
period = parse_chord_file(FIXTURES / fixture_name)
ids = tokenize_period(period)
tokens = torch.tensor(ids, dtype=torch.long)
out = tmp_path / f"{fixture_name}.pt"
torch.save({"tokens": tokens, "meta": {"style": period.style}}, out)
return out, len(ids)
# ---------------------------------------------------------------------------
# Length and file discovery
# ---------------------------------------------------------------------------
class TestChordDatasetLength:
def test_empty_directory(self, tmp_path):
ds = ChordDataset(tmp_path)
assert len(ds) == 0
def test_single_file(self, tmp_path):
_write_pt(tmp_path, "a", 10)
assert len(ChordDataset(tmp_path)) == 1
def test_multiple_files(self, tmp_path):
for name in ("a", "b", "c"):
_write_pt(tmp_path, name, 10)
assert len(ChordDataset(tmp_path)) == 3
def test_non_pt_files_ignored(self, tmp_path):
_write_pt(tmp_path, "a", 10)
(tmp_path / "notes.txt").write_text("ignored")
(tmp_path / "model.pth").write_text("ignored")
assert len(ChordDataset(tmp_path)) == 1
# ---------------------------------------------------------------------------
# Output shape
# ---------------------------------------------------------------------------
class TestChordDatasetShape:
def test_returns_tensor(self, tmp_path):
_write_pt(tmp_path, "a", 50)
item = ChordDataset(tmp_path)[0]
assert isinstance(item, torch.Tensor)
def test_dtype_is_long(self, tmp_path):
_write_pt(tmp_path, "a", 50)
item = ChordDataset(tmp_path)[0]
assert item.dtype == torch.long
def test_shape_equals_max_length_when_shorter(self, tmp_path):
_write_pt(tmp_path, "a", 50)
assert ChordDataset(tmp_path, max_length=100)[0].shape[0] == 100
def test_shape_equals_max_length_when_longer(self, tmp_path):
_write_pt(tmp_path, "a", 600)
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
def test_shape_equals_max_length_exact(self, tmp_path):
_write_pt(tmp_path, "a", 512)
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
def test_custom_max_length(self, tmp_path):
_write_pt(tmp_path, "a", 30)
assert ChordDataset(tmp_path, max_length=64)[0].shape[0] == 64
# ---------------------------------------------------------------------------
# Padding
# ---------------------------------------------------------------------------
class TestChordDatasetPadding:
def test_trailing_tokens_are_pad_id(self, tmp_path):
n = 50
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=100)[0]
assert (item[n:] == _PAD_ID).all()
def test_prefix_matches_original_tokens(self, tmp_path):
n = 50
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=100)[0]
expected = torch.arange(n, dtype=torch.long)
assert (item[:n] == expected).all()
def test_no_padding_when_exact_length(self, tmp_path):
n = 100
_write_pt(tmp_path, "a", n)
item = ChordDataset(tmp_path, max_length=n)[0]
expected = torch.arange(n, dtype=torch.long)
assert (item == expected).all()
# ---------------------------------------------------------------------------
# Truncation
# ---------------------------------------------------------------------------
class TestChordDatasetTruncation:
def test_truncated_length(self, tmp_path):
_write_pt(tmp_path, "a", 600)
item = ChordDataset(tmp_path, max_length=512)[0]
assert item.shape[0] == 512
def test_truncated_prefix_matches_original(self, tmp_path):
_write_pt(tmp_path, "a", 600)
item = ChordDataset(tmp_path, max_length=512)[0]
expected = torch.arange(512, dtype=torch.long)
assert (item == expected).all()
# ---------------------------------------------------------------------------
# Real fixture round-trip
# ---------------------------------------------------------------------------
class TestChordDatasetRealFixture:
def test_bos_at_position_zero(self, tmp_path):
_write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert item[0] == TOKEN_TO_ID["<BOS>"]
def test_eos_at_correct_position(self, tmp_path):
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert item[n - 1] == TOKEN_TO_ID["<EOS>"]
def test_tokens_after_eos_are_pad(self, tmp_path):
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
item = ChordDataset(tmp_path, max_length=512)[0]
assert (item[n:] == _PAD_ID).all()
def test_all_valid_fixture_files_loadable(self, tmp_path):
for name in (
"valid_c_major.chord",
"valid_fsharp_major.chord",
"valid_b_minor.chord",
"valid_gsharp_minor.chord",
):
_write_real_pt(tmp_path, name)
ds = ChordDataset(tmp_path, max_length=512)
assert len(ds) == 4
for i in range(4):
item = ds[i]
assert item.shape[0] == 512
assert item[0] == TOKEN_TO_ID["<BOS>"]