feat: add vocabulary constants and tokenize/detokenize to tokenizer.py

Adds VOCAB (81 tokens), TOKEN_TO_ID, and ID_TO_TOKEN per spec §5.2.
tokenize_period() transposes to C/Am then emits BOS + metadata tokens +
per-bar chord/HOLD/NC tokens + BAR + EOS.  detokenize_to_period() is the
exact inverse, returning a ChordPeriod in canonical key.  The m(add9)
quality maps to QUAL_m_add9 in the vocab (parentheses not valid in token
names) via _qual_token/_token_qual helpers.

36 new tests cover vocabulary integrity, token sequence structure,
and full round-trip fidelity for all four valid fixture files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-19 15:47:28 +03:00
parent 355341dab9
commit 868af4ac42
2 changed files with 407 additions and 3 deletions
+197 -3
View File
@@ -1,10 +1,17 @@
"""Parser and canonical transposer for .chord files.
"""Parser, transposer, and tokenizer for .chord files.
Public API (token-ID conversion will be added in the next step):
Public API:
parse_chord_file(path: Path) -> ChordPeriod
transpose_to_canonical(period: ChordPeriod) -> ChordPeriod
tokenize_period(period: ChordPeriod) -> list[int]
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
See docs/chord_format_spec.md for the format specification.
Vocabulary constants:
VOCAB -- 81-token ordered list; index == token ID
TOKEN_TO_ID -- {token_string: id}
ID_TO_TOKEN -- alias for VOCAB
See docs/chord_format_spec.md §5.2 for the vocabulary specification.
"""
from __future__ import annotations
@@ -69,6 +76,44 @@ _VALID_FUNCTIONS: frozenset[str] = frozenset({
})
# ---------------------------------------------------------------------------
# Token vocabulary (§5.2)
# ---------------------------------------------------------------------------
VOCAB: list[str] = [
# Special (4)
"<BOS>", "<EOS>", "<PAD>", "<UNK>",
# Mode (2)
"MODE_major", "MODE_minor",
# Time signature (5)
"TIME_4/4", "TIME_3/4", "TIME_6/8", "TIME_2/4", "TIME_12/8",
# Subdivision (2)
"SUB_4", "SUB_8",
# Style (5)
"STYLE_user", "STYLE_jpop", "STYLE_classical", "STYLE_jazz", "STYLE_other",
# Function (9)
"FUNC_verse", "FUNC_prechorus", "FUNC_chorus", "FUNC_bridge",
"FUNC_intro", "FUNC_outro", "FUNC_interlude", "FUNC_other", "FUNC_unspecified",
# Chord root — 12 pitch classes, sharps only (12)
"ROOT_C", "ROOT_C#", "ROOT_D", "ROOT_D#", "ROOT_E", "ROOT_F",
"ROOT_F#", "ROOT_G", "ROOT_G#", "ROOT_A", "ROOT_A#", "ROOT_B",
# Chord quality (18)
"QUAL_maj", "QUAL_m", "QUAL_dim", "QUAL_aug", "QUAL_sus2", "QUAL_sus4",
"QUAL_maj7", "QUAL_m7", "QUAL_7", "QUAL_m7b5", "QUAL_dim7", "QUAL_mM7",
"QUAL_7sus4", "QUAL_aug7", "QUAL_6", "QUAL_m6", "QUAL_add9", "QUAL_m_add9",
# Extension (8)
"EXT_none", "EXT_9", "EXT_b9", "EXT_#9", "EXT_11", "EXT_#11", "EXT_13", "EXT_b13",
# Bass note — 'root' sentinel + 12 pitch classes (13)
"BASS_root", "BASS_C", "BASS_C#", "BASS_D", "BASS_D#", "BASS_E", "BASS_F",
"BASS_F#", "BASS_G", "BASS_G#", "BASS_A", "BASS_A#", "BASS_B",
# Structural (3)
"HOLD", "NC", "BAR",
]
TOKEN_TO_ID: dict[str, int] = {tok: i for i, tok in enumerate(VOCAB)}
ID_TO_TOKEN: list[str] = VOCAB
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
@@ -124,6 +169,17 @@ def _transpose_symbol(symbol: str, shift: int, fname: str, bar_no: int) -> str:
return _tokens_to_symbol(ChordTokens(new_root, t.quality, t.extension, new_bass))
def _qual_token(quality: str) -> str:
"""Map canonical quality string → QUAL_x token name."""
return "QUAL_m_add9" if quality == "m(add9)" else f"QUAL_{quality}"
def _token_qual(token: str) -> str:
"""Map QUAL_x token name → canonical quality string."""
suffix = token[5:] # strip "QUAL_"
return "m(add9)" if suffix == "m_add9" else suffix
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@@ -277,3 +333,141 @@ def transpose_to_canonical(period: ChordPeriod) -> ChordPeriod:
canonical_key = "C_major" if mode == "major" else "A_minor"
return replace(period, key=canonical_key, bars=new_bars)
def tokenize_period(period: ChordPeriod) -> list[int]:
"""Transpose a period to canonical key and encode it as a token ID sequence.
Args:
period: A ChordPeriod as returned by parse_chord_file.
Returns:
List of integer token IDs: <BOS>, metadata tokens, per-bar chord
tokens interleaved with HOLD/NC, each bar closed by BAR, then <EOS>.
Raises:
ChordFormatError: If a chord symbol cannot be parsed during transposition.
"""
p = transpose_to_canonical(period)
mode = "major" if p.key == "C_major" else "minor"
ids: list[int] = [TOKEN_TO_ID["<BOS>"]]
ids.append(TOKEN_TO_ID[f"MODE_{mode}"])
ids.append(TOKEN_TO_ID[f"TIME_{p.time}"])
ids.append(TOKEN_TO_ID[f"SUB_{p.subdivision}"])
ids.append(TOKEN_TO_ID[f"STYLE_{p.style}"])
ids.append(TOKEN_TO_ID[f"FUNC_{p.function}"])
for bar in p.bars:
for pos in bar:
if pos == ".":
ids.append(TOKEN_TO_ID["HOLD"])
elif pos == "NC":
ids.append(TOKEN_TO_ID["NC"])
elif pos == "?":
ids.append(TOKEN_TO_ID["<UNK>"])
else:
t = parse_chord_symbol(pos)
ids.append(TOKEN_TO_ID[f"ROOT_{t.root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{t.bass}"])
ids.append(TOKEN_TO_ID["BAR"])
ids.append(TOKEN_TO_ID["<EOS>"])
return ids
def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
"""Convert a token ID sequence back to a ChordPeriod in canonical key (C/Am).
Args:
token_ids: Sequence produced by tokenize_period.
Returns:
ChordPeriod with key='C_major' or 'A_minor', title='detokenized'.
Raises:
ChordFormatError: If the sequence is structurally malformed.
"""
tokens = [ID_TO_TOKEN[i] for i in token_ids]
n = len(tokens)
idx = 0
def _consume(prefix: str) -> str:
nonlocal idx
if idx >= n:
raise ChordFormatError(
f"unexpected end of token sequence; expected '{prefix}...'"
)
tok = tokens[idx]
if not tok.startswith(prefix):
raise ChordFormatError(
f"expected token starting with '{prefix}', got {tok!r} at position {idx}"
)
idx += 1
return tok[len(prefix):]
if not tokens or tokens[0] != "<BOS>":
got = repr(tokens[0]) if tokens else "empty sequence"
raise ChordFormatError(f"token sequence must start with <BOS>, got {got}")
idx += 1
mode = _consume("MODE_")
time = _consume("TIME_")
subdivision = int(_consume("SUB_"))
style = _consume("STYLE_")
function = _consume("FUNC_")
key = "C_major" if mode == "major" else "A_minor"
bars: list[list[str]] = []
current_bar: list[str] = []
while idx < n:
tok = tokens[idx]
idx += 1
if tok == "<EOS>":
break
elif tok == "BAR":
bars.append(current_bar)
current_bar = []
elif tok == "HOLD":
current_bar.append(".")
elif tok == "NC":
current_bar.append("NC")
elif tok == "<UNK>":
current_bar.append("?")
elif tok.startswith("ROOT_"):
if idx + 3 > n:
raise ChordFormatError(
"incomplete chord token group near end of sequence"
)
qual_tok = tokens[idx]; idx += 1
ext_tok = tokens[idx]; idx += 1
bass_tok = tokens[idx]; idx += 1
root = tok[5:] # strip "ROOT_"
quality = _token_qual(qual_tok)
extension = ext_tok[4:] # strip "EXT_"
bass = bass_tok[5:] # strip "BASS_"
current_bar.append(
_tokens_to_symbol(ChordTokens(root, quality, extension, bass))
)
else:
raise ChordFormatError(f"unexpected token in bar body: {tok!r}")
if current_bar:
raise ChordFormatError(
"token sequence ended without closing BAR before <EOS>"
)
return ChordPeriod(
title="detokenized",
key=key,
time=time,
subdivision=subdivision,
style=style,
function=function,
bars=bars,
)