feat: add --bars arg to control output length
generate_period() now accepts n_bars=N to stop after exactly N complete bars. bars_completed is seeded from the prefix length so --bars counts the full output, not just the generated tail. scripts/generate.py exposes this as --bars (default: None = model decides). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -190,6 +190,7 @@ def generate_period(
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.9,
|
||||
max_tokens: int = 300,
|
||||
n_bars: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> ChordPeriod:
|
||||
"""Generate one harmonic period autoregressively.
|
||||
@@ -210,6 +211,8 @@ def generate_period(
|
||||
temperature: Sampling temperature (> 1 = more random).
|
||||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||
max_tokens: Hard cap on generated tokens.
|
||||
n_bars: Stop after this many complete bars in the output.
|
||||
Counts bars from the prefix too. None = let the model decide.
|
||||
seed: RNG seed for reproducibility.
|
||||
|
||||
Returns:
|
||||
@@ -245,10 +248,15 @@ def generate_period(
|
||||
positions_per_bar = _expected_positions(time, subdivision)
|
||||
|
||||
pos_in_bar = 0
|
||||
bars_completed = 0
|
||||
if prefix:
|
||||
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
||||
ids.extend(encoded_prefix)
|
||||
pos_in_bar = n_prefix_positions % positions_per_bar
|
||||
bars_completed = n_prefix_positions // positions_per_bar
|
||||
|
||||
if n_bars is not None and bars_completed >= n_bars:
|
||||
log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars)
|
||||
|
||||
last_id = ids[-1]
|
||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||
@@ -268,6 +276,10 @@ def generate_period(
|
||||
# Advance position counter when a body position is completed
|
||||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||
if pos_in_bar == 0:
|
||||
bars_completed += 1
|
||||
if n_bars is not None and bars_completed >= n_bars:
|
||||
break
|
||||
|
||||
if token_id == _EOS:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user