84ba7b4743
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate - scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout split, logs token length stats and style/function distributions - src/external_converters/mcgill_to_chord.py: rewrite parser for real McGill v2 format (2-column annotation, each bar in its own pipe group, interval bass notation e.g. /5 and /b3) - .gitignore: exclude data/processed/train, val, holdout subdirectories - tests: 37 new tests for ChordDataset and converter (260 total, all pass) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
177 lines
6.1 KiB
Python
177 lines
6.1 KiB
Python
"""Tests for ChordDataset in src/dataset.py."""
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
from src.dataset import ChordDataset
|
|
from src.tokenizer import TOKEN_TO_ID, parse_chord_file, tokenize_period
|
|
|
|
FIXTURES = Path(__file__).parent / "fixtures"
|
|
_PAD_ID = TOKEN_TO_ID["<PAD>"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _write_pt(tmp_path: Path, stem: str, n_tokens: int) -> Path:
|
|
"""Write a dummy .pt file with sequential token IDs."""
|
|
tokens = torch.arange(n_tokens, dtype=torch.long)
|
|
path = tmp_path / f"{stem}.pt"
|
|
torch.save({"tokens": tokens, "meta": {"style": "user", "function": "verse"}}, path)
|
|
return path
|
|
|
|
|
|
def _write_real_pt(tmp_path: Path, fixture_name: str) -> tuple[Path, int]:
|
|
"""Tokenize a real fixture and write its .pt file. Returns (path, n_tokens)."""
|
|
period = parse_chord_file(FIXTURES / fixture_name)
|
|
ids = tokenize_period(period)
|
|
tokens = torch.tensor(ids, dtype=torch.long)
|
|
out = tmp_path / f"{fixture_name}.pt"
|
|
torch.save({"tokens": tokens, "meta": {"style": period.style}}, out)
|
|
return out, len(ids)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Length and file discovery
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChordDatasetLength:
|
|
def test_empty_directory(self, tmp_path):
|
|
ds = ChordDataset(tmp_path)
|
|
assert len(ds) == 0
|
|
|
|
def test_single_file(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 10)
|
|
assert len(ChordDataset(tmp_path)) == 1
|
|
|
|
def test_multiple_files(self, tmp_path):
|
|
for name in ("a", "b", "c"):
|
|
_write_pt(tmp_path, name, 10)
|
|
assert len(ChordDataset(tmp_path)) == 3
|
|
|
|
def test_non_pt_files_ignored(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 10)
|
|
(tmp_path / "notes.txt").write_text("ignored")
|
|
(tmp_path / "model.pth").write_text("ignored")
|
|
assert len(ChordDataset(tmp_path)) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Output shape
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChordDatasetShape:
|
|
def test_returns_tensor(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 50)
|
|
item = ChordDataset(tmp_path)[0]
|
|
assert isinstance(item, torch.Tensor)
|
|
|
|
def test_dtype_is_long(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 50)
|
|
item = ChordDataset(tmp_path)[0]
|
|
assert item.dtype == torch.long
|
|
|
|
def test_shape_equals_max_length_when_shorter(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 50)
|
|
assert ChordDataset(tmp_path, max_length=100)[0].shape[0] == 100
|
|
|
|
def test_shape_equals_max_length_when_longer(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 600)
|
|
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
|
|
|
|
def test_shape_equals_max_length_exact(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 512)
|
|
assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512
|
|
|
|
def test_custom_max_length(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 30)
|
|
assert ChordDataset(tmp_path, max_length=64)[0].shape[0] == 64
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Padding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChordDatasetPadding:
|
|
def test_trailing_tokens_are_pad_id(self, tmp_path):
|
|
n = 50
|
|
_write_pt(tmp_path, "a", n)
|
|
item = ChordDataset(tmp_path, max_length=100)[0]
|
|
assert (item[n:] == _PAD_ID).all()
|
|
|
|
def test_prefix_matches_original_tokens(self, tmp_path):
|
|
n = 50
|
|
_write_pt(tmp_path, "a", n)
|
|
item = ChordDataset(tmp_path, max_length=100)[0]
|
|
expected = torch.arange(n, dtype=torch.long)
|
|
assert (item[:n] == expected).all()
|
|
|
|
def test_no_padding_when_exact_length(self, tmp_path):
|
|
n = 100
|
|
_write_pt(tmp_path, "a", n)
|
|
item = ChordDataset(tmp_path, max_length=n)[0]
|
|
expected = torch.arange(n, dtype=torch.long)
|
|
assert (item == expected).all()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Truncation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChordDatasetTruncation:
|
|
def test_truncated_length(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 600)
|
|
item = ChordDataset(tmp_path, max_length=512)[0]
|
|
assert item.shape[0] == 512
|
|
|
|
def test_truncated_prefix_matches_original(self, tmp_path):
|
|
_write_pt(tmp_path, "a", 600)
|
|
item = ChordDataset(tmp_path, max_length=512)[0]
|
|
expected = torch.arange(512, dtype=torch.long)
|
|
assert (item == expected).all()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Real fixture round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChordDatasetRealFixture:
|
|
def test_bos_at_position_zero(self, tmp_path):
|
|
_write_real_pt(tmp_path, "valid_c_major.chord")
|
|
item = ChordDataset(tmp_path, max_length=512)[0]
|
|
assert item[0] == TOKEN_TO_ID["<BOS>"]
|
|
|
|
def test_eos_at_correct_position(self, tmp_path):
|
|
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
|
|
item = ChordDataset(tmp_path, max_length=512)[0]
|
|
assert item[n - 1] == TOKEN_TO_ID["<EOS>"]
|
|
|
|
def test_tokens_after_eos_are_pad(self, tmp_path):
|
|
_, n = _write_real_pt(tmp_path, "valid_c_major.chord")
|
|
item = ChordDataset(tmp_path, max_length=512)[0]
|
|
assert (item[n:] == _PAD_ID).all()
|
|
|
|
def test_all_valid_fixture_files_loadable(self, tmp_path):
|
|
for name in (
|
|
"valid_c_major.chord",
|
|
"valid_fsharp_major.chord",
|
|
"valid_b_minor.chord",
|
|
"valid_gsharp_minor.chord",
|
|
):
|
|
_write_real_pt(tmp_path, name)
|
|
ds = ChordDataset(tmp_path, max_length=512)
|
|
assert len(ds) == 4
|
|
for i in range(4):
|
|
item = ds[i]
|
|
assert item.shape[0] == 512
|
|
assert item[0] == TOKEN_TO_ID["<BOS>"]
|