Files
hamori/tests/test_generate.py
T
H1K0 7c0d147956 fix: --bars now suppresses early EOS until target bar count is reached
Previously the model could emit EOS before reaching n_bars because the
EOS-suppression was only applied via the n_bars break, not the grammar
bias. Fixed by masking EOS to -inf in the logit bias while
bars_completed < n_bars.

Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos
to catch this regression class.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 20:34:42 +03:00

155 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for src/generate.py — prefix encoding and position tracking."""
import pytest
import torch
import torch.nn as nn
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
from src.tokenizer import TOKEN_TO_ID, VOCAB
# ---------------------------------------------------------------------------
# Mock model that outputs uniform logits (EOS suppressed so generation runs
# until the bar-count limit or max_tokens).
# ---------------------------------------------------------------------------
class _UniformModel(nn.Module):
"""Always returns zero logits except EOS=-1000, forcing non-EOS sampling."""
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
super().__init__()
self.max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s = x.shape
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
logits[:, :, _EOS] = -1000.0
return logits
class _EosHungryModel(nn.Module):
"""Strongly prefers EOS at every step — simulates a model that wants to stop early."""
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
super().__init__()
self.max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._dummy = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s = x.shape
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS
return logits
def test_encode_prefix_chord_only():
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
assert len(ids) == 4
assert n_pos == 1
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
assert ids[3] == TOKEN_TO_ID["BASS_root"]
def test_encode_prefix_hold():
ids, n_pos = _encode_prefix(["."], shift=0)
assert ids == [_HOLD]
assert n_pos == 1
def test_encode_prefix_nc():
ids, n_pos = _encode_prefix(["NC"], shift=0)
assert ids == [_NC]
assert n_pos == 1
def test_encode_prefix_mixed():
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
shift = (0 - 8) % 12
ids, n_pos = _encode_prefix(symbols, shift=shift)
assert n_pos == 8 # 8 bar positions
# Dots become single HOLD tokens; chords become 4-token groups
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
assert len(ids) == 20
# Every dot position must be a HOLD
hold_positions = [1, 3, 5, 7] # indices in the symbol list
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
token_idx = 0
for sym in symbols:
if sym == ".":
assert ids[token_idx] == _HOLD
token_idx += 1
else:
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
token_idx += 4
def test_encode_prefix_position_count_with_holds():
# Mixed prefix: 2 chords and 2 holds = 4 positions
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
assert n_pos == 4
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
# ---------------------------------------------------------------------------
# n_bars tests
# ---------------------------------------------------------------------------
def test_generate_exact_bars():
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_exact_bars_various():
model = _UniformModel()
for n in (1, 2, 8, 16):
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=n, seed=0,
)
assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}"
def test_generate_bars_with_prefix():
# 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
prefix=["C", ".", ".", "."],
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_bars_overrides_early_eos():
# Model desperately wants EOS — n_bars must prevent it from stopping early
model = _EosHungryModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_no_bars_arg_still_works():
# Without n_bars the model generates until EOS or max_tokens
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
max_tokens=64, seed=0,
)
assert len(period.bars) >= 1