feat: add bigram repetition penalty to generate_period

Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events.
At each FREE position, subtracts penalty * count(prev→root) from ROOT logits,
capped at 3.0 to prevent NC/HOLD flooding at extreme values.

Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0
(aggressive). Default 0.0 keeps backward compatibility.

Added --repetition-penalty flag to scripts/generate.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:19:24 +03:00
parent 1a63b8e4d8
commit f00a6c1b3a
2 changed files with 98 additions and 15 deletions
+6
View File
@@ -91,6 +91,11 @@ def main() -> None:
help="Hard cap on generated tokens (default: 300).")
ap.add_argument("--bars", type=int, default=None,
help="Stop after this many complete bars (default: let the model decide).")
ap.add_argument("--repetition-penalty", type=float, default=0.0,
dest="repetition_penalty",
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
"from ROOT logits at each chord-change position. "
"0.0 = disabled (default). Suggested: 0.51.5.")
ap.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility.")
ap.add_argument("--tempo", type=int, default=90,
@@ -140,6 +145,7 @@ def main() -> None:
max_tokens=args.max_tokens,
n_bars=args.bars,
seed=args.seed,
repetition_penalty=args.repetition_penalty,
)
# Give generated periods a readable title
+78 -1
View File
@@ -3,13 +3,14 @@
Public API:
generate_period(model, mode, time, subdivision, style, function, key,
prefix=None, temperature=1.0, top_p=0.9,
max_tokens=300, seed=None) -> ChordPeriod
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
"""
from __future__ import annotations
import logging
import random
from collections import Counter
from dataclasses import replace
from typing import Optional
@@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
# ---------------------------------------------------------------------------
# Repetition-penalty helpers
# ---------------------------------------------------------------------------
def _scan_chord_roots(ids: list[int]) -> list[int]:
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
roots: list[int] = []
i, n = 0, len(ids)
while i < n:
tok = ids[i]
if (
_ROOT_START <= tok <= _ROOT_END
and i + 3 < n
and _QUAL_START <= ids[i + 1] <= _QUAL_END
and _EXT_START <= ids[i + 2] <= _EXT_END
and _BASS_START <= ids[i + 3] <= _BASS_END
):
roots.append(tok)
i += 4
continue
i += 1
return roots
def _init_bigram_state(
ids: list[int],
) -> tuple[int | None, Counter[tuple[int, int]]]:
"""Build bigram counts and last-root from token ids already in the buffer.
Used to seed repetition tracking from the prefix / metadata tokens.
Returns:
(last_chord_root_id, bigram_counts)
last_chord_root_id is None when no complete chord has been seen yet.
"""
roots = _scan_chord_roots(ids)
counts: Counter[tuple[int, int]] = Counter()
for j in range(1, len(roots)):
counts[(roots[j - 1], roots[j])] += 1
return (roots[-1] if roots else None), counts
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
"""Additive logit bias enforcing token grammar after *last_id*.
@@ -192,6 +235,7 @@ def generate_period(
max_tokens: int = 300,
n_bars: Optional[int] = None,
seed: Optional[int] = None,
repetition_penalty: float = 0.0,
) -> ChordPeriod:
"""Generate one harmonic period autoregressively.
@@ -214,6 +258,12 @@ def generate_period(
n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility.
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
FREE position. Specifically, for each candidate root R,
subtracts penalty * count(last_root → R) from logits,
where count is the number of times that bigram has appeared
in the generated body so far. 0.0 = disabled (default).
Suggested range: 0.51.5.
Returns:
ChordPeriod in *key*.
@@ -261,6 +311,11 @@ def generate_period(
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
last_chord_root: int | None
bigram_counts: Counter[tuple[int, int]]
last_chord_root, bigram_counts = _init_bigram_state(ids)
with torch.no_grad():
for _ in range(max_tokens):
if len(ids) >= context_limit:
@@ -270,11 +325,33 @@ def generate_period(
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early
# Bigram repetition penalty — applied at FREE positions only.
# Penalises root transitions that have already occurred in this period.
if (
repetition_penalty > 0.0
and last_chord_root is not None
and not _is_mid_chord(last_id)
):
for root_id in range(_ROOT_START, _ROOT_END + 1):
count = bigram_counts.get((last_chord_root, root_id), 0)
if count:
# Cap total reduction at 3.0 logits so NC/HOLD don't
# flood the distribution when all roots are heavily penalised.
bias[root_id] -= min(repetition_penalty * count, 3.0)
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)
last_id = token_id
# When a complete chord group is closed, update bigram state.
if _BASS_START <= token_id <= _BASS_END:
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
if last_chord_root is not None:
bigram_counts[(last_chord_root, new_root)] += 1
last_chord_root = new_root
# Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar