fix: support '.' and 'NC' in --prefix argument
_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.
Fixes ChordParseError when dots were passed in --prefix.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+25
-17
@@ -145,22 +145,33 @@ def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
|
||||
# Prefix encoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
|
||||
"""Encode chord symbols (in user key) as canonical body token IDs.
|
||||
def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
|
||||
"""Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
|
||||
|
||||
*shift* is the semitone offset from user key to canonical (C/Am).
|
||||
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
|
||||
'.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
|
||||
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
|
||||
Each element (chord, hold, or NC) counts as one bar position.
|
||||
|
||||
Returns (ids, n_positions).
|
||||
"""
|
||||
ids: list[int] = []
|
||||
n_positions = 0
|
||||
for sym in chord_symbols:
|
||||
t = parse_chord_symbol(sym)
|
||||
root = _transpose_note(t.root, shift)
|
||||
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||||
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
|
||||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||||
return ids
|
||||
if sym == ".":
|
||||
ids.append(_HOLD)
|
||||
elif sym == "NC":
|
||||
ids.append(_NC)
|
||||
else:
|
||||
t = parse_chord_symbol(sym)
|
||||
root = _transpose_note(t.root, shift)
|
||||
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||||
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
|
||||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||||
n_positions += 1
|
||||
return ids, n_positions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -233,14 +244,11 @@ def generate_period(
|
||||
|
||||
positions_per_bar = _expected_positions(time, subdivision)
|
||||
|
||||
encoded_prefix: list[int] = []
|
||||
pos_in_bar = 0
|
||||
if prefix:
|
||||
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
|
||||
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
||||
ids.extend(encoded_prefix)
|
||||
|
||||
# Compute starting position-in-bar after the prefix
|
||||
# Each chord in prefix = 4 tokens = 1 position
|
||||
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
|
||||
pos_in_bar = n_prefix_positions % positions_per_bar
|
||||
|
||||
last_id = ids[-1]
|
||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||
|
||||
Reference in New Issue
Block a user