fix: --bars now suppresses early EOS until target bar count is reached

Previously the model could emit EOS before reaching n_bars because the
EOS-suppression was only applied via the n_bars break, not the grammar
bias. Fixed by masking EOS to -inf in the logit bias while
bars_completed < n_bars.

Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos
to catch this regression class.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:34:42 +03:00
parent 9e73fa5d32
commit 7c0d147956
2 changed files with 28 additions and 0 deletions
+2
View File
@@ -268,6 +268,8 @@ def generate_period(
inp = torch.tensor([ids], dtype=torch.long, device=device)
logits = model(inp)[0, -1] # [vocab_size]
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)
+26
View File
@@ -28,6 +28,21 @@ class _UniformModel(nn.Module):
return logits
class _EosHungryModel(nn.Module):
"""Strongly prefers EOS at every step — simulates a model that wants to stop early."""
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
super().__init__()
self.max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._dummy = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s = x.shape
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS
return logits
def test_encode_prefix_chord_only():
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
@@ -117,6 +132,17 @@ def test_generate_bars_with_prefix():
assert len(period.bars) == 4
def test_generate_bars_overrides_early_eos():
# Model desperately wants EOS — n_bars must prevent it from stopping early
model = _EosHungryModel()
period = generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=4, seed=0,
)
assert len(period.bars) == 4
def test_generate_no_bars_arg_still_works():
# Without n_bars the model generates until EOS or max_tokens
model = _UniformModel()