feat: add --bars arg to control output length

generate_period() now accepts n_bars=N to stop after exactly N complete
bars. bars_completed is seeded from the prefix length so --bars counts
the full output, not just the generated tail.

scripts/generate.py exposes this as --bars (default: None = model decides).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:29:44 +03:00
parent f6ce2a41d3
commit 9e73fa5d32
3 changed files with 87 additions and 2 deletions
+72 -2
View File
@@ -1,9 +1,31 @@
"""Tests for src/generate.py — prefix encoding and position tracking."""
import pytest
import torch
import torch.nn as nn
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
from src.tokenizer import TOKEN_TO_ID
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
from src.tokenizer import TOKEN_TO_ID, VOCAB
# ---------------------------------------------------------------------------
# Mock model that outputs uniform logits (EOS suppressed so generation runs
# until the bar-count limit or max_tokens).
# ---------------------------------------------------------------------------
class _UniformModel(nn.Module):
"""Always returns zero logits except EOS=-1000, forcing non-EOS sampling."""
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
super().__init__()
self.max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s = x.shape
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
logits[:, :, _EOS] = -1000.0
return logits
def test_encode_prefix_chord_only():
@@ -56,3 +78,51 @@ def test_encode_prefix_position_count_with_holds():
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
assert n_pos == 4
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
# ---------------------------------------------------------------------------
# n_bars tests
# ---------------------------------------------------------------------------
def test_generate_exact_bars():
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_exact_bars_various():
model = _UniformModel()
for n in (1, 2, 8, 16):
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=n, seed=0,
)
assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}"
def test_generate_bars_with_prefix():
# 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
prefix=["C", ".", ".", "."],
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_no_bars_arg_still_works():
# Without n_bars the model generates until EOS or max_tokens
model = _UniformModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
max_tokens=64, seed=0,
)
assert len(period.bars) >= 1