fix: support '.' and 'NC' in --prefix argument

_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.

Fixes ChordParseError when dots were passed in --prefix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:25:41 +03:00
parent 2e6e934564
commit f6ce2a41d3
3 changed files with 85 additions and 18 deletions
+2 -1
View File
@@ -79,7 +79,8 @@ def main() -> None:
help="Optional output MIDI file path.") help="Optional output MIDI file path.")
ap.add_argument("--prefix", default=None, ap.add_argument("--prefix", default=None,
help='Space-separated chord symbols in the requested key, ' help='Space-separated chord symbols in the requested key, '
'e.g. "Cmaj7 Am7". Used as generation context.') 'e.g. "Cmaj7 . Am7 .". Use "." for held positions '
'and "NC" for no-chord positions.')
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor", ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
help="Do not prepend the tonic chord when --prefix is not given.") help="Do not prepend the tonic chord when --prefix is not given.")
ap.add_argument("--temperature", type=float, default=1.0, ap.add_argument("--temperature", type=float, default=1.0,
+25 -17
View File
@@ -145,22 +145,33 @@ def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
# Prefix encoding # Prefix encoding
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]: def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
"""Encode chord symbols (in user key) as canonical body token IDs. """Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
*shift* is the semitone offset from user key to canonical (C/Am). *shift* is the semitone offset from user key to canonical (C/Am).
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...]. '.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
Each element (chord, hold, or NC) counts as one bar position.
Returns (ids, n_positions).
""" """
ids: list[int] = [] ids: list[int] = []
n_positions = 0
for sym in chord_symbols: for sym in chord_symbols:
t = parse_chord_symbol(sym) if sym == ".":
root = _transpose_note(t.root, shift) ids.append(_HOLD)
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift) elif sym == "NC":
ids.append(TOKEN_TO_ID[f"ROOT_{root}"]) ids.append(_NC)
ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) else:
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) t = parse_chord_symbol(sym)
ids.append(TOKEN_TO_ID[f"BASS_{bass}"]) root = _transpose_note(t.root, shift)
return ids bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
n_positions += 1
return ids, n_positions
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -233,14 +244,11 @@ def generate_period(
positions_per_bar = _expected_positions(time, subdivision) positions_per_bar = _expected_positions(time, subdivision)
encoded_prefix: list[int] = [] pos_in_bar = 0
if prefix: if prefix:
encoded_prefix = _encode_prefix(prefix, shift_to_canonical) encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
ids.extend(encoded_prefix) ids.extend(encoded_prefix)
pos_in_bar = n_prefix_positions % positions_per_bar
# Compute starting position-in-bar after the prefix
# Each chord in prefix = 4 tokens = 1 position
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
last_id = ids[-1] last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
+58
View File
@@ -0,0 +1,58 @@
"""Tests for src/generate.py — prefix encoding and position tracking."""
import pytest
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
from src.tokenizer import TOKEN_TO_ID
def test_encode_prefix_chord_only():
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
assert len(ids) == 4
assert n_pos == 1
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
assert ids[3] == TOKEN_TO_ID["BASS_root"]
def test_encode_prefix_hold():
ids, n_pos = _encode_prefix(["."], shift=0)
assert ids == [_HOLD]
assert n_pos == 1
def test_encode_prefix_nc():
ids, n_pos = _encode_prefix(["NC"], shift=0)
assert ids == [_NC]
assert n_pos == 1
def test_encode_prefix_mixed():
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
shift = (0 - 8) % 12
ids, n_pos = _encode_prefix(symbols, shift=shift)
assert n_pos == 8 # 8 bar positions
# Dots become single HOLD tokens; chords become 4-token groups
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
assert len(ids) == 20
# Every dot position must be a HOLD
hold_positions = [1, 3, 5, 7] # indices in the symbol list
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
token_idx = 0
for sym in symbols:
if sym == ".":
assert ids[token_idx] == _HOLD
token_idx += 1
else:
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
token_idx += 4
def test_encode_prefix_position_count_with_holds():
# Mixed prefix: 2 chords and 2 holds = 4 positions
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
assert n_pos == 4
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens