diff --git a/src/generate.py b/src/generate.py index a2d8d7f..96bf984 100644 --- a/src/generate.py +++ b/src/generate.py @@ -268,6 +268,8 @@ def generate_period( inp = torch.tensor([ids], dtype=torch.long, device=device) logits = model(inp)[0, -1] # [vocab_size] bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) + if n_bars is not None and bars_completed < n_bars: + bias[_EOS] = float("-inf") # don't let the model stop early logits = logits + bias.to(device) token_id = _sample_top_p(logits, temperature, top_p) ids.append(token_id) diff --git a/tests/test_generate.py b/tests/test_generate.py index 8434e71..5c74517 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -28,6 +28,21 @@ class _UniformModel(nn.Module): return logits +class _EosHungryModel(nn.Module): + """Strongly prefers EOS at every step — simulates a model that wants to stop early.""" + def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512): + super().__init__() + self.max_seq_len = max_seq_len + self._vocab_size = vocab_size + self._dummy = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s = x.shape + logits = torch.zeros(b, s, self._vocab_size, device=x.device) + logits[:, :, _EOS] = 1000.0 # model desperately wants to emit EOS + return logits + + def test_encode_prefix_chord_only(): # "Cmaj7" in C major (shift=0) → ROOT_C QUAL_maj7 EXT_none BASS_root ids, n_pos = _encode_prefix(["Cmaj7"], shift=0) @@ -117,6 +132,17 @@ def test_generate_bars_with_prefix(): assert len(period.bars) == 4 +def test_generate_bars_overrides_early_eos(): + # Model desperately wants EOS — n_bars must prevent it from stopping early + model = _EosHungryModel() + period = generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + n_bars=4, seed=0, + ) + assert len(period.bars) == 4 + + def test_generate_no_bars_arg_still_works(): # Without n_bars the model generates until EOS or max_tokens model = _UniformModel()