fix: clone grammar bias per step in generate_period
_grammar_bias returned a shared module-level singleton that the loop mutated in place (EOS block + repetition penalty). The penalty thus accumulated across positions within a call and persisted across calls, collapsing output to HOLD/NC until process restart. Clone the bias each step so edits stay local. Add regression tests guarding the invariant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import src.generate as gen
|
||||
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
|
||||
from src.tokenizer import TOKEN_TO_ID, VOCAB
|
||||
|
||||
@@ -152,3 +153,52 @@ def test_generate_no_bars_arg_still_works():
|
||||
max_tokens=64, seed=0,
|
||||
)
|
||||
assert len(period.bars) >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: generation must not mutate shared module-level bias tensors.
|
||||
#
|
||||
# _grammar_bias() returns shared singletons; the loop applied the EOS block and
|
||||
# the repetition penalty *in place*, permanently corrupting them. Within one call
|
||||
# the penalty accumulated across positions (collapsing output to HOLD/NC); across
|
||||
# calls the corruption persisted in the process until restart. Both symptoms are
|
||||
# the same root cause: the per-step bias must be a fresh copy.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level bias singletons that the generation loop must never mutate.
|
||||
_SHARED_BIASES = [
|
||||
"_BIAS_FREE", "_BIAS_FREE_EOS", "_BIAS_QUAL",
|
||||
"_BIAS_EXT", "_BIAS_EXT_NONE", "_BIAS_BASS",
|
||||
]
|
||||
|
||||
|
||||
def test_generation_does_not_mutate_shared_bias_tensors():
|
||||
before = {name: getattr(gen, name).clone() for name in _SHARED_BIASES}
|
||||
|
||||
model = _UniformModel()
|
||||
generate_period(
|
||||
model=model, mode="major", time="4/4", subdivision=4,
|
||||
style="H1K0", function="verse", key="C_major",
|
||||
n_bars=8, repetition_penalty=1.0, seed=0,
|
||||
)
|
||||
|
||||
for name in _SHARED_BIASES:
|
||||
assert torch.equal(getattr(gen, name), before[name]), (
|
||||
f"{name} was mutated in place by generate_period"
|
||||
)
|
||||
|
||||
|
||||
def test_repeated_generation_is_reproducible_with_penalty():
|
||||
# Same seed + same params must give identical output no matter how many times
|
||||
# generation ran before. State bleed between calls made this fail.
|
||||
model = _UniformModel()
|
||||
kwargs = dict(
|
||||
model=model, mode="major", time="4/4", subdivision=4,
|
||||
style="H1K0", function="verse", key="C_major",
|
||||
n_bars=8, repetition_penalty=1.0, seed=123,
|
||||
)
|
||||
first = generate_period(**kwargs)
|
||||
# A differently-seeded call in between perturbs any shared state.
|
||||
generate_period(**{**kwargs, "seed": 7})
|
||||
second = generate_period(**kwargs)
|
||||
assert first.bars == second.bars
|
||||
|
||||
Reference in New Issue
Block a user