feat: add bigram repetition penalty to generate_period

Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events.
At each FREE position, subtracts penalty * count(prev→root) from ROOT logits,
capped at 3.0 to prevent NC/HOLD flooding at extreme values.

Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0
(aggressive). Default 0.0 keeps backward compatibility.

Added --repetition-penalty flag to scripts/generate.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:19:24 +03:00
parent 1a63b8e4d8
commit f00a6c1b3a
2 changed files with 98 additions and 15 deletions
+6
View File
@@ -91,6 +91,11 @@ def main() -> None:
help="Hard cap on generated tokens (default: 300).") help="Hard cap on generated tokens (default: 300).")
ap.add_argument("--bars", type=int, default=None, ap.add_argument("--bars", type=int, default=None,
help="Stop after this many complete bars (default: let the model decide).") help="Stop after this many complete bars (default: let the model decide).")
ap.add_argument("--repetition-penalty", type=float, default=0.0,
dest="repetition_penalty",
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
"from ROOT logits at each chord-change position. "
"0.0 = disabled (default). Suggested: 0.51.5.")
ap.add_argument("--seed", type=int, default=None, ap.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility.") help="Random seed for reproducibility.")
ap.add_argument("--tempo", type=int, default=90, ap.add_argument("--tempo", type=int, default=90,
@@ -140,6 +145,7 @@ def main() -> None:
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
n_bars=args.bars, n_bars=args.bars,
seed=args.seed, seed=args.seed,
repetition_penalty=args.repetition_penalty,
) )
# Give generated periods a readable title # Give generated periods a readable title
+78 -1
View File
@@ -3,13 +3,14 @@
Public API: Public API:
generate_period(model, mode, time, subdivision, style, function, key, generate_period(model, mode, time, subdivision, style, function, key,
prefix=None, temperature=1.0, top_p=0.9, prefix=None, temperature=1.0, top_p=0.9,
max_tokens=300, seed=None) -> ChordPeriod max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import random import random
from collections import Counter
from dataclasses import replace from dataclasses import replace
from typing import Optional from typing import Optional
@@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]}) _ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
# ---------------------------------------------------------------------------
# Repetition-penalty helpers
# ---------------------------------------------------------------------------
def _scan_chord_roots(ids: list[int]) -> list[int]:
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
roots: list[int] = []
i, n = 0, len(ids)
while i < n:
tok = ids[i]
if (
_ROOT_START <= tok <= _ROOT_END
and i + 3 < n
and _QUAL_START <= ids[i + 1] <= _QUAL_END
and _EXT_START <= ids[i + 2] <= _EXT_END
and _BASS_START <= ids[i + 3] <= _BASS_END
):
roots.append(tok)
i += 4
continue
i += 1
return roots
def _init_bigram_state(
ids: list[int],
) -> tuple[int | None, Counter[tuple[int, int]]]:
"""Build bigram counts and last-root from token ids already in the buffer.
Used to seed repetition tracking from the prefix / metadata tokens.
Returns:
(last_chord_root_id, bigram_counts)
last_chord_root_id is None when no complete chord has been seen yet.
"""
roots = _scan_chord_roots(ids)
counts: Counter[tuple[int, int]] = Counter()
for j in range(1, len(roots)):
counts[(roots[j - 1], roots[j])] += 1
return (roots[-1] if roots else None), counts
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor: def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
"""Additive logit bias enforcing token grammar after *last_id*. """Additive logit bias enforcing token grammar after *last_id*.
@@ -192,6 +235,7 @@ def generate_period(
max_tokens: int = 300, max_tokens: int = 300,
n_bars: Optional[int] = None, n_bars: Optional[int] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
repetition_penalty: float = 0.0,
) -> ChordPeriod: ) -> ChordPeriod:
"""Generate one harmonic period autoregressively. """Generate one harmonic period autoregressively.
@@ -214,6 +258,12 @@ def generate_period(
n_bars: Stop after this many complete bars in the output. n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide. Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility. seed: RNG seed for reproducibility.
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
FREE position. Specifically, for each candidate root R,
subtracts penalty * count(last_root → R) from logits,
where count is the number of times that bigram has appeared
in the generated body so far. 0.0 = disabled (default).
Suggested range: 0.51.5.
Returns: Returns:
ChordPeriod in *key*. ChordPeriod in *key*.
@@ -261,6 +311,11 @@ def generate_period(
last_id = ids[-1] last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
last_chord_root: int | None
bigram_counts: Counter[tuple[int, int]]
last_chord_root, bigram_counts = _init_bigram_state(ids)
with torch.no_grad(): with torch.no_grad():
for _ in range(max_tokens): for _ in range(max_tokens):
if len(ids) >= context_limit: if len(ids) >= context_limit:
@@ -270,11 +325,33 @@ def generate_period(
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
if n_bars is not None and bars_completed < n_bars: if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early bias[_EOS] = float("-inf") # don't let the model stop early
# Bigram repetition penalty — applied at FREE positions only.
# Penalises root transitions that have already occurred in this period.
if (
repetition_penalty > 0.0
and last_chord_root is not None
and not _is_mid_chord(last_id)
):
for root_id in range(_ROOT_START, _ROOT_END + 1):
count = bigram_counts.get((last_chord_root, root_id), 0)
if count:
# Cap total reduction at 3.0 logits so NC/HOLD don't
# flood the distribution when all roots are heavily penalised.
bias[root_id] -= min(repetition_penalty * count, 3.0)
logits = logits + bias.to(device) logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p) token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id) ids.append(token_id)
last_id = token_id last_id = token_id
# When a complete chord group is closed, update bigram state.
if _BASS_START <= token_id <= _BASS_END:
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
if last_chord_root is not None:
bigram_counts[(last_chord_root, new_root)] += 1
last_chord_root = new_root
# Advance position counter when a body position is completed # Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC): if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar pos_in_bar = (pos_in_bar + 1) % positions_per_bar