dc037b0895
_grammar_bias returned a shared module-level singleton that the loop mutated in place (EOS block + repetition penalty). The penalty thus accumulated across positions within a call and persisted across calls, collapsing output to HOLD/NC until process restart. Clone the bias each step so edits stay local. Add regression tests guarding the invariant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
386 lines
15 KiB
Python
386 lines
15 KiB
Python
"""Autoregressive generation of harmonic periods.
|
||
|
||
Public API:
|
||
generate_period(model, mode, time, subdivision, style, function, key,
|
||
prefix=None, temperature=1.0, top_p=0.9,
|
||
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import random
|
||
from collections import Counter
|
||
from dataclasses import replace
|
||
from typing import Optional
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
from src.chord_parser import parse_chord_symbol
|
||
from src.model import ChordTransformer
|
||
from src.tokenizer import (
|
||
VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period,
|
||
_qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol,
|
||
_expected_positions,
|
||
)
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_VOCAB_SIZE = len(VOCAB)
|
||
_EOS = TOKEN_TO_ID["<EOS>"]
|
||
_HOLD = TOKEN_TO_ID["HOLD"]
|
||
_NC = TOKEN_TO_ID["NC"]
|
||
_ROOT_START = TOKEN_TO_ID["ROOT_C"]
|
||
_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive
|
||
_QUAL_START = TOKEN_TO_ID["QUAL_maj"]
|
||
_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive
|
||
_EXT_START = TOKEN_TO_ID["EXT_none"]
|
||
_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive
|
||
_BASS_START = TOKEN_TO_ID["BASS_root"]
|
||
_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Grammar masks
|
||
# Additive logit bias: 0.0 = allowed, -inf = blocked.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _make_bias(allowed: list[int]) -> torch.Tensor:
|
||
b = torch.full((_VOCAB_SIZE,), float("-inf"))
|
||
for i in allowed:
|
||
b[i] = 0.0
|
||
return b
|
||
|
||
|
||
_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC]
|
||
_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS
|
||
_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed
|
||
_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1)))
|
||
_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1)))
|
||
_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities
|
||
_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
|
||
|
||
# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid
|
||
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Repetition-penalty helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _scan_chord_roots(ids: list[int]) -> list[int]:
|
||
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
|
||
roots: list[int] = []
|
||
i, n = 0, len(ids)
|
||
while i < n:
|
||
tok = ids[i]
|
||
if (
|
||
_ROOT_START <= tok <= _ROOT_END
|
||
and i + 3 < n
|
||
and _QUAL_START <= ids[i + 1] <= _QUAL_END
|
||
and _EXT_START <= ids[i + 2] <= _EXT_END
|
||
and _BASS_START <= ids[i + 3] <= _BASS_END
|
||
):
|
||
roots.append(tok)
|
||
i += 4
|
||
continue
|
||
i += 1
|
||
return roots
|
||
|
||
|
||
def _init_bigram_state(
|
||
ids: list[int],
|
||
) -> tuple[int | None, Counter[tuple[int, int]]]:
|
||
"""Build bigram counts and last-root from token ids already in the buffer.
|
||
|
||
Used to seed repetition tracking from the prefix / metadata tokens.
|
||
|
||
Returns:
|
||
(last_chord_root_id, bigram_counts)
|
||
last_chord_root_id is None when no complete chord has been seen yet.
|
||
"""
|
||
roots = _scan_chord_roots(ids)
|
||
counts: Counter[tuple[int, int]] = Counter()
|
||
for j in range(1, len(roots)):
|
||
counts[(roots[j - 1], roots[j])] += 1
|
||
return (roots[-1] if roots else None), counts
|
||
|
||
|
||
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
||
"""Additive logit bias enforcing token grammar after *last_id*.
|
||
|
||
EOS is allowed only at bar boundaries (pos_in_bar == 0).
|
||
"""
|
||
if _ROOT_START <= last_id <= _ROOT_END:
|
||
return _BIAS_QUAL
|
||
if _QUAL_START <= last_id <= _QUAL_END:
|
||
return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT
|
||
if _EXT_START <= last_id <= _EXT_END:
|
||
return _BIAS_BASS
|
||
# FREE position: after BASS, HOLD, NC, or the last metadata token
|
||
if pos_in_bar == 0:
|
||
return _BIAS_FREE_EOS # at bar boundary — EOS is valid
|
||
return _BIAS_FREE # mid-bar — EOS not yet valid
|
||
|
||
|
||
def _is_mid_chord(last_id: int) -> bool:
|
||
"""True when the next required token is QUAL, EXT, or BASS."""
|
||
return (
|
||
(_ROOT_START <= last_id <= _ROOT_END) or
|
||
(_QUAL_START <= last_id <= _QUAL_END) or
|
||
(_EXT_START <= last_id <= _EXT_END)
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Sampling
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int:
|
||
"""Sample one token index via nucleus (top-p) sampling."""
|
||
logits = logits / max(temperature, 1e-8)
|
||
probs = F.softmax(logits, dim=-1)
|
||
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
||
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
||
cut = (cumulative - sorted_probs) >= top_p
|
||
sorted_probs[cut] = 0.0
|
||
sorted_probs /= sorted_probs.sum()
|
||
chosen = torch.multinomial(sorted_probs, 1).item()
|
||
return int(sorted_idx[chosen].item())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Transposition helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _transpose_note(note: str, shift: int) -> str:
|
||
return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12]
|
||
|
||
|
||
def _to_canonical_shift(mode: str, tonic: str) -> int:
|
||
"""Semitone shift to transpose from *tonic* to canonical (C=0 / A=9)."""
|
||
canonical = 0 if mode == "major" else 9
|
||
return (canonical - _NOTE_INDEX[tonic]) % 12
|
||
|
||
|
||
def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
|
||
"""Transpose a canonical (C/Am) ChordPeriod to *target_key*."""
|
||
parts = target_key.split("_")
|
||
tonic = _parse_note_from_key(parts[0])
|
||
mode = parts[-1]
|
||
canonical = 0 if mode == "major" else 9
|
||
shift = (_NOTE_INDEX[tonic] - canonical) % 12
|
||
if shift == 0:
|
||
return replace(period, key=target_key)
|
||
fname = "<generate>"
|
||
new_bars = [
|
||
[_transpose_symbol(sym, shift, fname, bar_no) for sym in bar]
|
||
for bar_no, bar in enumerate(period.bars, start=1)
|
||
]
|
||
return replace(period, key=target_key, bars=new_bars)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Prefix encoding
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
|
||
"""Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
|
||
|
||
*shift* is the semitone offset from user key to canonical (C/Am).
|
||
'.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
|
||
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
|
||
Each element (chord, hold, or NC) counts as one bar position.
|
||
|
||
Returns (ids, n_positions).
|
||
"""
|
||
ids: list[int] = []
|
||
n_positions = 0
|
||
for sym in chord_symbols:
|
||
if sym == ".":
|
||
ids.append(_HOLD)
|
||
elif sym == "NC":
|
||
ids.append(_NC)
|
||
else:
|
||
t = parse_chord_symbol(sym)
|
||
root = _transpose_note(t.root, shift)
|
||
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
|
||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||
n_positions += 1
|
||
return ids, n_positions
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Public API
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def generate_period(
|
||
model: ChordTransformer,
|
||
mode: str,
|
||
time: str,
|
||
subdivision: int,
|
||
style: str,
|
||
function: str,
|
||
key: str,
|
||
prefix: Optional[list[str]] = None,
|
||
temperature: float = 1.0,
|
||
top_p: float = 0.9,
|
||
max_tokens: int = 300,
|
||
n_bars: Optional[int] = None,
|
||
seed: Optional[int] = None,
|
||
repetition_penalty: float = 0.0,
|
||
) -> ChordPeriod:
|
||
"""Generate one harmonic period autoregressively.
|
||
|
||
The model generates in canonical key (C major / A minor); bar boundaries
|
||
are determined by position counting in the detokenizer (no BAR tokens).
|
||
The result is transposed to *key* before being returned.
|
||
|
||
Args:
|
||
model: Loaded ChordTransformer in eval mode.
|
||
mode: 'major' or 'minor'.
|
||
time: Time signature string, e.g. '4/4'.
|
||
subdivision: Positions per beat unit (4 or 8).
|
||
style: Style tag, e.g. 'H1K0'.
|
||
function: Section label, e.g. 'chorus'.
|
||
key: Target output key, e.g. 'F#_major' or 'B_minor'.
|
||
prefix: Chord symbols (in *key*) prepended as body context.
|
||
temperature: Sampling temperature (> 1 = more random).
|
||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||
max_tokens: Hard cap on generated tokens.
|
||
n_bars: Stop after this many complete bars in the output.
|
||
Counts bars from the prefix too. None = let the model decide.
|
||
seed: RNG seed for reproducibility.
|
||
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
|
||
FREE position. Specifically, for each candidate root R,
|
||
subtracts penalty * count(last_root → R) from logits,
|
||
where count is the number of times that bigram has appeared
|
||
in the generated body so far. 0.0 = disabled (default).
|
||
Suggested range: 0.5–1.5.
|
||
|
||
Returns:
|
||
ChordPeriod in *key*.
|
||
"""
|
||
if seed is not None:
|
||
torch.manual_seed(seed)
|
||
random.seed(seed)
|
||
|
||
device = next(model.parameters()).device
|
||
model.eval()
|
||
|
||
key_parts = key.split("_")
|
||
tonic = _parse_note_from_key(key_parts[0])
|
||
shift_to_canonical = _to_canonical_shift(mode, tonic)
|
||
|
||
style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other"
|
||
func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified"
|
||
if style_tok == "STYLE_other" and style != "other":
|
||
log.warning("unknown style %r — using STYLE_other", style)
|
||
if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""):
|
||
log.warning("unknown function %r — using FUNC_unspecified", function)
|
||
|
||
ids: list[int] = [
|
||
TOKEN_TO_ID["<BOS>"],
|
||
TOKEN_TO_ID[f"MODE_{mode}"],
|
||
TOKEN_TO_ID[f"TIME_{time}"],
|
||
TOKEN_TO_ID[f"SUB_{subdivision}"],
|
||
TOKEN_TO_ID[style_tok],
|
||
TOKEN_TO_ID[func_tok],
|
||
]
|
||
|
||
positions_per_bar = _expected_positions(time, subdivision)
|
||
|
||
pos_in_bar = 0
|
||
bars_completed = 0
|
||
if prefix:
|
||
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
|
||
ids.extend(encoded_prefix)
|
||
pos_in_bar = n_prefix_positions % positions_per_bar
|
||
bars_completed = n_prefix_positions // positions_per_bar
|
||
|
||
if n_bars is not None and bars_completed >= n_bars:
|
||
log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars)
|
||
|
||
last_id = ids[-1]
|
||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||
|
||
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
|
||
last_chord_root: int | None
|
||
bigram_counts: Counter[tuple[int, int]]
|
||
last_chord_root, bigram_counts = _init_bigram_state(ids)
|
||
|
||
with torch.no_grad():
|
||
for _ in range(max_tokens):
|
||
if len(ids) >= context_limit:
|
||
break
|
||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||
logits = model(inp)[0, -1] # [vocab_size]
|
||
# _grammar_bias returns a shared module-level singleton; clone before
|
||
# any in-place edit so EOS-blocking and the repetition penalty never
|
||
# leak across positions or across calls.
|
||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone()
|
||
if n_bars is not None and bars_completed < n_bars:
|
||
bias[_EOS] = float("-inf") # don't let the model stop early
|
||
|
||
# Bigram repetition penalty — applied at FREE positions only.
|
||
# Penalises root transitions that have already occurred in this period.
|
||
if (
|
||
repetition_penalty > 0.0
|
||
and last_chord_root is not None
|
||
and not _is_mid_chord(last_id)
|
||
):
|
||
for root_id in range(_ROOT_START, _ROOT_END + 1):
|
||
count = bigram_counts.get((last_chord_root, root_id), 0)
|
||
if count:
|
||
# Cap total reduction at 3.0 logits so NC/HOLD don't
|
||
# flood the distribution when all roots are heavily penalised.
|
||
bias[root_id] -= min(repetition_penalty * count, 3.0)
|
||
|
||
logits = logits + bias.to(device)
|
||
token_id = _sample_top_p(logits, temperature, top_p)
|
||
ids.append(token_id)
|
||
last_id = token_id
|
||
|
||
# When a complete chord group is closed, update bigram state.
|
||
if _BASS_START <= token_id <= _BASS_END:
|
||
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
|
||
if last_chord_root is not None:
|
||
bigram_counts[(last_chord_root, new_root)] += 1
|
||
last_chord_root = new_root
|
||
|
||
# Advance position counter when a body position is completed
|
||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||
if pos_in_bar == 0:
|
||
bars_completed += 1
|
||
if n_bars is not None and bars_completed >= n_bars:
|
||
break
|
||
|
||
if token_id == _EOS:
|
||
break
|
||
|
||
# Ensure sequence ends with EOS at a bar boundary.
|
||
if ids[-1] != _EOS:
|
||
if _is_mid_chord(last_id):
|
||
# Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail.
|
||
# Mid-chord tokens never advance pos_in_bar, so the counter stays valid.
|
||
while ids and _is_mid_chord(ids[-1]):
|
||
ids.pop()
|
||
# Pad the partial bar with HOLDs so EOS lands on a bar boundary.
|
||
while pos_in_bar > 0:
|
||
ids.append(_HOLD)
|
||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||
ids.append(_EOS)
|
||
|
||
log.debug("generated %d tokens total (prompt + body)", len(ids))
|
||
|
||
period = detokenize_to_period(ids)
|
||
return _transpose_to_key(period, key)
|