fix: --bars now suppresses early EOS until target bar count is reached

Previously the model could emit EOS before reaching n_bars because the
EOS-suppression was only applied via the n_bars break, not the grammar
bias. Fixed by masking EOS to -inf in the logit bias while
bars_completed < n_bars.

Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos
to catch this regression class.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:34:42 +03:00
parent 9e73fa5d32
commit 7c0d147956
2 changed files with 28 additions and 0 deletions
+2
View File
@@ -268,6 +268,8 @@ def generate_period(
inp = torch.tensor([ids], dtype=torch.long, device=device)
logits = model(inp)[0, -1] # [vocab_size]
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)