feat: add bigram repetition penalty to generate_period
Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events. At each FREE position, subtracts penalty * count(prev→root) from ROOT logits, capped at 3.0 to prevent NC/HOLD flooding at extreme values. Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0 (aggressive). Default 0.0 keeps backward compatibility. Added --repetition-penalty flag to scripts/generate.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+92
-15
@@ -3,13 +3,14 @@
|
||||
Public API:
|
||||
generate_period(model, mode, time, subdivision, style, function, key,
|
||||
prefix=None, temperature=1.0, top_p=0.9,
|
||||
max_tokens=300, seed=None) -> ChordPeriod
|
||||
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
from dataclasses import replace
|
||||
from typing import Optional
|
||||
|
||||
@@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
|
||||
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repetition-penalty helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _scan_chord_roots(ids: list[int]) -> list[int]:
|
||||
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
|
||||
roots: list[int] = []
|
||||
i, n = 0, len(ids)
|
||||
while i < n:
|
||||
tok = ids[i]
|
||||
if (
|
||||
_ROOT_START <= tok <= _ROOT_END
|
||||
and i + 3 < n
|
||||
and _QUAL_START <= ids[i + 1] <= _QUAL_END
|
||||
and _EXT_START <= ids[i + 2] <= _EXT_END
|
||||
and _BASS_START <= ids[i + 3] <= _BASS_END
|
||||
):
|
||||
roots.append(tok)
|
||||
i += 4
|
||||
continue
|
||||
i += 1
|
||||
return roots
|
||||
|
||||
|
||||
def _init_bigram_state(
|
||||
ids: list[int],
|
||||
) -> tuple[int | None, Counter[tuple[int, int]]]:
|
||||
"""Build bigram counts and last-root from token ids already in the buffer.
|
||||
|
||||
Used to seed repetition tracking from the prefix / metadata tokens.
|
||||
|
||||
Returns:
|
||||
(last_chord_root_id, bigram_counts)
|
||||
last_chord_root_id is None when no complete chord has been seen yet.
|
||||
"""
|
||||
roots = _scan_chord_roots(ids)
|
||||
counts: Counter[tuple[int, int]] = Counter()
|
||||
for j in range(1, len(roots)):
|
||||
counts[(roots[j - 1], roots[j])] += 1
|
||||
return (roots[-1] if roots else None), counts
|
||||
|
||||
|
||||
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
||||
"""Additive logit bias enforcing token grammar after *last_id*.
|
||||
|
||||
@@ -192,6 +235,7 @@ def generate_period(
|
||||
max_tokens: int = 300,
|
||||
n_bars: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
repetition_penalty: float = 0.0,
|
||||
) -> ChordPeriod:
|
||||
"""Generate one harmonic period autoregressively.
|
||||
|
||||
@@ -200,20 +244,26 @@ def generate_period(
|
||||
The result is transposed to *key* before being returned.
|
||||
|
||||
Args:
|
||||
model: Loaded ChordTransformer in eval mode.
|
||||
mode: 'major' or 'minor'.
|
||||
time: Time signature string, e.g. '4/4'.
|
||||
subdivision: Positions per beat unit (4 or 8).
|
||||
style: Style tag, e.g. 'H1K0'.
|
||||
function: Section label, e.g. 'chorus'.
|
||||
key: Target output key, e.g. 'F#_major' or 'B_minor'.
|
||||
prefix: Chord symbols (in *key*) prepended as body context.
|
||||
temperature: Sampling temperature (> 1 = more random).
|
||||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||
max_tokens: Hard cap on generated tokens.
|
||||
n_bars: Stop after this many complete bars in the output.
|
||||
Counts bars from the prefix too. None = let the model decide.
|
||||
seed: RNG seed for reproducibility.
|
||||
model: Loaded ChordTransformer in eval mode.
|
||||
mode: 'major' or 'minor'.
|
||||
time: Time signature string, e.g. '4/4'.
|
||||
subdivision: Positions per beat unit (4 or 8).
|
||||
style: Style tag, e.g. 'H1K0'.
|
||||
function: Section label, e.g. 'chorus'.
|
||||
key: Target output key, e.g. 'F#_major' or 'B_minor'.
|
||||
prefix: Chord symbols (in *key*) prepended as body context.
|
||||
temperature: Sampling temperature (> 1 = more random).
|
||||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||
max_tokens: Hard cap on generated tokens.
|
||||
n_bars: Stop after this many complete bars in the output.
|
||||
Counts bars from the prefix too. None = let the model decide.
|
||||
seed: RNG seed for reproducibility.
|
||||
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
|
||||
FREE position. Specifically, for each candidate root R,
|
||||
subtracts penalty * count(last_root → R) from logits,
|
||||
where count is the number of times that bigram has appeared
|
||||
in the generated body so far. 0.0 = disabled (default).
|
||||
Suggested range: 0.5–1.5.
|
||||
|
||||
Returns:
|
||||
ChordPeriod in *key*.
|
||||
@@ -261,6 +311,11 @@ def generate_period(
|
||||
last_id = ids[-1]
|
||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||
|
||||
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
|
||||
last_chord_root: int | None
|
||||
bigram_counts: Counter[tuple[int, int]]
|
||||
last_chord_root, bigram_counts = _init_bigram_state(ids)
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_tokens):
|
||||
if len(ids) >= context_limit:
|
||||
@@ -270,11 +325,33 @@ def generate_period(
|
||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||
if n_bars is not None and bars_completed < n_bars:
|
||||
bias[_EOS] = float("-inf") # don't let the model stop early
|
||||
|
||||
# Bigram repetition penalty — applied at FREE positions only.
|
||||
# Penalises root transitions that have already occurred in this period.
|
||||
if (
|
||||
repetition_penalty > 0.0
|
||||
and last_chord_root is not None
|
||||
and not _is_mid_chord(last_id)
|
||||
):
|
||||
for root_id in range(_ROOT_START, _ROOT_END + 1):
|
||||
count = bigram_counts.get((last_chord_root, root_id), 0)
|
||||
if count:
|
||||
# Cap total reduction at 3.0 logits so NC/HOLD don't
|
||||
# flood the distribution when all roots are heavily penalised.
|
||||
bias[root_id] -= min(repetition_penalty * count, 3.0)
|
||||
|
||||
logits = logits + bias.to(device)
|
||||
token_id = _sample_top_p(logits, temperature, top_p)
|
||||
ids.append(token_id)
|
||||
last_id = token_id
|
||||
|
||||
# When a complete chord group is closed, update bigram state.
|
||||
if _BASS_START <= token_id <= _BASS_END:
|
||||
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
|
||||
if last_chord_root is not None:
|
||||
bigram_counts[(last_chord_root, new_root)] += 1
|
||||
last_chord_root = new_root
|
||||
|
||||
# Advance position counter when a body position is completed
|
||||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||
|
||||
Reference in New Issue
Block a user