diff --git a/scripts/generate.py b/scripts/generate.py index 6873f04..f6dfb23 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -89,6 +89,8 @@ def main() -> None: help="Nucleus sampling cutoff (default: 0.9).") ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens", help="Hard cap on generated tokens (default: 300).") + ap.add_argument("--bars", type=int, default=None, + help="Stop after this many complete bars (default: let the model decide).") ap.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") ap.add_argument("--tempo", type=int, default=90, @@ -136,6 +138,7 @@ def main() -> None: temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, + n_bars=args.bars, seed=args.seed, ) diff --git a/src/generate.py b/src/generate.py index 7dad9d8..a2d8d7f 100644 --- a/src/generate.py +++ b/src/generate.py @@ -190,6 +190,7 @@ def generate_period( temperature: float = 1.0, top_p: float = 0.9, max_tokens: int = 300, + n_bars: Optional[int] = None, seed: Optional[int] = None, ) -> ChordPeriod: """Generate one harmonic period autoregressively. @@ -210,6 +211,8 @@ def generate_period( temperature: Sampling temperature (> 1 = more random). top_p: Nucleus cutoff probability (0 < top_p <= 1). max_tokens: Hard cap on generated tokens. + n_bars: Stop after this many complete bars in the output. + Counts bars from the prefix too. None = let the model decide. seed: RNG seed for reproducibility. Returns: @@ -245,10 +248,15 @@ def generate_period( positions_per_bar = _expected_positions(time, subdivision) pos_in_bar = 0 + bars_completed = 0 if prefix: encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical) ids.extend(encoded_prefix) pos_in_bar = n_prefix_positions % positions_per_bar + bars_completed = n_prefix_positions // positions_per_bar + + if n_bars is not None and bars_completed >= n_bars: + log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars) last_id = ids[-1] context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max @@ -268,6 +276,10 @@ def generate_period( # Advance position counter when a body position is completed if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC): pos_in_bar = (pos_in_bar + 1) % positions_per_bar + if pos_in_bar == 0: + bars_completed += 1 + if n_bars is not None and bars_completed >= n_bars: + break if token_id == _EOS: break diff --git a/tests/test_generate.py b/tests/test_generate.py index d2719f7..8434e71 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -1,9 +1,31 @@ """Tests for src/generate.py — prefix encoding and position tracking.""" import pytest +import torch +import torch.nn as nn -from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END -from src.tokenizer import TOKEN_TO_ID +from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period +from src.tokenizer import TOKEN_TO_ID, VOCAB + + +# --------------------------------------------------------------------------- +# Mock model that outputs uniform logits (EOS suppressed so generation runs +# until the bar-count limit or max_tokens). +# --------------------------------------------------------------------------- + +class _UniformModel(nn.Module): + """Always returns zero logits except EOS=-1000, forcing non-EOS sampling.""" + def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512): + super().__init__() + self.max_seq_len = max_seq_len + self._vocab_size = vocab_size + self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s = x.shape + logits = torch.zeros(b, s, self._vocab_size, device=x.device) + logits[:, :, _EOS] = -1000.0 + return logits def test_encode_prefix_chord_only(): @@ -56,3 +78,51 @@ def test_encode_prefix_position_count_with_holds(): ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0) assert n_pos == 4 assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens + + +# --------------------------------------------------------------------------- +# n_bars tests +# --------------------------------------------------------------------------- + +def test_generate_exact_bars(): + model = _UniformModel() + period = generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + n_bars=4, seed=0, + ) + assert len(period.bars) == 4 + + +def test_generate_exact_bars_various(): + model = _UniformModel() + for n in (1, 2, 8, 16): + period = generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + n_bars=n, seed=0, + ) + assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}" + + +def test_generate_bars_with_prefix(): + # 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total + model = _UniformModel() + period = generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + prefix=["C", ".", ".", "."], + n_bars=4, seed=0, + ) + assert len(period.bars) == 4 + + +def test_generate_no_bars_arg_still_works(): + # Without n_bars the model generates until EOS or max_tokens + model = _UniformModel() + period = generate_period( + model=model, mode="major", time="4/4", subdivision=4, + style="H1K0", function="verse", key="C_major", + max_tokens=64, seed=0, + ) + assert len(period.bars) >= 1