fix: support '.' and 'NC' in --prefix argument

_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.

Fixes ChordParseError when dots were passed in --prefix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:25:41 +03:00
parent 2e6e934564
commit f6ce2a41d3
3 changed files with 85 additions and 18 deletions
+25 -17
View File
@@ -145,22 +145,33 @@ def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
# Prefix encoding
# ---------------------------------------------------------------------------
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
"""Encode chord symbols (in user key) as canonical body token IDs.
def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
"""Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
*shift* is the semitone offset from user key to canonical (C/Am).
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
'.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
Each element (chord, hold, or NC) counts as one bar position.
Returns (ids, n_positions).
"""
ids: list[int] = []
n_positions = 0
for sym in chord_symbols:
t = parse_chord_symbol(sym)
root = _transpose_note(t.root, shift)
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
return ids
if sym == ".":
ids.append(_HOLD)
elif sym == "NC":
ids.append(_NC)
else:
t = parse_chord_symbol(sym)
root = _transpose_note(t.root, shift)
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
n_positions += 1
return ids, n_positions
# ---------------------------------------------------------------------------
@@ -233,14 +244,11 @@ def generate_period(
positions_per_bar = _expected_positions(time, subdivision)
encoded_prefix: list[int] = []
pos_in_bar = 0
if prefix:
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
ids.extend(encoded_prefix)
# Compute starting position-in-bar after the prefix
# Each chord in prefix = 4 tokens = 1 position
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
pos_in_bar = n_prefix_positions % positions_per_bar
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max