From 868af4ac425c4222b577e6590cbe00d6e3f89a09 Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Tue, 19 May 2026 15:47:28 +0300 Subject: [PATCH] feat: add vocabulary constants and tokenize/detokenize to tokenizer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds VOCAB (81 tokens), TOKEN_TO_ID, and ID_TO_TOKEN per spec §5.2. tokenize_period() transposes to C/Am then emits BOS + metadata tokens + per-bar chord/HOLD/NC tokens + BAR + EOS. detokenize_to_period() is the exact inverse, returning a ChordPeriod in canonical key. The m(add9) quality maps to QUAL_m_add9 in the vocab (parentheses not valid in token names) via _qual_token/_token_qual helpers. 36 new tests cover vocabulary integrity, token sequence structure, and full round-trip fidelity for all four valid fixture files. Co-Authored-By: Claude Sonnet 4.6 --- src/tokenizer.py | 200 +++++++++++++++++++++++++++++++++++++- tests/test_tokenizer.py | 210 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 407 insertions(+), 3 deletions(-) create mode 100644 tests/test_tokenizer.py diff --git a/src/tokenizer.py b/src/tokenizer.py index 50ddd87..1d7da3e 100644 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -1,10 +1,17 @@ -"""Parser and canonical transposer for .chord files. +"""Parser, transposer, and tokenizer for .chord files. -Public API (token-ID conversion will be added in the next step): +Public API: parse_chord_file(path: Path) -> ChordPeriod transpose_to_canonical(period: ChordPeriod) -> ChordPeriod + tokenize_period(period: ChordPeriod) -> list[int] + detokenize_to_period(token_ids: list[int]) -> ChordPeriod -See docs/chord_format_spec.md for the format specification. +Vocabulary constants: + VOCAB -- 81-token ordered list; index == token ID + TOKEN_TO_ID -- {token_string: id} + ID_TO_TOKEN -- alias for VOCAB + +See docs/chord_format_spec.md §5.2 for the vocabulary specification. """ from __future__ import annotations @@ -69,6 +76,44 @@ _VALID_FUNCTIONS: frozenset[str] = frozenset({ }) +# --------------------------------------------------------------------------- +# Token vocabulary (§5.2) +# --------------------------------------------------------------------------- + +VOCAB: list[str] = [ + # Special (4) + "", "", "", "", + # Mode (2) + "MODE_major", "MODE_minor", + # Time signature (5) + "TIME_4/4", "TIME_3/4", "TIME_6/8", "TIME_2/4", "TIME_12/8", + # Subdivision (2) + "SUB_4", "SUB_8", + # Style (5) + "STYLE_user", "STYLE_jpop", "STYLE_classical", "STYLE_jazz", "STYLE_other", + # Function (9) + "FUNC_verse", "FUNC_prechorus", "FUNC_chorus", "FUNC_bridge", + "FUNC_intro", "FUNC_outro", "FUNC_interlude", "FUNC_other", "FUNC_unspecified", + # Chord root — 12 pitch classes, sharps only (12) + "ROOT_C", "ROOT_C#", "ROOT_D", "ROOT_D#", "ROOT_E", "ROOT_F", + "ROOT_F#", "ROOT_G", "ROOT_G#", "ROOT_A", "ROOT_A#", "ROOT_B", + # Chord quality (18) + "QUAL_maj", "QUAL_m", "QUAL_dim", "QUAL_aug", "QUAL_sus2", "QUAL_sus4", + "QUAL_maj7", "QUAL_m7", "QUAL_7", "QUAL_m7b5", "QUAL_dim7", "QUAL_mM7", + "QUAL_7sus4", "QUAL_aug7", "QUAL_6", "QUAL_m6", "QUAL_add9", "QUAL_m_add9", + # Extension (8) + "EXT_none", "EXT_9", "EXT_b9", "EXT_#9", "EXT_11", "EXT_#11", "EXT_13", "EXT_b13", + # Bass note — 'root' sentinel + 12 pitch classes (13) + "BASS_root", "BASS_C", "BASS_C#", "BASS_D", "BASS_D#", "BASS_E", "BASS_F", + "BASS_F#", "BASS_G", "BASS_G#", "BASS_A", "BASS_A#", "BASS_B", + # Structural (3) + "HOLD", "NC", "BAR", +] + +TOKEN_TO_ID: dict[str, int] = {tok: i for i, tok in enumerate(VOCAB)} +ID_TO_TOKEN: list[str] = VOCAB + + # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- @@ -124,6 +169,17 @@ def _transpose_symbol(symbol: str, shift: int, fname: str, bar_no: int) -> str: return _tokens_to_symbol(ChordTokens(new_root, t.quality, t.extension, new_bass)) +def _qual_token(quality: str) -> str: + """Map canonical quality string → QUAL_x token name.""" + return "QUAL_m_add9" if quality == "m(add9)" else f"QUAL_{quality}" + + +def _token_qual(token: str) -> str: + """Map QUAL_x token name → canonical quality string.""" + suffix = token[5:] # strip "QUAL_" + return "m(add9)" if suffix == "m_add9" else suffix + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -277,3 +333,141 @@ def transpose_to_canonical(period: ChordPeriod) -> ChordPeriod: canonical_key = "C_major" if mode == "major" else "A_minor" return replace(period, key=canonical_key, bars=new_bars) + + +def tokenize_period(period: ChordPeriod) -> list[int]: + """Transpose a period to canonical key and encode it as a token ID sequence. + + Args: + period: A ChordPeriod as returned by parse_chord_file. + + Returns: + List of integer token IDs: , metadata tokens, per-bar chord + tokens interleaved with HOLD/NC, each bar closed by BAR, then . + + Raises: + ChordFormatError: If a chord symbol cannot be parsed during transposition. + """ + p = transpose_to_canonical(period) + mode = "major" if p.key == "C_major" else "minor" + + ids: list[int] = [TOKEN_TO_ID[""]] + ids.append(TOKEN_TO_ID[f"MODE_{mode}"]) + ids.append(TOKEN_TO_ID[f"TIME_{p.time}"]) + ids.append(TOKEN_TO_ID[f"SUB_{p.subdivision}"]) + ids.append(TOKEN_TO_ID[f"STYLE_{p.style}"]) + ids.append(TOKEN_TO_ID[f"FUNC_{p.function}"]) + + for bar in p.bars: + for pos in bar: + if pos == ".": + ids.append(TOKEN_TO_ID["HOLD"]) + elif pos == "NC": + ids.append(TOKEN_TO_ID["NC"]) + elif pos == "?": + ids.append(TOKEN_TO_ID[""]) + else: + t = parse_chord_symbol(pos) + ids.append(TOKEN_TO_ID[f"ROOT_{t.root}"]) + ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) + ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) + ids.append(TOKEN_TO_ID[f"BASS_{t.bass}"]) + ids.append(TOKEN_TO_ID["BAR"]) + + ids.append(TOKEN_TO_ID[""]) + return ids + + +def detokenize_to_period(token_ids: list[int]) -> ChordPeriod: + """Convert a token ID sequence back to a ChordPeriod in canonical key (C/Am). + + Args: + token_ids: Sequence produced by tokenize_period. + + Returns: + ChordPeriod with key='C_major' or 'A_minor', title='detokenized'. + + Raises: + ChordFormatError: If the sequence is structurally malformed. + """ + tokens = [ID_TO_TOKEN[i] for i in token_ids] + n = len(tokens) + idx = 0 + + def _consume(prefix: str) -> str: + nonlocal idx + if idx >= n: + raise ChordFormatError( + f"unexpected end of token sequence; expected '{prefix}...'" + ) + tok = tokens[idx] + if not tok.startswith(prefix): + raise ChordFormatError( + f"expected token starting with '{prefix}', got {tok!r} at position {idx}" + ) + idx += 1 + return tok[len(prefix):] + + if not tokens or tokens[0] != "": + got = repr(tokens[0]) if tokens else "empty sequence" + raise ChordFormatError(f"token sequence must start with , got {got}") + idx += 1 + + mode = _consume("MODE_") + time = _consume("TIME_") + subdivision = int(_consume("SUB_")) + style = _consume("STYLE_") + function = _consume("FUNC_") + + key = "C_major" if mode == "major" else "A_minor" + + bars: list[list[str]] = [] + current_bar: list[str] = [] + + while idx < n: + tok = tokens[idx] + idx += 1 + + if tok == "": + break + elif tok == "BAR": + bars.append(current_bar) + current_bar = [] + elif tok == "HOLD": + current_bar.append(".") + elif tok == "NC": + current_bar.append("NC") + elif tok == "": + current_bar.append("?") + elif tok.startswith("ROOT_"): + if idx + 3 > n: + raise ChordFormatError( + "incomplete chord token group near end of sequence" + ) + qual_tok = tokens[idx]; idx += 1 + ext_tok = tokens[idx]; idx += 1 + bass_tok = tokens[idx]; idx += 1 + root = tok[5:] # strip "ROOT_" + quality = _token_qual(qual_tok) + extension = ext_tok[4:] # strip "EXT_" + bass = bass_tok[5:] # strip "BASS_" + current_bar.append( + _tokens_to_symbol(ChordTokens(root, quality, extension, bass)) + ) + else: + raise ChordFormatError(f"unexpected token in bar body: {tok!r}") + + if current_bar: + raise ChordFormatError( + "token sequence ended without closing BAR before " + ) + + return ChordPeriod( + title="detokenized", + key=key, + time=time, + subdivision=subdivision, + style=style, + function=function, + bars=bars, + ) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..3f2656f --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,210 @@ +"""Round-trip tests for tokenize_period / detokenize_to_period in src/tokenizer.py. + +Fixture files used (all from tests/fixtures/): + valid_c_major.chord — already canonical; 8 bars, chorus, has F/A slash chord + valid_fsharp_major.chord — F# major → C major (shift=6); 4 bars + valid_b_minor.chord — B minor → A minor (shift=10); 4 bars + valid_gsharp_minor.chord — G# minor → A minor (shift=1); 4 bars +""" + +from pathlib import Path + +import pytest + +from src.chord_parser import ChordTokens, parse_chord_symbol +from src.tokenizer import ( + ID_TO_TOKEN, + TOKEN_TO_ID, + VOCAB, + ChordPeriod, + detokenize_to_period, + parse_chord_file, + tokenize_period, + transpose_to_canonical, +) + +FIXTURES = Path(__file__).parent / "fixtures" + +VALID_FIXTURES = [ + "valid_c_major.chord", + "valid_fsharp_major.chord", + "valid_b_minor.chord", + "valid_gsharp_minor.chord", +] + + +# --------------------------------------------------------------------------- +# Vocabulary +# --------------------------------------------------------------------------- + + +class TestVocabulary: + def test_vocab_has_81_tokens(self): + assert len(VOCAB) == 81 + + def test_no_duplicate_tokens(self): + assert len(set(VOCAB)) == 81 + + def test_token_to_id_covers_all_vocab(self): + assert len(TOKEN_TO_ID) == 81 + + def test_id_to_token_covers_all_vocab(self): + assert len(ID_TO_TOKEN) == 81 + + def test_ids_are_contiguous_from_zero(self): + for i, tok in enumerate(VOCAB): + assert TOKEN_TO_ID[tok] == i + + def test_id_to_token_is_inverse_of_token_to_id(self): + for i, tok in enumerate(VOCAB): + assert ID_TO_TOKEN[i] == tok + + def test_special_tokens_at_front(self): + assert VOCAB[:4] == ["", "", "", ""] + + def test_structural_tokens_at_end(self): + assert VOCAB[-3:] == ["HOLD", "NC", "BAR"] + + def test_all_roots_present(self): + for note in ("C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"): + assert f"ROOT_{note}" in TOKEN_TO_ID + + def test_all_qualities_present(self): + for qual_tok in ( + "QUAL_maj", "QUAL_m", "QUAL_dim", "QUAL_aug", + "QUAL_sus2", "QUAL_sus4", "QUAL_maj7", "QUAL_m7", + "QUAL_7", "QUAL_m7b5", "QUAL_dim7", "QUAL_mM7", + "QUAL_7sus4", "QUAL_aug7", "QUAL_6", "QUAL_m6", + "QUAL_add9", "QUAL_m_add9", + ): + assert qual_tok in TOKEN_TO_ID + + def test_all_extensions_present(self): + for ext in ("none", "9", "b9", "#9", "11", "#11", "13", "b13"): + assert f"EXT_{ext}" in TOKEN_TO_ID + + def test_all_bass_notes_present(self): + assert "BASS_root" in TOKEN_TO_ID + for note in ("C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"): + assert f"BASS_{note}" in TOKEN_TO_ID + + +# --------------------------------------------------------------------------- +# Tokenize structure +# --------------------------------------------------------------------------- + + +class TestTokenizeStructure: + def test_starts_with_bos(self): + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + assert tokenize_period(p)[0] == TOKEN_TO_ID[""] + + def test_ends_with_eos(self): + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + assert tokenize_period(p)[-1] == TOKEN_TO_ID[""] + + def test_metadata_order_after_bos(self): + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] + assert toks[1] == "MODE_major" + assert toks[2] == "TIME_4/4" + assert toks[3] == "SUB_4" + assert toks[4] == "STYLE_user" + assert toks[5] == "FUNC_chorus" + + def test_bar_token_count_matches_bar_count(self): + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + ids = tokenize_period(p) + assert sum(1 for i in ids if i == TOKEN_TO_ID["BAR"]) == len(p.bars) + + def test_minor_period_emits_mode_minor(self): + p = parse_chord_file(FIXTURES / "valid_b_minor.chord") + toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] + assert toks[1] == "MODE_minor" + + def test_missing_function_emits_func_unspecified(self): + p = parse_chord_file(FIXTURES / "valid_b_minor.chord") + toks = [ID_TO_TOKEN[i] for i in tokenize_period(p)] + assert toks[5] == "FUNC_unspecified" + + def test_all_ids_in_vocab_range(self): + for fixture_name in VALID_FIXTURES: + p = parse_chord_file(FIXTURES / fixture_name) + assert all(0 <= i < 81 for i in tokenize_period(p)) + + def test_non_canonical_key_transposed_before_encoding(self): + # F# major: first chord F#maj7 → Cmaj7 after shift=6; ROOT_C is at index 6. + p = parse_chord_file(FIXTURES / "valid_fsharp_major.chord") + ids = tokenize_period(p) + assert ID_TO_TOKEN[ids[6]] == "ROOT_C" + + +# --------------------------------------------------------------------------- +# Round-trip +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) + def test_chord_symbols_survive_round_trip(self, fixture_name): + p = parse_chord_file(FIXTURES / fixture_name) + canonical = transpose_to_canonical(p) + recovered = detokenize_to_period(tokenize_period(p)) + + assert len(recovered.bars) == len(canonical.bars) + for bar_c, bar_r in zip(canonical.bars, recovered.bars): + assert len(bar_c) == len(bar_r) + for sym_c, sym_r in zip(bar_c, bar_r): + if sym_c in (".", "NC", "?"): + assert sym_r == sym_c + else: + assert parse_chord_symbol(sym_r) == parse_chord_symbol(sym_c) + + @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) + def test_metadata_survives_round_trip(self, fixture_name): + p = parse_chord_file(FIXTURES / fixture_name) + canonical = transpose_to_canonical(p) + recovered = detokenize_to_period(tokenize_period(p)) + + assert recovered.key == canonical.key + assert recovered.time == canonical.time + assert recovered.subdivision == canonical.subdivision + assert recovered.style == canonical.style + assert recovered.function == canonical.function + + @pytest.mark.parametrize("fixture_name", VALID_FIXTURES) + def test_bar_count_survives_round_trip(self, fixture_name): + p = parse_chord_file(FIXTURES / fixture_name) + recovered = detokenize_to_period(tokenize_period(p)) + assert len(recovered.bars) == len(p.bars) + + def test_c_major_chord_identity(self): + # C major is already canonical — chord tokens must reproduce exactly. + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + recovered = detokenize_to_period(tokenize_period(p)) + assert recovered.key == "C_major" + for bar_p, bar_r in zip(p.bars, recovered.bars): + for sym_p, sym_r in zip(bar_p, bar_r): + if sym_p not in (".", "NC", "?"): + assert parse_chord_symbol(sym_r) == parse_chord_symbol(sym_p) + + def test_slash_chord_root_and_bass_survive(self): + # F/A in C major: root=F maj, bass=A must survive the round-trip. + p = parse_chord_file(FIXTURES / "valid_c_major.chord") + recovered = detokenize_to_period(tokenize_period(p)) + t = parse_chord_symbol(recovered.bars[2][0]) + assert t == ChordTokens("F", "maj", "none", "A") + + def test_fsharp_major_tonic_becomes_c_after_round_trip(self): + p = parse_chord_file(FIXTURES / "valid_fsharp_major.chord") + recovered = detokenize_to_period(tokenize_period(p)) + t = parse_chord_symbol(recovered.bars[0][0]) + assert t.root == "C" + assert t.quality == "maj7" + + def test_b_minor_tonic_becomes_am_after_round_trip(self): + p = parse_chord_file(FIXTURES / "valid_b_minor.chord") + recovered = detokenize_to_period(tokenize_period(p)) + t = parse_chord_symbol(recovered.bars[0][0]) + assert t.root == "A" + assert t.quality == "m"