Files
hamori/src/generate.py
H1K0 dc037b0895 fix: clone grammar bias per step in generate_period
_grammar_bias returned a shared module-level singleton that the loop
mutated in place (EOS block + repetition penalty). The penalty thus
accumulated across positions within a call and persisted across calls,
collapsing output to HOLD/NC until process restart. Clone the bias each
step so edits stay local. Add regression tests guarding the invariant.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-04 21:14:04 +03:00

386 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Autoregressive generation of harmonic periods.
Public API:
generate_period(model, mode, time, subdivision, style, function, key,
prefix=None, temperature=1.0, top_p=0.9,
max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod
"""
from __future__ import annotations
import logging
import random
from collections import Counter
from dataclasses import replace
from typing import Optional
import torch
import torch.nn.functional as F
from src.chord_parser import parse_chord_symbol
from src.model import ChordTransformer
from src.tokenizer import (
VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period,
_qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol,
_expected_positions,
)
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py)
# ---------------------------------------------------------------------------
_VOCAB_SIZE = len(VOCAB)
_EOS = TOKEN_TO_ID["<EOS>"]
_HOLD = TOKEN_TO_ID["HOLD"]
_NC = TOKEN_TO_ID["NC"]
_ROOT_START = TOKEN_TO_ID["ROOT_C"]
_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive
_QUAL_START = TOKEN_TO_ID["QUAL_maj"]
_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive
_EXT_START = TOKEN_TO_ID["EXT_none"]
_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive
_BASS_START = TOKEN_TO_ID["BASS_root"]
_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive
# ---------------------------------------------------------------------------
# Grammar masks
# Additive logit bias: 0.0 = allowed, -inf = blocked.
# ---------------------------------------------------------------------------
def _make_bias(allowed: list[int]) -> torch.Tensor:
b = torch.full((_VOCAB_SIZE,), float("-inf"))
for i in allowed:
b[i] = 0.0
return b
_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC]
_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS
_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed
_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1)))
_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1)))
_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities
_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
# ---------------------------------------------------------------------------
# Repetition-penalty helpers
# ---------------------------------------------------------------------------
def _scan_chord_roots(ids: list[int]) -> list[int]:
"""Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*."""
roots: list[int] = []
i, n = 0, len(ids)
while i < n:
tok = ids[i]
if (
_ROOT_START <= tok <= _ROOT_END
and i + 3 < n
and _QUAL_START <= ids[i + 1] <= _QUAL_END
and _EXT_START <= ids[i + 2] <= _EXT_END
and _BASS_START <= ids[i + 3] <= _BASS_END
):
roots.append(tok)
i += 4
continue
i += 1
return roots
def _init_bigram_state(
ids: list[int],
) -> tuple[int | None, Counter[tuple[int, int]]]:
"""Build bigram counts and last-root from token ids already in the buffer.
Used to seed repetition tracking from the prefix / metadata tokens.
Returns:
(last_chord_root_id, bigram_counts)
last_chord_root_id is None when no complete chord has been seen yet.
"""
roots = _scan_chord_roots(ids)
counts: Counter[tuple[int, int]] = Counter()
for j in range(1, len(roots)):
counts[(roots[j - 1], roots[j])] += 1
return (roots[-1] if roots else None), counts
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
"""Additive logit bias enforcing token grammar after *last_id*.
EOS is allowed only at bar boundaries (pos_in_bar == 0).
"""
if _ROOT_START <= last_id <= _ROOT_END:
return _BIAS_QUAL
if _QUAL_START <= last_id <= _QUAL_END:
return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT
if _EXT_START <= last_id <= _EXT_END:
return _BIAS_BASS
# FREE position: after BASS, HOLD, NC, or the last metadata token
if pos_in_bar == 0:
return _BIAS_FREE_EOS # at bar boundary — EOS is valid
return _BIAS_FREE # mid-bar — EOS not yet valid
def _is_mid_chord(last_id: int) -> bool:
"""True when the next required token is QUAL, EXT, or BASS."""
return (
(_ROOT_START <= last_id <= _ROOT_END) or
(_QUAL_START <= last_id <= _QUAL_END) or
(_EXT_START <= last_id <= _EXT_END)
)
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int:
"""Sample one token index via nucleus (top-p) sampling."""
logits = logits / max(temperature, 1e-8)
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
cut = (cumulative - sorted_probs) >= top_p
sorted_probs[cut] = 0.0
sorted_probs /= sorted_probs.sum()
chosen = torch.multinomial(sorted_probs, 1).item()
return int(sorted_idx[chosen].item())
# ---------------------------------------------------------------------------
# Transposition helpers
# ---------------------------------------------------------------------------
def _transpose_note(note: str, shift: int) -> str:
return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12]
def _to_canonical_shift(mode: str, tonic: str) -> int:
"""Semitone shift to transpose from *tonic* to canonical (C=0 / A=9)."""
canonical = 0 if mode == "major" else 9
return (canonical - _NOTE_INDEX[tonic]) % 12
def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
"""Transpose a canonical (C/Am) ChordPeriod to *target_key*."""
parts = target_key.split("_")
tonic = _parse_note_from_key(parts[0])
mode = parts[-1]
canonical = 0 if mode == "major" else 9
shift = (_NOTE_INDEX[tonic] - canonical) % 12
if shift == 0:
return replace(period, key=target_key)
fname = "<generate>"
new_bars = [
[_transpose_symbol(sym, shift, fname, bar_no) for sym in bar]
for bar_no, bar in enumerate(period.bars, start=1)
]
return replace(period, key=target_key, bars=new_bars)
# ---------------------------------------------------------------------------
# Prefix encoding
# ---------------------------------------------------------------------------
def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]:
"""Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs.
*shift* is the semitone offset from user key to canonical (C/Am).
'.' encodes as a single HOLD token; 'NC' encodes as a single NC token.
A chord symbol encodes as four tokens: ROOT QUAL EXT BASS.
Each element (chord, hold, or NC) counts as one bar position.
Returns (ids, n_positions).
"""
ids: list[int] = []
n_positions = 0
for sym in chord_symbols:
if sym == ".":
ids.append(_HOLD)
elif sym == "NC":
ids.append(_NC)
else:
t = parse_chord_symbol(sym)
root = _transpose_note(t.root, shift)
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
n_positions += 1
return ids, n_positions
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def generate_period(
model: ChordTransformer,
mode: str,
time: str,
subdivision: int,
style: str,
function: str,
key: str,
prefix: Optional[list[str]] = None,
temperature: float = 1.0,
top_p: float = 0.9,
max_tokens: int = 300,
n_bars: Optional[int] = None,
seed: Optional[int] = None,
repetition_penalty: float = 0.0,
) -> ChordPeriod:
"""Generate one harmonic period autoregressively.
The model generates in canonical key (C major / A minor); bar boundaries
are determined by position counting in the detokenizer (no BAR tokens).
The result is transposed to *key* before being returned.
Args:
model: Loaded ChordTransformer in eval mode.
mode: 'major' or 'minor'.
time: Time signature string, e.g. '4/4'.
subdivision: Positions per beat unit (4 or 8).
style: Style tag, e.g. 'H1K0'.
function: Section label, e.g. 'chorus'.
key: Target output key, e.g. 'F#_major' or 'B_minor'.
prefix: Chord symbols (in *key*) prepended as body context.
temperature: Sampling temperature (> 1 = more random).
top_p: Nucleus cutoff probability (0 < top_p <= 1).
max_tokens: Hard cap on generated tokens.
n_bars: Stop after this many complete bars in the output.
Counts bars from the prefix too. None = let the model decide.
seed: RNG seed for reproducibility.
repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each
FREE position. Specifically, for each candidate root R,
subtracts penalty * count(last_root → R) from logits,
where count is the number of times that bigram has appeared
in the generated body so far. 0.0 = disabled (default).
Suggested range: 0.51.5.
Returns:
ChordPeriod in *key*.
"""
if seed is not None:
torch.manual_seed(seed)
random.seed(seed)
device = next(model.parameters()).device
model.eval()
key_parts = key.split("_")
tonic = _parse_note_from_key(key_parts[0])
shift_to_canonical = _to_canonical_shift(mode, tonic)
style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other"
func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified"
if style_tok == "STYLE_other" and style != "other":
log.warning("unknown style %r — using STYLE_other", style)
if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""):
log.warning("unknown function %r — using FUNC_unspecified", function)
ids: list[int] = [
TOKEN_TO_ID["<BOS>"],
TOKEN_TO_ID[f"MODE_{mode}"],
TOKEN_TO_ID[f"TIME_{time}"],
TOKEN_TO_ID[f"SUB_{subdivision}"],
TOKEN_TO_ID[style_tok],
TOKEN_TO_ID[func_tok],
]
positions_per_bar = _expected_positions(time, subdivision)
pos_in_bar = 0
bars_completed = 0
if prefix:
encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical)
ids.extend(encoded_prefix)
pos_in_bar = n_prefix_positions % positions_per_bar
bars_completed = n_prefix_positions // positions_per_bar
if n_bars is not None and bars_completed >= n_bars:
log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars)
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
# Repetition-penalty state: seed from prefix/metadata tokens already in ids.
last_chord_root: int | None
bigram_counts: Counter[tuple[int, int]]
last_chord_root, bigram_counts = _init_bigram_state(ids)
with torch.no_grad():
for _ in range(max_tokens):
if len(ids) >= context_limit:
break
inp = torch.tensor([ids], dtype=torch.long, device=device)
logits = model(inp)[0, -1] # [vocab_size]
# _grammar_bias returns a shared module-level singleton; clone before
# any in-place edit so EOS-blocking and the repetition penalty never
# leak across positions or across calls.
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone()
if n_bars is not None and bars_completed < n_bars:
bias[_EOS] = float("-inf") # don't let the model stop early
# Bigram repetition penalty — applied at FREE positions only.
# Penalises root transitions that have already occurred in this period.
if (
repetition_penalty > 0.0
and last_chord_root is not None
and not _is_mid_chord(last_id)
):
for root_id in range(_ROOT_START, _ROOT_END + 1):
count = bigram_counts.get((last_chord_root, root_id), 0)
if count:
# Cap total reduction at 3.0 logits so NC/HOLD don't
# flood the distribution when all roots are heavily penalised.
bias[root_id] -= min(repetition_penalty * count, 3.0)
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)
last_id = token_id
# When a complete chord group is closed, update bigram state.
if _BASS_START <= token_id <= _BASS_END:
new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS
if last_chord_root is not None:
bigram_counts[(last_chord_root, new_root)] += 1
last_chord_root = new_root
# Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
if pos_in_bar == 0:
bars_completed += 1
if n_bars is not None and bars_completed >= n_bars:
break
if token_id == _EOS:
break
# Ensure sequence ends with EOS at a bar boundary.
if ids[-1] != _EOS:
if _is_mid_chord(last_id):
# Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail.
# Mid-chord tokens never advance pos_in_bar, so the counter stays valid.
while ids and _is_mid_chord(ids[-1]):
ids.pop()
# Pad the partial bar with HOLDs so EOS lands on a bar boundary.
while pos_in_bar > 0:
ids.append(_HOLD)
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
ids.append(_EOS)
log.debug("generated %d tokens total (prompt + body)", len(ids))
period = detokenize_to_period(ids)
return _transpose_to_key(period, key)