fix: --bars now suppresses early EOS until target bar count is reached
Previously the model could emit EOS before reaching n_bars because the EOS-suppression was only applied via the n_bars break, not the grammar bias. Fixed by masking EOS to -inf in the logit bias while bars_completed < n_bars. Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos to catch this regression class. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -28,6 +28,21 @@ class _UniformModel(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
class _EosHungryModel(nn.Module):
|
||||
"""Strongly prefers EOS at every step — simulates a model that wants to stop early."""
|
||||
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
|
||||
super().__init__()
|
||||
self.max_seq_len = max_seq_len
|
||||
self._vocab_size = vocab_size
|
||||
self._dummy = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s = x.shape
|
||||
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
|
||||
logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS
|
||||
return logits
|
||||
|
||||
|
||||
def test_encode_prefix_chord_only():
|
||||
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
||||
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
||||
@@ -117,6 +132,17 @@ def test_generate_bars_with_prefix():
|
||||
assert len(period.bars) == 4
|
||||
|
||||
|
||||
def test_generate_bars_overrides_early_eos():
|
||||
# Model desperately wants EOS — n_bars must prevent it from stopping early
|
||||
model = _EosHungryModel()
|
||||
period = generate_period(
|
||||
model=model, mode="major", time="4/4", subdivision=4,
|
||||
style="H1K0", function="verse", key="C_major",
|
||||
n_bars=4, seed=0,
|
||||
)
|
||||
assert len(period.bars) == 4
|
||||
|
||||
|
||||
def test_generate_no_bars_arg_still_works():
|
||||
# Without n_bars the model generates until EOS or max_tokens
|
||||
model = _UniformModel()
|
||||
|
||||
Reference in New Issue
Block a user