feat: remove BAR token; bump spec to v2.3; fix max_seq_len
Bar boundaries are now implicit — the detokenizer counts positions per bar using TIME × SUB, and the generator gates EOS to bar boundaries only. Removing the deterministic BAR token reduces vocab size from 85 to 84 and lets the model focus on meaningful predictions. - src/tokenizer.py: drop BAR from VOCAB (85→84); replace BAR-based detokenize_to_period with position-counting logic; add write_chord_file; fix _tokens_to_symbol for add9/m(add9) qualities - tests/test_tokenizer.py: update vocab-size assertions to 84, structural token test, remove bar-count test, add test_no_bar_token_in_vocab - docs/chord_format_spec.md: bump to v2.3; document BAR removal in §5.2, §5.3, §5.4, §5.5, §5.6, §6.2, and changelog - CLAUDE.md: remove stale BAR reference, update vocab size to 84 - scripts/pretrain.py: raise max_seq_len 256→320 to cover regenerated McGill data (mean=83, max=283 tokens with BAR-free tokenizer) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+47
-11
@@ -2,6 +2,7 @@
|
||||
|
||||
Public API:
|
||||
parse_chord_file(path: Path) -> ChordPeriod
|
||||
write_chord_file(period: ChordPeriod, path: Path) -> None
|
||||
transpose_to_canonical(period: ChordPeriod) -> ChordPeriod
|
||||
tokenize_period(period: ChordPeriod) -> list[int]
|
||||
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
|
||||
@@ -108,8 +109,8 @@ VOCAB: list[str] = [
|
||||
# Bass note — 'root' sentinel + 12 pitch classes (13)
|
||||
"BASS_root", "BASS_C", "BASS_C#", "BASS_D", "BASS_D#", "BASS_E", "BASS_F",
|
||||
"BASS_F#", "BASS_G", "BASS_G#", "BASS_A", "BASS_A#", "BASS_B",
|
||||
# Structural (3)
|
||||
"HOLD", "NC", "BAR",
|
||||
# Structural (2)
|
||||
"HOLD", "NC",
|
||||
]
|
||||
|
||||
TOKEN_TO_ID: dict[str, int] = {tok: i for i, tok in enumerate(VOCAB)}
|
||||
@@ -146,7 +147,12 @@ def _expected_positions(time: str, subdivision: int) -> int:
|
||||
|
||||
def _tokens_to_symbol(t: ChordTokens) -> str:
|
||||
"""Reconstruct a canonical, parseable chord symbol string from ChordTokens."""
|
||||
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
|
||||
# add9/m(add9) already encode the extension; appending another EXT would be
|
||||
# unparseable. The grammar mask prevents this during generation, but guard here too.
|
||||
if t.quality in ("add9", "m(add9)"):
|
||||
quality_ext = t.quality
|
||||
else:
|
||||
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
|
||||
bass_part = "" if t.bass == "root" else f"/{t.bass}"
|
||||
return t.root + quality_ext + bass_part
|
||||
|
||||
@@ -347,8 +353,9 @@ def tokenize_period(period: ChordPeriod) -> list[int]:
|
||||
period: A ChordPeriod as returned by parse_chord_file.
|
||||
|
||||
Returns:
|
||||
List of integer token IDs: <BOS>, metadata tokens, per-bar chord
|
||||
tokens interleaved with HOLD/NC, each bar closed by BAR, then <EOS>.
|
||||
List of integer token IDs: <BOS>, metadata tokens, a flat sequence of
|
||||
chord/HOLD/NC tokens for every position across all bars, then <EOS>.
|
||||
Bar boundaries are implicit: every positions_per_bar positions form one bar.
|
||||
|
||||
Raises:
|
||||
ChordFormatError: If a chord symbol cannot be parsed during transposition.
|
||||
@@ -381,12 +388,32 @@ def tokenize_period(period: ChordPeriod) -> list[int]:
|
||||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||
ids.append(TOKEN_TO_ID[f"BASS_{t.bass}"])
|
||||
ids.append(TOKEN_TO_ID["BAR"])
|
||||
|
||||
ids.append(TOKEN_TO_ID["<EOS>"])
|
||||
return ids
|
||||
|
||||
|
||||
def write_chord_file(period: ChordPeriod, path: Path) -> None:
|
||||
"""Serialise a ChordPeriod to a .chord file.
|
||||
|
||||
Args:
|
||||
period: ChordPeriod to write.
|
||||
path: Destination path (created or overwritten).
|
||||
"""
|
||||
lines: list[str] = [
|
||||
f"# title: {period.title}",
|
||||
f"# key: {period.key}",
|
||||
f"# time: {period.time}",
|
||||
f"# subdivision: {period.subdivision}",
|
||||
f"# style: {period.style}",
|
||||
f"# function: {period.function}",
|
||||
"",
|
||||
"| " + " | ".join(" ".join(bar) for bar in period.bars) + " |",
|
||||
]
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
|
||||
"""Convert a token ID sequence back to a ChordPeriod in canonical key (C/Am).
|
||||
|
||||
@@ -429,9 +456,11 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
|
||||
function = _consume("FUNC_")
|
||||
|
||||
key = "C_major" if mode == "major" else "A_minor"
|
||||
positions_per_bar = _expected_positions(time, subdivision)
|
||||
|
||||
bars: list[list[str]] = []
|
||||
current_bar: list[str] = []
|
||||
pos_in_bar = 0
|
||||
|
||||
while idx < n:
|
||||
tok = tokens[idx]
|
||||
@@ -439,15 +468,15 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
|
||||
|
||||
if tok == "<EOS>":
|
||||
break
|
||||
elif tok == "BAR":
|
||||
bars.append(current_bar)
|
||||
current_bar = []
|
||||
elif tok == "HOLD":
|
||||
current_bar.append(".")
|
||||
pos_in_bar += 1
|
||||
elif tok == "NC":
|
||||
current_bar.append("NC")
|
||||
pos_in_bar += 1
|
||||
elif tok == "<UNK>":
|
||||
current_bar.append("?")
|
||||
pos_in_bar += 1
|
||||
elif tok.startswith("ROOT_"):
|
||||
if idx + 3 > n:
|
||||
raise ChordFormatError(
|
||||
@@ -463,12 +492,19 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
|
||||
current_bar.append(
|
||||
_tokens_to_symbol(ChordTokens(root, quality, extension, bass))
|
||||
)
|
||||
pos_in_bar += 1
|
||||
else:
|
||||
raise ChordFormatError(f"unexpected token in bar body: {tok!r}")
|
||||
|
||||
if pos_in_bar == positions_per_bar:
|
||||
bars.append(current_bar)
|
||||
current_bar = []
|
||||
pos_in_bar = 0
|
||||
|
||||
if current_bar:
|
||||
raise ChordFormatError(
|
||||
"token sequence ended without closing BAR before <EOS>"
|
||||
log.warning(
|
||||
"detokenize: discarding partial bar (%d/%d positions filled)",
|
||||
pos_in_bar, positions_per_bar,
|
||||
)
|
||||
|
||||
return ChordPeriod(
|
||||
|
||||
Reference in New Issue
Block a user