9e73fa5d32
generate_period() now accepts n_bars=N to stop after exactly N complete bars. bars_completed is seeded from the prefix length so --bars counts the full output, not just the generated tail. scripts/generate.py exposes this as --bars (default: None = model decides). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
129 lines
4.5 KiB
Python
129 lines
4.5 KiB
Python
"""Tests for src/generate.py — prefix encoding and position tracking."""
|
||
|
||
import pytest
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
|
||
from src.tokenizer import TOKEN_TO_ID, VOCAB
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Mock model that outputs uniform logits (EOS suppressed so generation runs
|
||
# until the bar-count limit or max_tokens).
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class _UniformModel(nn.Module):
|
||
"""Always returns zero logits except EOS=-1000, forcing non-EOS sampling."""
|
||
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
|
||
super().__init__()
|
||
self.max_seq_len = max_seq_len
|
||
self._vocab_size = vocab_size
|
||
self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
b, s = x.shape
|
||
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
|
||
logits[:, :, _EOS] = -1000.0
|
||
return logits
|
||
|
||
|
||
def test_encode_prefix_chord_only():
|
||
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
||
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
||
assert len(ids) == 4
|
||
assert n_pos == 1
|
||
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
|
||
assert ids[3] == TOKEN_TO_ID["BASS_root"]
|
||
|
||
|
||
def test_encode_prefix_hold():
|
||
ids, n_pos = _encode_prefix(["."], shift=0)
|
||
assert ids == [_HOLD]
|
||
assert n_pos == 1
|
||
|
||
|
||
def test_encode_prefix_nc():
|
||
ids, n_pos = _encode_prefix(["NC"], shift=0)
|
||
assert ids == [_NC]
|
||
assert n_pos == 1
|
||
|
||
|
||
def test_encode_prefix_mixed():
|
||
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
|
||
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
|
||
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
|
||
shift = (0 - 8) % 12
|
||
ids, n_pos = _encode_prefix(symbols, shift=shift)
|
||
assert n_pos == 8 # 8 bar positions
|
||
# Dots become single HOLD tokens; chords become 4-token groups
|
||
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
|
||
assert len(ids) == 20
|
||
# Every dot position must be a HOLD
|
||
hold_positions = [1, 3, 5, 7] # indices in the symbol list
|
||
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
|
||
token_idx = 0
|
||
for sym in symbols:
|
||
if sym == ".":
|
||
assert ids[token_idx] == _HOLD
|
||
token_idx += 1
|
||
else:
|
||
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
|
||
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
|
||
token_idx += 4
|
||
|
||
|
||
def test_encode_prefix_position_count_with_holds():
|
||
# Mixed prefix: 2 chords and 2 holds = 4 positions
|
||
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
|
||
assert n_pos == 4
|
||
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# n_bars tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def test_generate_exact_bars():
|
||
model = _UniformModel()
|
||
period = generate_period(
|
||
model=model, mode="major", time="4/4", subdivision=4,
|
||
style="H1K0", function="verse", key="C_major",
|
||
n_bars=4, seed=0,
|
||
)
|
||
assert len(period.bars) == 4
|
||
|
||
|
||
def test_generate_exact_bars_various():
|
||
model = _UniformModel()
|
||
for n in (1, 2, 8, 16):
|
||
period = generate_period(
|
||
model=model, mode="major", time="4/4", subdivision=4,
|
||
style="H1K0", function="verse", key="C_major",
|
||
n_bars=n, seed=0,
|
||
)
|
||
assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}"
|
||
|
||
|
||
def test_generate_bars_with_prefix():
|
||
# 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total
|
||
model = _UniformModel()
|
||
period = generate_period(
|
||
model=model, mode="major", time="4/4", subdivision=4,
|
||
style="H1K0", function="verse", key="C_major",
|
||
prefix=["C", ".", ".", "."],
|
||
n_bars=4, seed=0,
|
||
)
|
||
assert len(period.bars) == 4
|
||
|
||
|
||
def test_generate_no_bars_arg_still_works():
|
||
# Without n_bars the model generates until EOS or max_tokens
|
||
model = _UniformModel()
|
||
period = generate_period(
|
||
model=model, mode="major", time="4/4", subdivision=4,
|
||
style="H1K0", function="verse", key="C_major",
|
||
max_tokens=64, seed=0,
|
||
)
|
||
assert len(period.bars) >= 1
|