feat: remove BAR token; bump spec to v2.3; fix max_seq_len

Bar boundaries are now implicit — the detokenizer counts positions per bar
using TIME × SUB, and the generator gates EOS to bar boundaries only.
Removing the deterministic BAR token reduces vocab size from 85 to 84 and
lets the model focus on meaningful predictions.

- src/tokenizer.py: drop BAR from VOCAB (85→84); replace BAR-based
  detokenize_to_period with position-counting logic; add write_chord_file;
  fix _tokens_to_symbol for add9/m(add9) qualities
- tests/test_tokenizer.py: update vocab-size assertions to 84, structural
  token test, remove bar-count test, add test_no_bar_token_in_vocab
- docs/chord_format_spec.md: bump to v2.3; document BAR removal in §5.2,
  §5.3, §5.4, §5.5, §5.6, §6.2, and changelog
- CLAUDE.md: remove stale BAR reference, update vocab size to 84
- scripts/pretrain.py: raise max_seq_len 256→320 to cover regenerated
  McGill data (mean=83, max=283 tokens with BAR-free tokenizer)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 13:56:34 +03:00
parent 329952b02e
commit 4aead2ea20
5 changed files with 92 additions and 54 deletions
+47 -11
View File
@@ -2,6 +2,7 @@
Public API:
parse_chord_file(path: Path) -> ChordPeriod
write_chord_file(period: ChordPeriod, path: Path) -> None
transpose_to_canonical(period: ChordPeriod) -> ChordPeriod
tokenize_period(period: ChordPeriod) -> list[int]
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
@@ -108,8 +109,8 @@ VOCAB: list[str] = [
# Bass note — 'root' sentinel + 12 pitch classes (13)
"BASS_root", "BASS_C", "BASS_C#", "BASS_D", "BASS_D#", "BASS_E", "BASS_F",
"BASS_F#", "BASS_G", "BASS_G#", "BASS_A", "BASS_A#", "BASS_B",
# Structural (3)
"HOLD", "NC", "BAR",
# Structural (2)
"HOLD", "NC",
]
TOKEN_TO_ID: dict[str, int] = {tok: i for i, tok in enumerate(VOCAB)}
@@ -146,7 +147,12 @@ def _expected_positions(time: str, subdivision: int) -> int:
def _tokens_to_symbol(t: ChordTokens) -> str:
"""Reconstruct a canonical, parseable chord symbol string from ChordTokens."""
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
# add9/m(add9) already encode the extension; appending another EXT would be
# unparseable. The grammar mask prevents this during generation, but guard here too.
if t.quality in ("add9", "m(add9)"):
quality_ext = t.quality
else:
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
bass_part = "" if t.bass == "root" else f"/{t.bass}"
return t.root + quality_ext + bass_part
@@ -347,8 +353,9 @@ def tokenize_period(period: ChordPeriod) -> list[int]:
period: A ChordPeriod as returned by parse_chord_file.
Returns:
List of integer token IDs: <BOS>, metadata tokens, per-bar chord
tokens interleaved with HOLD/NC, each bar closed by BAR, then <EOS>.
List of integer token IDs: <BOS>, metadata tokens, a flat sequence of
chord/HOLD/NC tokens for every position across all bars, then <EOS>.
Bar boundaries are implicit: every positions_per_bar positions form one bar.
Raises:
ChordFormatError: If a chord symbol cannot be parsed during transposition.
@@ -381,12 +388,32 @@ def tokenize_period(period: ChordPeriod) -> list[int]:
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{t.bass}"])
ids.append(TOKEN_TO_ID["BAR"])
ids.append(TOKEN_TO_ID["<EOS>"])
return ids
def write_chord_file(period: ChordPeriod, path: Path) -> None:
"""Serialise a ChordPeriod to a .chord file.
Args:
period: ChordPeriod to write.
path: Destination path (created or overwritten).
"""
lines: list[str] = [
f"# title: {period.title}",
f"# key: {period.key}",
f"# time: {period.time}",
f"# subdivision: {period.subdivision}",
f"# style: {period.style}",
f"# function: {period.function}",
"",
"| " + " | ".join(" ".join(bar) for bar in period.bars) + " |",
]
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
"""Convert a token ID sequence back to a ChordPeriod in canonical key (C/Am).
@@ -429,9 +456,11 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
function = _consume("FUNC_")
key = "C_major" if mode == "major" else "A_minor"
positions_per_bar = _expected_positions(time, subdivision)
bars: list[list[str]] = []
current_bar: list[str] = []
pos_in_bar = 0
while idx < n:
tok = tokens[idx]
@@ -439,15 +468,15 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
if tok == "<EOS>":
break
elif tok == "BAR":
bars.append(current_bar)
current_bar = []
elif tok == "HOLD":
current_bar.append(".")
pos_in_bar += 1
elif tok == "NC":
current_bar.append("NC")
pos_in_bar += 1
elif tok == "<UNK>":
current_bar.append("?")
pos_in_bar += 1
elif tok.startswith("ROOT_"):
if idx + 3 > n:
raise ChordFormatError(
@@ -463,12 +492,19 @@ def detokenize_to_period(token_ids: list[int]) -> ChordPeriod:
current_bar.append(
_tokens_to_symbol(ChordTokens(root, quality, extension, bass))
)
pos_in_bar += 1
else:
raise ChordFormatError(f"unexpected token in bar body: {tok!r}")
if pos_in_bar == positions_per_bar:
bars.append(current_bar)
current_bar = []
pos_in_bar = 0
if current_bar:
raise ChordFormatError(
"token sequence ended without closing BAR before <EOS>"
log.warning(
"detokenize: discarding partial bar (%d/%d positions filled)",
pos_in_bar, positions_per_bar,
)
return ChordPeriod(