fix: support '.' and 'NC' in --prefix argument
_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.
Fixes ChordParseError when dots were passed in --prefix.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
"""Tests for src/generate.py — prefix encoding and position tracking."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
|
||||
from src.tokenizer import TOKEN_TO_ID
|
||||
|
||||
|
||||
def test_encode_prefix_chord_only():
|
||||
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
||||
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
||||
assert len(ids) == 4
|
||||
assert n_pos == 1
|
||||
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
|
||||
assert ids[3] == TOKEN_TO_ID["BASS_root"]
|
||||
|
||||
|
||||
def test_encode_prefix_hold():
|
||||
ids, n_pos = _encode_prefix(["."], shift=0)
|
||||
assert ids == [_HOLD]
|
||||
assert n_pos == 1
|
||||
|
||||
|
||||
def test_encode_prefix_nc():
|
||||
ids, n_pos = _encode_prefix(["NC"], shift=0)
|
||||
assert ids == [_NC]
|
||||
assert n_pos == 1
|
||||
|
||||
|
||||
def test_encode_prefix_mixed():
|
||||
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
|
||||
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
|
||||
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
|
||||
shift = (0 - 8) % 12
|
||||
ids, n_pos = _encode_prefix(symbols, shift=shift)
|
||||
assert n_pos == 8 # 8 bar positions
|
||||
# Dots become single HOLD tokens; chords become 4-token groups
|
||||
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
|
||||
assert len(ids) == 20
|
||||
# Every dot position must be a HOLD
|
||||
hold_positions = [1, 3, 5, 7] # indices in the symbol list
|
||||
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
|
||||
token_idx = 0
|
||||
for sym in symbols:
|
||||
if sym == ".":
|
||||
assert ids[token_idx] == _HOLD
|
||||
token_idx += 1
|
||||
else:
|
||||
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
|
||||
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
|
||||
token_idx += 4
|
||||
|
||||
|
||||
def test_encode_prefix_position_count_with_holds():
|
||||
# Mixed prefix: 2 chords and 2 holds = 4 positions
|
||||
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
|
||||
assert n_pos == 4
|
||||
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
|
||||
Reference in New Issue
Block a user