diff --git a/scripts/generate.py b/scripts/generate.py index f6dfb23..239c9b7 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -91,6 +91,11 @@ def main() -> None: help="Hard cap on generated tokens (default: 300).") ap.add_argument("--bars", type=int, default=None, help="Stop after this many complete bars (default: let the model decide).") + ap.add_argument("--repetition-penalty", type=float, default=0.0, + dest="repetition_penalty", + help="Bigram repetition penalty: subtracts penalty * count(prev→root) " + "from ROOT logits at each chord-change position. " + "0.0 = disabled (default). Suggested: 0.5–1.5.") ap.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") ap.add_argument("--tempo", type=int, default=90, @@ -140,6 +145,7 @@ def main() -> None: max_tokens=args.max_tokens, n_bars=args.bars, seed=args.seed, + repetition_penalty=args.repetition_penalty, ) # Give generated periods a readable title diff --git a/src/generate.py b/src/generate.py index 96bf984..951543c 100644 --- a/src/generate.py +++ b/src/generate.py @@ -3,13 +3,14 @@ Public API: generate_period(model, mode, time, subdivision, style, function, key, prefix=None, temperature=1.0, top_p=0.9, - max_tokens=300, seed=None) -> ChordPeriod + max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod """ from __future__ import annotations import logging import random +from collections import Counter from dataclasses import replace from typing import Optional @@ -67,6 +68,48 @@ _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1))) _ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]}) +# --------------------------------------------------------------------------- +# Repetition-penalty helpers +# --------------------------------------------------------------------------- + +def _scan_chord_roots(ids: list[int]) -> list[int]: + """Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*.""" + roots: list[int] = [] + i, n = 0, len(ids) + while i < n: + tok = ids[i] + if ( + _ROOT_START <= tok <= _ROOT_END + and i + 3 < n + and _QUAL_START <= ids[i + 1] <= _QUAL_END + and _EXT_START <= ids[i + 2] <= _EXT_END + and _BASS_START <= ids[i + 3] <= _BASS_END + ): + roots.append(tok) + i += 4 + continue + i += 1 + return roots + + +def _init_bigram_state( + ids: list[int], +) -> tuple[int | None, Counter[tuple[int, int]]]: + """Build bigram counts and last-root from token ids already in the buffer. + + Used to seed repetition tracking from the prefix / metadata tokens. + + Returns: + (last_chord_root_id, bigram_counts) + last_chord_root_id is None when no complete chord has been seen yet. + """ + roots = _scan_chord_roots(ids) + counts: Counter[tuple[int, int]] = Counter() + for j in range(1, len(roots)): + counts[(roots[j - 1], roots[j])] += 1 + return (roots[-1] if roots else None), counts + + def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor: """Additive logit bias enforcing token grammar after *last_id*. @@ -192,6 +235,7 @@ def generate_period( max_tokens: int = 300, n_bars: Optional[int] = None, seed: Optional[int] = None, + repetition_penalty: float = 0.0, ) -> ChordPeriod: """Generate one harmonic period autoregressively. @@ -200,20 +244,26 @@ def generate_period( The result is transposed to *key* before being returned. Args: - model: Loaded ChordTransformer in eval mode. - mode: 'major' or 'minor'. - time: Time signature string, e.g. '4/4'. - subdivision: Positions per beat unit (4 or 8). - style: Style tag, e.g. 'H1K0'. - function: Section label, e.g. 'chorus'. - key: Target output key, e.g. 'F#_major' or 'B_minor'. - prefix: Chord symbols (in *key*) prepended as body context. - temperature: Sampling temperature (> 1 = more random). - top_p: Nucleus cutoff probability (0 < top_p <= 1). - max_tokens: Hard cap on generated tokens. - n_bars: Stop after this many complete bars in the output. - Counts bars from the prefix too. None = let the model decide. - seed: RNG seed for reproducibility. + model: Loaded ChordTransformer in eval mode. + mode: 'major' or 'minor'. + time: Time signature string, e.g. '4/4'. + subdivision: Positions per beat unit (4 or 8). + style: Style tag, e.g. 'H1K0'. + function: Section label, e.g. 'chorus'. + key: Target output key, e.g. 'F#_major' or 'B_minor'. + prefix: Chord symbols (in *key*) prepended as body context. + temperature: Sampling temperature (> 1 = more random). + top_p: Nucleus cutoff probability (0 < top_p <= 1). + max_tokens: Hard cap on generated tokens. + n_bars: Stop after this many complete bars in the output. + Counts bars from the prefix too. None = let the model decide. + seed: RNG seed for reproducibility. + repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each + FREE position. Specifically, for each candidate root R, + subtracts penalty * count(last_root → R) from logits, + where count is the number of times that bigram has appeared + in the generated body so far. 0.0 = disabled (default). + Suggested range: 0.5–1.5. Returns: ChordPeriod in *key*. @@ -261,6 +311,11 @@ def generate_period( last_id = ids[-1] context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max + # Repetition-penalty state: seed from prefix/metadata tokens already in ids. + last_chord_root: int | None + bigram_counts: Counter[tuple[int, int]] + last_chord_root, bigram_counts = _init_bigram_state(ids) + with torch.no_grad(): for _ in range(max_tokens): if len(ids) >= context_limit: @@ -270,11 +325,33 @@ def generate_period( bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) if n_bars is not None and bars_completed < n_bars: bias[_EOS] = float("-inf") # don't let the model stop early + + # Bigram repetition penalty — applied at FREE positions only. + # Penalises root transitions that have already occurred in this period. + if ( + repetition_penalty > 0.0 + and last_chord_root is not None + and not _is_mid_chord(last_id) + ): + for root_id in range(_ROOT_START, _ROOT_END + 1): + count = bigram_counts.get((last_chord_root, root_id), 0) + if count: + # Cap total reduction at 3.0 logits so NC/HOLD don't + # flood the distribution when all roots are heavily penalised. + bias[root_id] -= min(repetition_penalty * count, 3.0) + logits = logits + bias.to(device) token_id = _sample_top_p(logits, temperature, top_p) ids.append(token_id) last_id = token_id + # When a complete chord group is closed, update bigram state. + if _BASS_START <= token_id <= _BASS_END: + new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS + if last_chord_root is not None: + bigram_counts[(last_chord_root, new_root)] += 1 + last_chord_root = new_root + # Advance position counter when a body position is completed if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC): pos_in_bar = (pos_in_bar + 1) % positions_per_bar