fix: clone grammar bias per step in generate_period

_grammar_bias returned a shared module-level singleton that the loop
mutated in place (EOS block + repetition penalty). The penalty thus
accumulated across positions within a call and persisted across calls,
collapsing output to HOLD/NC until process restart. Clone the bias each
step so edits stay local. Add regression tests guarding the invariant.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 21:14:04 +03:00
parent d1af7bceb8
commit dc037b0895
2 changed files with 54 additions and 1 deletions
+4 -1
View File
@@ -322,7 +322,10 @@ def generate_period(
break break
inp = torch.tensor([ids], dtype=torch.long, device=device) inp = torch.tensor([ids], dtype=torch.long, device=device)
logits = model(inp)[0, -1] # [vocab_size] logits = model(inp)[0, -1] # [vocab_size]
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) # _grammar_bias returns a shared module-level singleton; clone before
# any in-place edit so EOS-blocking and the repetition penalty never
# leak across positions or across calls.
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone()
if n_bars is not None and bars_completed < n_bars: if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early bias[_EOS] = float("-inf") # don't let the model stop early
+50
View File
@@ -4,6 +4,7 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import src.generate as gen
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
from src.tokenizer import TOKEN_TO_ID, VOCAB from src.tokenizer import TOKEN_TO_ID, VOCAB
@@ -152,3 +153,52 @@ def test_generate_no_bars_arg_still_works():
max_tokens=64, seed=0, max_tokens=64, seed=0,
) )
assert len(period.bars) >= 1 assert len(period.bars) >= 1
# ---------------------------------------------------------------------------
# Regression: generation must not mutate shared module-level bias tensors.
#
# _grammar_bias() returns shared singletons; the loop applied the EOS block and
# the repetition penalty *in place*, permanently corrupting them. Within one call
# the penalty accumulated across positions (collapsing output to HOLD/NC); across
# calls the corruption persisted in the process until restart. Both symptoms are
# the same root cause: the per-step bias must be a fresh copy.
# ---------------------------------------------------------------------------
# Module-level bias singletons that the generation loop must never mutate.
_SHARED_BIASES = [
"_BIAS_FREE", "_BIAS_FREE_EOS", "_BIAS_QUAL",
"_BIAS_EXT", "_BIAS_EXT_NONE", "_BIAS_BASS",
]
def test_generation_does_not_mutate_shared_bias_tensors():
before = {name: getattr(gen, name).clone() for name in _SHARED_BIASES}
model = _UniformModel()
generate_period(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=8, repetition_penalty=1.0, seed=0,
)
for name in _SHARED_BIASES:
assert torch.equal(getattr(gen, name), before[name]), (
f"{name} was mutated in place by generate_period"
)
def test_repeated_generation_is_reproducible_with_penalty():
# Same seed + same params must give identical output no matter how many times
# generation ran before. State bleed between calls made this fail.
model = _UniformModel()
kwargs = dict(
model=model, mode="major", time="4/4", subdivision=4,
style="H1K0", function="verse", key="C_major",
n_bars=8, repetition_penalty=1.0, seed=123,
)
first = generate_period(**kwargs)
# A differently-seeded call in between perturbs any shared state.
generate_period(**{**kwargs, "seed": 7})
second = generate_period(**kwargs)
assert first.bars == second.bars