"""Tests for src/generate.py — prefix encoding and position tracking.""" import pytest import torch import torch.nn as nn from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period from src.tokenizer import TOKEN_TO_ID, VOCAB # --------------------------------------------------------------------------- # Mock model that outputs uniform logits (EOS suppressed so generation runs # until the bar-count limit or max_tokens). # --------------------------------------------------------------------------- class _UniformModel(nn.Module): """Always returns zero logits except EOS=-1000, forcing non-EOS sampling.""" def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512): super().__init__() self.max_seq_len = max_seq_len self._vocab_size = vocab_size self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something def forward(self, x: torch.Tensor) -> torch.Tensor: b, s = x.shape logits = torch.zeros(b, s, self._vocab_size, device=x.device) logits[:, :, _EOS] = -1000.0 return logits class _EosHungryModel(nn.Module): """Strongly prefers EOS at every step — simulates a model that wants to stop early.""" def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512): super().__init__() self.max_seq_len = max_seq_len self._vocab_size = vocab_size self._dummy = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: b, s = x.shape logits = torch.zeros(b, s, self._vocab_size, device=x.device) logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS return logits def test_encode_prefix_chord_only(): # "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root ids, n_pos = _encode_prefix(["Cmaj7"], shift=0) assert len(ids) == 4 assert n_pos == 1 assert ids[0] == TOKEN_TO_ID["ROOT_C"] assert ids[3] == TOKEN_TO_ID["BASS_root"] def test_encode_prefix_hold(): ids, n_pos = _encode_prefix(["."], shift=0) assert ids == [_HOLD] assert n_pos == 1 def test_encode_prefix_nc(): ids, n_pos = _encode_prefix(["NC"], shift=0) assert ids == [_NC] assert n_pos == 1 def test_encode_prefix_mixed(): # "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."] # Ab major → shift to canonical C major = (0 - 8) % 12 = 4 shift = (0 - 8) % 12 ids, n_pos = _encode_prefix(symbols, shift=shift) assert n_pos == 8 # 8 bar positions # Dots become single HOLD tokens; chords become 4-token groups # Total tokens: 4 chords × 4 + 4 holds × 1 = 20 assert len(ids) == 20 # Every dot position must be a HOLD hold_positions = [1, 3, 5, 7] # indices in the symbol list # Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc. token_idx = 0 for sym in symbols: if sym == ".": assert ids[token_idx] == _HOLD token_idx += 1 else: assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"] assert _BASS_START <= ids[token_idx + 3] <= _BASS_END token_idx += 4 def test_encode_prefix_position_count_with_holds(): # Mixed prefix: 2 chords and 2 holds = 4 positions ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0) assert n_pos == 4 assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens # --------------------------------------------------------------------------- # n_bars tests # --------------------------------------------------------------------------- def test_generate_exact_bars(): model = _UniformModel() period = generate_period( model=model, mode="major", time="4/4", subdivision=4, style="H1K0", function="verse", key="C_major", n_bars=4, seed=0, ) assert len(period.bars) == 4 def test_generate_exact_bars_various(): model = _UniformModel() for n in (1, 2, 8, 16): period = generate_period( model=model, mode="major", time="4/4", subdivision=4, style="H1K0", function="verse", key="C_major", n_bars=n, seed=0, ) assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}" def test_generate_bars_with_prefix(): # 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total model = _UniformModel() period = generate_period( model=model, mode="major", time="4/4", subdivision=4, style="H1K0", function="verse", key="C_major", prefix=["C", ".", ".", "."], n_bars=4, seed=0, ) assert len(period.bars) == 4 def test_generate_bars_overrides_early_eos(): # Model desperately wants EOS — n_bars must prevent it from stopping early model = _EosHungryModel() period = generate_period( model=model, mode="major", time="4/4", subdivision=4, style="H1K0", function="verse", key="C_major", n_bars=4, seed=0, ) assert len(period.bars) == 4 def test_generate_no_bars_arg_still_works(): # Without n_bars the model generates until EOS or max_tokens model = _UniformModel() period = generate_period( model=model, mode="major", time="4/4", subdivision=4, style="H1K0", function="verse", key="C_major", max_tokens=64, seed=0, ) assert len(period.bars) >= 1