"""Autoregressive generation of harmonic periods. Public API: generate_period(model, mode, time, subdivision, style, function, key, prefix=None, temperature=1.0, top_p=0.9, max_tokens=300, seed=None, repetition_penalty=0.0) -> ChordPeriod """ from __future__ import annotations import logging import random from collections import Counter from dataclasses import replace from typing import Optional import torch import torch.nn.functional as F from src.chord_parser import parse_chord_symbol from src.model import ChordTransformer from src.tokenizer import ( VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period, _qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol, _expected_positions, ) log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py) # --------------------------------------------------------------------------- _VOCAB_SIZE = len(VOCAB) _EOS = TOKEN_TO_ID[""] _HOLD = TOKEN_TO_ID["HOLD"] _NC = TOKEN_TO_ID["NC"] _ROOT_START = TOKEN_TO_ID["ROOT_C"] _ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive _QUAL_START = TOKEN_TO_ID["QUAL_maj"] _QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive _EXT_START = TOKEN_TO_ID["EXT_none"] _EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive _BASS_START = TOKEN_TO_ID["BASS_root"] _BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive # --------------------------------------------------------------------------- # Grammar masks # Additive logit bias: 0.0 = allowed, -inf = blocked. # --------------------------------------------------------------------------- def _make_bias(allowed: list[int]) -> torch.Tensor: b = torch.full((_VOCAB_SIZE,), float("-inf")) for i in allowed: b[i] = 0.0 return b _free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC] _BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS _BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed _BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1))) _BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1))) _BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities _BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1))) # QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid _ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]}) # --------------------------------------------------------------------------- # Repetition-penalty helpers # --------------------------------------------------------------------------- def _scan_chord_roots(ids: list[int]) -> list[int]: """Return ROOT token IDs for every complete ROOT-QUAL-EXT-BASS group in *ids*.""" roots: list[int] = [] i, n = 0, len(ids) while i < n: tok = ids[i] if ( _ROOT_START <= tok <= _ROOT_END and i + 3 < n and _QUAL_START <= ids[i + 1] <= _QUAL_END and _EXT_START <= ids[i + 2] <= _EXT_END and _BASS_START <= ids[i + 3] <= _BASS_END ): roots.append(tok) i += 4 continue i += 1 return roots def _init_bigram_state( ids: list[int], ) -> tuple[int | None, Counter[tuple[int, int]]]: """Build bigram counts and last-root from token ids already in the buffer. Used to seed repetition tracking from the prefix / metadata tokens. Returns: (last_chord_root_id, bigram_counts) last_chord_root_id is None when no complete chord has been seen yet. """ roots = _scan_chord_roots(ids) counts: Counter[tuple[int, int]] = Counter() for j in range(1, len(roots)): counts[(roots[j - 1], roots[j])] += 1 return (roots[-1] if roots else None), counts def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor: """Additive logit bias enforcing token grammar after *last_id*. EOS is allowed only at bar boundaries (pos_in_bar == 0). """ if _ROOT_START <= last_id <= _ROOT_END: return _BIAS_QUAL if _QUAL_START <= last_id <= _QUAL_END: return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT if _EXT_START <= last_id <= _EXT_END: return _BIAS_BASS # FREE position: after BASS, HOLD, NC, or the last metadata token if pos_in_bar == 0: return _BIAS_FREE_EOS # at bar boundary — EOS is valid return _BIAS_FREE # mid-bar — EOS not yet valid def _is_mid_chord(last_id: int) -> bool: """True when the next required token is QUAL, EXT, or BASS.""" return ( (_ROOT_START <= last_id <= _ROOT_END) or (_QUAL_START <= last_id <= _QUAL_END) or (_EXT_START <= last_id <= _EXT_END) ) # --------------------------------------------------------------------------- # Sampling # --------------------------------------------------------------------------- def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int: """Sample one token index via nucleus (top-p) sampling.""" logits = logits / max(temperature, 1e-8) probs = F.softmax(logits, dim=-1) sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) cut = (cumulative - sorted_probs) >= top_p sorted_probs[cut] = 0.0 sorted_probs /= sorted_probs.sum() chosen = torch.multinomial(sorted_probs, 1).item() return int(sorted_idx[chosen].item()) # --------------------------------------------------------------------------- # Transposition helpers # --------------------------------------------------------------------------- def _transpose_note(note: str, shift: int) -> str: return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12] def _to_canonical_shift(mode: str, tonic: str) -> int: """Semitone shift to transpose from *tonic* to canonical (C=0 / A=9).""" canonical = 0 if mode == "major" else 9 return (canonical - _NOTE_INDEX[tonic]) % 12 def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod: """Transpose a canonical (C/Am) ChordPeriod to *target_key*.""" parts = target_key.split("_") tonic = _parse_note_from_key(parts[0]) mode = parts[-1] canonical = 0 if mode == "major" else 9 shift = (_NOTE_INDEX[tonic] - canonical) % 12 if shift == 0: return replace(period, key=target_key) fname = "" new_bars = [ [_transpose_symbol(sym, shift, fname, bar_no) for sym in bar] for bar_no, bar in enumerate(period.bars, start=1) ] return replace(period, key=target_key, bars=new_bars) # --------------------------------------------------------------------------- # Prefix encoding # --------------------------------------------------------------------------- def _encode_prefix(chord_symbols: list[str], shift: int) -> tuple[list[int], int]: """Encode prefix tokens (chord symbols, '.', 'NC') as canonical body token IDs. *shift* is the semitone offset from user key to canonical (C/Am). '.' encodes as a single HOLD token; 'NC' encodes as a single NC token. A chord symbol encodes as four tokens: ROOT QUAL EXT BASS. Each element (chord, hold, or NC) counts as one bar position. Returns (ids, n_positions). """ ids: list[int] = [] n_positions = 0 for sym in chord_symbols: if sym == ".": ids.append(_HOLD) elif sym == "NC": ids.append(_NC) else: t = parse_chord_symbol(sym) root = _transpose_note(t.root, shift) bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift) ids.append(TOKEN_TO_ID[f"ROOT_{root}"]) ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) ids.append(TOKEN_TO_ID[f"BASS_{bass}"]) n_positions += 1 return ids, n_positions # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def generate_period( model: ChordTransformer, mode: str, time: str, subdivision: int, style: str, function: str, key: str, prefix: Optional[list[str]] = None, temperature: float = 1.0, top_p: float = 0.9, max_tokens: int = 300, n_bars: Optional[int] = None, seed: Optional[int] = None, repetition_penalty: float = 0.0, ) -> ChordPeriod: """Generate one harmonic period autoregressively. The model generates in canonical key (C major / A minor); bar boundaries are determined by position counting in the detokenizer (no BAR tokens). The result is transposed to *key* before being returned. Args: model: Loaded ChordTransformer in eval mode. mode: 'major' or 'minor'. time: Time signature string, e.g. '4/4'. subdivision: Positions per beat unit (4 or 8). style: Style tag, e.g. 'H1K0'. function: Section label, e.g. 'chorus'. key: Target output key, e.g. 'F#_major' or 'B_minor'. prefix: Chord symbols (in *key*) prepended as body context. temperature: Sampling temperature (> 1 = more random). top_p: Nucleus cutoff probability (0 < top_p <= 1). max_tokens: Hard cap on generated tokens. n_bars: Stop after this many complete bars in the output. Counts bars from the prefix too. None = let the model decide. seed: RNG seed for reproducibility. repetition_penalty: Per-occurrence penalty subtracted from ROOT logits at each FREE position. Specifically, for each candidate root R, subtracts penalty * count(last_root → R) from logits, where count is the number of times that bigram has appeared in the generated body so far. 0.0 = disabled (default). Suggested range: 0.5–1.5. Returns: ChordPeriod in *key*. """ if seed is not None: torch.manual_seed(seed) random.seed(seed) device = next(model.parameters()).device model.eval() key_parts = key.split("_") tonic = _parse_note_from_key(key_parts[0]) shift_to_canonical = _to_canonical_shift(mode, tonic) style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other" func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified" if style_tok == "STYLE_other" and style != "other": log.warning("unknown style %r — using STYLE_other", style) if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""): log.warning("unknown function %r — using FUNC_unspecified", function) ids: list[int] = [ TOKEN_TO_ID[""], TOKEN_TO_ID[f"MODE_{mode}"], TOKEN_TO_ID[f"TIME_{time}"], TOKEN_TO_ID[f"SUB_{subdivision}"], TOKEN_TO_ID[style_tok], TOKEN_TO_ID[func_tok], ] positions_per_bar = _expected_positions(time, subdivision) pos_in_bar = 0 bars_completed = 0 if prefix: encoded_prefix, n_prefix_positions = _encode_prefix(prefix, shift_to_canonical) ids.extend(encoded_prefix) pos_in_bar = n_prefix_positions % positions_per_bar bars_completed = n_prefix_positions // positions_per_bar if n_bars is not None and bars_completed >= n_bars: log.warning("prefix already spans %d bars (>= requested %d)", bars_completed, n_bars) last_id = ids[-1] context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max # Repetition-penalty state: seed from prefix/metadata tokens already in ids. last_chord_root: int | None bigram_counts: Counter[tuple[int, int]] last_chord_root, bigram_counts = _init_bigram_state(ids) with torch.no_grad(): for _ in range(max_tokens): if len(ids) >= context_limit: break inp = torch.tensor([ids], dtype=torch.long, device=device) logits = model(inp)[0, -1] # [vocab_size] # _grammar_bias returns a shared module-level singleton; clone before # any in-place edit so EOS-blocking and the repetition penalty never # leak across positions or across calls. bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar).clone() if n_bars is not None and bars_completed < n_bars: bias[_EOS] = float("-inf") # don't let the model stop early # Bigram repetition penalty — applied at FREE positions only. # Penalises root transitions that have already occurred in this period. if ( repetition_penalty > 0.0 and last_chord_root is not None and not _is_mid_chord(last_id) ): for root_id in range(_ROOT_START, _ROOT_END + 1): count = bigram_counts.get((last_chord_root, root_id), 0) if count: # Cap total reduction at 3.0 logits so NC/HOLD don't # flood the distribution when all roots are heavily penalised. bias[root_id] -= min(repetition_penalty * count, 3.0) logits = logits + bias.to(device) token_id = _sample_top_p(logits, temperature, top_p) ids.append(token_id) last_id = token_id # When a complete chord group is closed, update bigram state. if _BASS_START <= token_id <= _BASS_END: new_root = ids[-4] # ROOT is 3 slots before the just-appended BASS if last_chord_root is not None: bigram_counts[(last_chord_root, new_root)] += 1 last_chord_root = new_root # Advance position counter when a body position is completed if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC): pos_in_bar = (pos_in_bar + 1) % positions_per_bar if pos_in_bar == 0: bars_completed += 1 if n_bars is not None and bars_completed >= n_bars: break if token_id == _EOS: break # Ensure sequence ends with EOS at a bar boundary. if ids[-1] != _EOS: if _is_mid_chord(last_id): # Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail. # Mid-chord tokens never advance pos_in_bar, so the counter stays valid. while ids and _is_mid_chord(ids[-1]): ids.pop() # Pad the partial bar with HOLDs so EOS lands on a bar boundary. while pos_in_bar > 0: ids.append(_HOLD) pos_in_bar = (pos_in_bar + 1) % positions_per_bar ids.append(_EOS) log.debug("generated %d tokens total (prompt + body)", len(ids)) period = detokenize_to_period(ids) return _transpose_to_key(period, key)