feat: add bigram repetition penalty to generate_period
Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events. At each FREE position, subtracts penalty * count(prev→root) from ROOT logits, capped at 3.0 to prevent NC/HOLD flooding at extreme values. Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0 (aggressive). Default 0.0 keeps backward compatibility. Added --repetition-penalty flag to scripts/generate.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -91,6 +91,11 @@ def main() -> None:
|
|||||||
help="Hard cap on generated tokens (default: 300).")
|
help="Hard cap on generated tokens (default: 300).")
|
||||||
ap.add_argument("--bars", type=int, default=None,
|
ap.add_argument("--bars", type=int, default=None,
|
||||||
help="Stop after this many complete bars (default: let the model decide).")
|
help="Stop after this many complete bars (default: let the model decide).")
|
||||||
|
ap.add_argument("--repetition-penalty", type=float, default=0.0,
|
||||||
|
dest="repetition_penalty",
|
||||||
|
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
|
||||||
|
"from ROOT logits at each chord-change position. "
|
||||||
|
"0.0 = disabled (default). Suggested: 0.5–1.5.")
|
||||||
ap.add_argument("--seed", type=int, default=None,
|
ap.add_argument("--seed", type=int, default=None,
|
||||||
help="Random seed for reproducibility.")
|
help="Random seed for reproducibility.")
|
||||||
ap.add_argument("--tempo", type=int, default=90,
|
ap.add_argument("--tempo", type=int, default=90,
|
||||||
@@ -140,6 +145,7 @@ def main() -> None:
|
|||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
n_bars=args.bars,
|
n_bars=args.bars,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
repetition_penalty=args.repetition_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Give generated periods a readable title
|
# Give generated periods a readable title
|
||||||
|
|||||||
+78
-1
@@ -3,13 +3,14 @@
|
|||||||
Public API:
|
Public API:
|
||||||
generate_period(model, mode, time, subdivision, style, function, key,
|
generate_period(model, mode, time, subdivision, style, function, key,
|
||||||
prefix=None, temperature=1.0, top_p=0.9,
|
prefix=None, temperature=1.0, top_p=0.9,
|
||||||
max_tokens=300, seed=None) -> ChordPeriod
|
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from collections import Counter
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
|
|||||||
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Repetition-penalty helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _scan_chord_roots(ids: list[int]) -> list[int]:
|
||||||
|
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
|
||||||
|
roots: list[int] = []
|
||||||
|
i, n = 0, len(ids)
|
||||||
|
while i < n:
|
||||||
|
tok = ids[i]
|
||||||
|
if (
|
||||||
|
_ROOT_START <= tok <= _ROOT_END
|
||||||
|
and i + 3 < n
|
||||||
|
and _QUAL_START <= ids[i + 1] <= _QUAL_END
|
||||||
|
and _EXT_START <= ids[i + 2] <= _EXT_END
|
||||||
|
and _BASS_START <= ids[i + 3] <= _BASS_END
|
||||||
|
):
|
||||||
|
roots.append(tok)
|
||||||
|
i += 4
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
return roots
|
||||||
|
|
||||||
|
|
||||||
|
def _init_bigram_state(
|
||||||
|
ids: list[int],
|
||||||
|
) -> tuple[int | None, Counter[tuple[int, int]]]:
|
||||||
|
"""Build bigram counts and last-root from token ids already in the buffer.
|
||||||
|
|
||||||
|
Used to seed repetition tracking from the prefix / metadata tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(last_chord_root_id, bigram_counts)
|
||||||
|
last_chord_root_id is None when no complete chord has been seen yet.
|
||||||
|
"""
|
||||||
|
roots = _scan_chord_roots(ids)
|
||||||
|
counts: Counter[tuple[int, int]] = Counter()
|
||||||
|
for j in range(1, len(roots)):
|
||||||
|
counts[(roots[j - 1], roots[j])] += 1
|
||||||
|
return (roots[-1] if roots else None), counts
|
||||||
|
|
||||||
|
|
||||||
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
||||||
"""Additive logit bias enforcing token grammar after *last_id*.
|
"""Additive logit bias enforcing token grammar after *last_id*.
|
||||||
|
|
||||||
@@ -192,6 +235,7 @@ def generate_period(
|
|||||||
max_tokens: int = 300,
|
max_tokens: int = 300,
|
||||||
n_bars: Optional[int] = None,
|
n_bars: Optional[int] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
repetition_penalty: float = 0.0,
|
||||||
) -> ChordPeriod:
|
) -> ChordPeriod:
|
||||||
"""Generate one harmonic period autoregressively.
|
"""Generate one harmonic period autoregressively.
|
||||||
|
|
||||||
@@ -214,6 +258,12 @@ def generate_period(
|
|||||||
n_bars: Stop after this many complete bars in the output.
|
n_bars: Stop after this many complete bars in the output.
|
||||||
Counts bars from the prefix too. None = let the model decide.
|
Counts bars from the prefix too. None = let the model decide.
|
||||||
seed: RNG seed for reproducibility.
|
seed: RNG seed for reproducibility.
|
||||||
|
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
|
||||||
|
FREE position. Specifically, for each candidate root R,
|
||||||
|
subtracts penalty * count(last_root → R) from logits,
|
||||||
|
where count is the number of times that bigram has appeared
|
||||||
|
in the generated body so far. 0.0 = disabled (default).
|
||||||
|
Suggested range: 0.5–1.5.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChordPeriod in *key*.
|
ChordPeriod in *key*.
|
||||||
@@ -261,6 +311,11 @@ def generate_period(
|
|||||||
last_id = ids[-1]
|
last_id = ids[-1]
|
||||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||||
|
|
||||||
|
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
|
||||||
|
last_chord_root: int | None
|
||||||
|
bigram_counts: Counter[tuple[int, int]]
|
||||||
|
last_chord_root, bigram_counts = _init_bigram_state(ids)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
if len(ids) >= context_limit:
|
if len(ids) >= context_limit:
|
||||||
@@ -270,11 +325,33 @@ def generate_period(
|
|||||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||||
if n_bars is not None and bars_completed < n_bars:
|
if n_bars is not None and bars_completed < n_bars:
|
||||||
bias[_EOS] = float("-inf") # don't let the model stop early
|
bias[_EOS] = float("-inf") # don't let the model stop early
|
||||||
|
|
||||||
|
# Bigram repetition penalty — applied at FREE positions only.
|
||||||
|
# Penalises root transitions that have already occurred in this period.
|
||||||
|
if (
|
||||||
|
repetition_penalty > 0.0
|
||||||
|
and last_chord_root is not None
|
||||||
|
and not _is_mid_chord(last_id)
|
||||||
|
):
|
||||||
|
for root_id in range(_ROOT_START, _ROOT_END + 1):
|
||||||
|
count = bigram_counts.get((last_chord_root, root_id), 0)
|
||||||
|
if count:
|
||||||
|
# Cap total reduction at 3.0 logits so NC/HOLD don't
|
||||||
|
# flood the distribution when all roots are heavily penalised.
|
||||||
|
bias[root_id] -= min(repetition_penalty * count, 3.0)
|
||||||
|
|
||||||
logits = logits + bias.to(device)
|
logits = logits + bias.to(device)
|
||||||
token_id = _sample_top_p(logits, temperature, top_p)
|
token_id = _sample_top_p(logits, temperature, top_p)
|
||||||
ids.append(token_id)
|
ids.append(token_id)
|
||||||
last_id = token_id
|
last_id = token_id
|
||||||
|
|
||||||
|
# When a complete chord group is closed, update bigram state.
|
||||||
|
if _BASS_START <= token_id <= _BASS_END:
|
||||||
|
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
|
||||||
|
if last_chord_root is not None:
|
||||||
|
bigram_counts[(last_chord_root, new_root)] += 1
|
||||||
|
last_chord_root = new_root
|
||||||
|
|
||||||
# Advance position counter when a body position is completed
|
# Advance position counter when a body position is completed
|
||||||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||||
|
|||||||
Reference in New Issue
Block a user