fix: --bars now suppresses early EOS until target bar count is reached
Previously the model could emit EOS before reaching n_bars because the EOS-suppression was only applied via the n_bars break, not the grammar bias. Fixed by masking EOS to -inf in the logit bias while bars_completed < n_bars. Added _EosHungryModel fixture and test_generate_bars_overrides_early_eos to catch this regression class. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -268,6 +268,8 @@ def generate_period(
|
|||||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||||
logits = model(inp)[0, -1] # [vocab_size]
|
logits = model(inp)[0, -1] # [vocab_size]
|
||||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||||
|
if n_bars is not None and bars_completed < n_bars:
|
||||||
|
bias[_EOS] = float("-inf") # don't let the model stop early
|
||||||
logits = logits + bias.to(device)
|
logits = logits + bias.to(device)
|
||||||
token_id = _sample_top_p(logits, temperature, top_p)
|
token_id = _sample_top_p(logits, temperature, top_p)
|
||||||
ids.append(token_id)
|
ids.append(token_id)
|
||||||
|
|||||||
@@ -28,6 +28,21 @@ class _UniformModel(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class _EosHungryModel(nn.Module):
|
||||||
|
"""Strongly prefers EOS at every step — simulates a model that wants to stop early."""
|
||||||
|
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self._vocab_size = vocab_size
|
||||||
|
self._dummy = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, s = x.shape
|
||||||
|
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
|
||||||
|
logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def test_encode_prefix_chord_only():
|
def test_encode_prefix_chord_only():
|
||||||
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
# "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root
|
||||||
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
ids, n_pos = _encode_prefix(["Cmaj7"], shift=0)
|
||||||
@@ -117,6 +132,17 @@ def test_generate_bars_with_prefix():
|
|||||||
assert len(period.bars) == 4
|
assert len(period.bars) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_bars_overrides_early_eos():
|
||||||
|
# Model desperately wants EOS — n_bars must prevent it from stopping early
|
||||||
|
model = _EosHungryModel()
|
||||||
|
period = generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
n_bars=4, seed=0,
|
||||||
|
)
|
||||||
|
assert len(period.bars) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_generate_no_bars_arg_still_works():
|
def test_generate_no_bars_arg_still_works():
|
||||||
# Without n_bars the model generates until EOS or max_tokens
|
# Without n_bars the model generates until EOS or max_tokens
|
||||||
model = _UniformModel()
|
model = _UniformModel()
|
||||||
|
|||||||
Reference in New Issue
Block a user