"""Tests for ChordDataset in src/dataset.py.""" from pathlib import Path import torch import pytest from src.dataset import ChordDataset from src.tokenizer import TOKEN_TO_ID, parse_chord_file, tokenize_period FIXTURES = Path(__file__).parent / "fixtures" _PAD_ID = TOKEN_TO_ID[""] # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _write_pt(tmp_path: Path, stem: str, n_tokens: int) -> Path: """Write a dummy .pt file with sequential token IDs.""" tokens = torch.arange(n_tokens, dtype=torch.long) path = tmp_path / f"{stem}.pt" torch.save({"tokens": tokens, "meta": {"style": "user", "function": "verse"}}, path) return path def _write_real_pt(tmp_path: Path, fixture_name: str) -> tuple[Path, int]: """Tokenize a real fixture and write its .pt file. Returns (path, n_tokens).""" period = parse_chord_file(FIXTURES / fixture_name) ids = tokenize_period(period) tokens = torch.tensor(ids, dtype=torch.long) out = tmp_path / f"{fixture_name}.pt" torch.save({"tokens": tokens, "meta": {"style": period.style}}, out) return out, len(ids) # --------------------------------------------------------------------------- # Length and file discovery # --------------------------------------------------------------------------- class TestChordDatasetLength: def test_empty_directory(self, tmp_path): ds = ChordDataset(tmp_path) assert len(ds) == 0 def test_single_file(self, tmp_path): _write_pt(tmp_path, "a", 10) assert len(ChordDataset(tmp_path)) == 1 def test_multiple_files(self, tmp_path): for name in ("a", "b", "c"): _write_pt(tmp_path, name, 10) assert len(ChordDataset(tmp_path)) == 3 def test_non_pt_files_ignored(self, tmp_path): _write_pt(tmp_path, "a", 10) (tmp_path / "notes.txt").write_text("ignored") (tmp_path / "model.pth").write_text("ignored") assert len(ChordDataset(tmp_path)) == 1 # --------------------------------------------------------------------------- # Output shape # --------------------------------------------------------------------------- class TestChordDatasetShape: def test_returns_tensor(self, tmp_path): _write_pt(tmp_path, "a", 50) item = ChordDataset(tmp_path)[0] assert isinstance(item, torch.Tensor) def test_dtype_is_long(self, tmp_path): _write_pt(tmp_path, "a", 50) item = ChordDataset(tmp_path)[0] assert item.dtype == torch.long def test_shape_equals_max_length_when_shorter(self, tmp_path): _write_pt(tmp_path, "a", 50) assert ChordDataset(tmp_path, max_length=100)[0].shape[0] == 100 def test_shape_equals_max_length_when_longer(self, tmp_path): _write_pt(tmp_path, "a", 600) assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512 def test_shape_equals_max_length_exact(self, tmp_path): _write_pt(tmp_path, "a", 512) assert ChordDataset(tmp_path, max_length=512)[0].shape[0] == 512 def test_custom_max_length(self, tmp_path): _write_pt(tmp_path, "a", 30) assert ChordDataset(tmp_path, max_length=64)[0].shape[0] == 64 # --------------------------------------------------------------------------- # Padding # --------------------------------------------------------------------------- class TestChordDatasetPadding: def test_trailing_tokens_are_pad_id(self, tmp_path): n = 50 _write_pt(tmp_path, "a", n) item = ChordDataset(tmp_path, max_length=100)[0] assert (item[n:] == _PAD_ID).all() def test_prefix_matches_original_tokens(self, tmp_path): n = 50 _write_pt(tmp_path, "a", n) item = ChordDataset(tmp_path, max_length=100)[0] expected = torch.arange(n, dtype=torch.long) assert (item[:n] == expected).all() def test_no_padding_when_exact_length(self, tmp_path): n = 100 _write_pt(tmp_path, "a", n) item = ChordDataset(tmp_path, max_length=n)[0] expected = torch.arange(n, dtype=torch.long) assert (item == expected).all() # --------------------------------------------------------------------------- # Truncation # --------------------------------------------------------------------------- class TestChordDatasetTruncation: def test_truncated_length(self, tmp_path): _write_pt(tmp_path, "a", 600) item = ChordDataset(tmp_path, max_length=512)[0] assert item.shape[0] == 512 def test_truncated_prefix_matches_original(self, tmp_path): _write_pt(tmp_path, "a", 600) item = ChordDataset(tmp_path, max_length=512)[0] expected = torch.arange(512, dtype=torch.long) assert (item == expected).all() # --------------------------------------------------------------------------- # Real fixture round-trip # --------------------------------------------------------------------------- class TestChordDatasetRealFixture: def test_bos_at_position_zero(self, tmp_path): _write_real_pt(tmp_path, "valid_c_major.chord") item = ChordDataset(tmp_path, max_length=512)[0] assert item[0] == TOKEN_TO_ID[""] def test_eos_at_correct_position(self, tmp_path): _, n = _write_real_pt(tmp_path, "valid_c_major.chord") item = ChordDataset(tmp_path, max_length=512)[0] assert item[n - 1] == TOKEN_TO_ID[""] def test_tokens_after_eos_are_pad(self, tmp_path): _, n = _write_real_pt(tmp_path, "valid_c_major.chord") item = ChordDataset(tmp_path, max_length=512)[0] assert (item[n:] == _PAD_ID).all() def test_all_valid_fixture_files_loadable(self, tmp_path): for name in ( "valid_c_major.chord", "valid_fsharp_major.chord", "valid_b_minor.chord", "valid_gsharp_minor.chord", ): _write_real_pt(tmp_path, name) ds = ChordDataset(tmp_path, max_length=512) assert len(ds) == 4 for i in range(4): item = ds[i] assert item.shape[0] == 512 assert item[0] == TOKEN_TO_ID[""]