fix: clone grammar bias per step in generate_period
_grammar_bias returned a shared module-level singleton that the loop mutated in place (EOS block + repetition penalty). The penalty thus accumulated across positions within a call and persisted across calls, collapsing output to HOLD/NC until process restart. Clone the bias each step so edits stay local. Add regression tests guarding the invariant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+4
-1
@@ -322,7 +322,10 @@ def generate_period(
|
|||||||
break
|
break
|
||||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||||
logits = model(inp)[0, -1] # [vocab_size]
|
logits = model(inp)[0, -1] # [vocab_size]
|
||||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
# _grammar_bias returns a shared module-level singleton; clone before
|
||||||
|
# any in-place edit so EOS-blocking and the repetition penalty never
|
||||||
|
# leak across positions or across calls.
|
||||||
|
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone()
|
||||||
if n_bars is not None and bars_completed < n_bars:
|
if n_bars is not None and bars_completed < n_bars:
|
||||||
bias[_EOS] = float("-inf") # don't let the model stop early
|
bias[_EOS] = float("-inf") # don't let the model stop early
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import src.generate as gen
|
||||||
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
|
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
|
||||||
from src.tokenizer import TOKEN_TO_ID, VOCAB
|
from src.tokenizer import TOKEN_TO_ID, VOCAB
|
||||||
|
|
||||||
@@ -152,3 +153,52 @@ def test_generate_no_bars_arg_still_works():
|
|||||||
max_tokens=64, seed=0,
|
max_tokens=64, seed=0,
|
||||||
)
|
)
|
||||||
assert len(period.bars) >= 1
|
assert len(period.bars) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression: generation must not mutate shared module-level bias tensors.
|
||||||
|
#
|
||||||
|
# _grammar_bias() returns shared singletons; the loop applied the EOS block and
|
||||||
|
# the repetition penalty *in place*, permanently corrupting them. Within one call
|
||||||
|
# the penalty accumulated across positions (collapsing output to HOLD/NC); across
|
||||||
|
# calls the corruption persisted in the process until restart. Both symptoms are
|
||||||
|
# the same root cause: the per-step bias must be a fresh copy.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Module-level bias singletons that the generation loop must never mutate.
|
||||||
|
_SHARED_BIASES = [
|
||||||
|
"_BIAS_FREE", "_BIAS_FREE_EOS", "_BIAS_QUAL",
|
||||||
|
"_BIAS_EXT", "_BIAS_EXT_NONE", "_BIAS_BASS",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_generation_does_not_mutate_shared_bias_tensors():
|
||||||
|
before = {name: getattr(gen, name).clone() for name in _SHARED_BIASES}
|
||||||
|
|
||||||
|
model = _UniformModel()
|
||||||
|
generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
n_bars=8, repetition_penalty=1.0, seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in _SHARED_BIASES:
|
||||||
|
assert torch.equal(getattr(gen, name), before[name]), (
|
||||||
|
f"{name} was mutated in place by generate_period"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_generation_is_reproducible_with_penalty():
|
||||||
|
# Same seed + same params must give identical output no matter how many times
|
||||||
|
# generation ran before. State bleed between calls made this fail.
|
||||||
|
model = _UniformModel()
|
||||||
|
kwargs = dict(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
n_bars=8, repetition_penalty=1.0, seed=123,
|
||||||
|
)
|
||||||
|
first = generate_period(**kwargs)
|
||||||
|
# A differently-seeded call in between perturbs any shared state.
|
||||||
|
generate_period(**{**kwargs, "seed": 7})
|
||||||
|
second = generate_period(**kwargs)
|
||||||
|
assert first.bars == second.bars
|
||||||
|
|||||||
Reference in New Issue
Block a user