feat: add --bars arg to control output length

generate_period() now accepts n_bars=N to stop after exactly N complete
bars. bars_completed is seeded from the prefix length so --bars counts
the full output, not just the generated tail.

scripts/generate.py exposes this as --bars (default: None = model decides).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:29:44 +03:00
parent f6ce2a41d3
commit 9e73fa5d32
3 changed files with 87 additions and 2 deletions
+12
View File
@@ -190,6 +190,7 @@ def generate_period(
temperature: float = 1.0,
top_p: float = 0.9,
max_tokens: int = 300,
n_bars: Optional[int] = None,
seed: Optional[int] = None,
) -> ChordPeriod:
"""Generate one harmonic period autoregressively.
@@ -210,6 +211,8 @@ def generate_period(
temperature: Sampling temperature (> 1 = more random).
top_p: Nucleus cutoff probability (0 < top_p <= 1).
max_tokens: Hard cap on generated tokens.
n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility.
Returns:
@@ -245,10 +248,15 @@ def generate_period(
positions_per_bar = _expected_positions(time, subdivision)
pos_in_bar = 0
bars_completed = 0
if prefix:
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
ids.extend(encoded_prefix)
pos_in_bar = n_prefix_positions % positions_per_bar
bars_completed = n_prefix_positions // positions_per_bar
if n_bars is not None and bars_completed >= n_bars:
log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars)
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
@@ -268,6 +276,10 @@ def generate_period(
# Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
if pos_in_bar == 0:
bars_completed += 1
if n_bars is not None and bars_completed >= n_bars:
break
if token_id == _EOS:
break