feat: add --bars arg to control output length
generate_period() now accepts n_bars=N to stop after exactly N complete bars. bars_completed is seeded from the prefix length so --bars counts the full output, not just the generated tail. scripts/generate.py exposes this as --bars (default: None = model decides). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -89,6 +89,8 @@ def main() -> None:
|
|||||||
help="Nucleus sampling cutoff (default: 0.9).")
|
help="Nucleus sampling cutoff (default: 0.9).")
|
||||||
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
|
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
|
||||||
help="Hard cap on generated tokens (default: 300).")
|
help="Hard cap on generated tokens (default: 300).")
|
||||||
|
ap.add_argument("--bars", type=int, default=None,
|
||||||
|
help="Stop after this many complete bars (default: let the model decide).")
|
||||||
ap.add_argument("--seed", type=int, default=None,
|
ap.add_argument("--seed", type=int, default=None,
|
||||||
help="Random seed for reproducibility.")
|
help="Random seed for reproducibility.")
|
||||||
ap.add_argument("--tempo", type=int, default=90,
|
ap.add_argument("--tempo", type=int, default=90,
|
||||||
@@ -136,6 +138,7 @@ def main() -> None:
|
|||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
|
n_bars=args.bars,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ def generate_period(
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_tokens: int = 300,
|
max_tokens: int = 300,
|
||||||
|
n_bars: Optional[int] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> ChordPeriod:
|
) -> ChordPeriod:
|
||||||
"""Generate one harmonic period autoregressively.
|
"""Generate one harmonic period autoregressively.
|
||||||
@@ -210,6 +211,8 @@ def generate_period(
|
|||||||
temperature: Sampling temperature (> 1 = more random).
|
temperature: Sampling temperature (> 1 = more random).
|
||||||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||||
max_tokens: Hard cap on generated tokens.
|
max_tokens: Hard cap on generated tokens.
|
||||||
|
n_bars: Stop after this many complete bars in the output.
|
||||||
|
Counts bars from the prefix too. None = let the model decide.
|
||||||
seed: RNG seed for reproducibility.
|
seed: RNG seed for reproducibility.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -245,10 +248,15 @@ def generate_period(
|
|||||||
positions_per_bar = _expected_positions(time, subdivision)
|
positions_per_bar = _expected_positions(time, subdivision)
|
||||||
|
|
||||||
pos_in_bar = 0
|
pos_in_bar = 0
|
||||||
|
bars_completed = 0
|
||||||
if prefix:
|
if prefix:
|
||||||
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
||||||
ids.extend(encoded_prefix)
|
ids.extend(encoded_prefix)
|
||||||
pos_in_bar = n_prefix_positions % positions_per_bar
|
pos_in_bar = n_prefix_positions % positions_per_bar
|
||||||
|
bars_completed = n_prefix_positions // positions_per_bar
|
||||||
|
|
||||||
|
if n_bars is not None and bars_completed >= n_bars:
|
||||||
|
log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars)
|
||||||
|
|
||||||
last_id = ids[-1]
|
last_id = ids[-1]
|
||||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||||
@@ -268,6 +276,10 @@ def generate_period(
|
|||||||
# Advance position counter when a body position is completed
|
# Advance position counter when a body position is completed
|
||||||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||||
|
if pos_in_bar == 0:
|
||||||
|
bars_completed += 1
|
||||||
|
if n_bars is not None and bars_completed >= n_bars:
|
||||||
|
break
|
||||||
|
|
||||||
if token_id == _EOS:
|
if token_id == _EOS:
|
||||||
break
|
break
|
||||||
|
|||||||
+72
-2
@@ -1,9 +1,31 @@
|
|||||||
"""Tests for src/generate.py — prefix encoding and position tracking."""
|
"""Tests for src/generate.py — prefix encoding and position tracking."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from src.generate import _encode_prefix, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END
|
from src.generate import _encode_prefix, _EOS, _HOLD, _NC, _ROOT_START, _BASS_START, _BASS_END, generate_period
|
||||||
from src.tokenizer import TOKEN_TO_ID
|
from src.tokenizer import TOKEN_TO_ID, VOCAB
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mock model that outputs uniform logits (EOS suppressed so generation runs
|
||||||
|
# until the bar-count limit or max_tokens).
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class _UniformModel(nn.Module):
|
||||||
|
"""Always returns zero logits except EOS=-1000, forcing non-EOS sampling."""
|
||||||
|
def __init__(self, vocab_size: int = len(VOCAB), max_seq_len: int = 512):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self._vocab_size = vocab_size
|
||||||
|
self._dummy = nn.Parameter(torch.zeros(1)) # gives .parameters() something
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, s = x.shape
|
||||||
|
logits = torch.zeros(b, s, self._vocab_size, device=x.device)
|
||||||
|
logits[:, :, _EOS] = -1000.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def test_encode_prefix_chord_only():
|
def test_encode_prefix_chord_only():
|
||||||
@@ -56,3 +78,51 @@ def test_encode_prefix_position_count_with_holds():
|
|||||||
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
|
ids, n_pos = _encode_prefix(["Am", ".", "G", "."], shift=0)
|
||||||
assert n_pos == 4
|
assert n_pos == 4
|
||||||
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
|
assert len(ids) == 2 * 4 + 2 * 1 # 10 tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# n_bars tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_generate_exact_bars():
|
||||||
|
model = _UniformModel()
|
||||||
|
period = generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
n_bars=4, seed=0,
|
||||||
|
)
|
||||||
|
assert len(period.bars) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_exact_bars_various():
|
||||||
|
model = _UniformModel()
|
||||||
|
for n in (1, 2, 8, 16):
|
||||||
|
period = generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
n_bars=n, seed=0,
|
||||||
|
)
|
||||||
|
assert len(period.bars) == n, f"expected {n} bars, got {len(period.bars)}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_bars_with_prefix():
|
||||||
|
# 4-position prefix = 1 bar; n_bars=4 → 3 more bars generated → 4 total
|
||||||
|
model = _UniformModel()
|
||||||
|
period = generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
prefix=["C", ".", ".", "."],
|
||||||
|
n_bars=4, seed=0,
|
||||||
|
)
|
||||||
|
assert len(period.bars) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_no_bars_arg_still_works():
|
||||||
|
# Without n_bars the model generates until EOS or max_tokens
|
||||||
|
model = _UniformModel()
|
||||||
|
period = generate_period(
|
||||||
|
model=model, mode="major", time="4/4", subdivision=4,
|
||||||
|
style="H1K0", function="verse", key="C_major",
|
||||||
|
max_tokens=64, seed=0,
|
||||||
|
)
|
||||||
|
assert len(period.bars) >= 1
|
||||||
|
|||||||
Reference in New Issue
Block a user