fix: --bars now suppresses early EOS until target bar count is reached
Previously the model could emit EOS before reaching n_bars because the EOS-suppression was only applied via the n_bars break, not the grammar bias. Fixed by masking EOS to -inf in the logit bias while bars_completed < n_bars. Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos to catch this regression class. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -268,6 +268,8 @@ def generate_period(
|
||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||
logits = model(inp)[0, -1] # [vocab_size]
|
||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||
if n_bars is not None and bars_completed < n_bars:
|
||||
bias[_EOS] = float("-inf") # don't let the model stop early
|
||||
logits = logits + bias.to(device)
|
||||
token_id = _sample_top_p(logits, temperature, top_p)
|
||||
ids.append(token_id)
|
||||
|
||||
Reference in New Issue
Block a user