Files
hamori/tests/test_generate.py
T
H1K0 9e73fa5d32 feat: add --bars arg to control output length
generate_period() now accepts n_bars=N to stop after exactly N complete
bars. bars_completed is seeded from the prefix length so --bars counts
the full output, not just the generated tail.

scripts/generate.py exposes this as --bars (default: None = model decides).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 20:29:44 +03:00

129 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for src/generate.py — prefix encoding and position tracking."""
import pytest
import torch
import torch.nn as nn
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
from src.tokenizer import TOKEN_TO_ID, VOCAB
# ---------------------------------------------------------------------------
# Mock model that outputs uniform logits (EOS suppressed so generation runs
# until the bar-count limit or max_tokens).
# ---------------------------------------------------------------------------
class _UniformModel(nn.Module):
"""Always returns zero logits except EOS=-1000, forcing non-EOS sampling."""
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
super().__init__()
self.max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s = x.shape
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
logits[:, :, _EOS] = -1000.0
return logits
def test_encode_prefix_chord_only():
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
assert len(ids) == 4
assert n_pos == 1
assert ids[0] == TOKEN_TO_ID["ROOT_C"]
assert ids[3] == TOKEN_TO_ID["BASS_root"]
def test_encode_prefix_hold():
ids, n_pos = _encode_prefix(["."], shift=0)
assert ids == [_HOLD]
assert n_pos == 1
def test_encode_prefix_nc():
ids, n_pos = _encode_prefix(["NC"], shift=0)
assert ids == [_NC]
assert n_pos == 1
def test_encode_prefix_mixed():
# "Abmaj7 . Fm7 . Db . Bbm ." — what the user tried and got a parse error on
symbols = ["Abmaj7", ".", "Fm7", ".", "Db", ".", "Bbm", "."]
# Ab major → shift to canonical C major = (0 - 8) % 12 = 4
shift = (0 - 8) % 12
ids, n_pos = _encode_prefix(symbols, shift=shift)
assert n_pos == 8 # 8 bar positions
# Dots become single HOLD tokens; chords become 4-token groups
# Total tokens: 4 chords × 4 + 4 holds × 1 = 20
assert len(ids) == 20
# Every dot position must be a HOLD
hold_positions = [1, 3, 5, 7] # indices in the symbol list
# Reconstruct expected token positions: chord at 0→tokens 0-3, dot→token 4, etc.
token_idx = 0
for sym in symbols:
if sym == ".":
assert ids[token_idx] == _HOLD
token_idx += 1
else:
assert _ROOT_START <= ids[token_idx] <= TOKEN_TO_ID["ROOT_B"]
assert _BASS_START <= ids[token_idx + 3] <= _BASS_END
token_idx += 4
def test_encode_prefix_position_count_with_holds():
# Mixed prefix: 2 chords and 2 holds = 4 positions
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
assert n_pos == 4
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
# ---------------------------------------------------------------------------
# n_bars tests
# ---------------------------------------------------------------------------
def test_generate_exact_bars():
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_exact_bars_various():
model = _UniformModel()
for n in (1, 2, 8, 16):
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=n, seed=0,
)
assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}"
def test_generate_bars_with_prefix():
# 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
prefix=["C", ".", ".", "."],
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_no_bars_arg_still_works():
# Without n_bars the model generates until EOS or max_tokens
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
max_tokens=64, seed=0,
)
assert len(period.bars) >= 1