From dc037b0895b76e890d02cb7dea61d78669f6fa46 Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Thu, 4 Jun 2026 21:14:04 +0300 Subject: [PATCH] fix: clone grammar bias per step in generate_period _grammar_bias returned a shared module-level singleton that the loop mutated in place (EOS block + repetition penalty). The penalty thus accumulated across positions within a call and persisted across calls, collapsing output to HOLD/NC until process restart. Clone the bias each step so edits stay local. Add regression tests guarding the invariant. Co-Authored-By: Claude Opus 4.8 --- src/generate.py | 5 ++++- tests/test_generate.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/generate.py b/src/generate.py index 951543c..c5c11ef 100644 --- a/src/generate.py +++ b/src/generate.py @@ -322,7 +322,10 @@ def generate_period( break inp = torch.tensor([ids], dtype=torch.long, device=device) logits = model(inp)[0, -1] # [vocab_size] - bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) + # _grammar_bias returns a shared module-level singleton; clone before + # any in-place edit so EOS-blocking and the repetition penalty never + # leak across positions or across calls. + bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone() if n_bars is not None and bars_completed < n_bars: bias[_EOS] = float("-inf") # don't let the model stop early diff --git a/tests/test_generate.py b/tests/test_generate.py index 5c74517..8cad067 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -4,6 +4,7 @@ import pytest import torch import torch.nn as nn +import src.generate as gen from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period from src.tokenizer import TOKEN_TO_ID, VOCAB @@ -152,3 +153,52 @@ def test_generate_no_bars_arg_still_works(): max_tokens=64, seed=0, ) assert len(period.bars) >= 1 + + +# --------------------------------------------------------------------------- +# Regression: generation must not mutate shared module-level bias tensors. +# +# _grammar_bias() returns shared singletons; the loop applied the EOS block and +# the repetition penalty *in place*, permanently corrupting them. Within one call +# the penalty accumulated across positions (collapsing output to HOLD/NC); across +# calls the corruption persisted in the process until restart. Both symptoms are +# the same root cause: the per-step bias must be a fresh copy. +# --------------------------------------------------------------------------- + +# Module-level bias singletons that the generation loop must never mutate. +_SHARED_BIASES = [ + "_BIAS_FREE", "_BIAS_FREE_EOS", "_BIAS_QUAL", + "_BIAS_EXT", "_BIAS_EXT_NONE", "_BIAS_BASS", +] + + +def test_generation_does_not_mutate_shared_bias_tensors(): + before = {name: getattr(gen, name).clone() for name in _SHARED_BIASES} + + model = _UniformModel() + generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + n_bars=8, repetition_penalty=1.0, seed=0, + ) + + for name in _SHARED_BIASES: + assert torch.equal(getattr(gen, name), before[name]), ( + f"{name} was mutated in place by generate_period" + ) + + +def test_repeated_generation_is_reproducible_with_penalty(): + # Same seed + same params must give identical output no matter how many times + # generation ran before. State bleed between calls made this fail. + model = _UniformModel() + kwargs = dict( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + n_bars=8, repetition_penalty=1.0, seed=123, + ) + first = generate_period(**kwargs) + # A differently-seeded call in between perturbs any shared state. + generate_period(**{**kwargs, "seed": 7}) + second = generate_period(**kwargs) + assert first.bars == second.bars