fix: support '.' and 'NC' in --prefix argument
_encode_prefix now handles hold ('.') and no-chord ('NC') tokens
alongside chord symbols, and returns (ids, n_positions) so that
pos_in_bar is tracked correctly regardless of token type.
Fixes ChordParseError when dots were passed in --prefix.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+2
-1
@@ -79,7 +79,8 @@ def main() -> None:
|
|||||||
help="Optional output MIDI file path.")
|
help="Optional output MIDI file path.")
|
||||||
ap.add_argument("--prefix", default=None,
|
ap.add_argument("--prefix", default=None,
|
||||||
help='Space-separated chord symbols in the requested key, '
|
help='Space-separated chord symbols in the requested key, '
|
||||||
'e.g. "Cmaj7 Am7". Used as generation context.')
|
'e.g. "Cmaj7 . Am7 .". Use "." for held positions '
|
||||||
|
'and "NC" for no-chord positions.')
|
||||||
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
|
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
|
||||||
help="Do not prepend the tonic chord when --prefix is not given.")
|
help="Do not prepend the tonic chord when --prefix is not given.")
|
||||||
ap.add_argument("--temperature", type=float, default=1.0,
|
ap.add_argument("--temperature", type=float, default=1.0,
|
||||||
|
|||||||
+18
-10
@@ -145,14 +145,24 @@ def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
|
|||||||
# Prefix encoding
|
# Prefix encoding
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
|
def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
|
||||||
"""Encode chord symbols (in user key) as canonical body token IDs.
|
"""Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
|
||||||
|
|
||||||
*shift* is the semitone offset from user key to canonical (C/Am).
|
*shift* is the semitone offset from user key to canonical (C/Am).
|
||||||
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
|
'.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
|
||||||
|
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
|
||||||
|
Each element (chord, hold, or NC) counts as one bar position.
|
||||||
|
|
||||||
|
Returns (ids, n_positions).
|
||||||
"""
|
"""
|
||||||
ids: list[int] = []
|
ids: list[int] = []
|
||||||
|
n_positions = 0
|
||||||
for sym in chord_symbols:
|
for sym in chord_symbols:
|
||||||
|
if sym == ".":
|
||||||
|
ids.append(_HOLD)
|
||||||
|
elif sym == "NC":
|
||||||
|
ids.append(_NC)
|
||||||
|
else:
|
||||||
t = parse_chord_symbol(sym)
|
t = parse_chord_symbol(sym)
|
||||||
root = _transpose_note(t.root, shift)
|
root = _transpose_note(t.root, shift)
|
||||||
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||||||
@@ -160,7 +170,8 @@ def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
|
|||||||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||||
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||||||
return ids
|
n_positions += 1
|
||||||
|
return ids, n_positions
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -233,14 +244,11 @@ def generate_period(
|
|||||||
|
|
||||||
positions_per_bar = _expected_positions(time, subdivision)
|
positions_per_bar = _expected_positions(time, subdivision)
|
||||||
|
|
||||||
encoded_prefix: list[int] = []
|
pos_in_bar = 0
|
||||||
if prefix:
|
if prefix:
|
||||||
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
|
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
||||||
ids.extend(encoded_prefix)
|
ids.extend(encoded_prefix)
|
||||||
|
pos_in_bar = n_prefix_positions % positions_per_bar
|
||||||
# Compute starting position-in-bar after the prefix
|
|
||||||
# Each chord in prefix = 4 tokens = 1 position
|
|
||||||
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
|
|
||||||
|
|
||||||
last_id = ids[-1]
|
last_id = ids[-1]
|
||||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for src/generate.py — prefix encoding and position tracking."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
|
||||||
|
from src.tokenizer import TOKEN_TO_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_prefix_chord_only():
|
||||||
|
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
||||||
|
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
||||||
|
assert len(ids) == 4
|
||||||
|
assert n_pos == 1
|
||||||
|
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
|
||||||
|
assert ids[3] == TOKEN_TO_ID["BASS_root"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_prefix_hold():
|
||||||
|
ids, n_pos = _encode_prefix(["."], shift=0)
|
||||||
|
assert ids == [_HOLD]
|
||||||
|
assert n_pos == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_prefix_nc():
|
||||||
|
ids, n_pos = _encode_prefix(["NC"], shift=0)
|
||||||
|
assert ids == [_NC]
|
||||||
|
assert n_pos == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_prefix_mixed():
|
||||||
|
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
|
||||||
|
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
|
||||||
|
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
|
||||||
|
shift = (0 - 8) % 12
|
||||||
|
ids, n_pos = _encode_prefix(symbols, shift=shift)
|
||||||
|
assert n_pos == 8 # 8 bar positions
|
||||||
|
# Dots become single HOLD tokens; chords become 4-token groups
|
||||||
|
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
|
||||||
|
assert len(ids) == 20
|
||||||
|
# Every dot position must be a HOLD
|
||||||
|
hold_positions = [1, 3, 5, 7] # indices in the symbol list
|
||||||
|
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
|
||||||
|
token_idx = 0
|
||||||
|
for sym in symbols:
|
||||||
|
if sym == ".":
|
||||||
|
assert ids[token_idx] == _HOLD
|
||||||
|
token_idx += 1
|
||||||
|
else:
|
||||||
|
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
|
||||||
|
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
|
||||||
|
token_idx += 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_prefix_position_count_with_holds():
|
||||||
|
# Mixed prefix: 2 chords and 2 holds = 4 positions
|
||||||
|
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
|
||||||
|
assert n_pos == 4
|
||||||
|
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
|
||||||
Reference in New Issue
Block a user