feat: add bigram repetition penalty to generate_period

Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events.
At each FREE position, subtracts penalty * count(prev→root) from ROOT logits,
capped at 3.0 to prevent NC/HOLD flooding at extreme values.

Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0
(aggressive). Default 0.0 keeps backward compatibility.

Added --repetition-penalty flag to scripts/generate.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:19:24 +03:00
parent 1a63b8e4d8
commit f00a6c1b3a
2 changed files with 98 additions and 15 deletions
+92 -15
View File
@@ -3,13 +3,14 @@
Public API:
generate_period(model, mode, time, subdivision, style, function, key,
prefix=None, temperature=1.0, top_p=0.9,
max_tokens=300, seed=None) -> ChordPeriod
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
"""
from __future__ import annotations
import logging
import random
from collections import Counter
from dataclasses import replace
from typing import Optional
@@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
# ---------------------------------------------------------------------------
# Repetition-penalty helpers
# ---------------------------------------------------------------------------
def _scan_chord_roots(ids: list[int]) -> list[int]:
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
roots: list[int] = []
i, n = 0, len(ids)
while i < n:
tok = ids[i]
if (
_ROOT_START <= tok <= _ROOT_END
and i + 3 < n
and _QUAL_START <= ids[i + 1] <= _QUAL_END
and _EXT_START <= ids[i + 2] <= _EXT_END
and _BASS_START <= ids[i + 3] <= _BASS_END
):
roots.append(tok)
i += 4
continue
i += 1
return roots
def _init_bigram_state(
ids: list[int],
) -> tuple[int | None, Counter[tuple[int, int]]]:
"""Build bigram counts and last-root from token ids already in the buffer.
Used to seed repetition tracking from the prefix / metadata tokens.
Returns:
(last_chord_root_id, bigram_counts)
last_chord_root_id is None when no complete chord has been seen yet.
"""
roots = _scan_chord_roots(ids)
counts: Counter[tuple[int, int]] = Counter()
for j in range(1, len(roots)):
counts[(roots[j - 1], roots[j])] += 1
return (roots[-1] if roots else None), counts
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
"""Additive logit bias enforcing token grammar after *last_id*.
@@ -192,6 +235,7 @@ def generate_period(
max_tokens: int = 300,
n_bars: Optional[int] = None,
seed: Optional[int] = None,
repetition_penalty: float = 0.0,
) -> ChordPeriod:
"""Generate one harmonic period autoregressively.
@@ -200,20 +244,26 @@ def generate_period(
The result is transposed to *key* before being returned.
Args:
model: Loaded ChordTransformer in eval mode.
mode: 'major' or 'minor'.
time: Time signature string, e.g. '4/4'.
subdivision: Positions per beat unit (4 or 8).
style: Style tag, e.g. 'H1K0'.
function: Section label, e.g. 'chorus'.
key: Target output key, e.g. 'F#_major' or 'B_minor'.
prefix: Chord symbols (in *key*) prepended as body context.
temperature: Sampling temperature (> 1 = more random).
top_p: Nucleus cutoff probability (0 < top_p <= 1).
max_tokens: Hard cap on generated tokens.
n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility.
model: Loaded ChordTransformer in eval mode.
mode: 'major' or 'minor'.
time: Time signature string, e.g. '4/4'.
subdivision: Positions per beat unit (4 or 8).
style: Style tag, e.g. 'H1K0'.
function: Section label, e.g. 'chorus'.
key: Target output key, e.g. 'F#_major' or 'B_minor'.
prefix: Chord symbols (in *key*) prepended as body context.
temperature: Sampling temperature (> 1 = more random).
top_p: Nucleus cutoff probability (0 < top_p <= 1).
max_tokens: Hard cap on generated tokens.
n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility.
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
FREE position. Specifically, for each candidate root R,
subtracts penalty * count(last_root → R) from logits,
where count is the number of times that bigram has appeared
in the generated body so far. 0.0 = disabled (default).
Suggested range: 0.51.5.
Returns:
ChordPeriod in *key*.
@@ -261,6 +311,11 @@ def generate_period(
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
last_chord_root: int | None
bigram_counts: Counter[tuple[int, int]]
last_chord_root, bigram_counts = _init_bigram_state(ids)
with torch.no_grad():
for _ in range(max_tokens):
if len(ids) >= context_limit:
@@ -270,11 +325,33 @@ def generate_period(
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early
# Bigram repetition penalty — applied at FREE positions only.
# Penalises root transitions that have already occurred in this period.
if (
repetition_penalty > 0.0
and last_chord_root is not None
and not _is_mid_chord(last_id)
):
for root_id in range(_ROOT_START, _ROOT_END + 1):
count = bigram_counts.get((last_chord_root, root_id), 0)
if count:
# Cap total reduction at 3.0 logits so NC/HOLD don't
# flood the distribution when all roots are heavily penalised.
bias[root_id] -= min(repetition_penalty * count, 3.0)
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)
last_id = token_id
# When a complete chord group is closed, update bigram state.
if _BASS_START <= token_id <= _BASS_END:
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
if last_chord_root is not None:
bigram_counts[(last_chord_root, new_root)] += 1
last_chord_root = new_root
# Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar