Files
hamori/tests/test_generate.py
T
H1K0 f6ce2a41d3 fix: support '.' and 'NC' in --prefix argument
_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.

Fixes ChordParseError when dots were passed in --prefix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 20:25:41 +03:00

59 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for src/generate.py — prefix encoding and position tracking."""
import pytest
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
from src.tokenizer import TOKEN_TO_ID
def test_encode_prefix_chord_only():
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
assert len(ids) == 4
assert n_pos == 1
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
assert ids[3] == TOKEN_TO_ID["BASS_root"]
def test_encode_prefix_hold():
ids, n_pos = _encode_prefix(["."], shift=0)
assert ids == [_HOLD]
assert n_pos == 1
def test_encode_prefix_nc():
ids, n_pos = _encode_prefix(["NC"], shift=0)
assert ids == [_NC]
assert n_pos == 1
def test_encode_prefix_mixed():
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
shift = (0 - 8) % 12
ids, n_pos = _encode_prefix(symbols, shift=shift)
assert n_pos == 8 # 8 bar positions
# Dots become single HOLD tokens; chords become 4-token groups
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
assert len(ids) == 20
# Every dot position must be a HOLD
hold_positions = [1, 3, 5, 7] # indices in the symbol list
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
token_idx = 0
for sym in symbols:
if sym == ".":
assert ids[token_idx] == _HOLD
token_idx += 1
else:
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
token_idx += 4
def test_encode_prefix_position_count_with_holds():
# Mixed prefix: 2 chords and 2 holds = 4 positions
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
assert n_pos == 4
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens