"""Round-trip tests for tokenize_period / detokenize_to_period in src/tokenizer.py. Fixture files used (all from tests/fixtures/): valid_c_major.chord — already canonical; 8 bars, chorus, has F/A slash chord valid_fsharp_major.chord — F# major → C major (shift=6); 4 bars valid_b_minor.chord — B minor → A minor (shift=10); 4 bars valid_gsharp_minor.chord — G# minor → A minor (shift=1); 4 bars """ from pathlib import Path import pytest from src.chord_parser import ChordTokens, parse_chord_symbol from src.tokenizer import ( ID_TO_TOKEN, TOKEN_TO_ID, VOCAB, ChordPeriod, detokenize_to_period, parse_chord_file, tokenize_period, transpose_to_canonical, ) FIXTURES = Path(__file__).parent / "fixtures" VALID_FIXTURES = [ "valid_c_major.chord", "valid_fsharp_major.chord", "valid_b_minor.chord", "valid_gsharp_minor.chord", ] # --------------------------------------------------------------------------- # Vocabulary # --------------------------------------------------------------------------- class TestVocabulary: def test_vocab_has_85_tokens(self): assert len(VOCAB) == 85 def test_no_duplicate_tokens(self): assert len(set(VOCAB)) == 85 def test_token_to_id_covers_all_vocab(self): assert len(TOKEN_TO_ID) == 85 def test_id_to_token_covers_all_vocab(self): assert len(ID_TO_TOKEN) == 85 def test_ids_are_contiguous_from_zero(self): for i, tok in enumerate(VOCAB): assert TOKEN_TO_ID[tok] == i def test_id_to_token_is_inverse_of_token_to_id(self): for i, tok in enumerate(VOCAB): assert ID_TO_TOKEN[i] == tok def test_special_tokens_at_front(self): assert VOCAB[:4] == ["", "", "", ""] def test_structural_tokens_at_end(self): assert VOCAB[-3:] == ["HOLD", "NC", "BAR"] def test_all_roots_present(self): for note in ("C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"): assert f"ROOT_{note}" in TOKEN_TO_ID def test_all_qualities_present(self): for qual_tok in ( "QUAL_maj", "QUAL_m", "QUAL_dim", "QUAL_aug", "QUAL_sus2", "QUAL_sus4", "QUAL_maj7", "QUAL_m7", "QUAL_7", "QUAL_m7b5", "QUAL_dim7", "QUAL_mM7", "QUAL_7sus4", "QUAL_aug7", "QUAL_6", "QUAL_m6", "QUAL_add9", "QUAL_m_add9", ): assert qual_tok in TOKEN_TO_ID def test_all_extensions_present(self): for ext in ("none", "9", "b9", "#9", "11", "#11", "13", "b13"): assert f"EXT_{ext}" in TOKEN_TO_ID def test_all_bass_notes_present(self): assert "BASS_root" in TOKEN_TO_ID for note in ("C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"): assert f"BASS_{note}" in TOKEN_TO_ID # --------------------------------------------------------------------------- # Tokenize structure # --------------------------------------------------------------------------- class TestTokenizeStructure: def test_starts_with_bos(self): p = parse_chord_file(FIXTURES / "valid_c_major.chord") assert tokenize_period(p)[0] == TOKEN_TO_ID[""] def test_ends_with_eos(self): p = parse_chord_file(FIXTURES / "valid_c_major.chord") assert tokenize_period(p)[-1] == TOKEN_TO_ID[""] def test_metadata_order_after_bos(self): p = parse_chord_file(FIXTURES / "valid_c_major.chord") toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] assert toks[1] == "MODE_major" assert toks[2] == "TIME_4/4" assert toks[3] == "SUB_4" assert toks[4] == "STYLE_other" # 'unspecified' is not in VOCAB → falls back to STYLE_other assert toks[5] == "FUNC_chorus" def test_bar_token_count_matches_bar_count(self): p = parse_chord_file(FIXTURES / "valid_c_major.chord") ids = tokenize_period(p) assert sum(1 for i in ids if i == TOKEN_TO_ID["BAR"]) == len(p.bars) def test_minor_period_emits_mode_minor(self): p = parse_chord_file(FIXTURES / "valid_b_minor.chord") toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] assert toks[1] == "MODE_minor" def test_missing_function_emits_func_unspecified(self): p = parse_chord_file(FIXTURES / "valid_b_minor.chord") toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] assert toks[5] == "FUNC_unspecified" def test_all_ids_in_vocab_range(self): for fixture_name in VALID_FIXTURES: p = parse_chord_file(FIXTURES / fixture_name) assert all(0 <= i < 85 for i in tokenize_period(p)) def test_non_canonical_key_transposed_before_encoding(self): # F# major: first chord F#maj7 → Cmaj7 after shift=6; ROOT_C is at index 6. p = parse_chord_file(FIXTURES / "valid_fsharp_major.chord") ids = tokenize_period(p) assert ID_TO_TOKEN[ids[6]] == "ROOT_C" # --------------------------------------------------------------------------- # Round-trip # --------------------------------------------------------------------------- class TestRoundTrip: @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) def test_chord_symbols_survive_round_trip(self, fixture_name): p = parse_chord_file(FIXTURES / fixture_name) canonical = transpose_to_canonical(p) recovered = detokenize_to_period(tokenize_period(p)) assert len(recovered.bars) == len(canonical.bars) for bar_c, bar_r in zip(canonical.bars, recovered.bars): assert len(bar_c) == len(bar_r) for sym_c, sym_r in zip(bar_c, bar_r): if sym_c in (".", "NC", "?"): assert sym_r == sym_c else: assert parse_chord_symbol(sym_r) == parse_chord_symbol(sym_c) @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) def test_metadata_survives_round_trip(self, fixture_name): p = parse_chord_file(FIXTURES / fixture_name) canonical = transpose_to_canonical(p) recovered = detokenize_to_period(tokenize_period(p)) assert recovered.key == canonical.key assert recovered.time == canonical.time assert recovered.subdivision == canonical.subdivision assert recovered.style == canonical.style assert recovered.function == canonical.function @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) def test_bar_count_survives_round_trip(self, fixture_name): p = parse_chord_file(FIXTURES / fixture_name) recovered = detokenize_to_period(tokenize_period(p)) assert len(recovered.bars) == len(p.bars) def test_c_major_chord_identity(self): # C major is already canonical — chord tokens must reproduce exactly. p = parse_chord_file(FIXTURES / "valid_c_major.chord") recovered = detokenize_to_period(tokenize_period(p)) assert recovered.key == "C_major" for bar_p, bar_r in zip(p.bars, recovered.bars): for sym_p, sym_r in zip(bar_p, bar_r): if sym_p not in (".", "NC", "?"): assert parse_chord_symbol(sym_r) == parse_chord_symbol(sym_p) def test_slash_chord_root_and_bass_survive(self): # F/A in C major: root=F maj, bass=A must survive the round-trip. p = parse_chord_file(FIXTURES / "valid_c_major.chord") recovered = detokenize_to_period(tokenize_period(p)) t = parse_chord_symbol(recovered.bars[2][0]) assert t == ChordTokens("F", "maj", "none", "A") def test_fsharp_major_tonic_becomes_c_after_round_trip(self): p = parse_chord_file(FIXTURES / "valid_fsharp_major.chord") recovered = detokenize_to_period(tokenize_period(p)) t = parse_chord_symbol(recovered.bars[0][0]) assert t.root == "C" assert t.quality == "maj7" def test_b_minor_tonic_becomes_am_after_round_trip(self): p = parse_chord_file(FIXTURES / "valid_b_minor.chord") recovered = detokenize_to_period(tokenize_period(p)) t = parse_chord_symbol(recovered.bars[0][0]) assert t.root == "A" assert t.quality == "m"