diff --git a/scripts/generate.py b/scripts/generate.py index 22e39aa..6873f04 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -79,7 +79,8 @@ def main() -> None: help="Optional output MIDI file path.") ap.add_argument("--prefix", default=None, help='Space-separated chord symbols in the requested key, ' - 'e.g. "Cmaj7 Am7". Used as generation context.') + 'e.g. "Cmaj7 . Am7 .". Use "." for held positions ' + 'and "NC" for no-chord positions.') ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor", help="Do not prepend the tonic chord when --prefix is not given.") ap.add_argument("--temperature", type=float, default=1.0, diff --git a/src/generate.py b/src/generate.py index 2e068de..7dad9d8 100644 --- a/src/generate.py +++ b/src/generate.py @@ -145,22 +145,33 @@ def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod: # Prefix encoding # --------------------------------------------------------------------------- -def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]: - """Encode chord symbols (in user key) as canonical body token IDs. +def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]: + """Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs. *shift* is the semitone offset from user key to canonical (C/Am). - Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...]. + '.' encodes as a single HOLD token; 'NC' encodes as a single NC token. + A chord symbol encodes as four tokens: ROOT QUAL EXT BASS. + Each element (chord, hold, or NC) counts as one bar position. + + Returns (ids, n_positions). """ ids: list[int] = [] + n_positions = 0 for sym in chord_symbols: - t = parse_chord_symbol(sym) - root = _transpose_note(t.root, shift) - bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift) - ids.append(TOKEN_TO_ID[f"ROOT_{root}"]) - ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) - ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) - ids.append(TOKEN_TO_ID[f"BASS_{bass}"]) - return ids + if sym == ".": + ids.append(_HOLD) + elif sym == "NC": + ids.append(_NC) + else: + t = parse_chord_symbol(sym) + root = _transpose_note(t.root, shift) + bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift) + ids.append(TOKEN_TO_ID[f"ROOT_{root}"]) + ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) + ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) + ids.append(TOKEN_TO_ID[f"BASS_{bass}"]) + n_positions += 1 + return ids, n_positions # --------------------------------------------------------------------------- @@ -233,14 +244,11 @@ def generate_period( positions_per_bar = _expected_positions(time, subdivision) - encoded_prefix: list[int] = [] + pos_in_bar = 0 if prefix: - encoded_prefix = _encode_prefix(prefix, shift_to_canonical) + encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical) ids.extend(encoded_prefix) - - # Compute starting position-in-bar after the prefix - # Each chord in prefix = 4 tokens = 1 position - pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar + pos_in_bar = n_prefix_positions % positions_per_bar last_id = ids[-1] context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max diff --git a/tests/test_generate.py b/tests/test_generate.py new file mode 100644 index 0000000..d2719f7 --- /dev/null +++ b/tests/test_generate.py @@ -0,0 +1,58 @@ +"""Tests for src/generate.py — prefix encoding and position tracking.""" + +import pytest + +from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END +from src.tokenizer import TOKEN_TO_ID + + +def test_encode_prefix_chord_only(): + # "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root + ids, n_pos = _encode_prefix(["Cmaj7"], shift=0) + assert len(ids) == 4 + assert n_pos == 1 + assert ids[0] == TOKEN_TO_ID["ROOT_C"] + assert ids[3] == TOKEN_TO_ID["BASS_root"] + + +def test_encode_prefix_hold(): + ids, n_pos = _encode_prefix(["."], shift=0) + assert ids == [_HOLD] + assert n_pos == 1 + + +def test_encode_prefix_nc(): + ids, n_pos = _encode_prefix(["NC"], shift=0) + assert ids == [_NC] + assert n_pos == 1 + + +def test_encode_prefix_mixed(): + # "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on + symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."] + # Ab major → shift to canonical C major = (0 - 8) % 12 = 4 + shift = (0 - 8) % 12 + ids, n_pos = _encode_prefix(symbols, shift=shift) + assert n_pos == 8 # 8 bar positions + # Dots become single HOLD tokens; chords become 4-token groups + # Total tokens: 4 chords × 4 + 4 holds × 1 = 20 + assert len(ids) == 20 + # Every dot position must be a HOLD + hold_positions = [1, 3, 5, 7] # indices in the symbol list + # Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc. + token_idx = 0 + for sym in symbols: + if sym == ".": + assert ids[token_idx] == _HOLD + token_idx += 1 + else: + assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"] + assert _BASS_START <= ids[token_idx + 3] <= _BASS_END + token_idx += 4 + + +def test_encode_prefix_position_count_with_holds(): + # Mixed prefix: 2 chords and 2 holds = 4 positions + ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0) + assert n_pos == 4 + assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens