feat: add generate module and CLI; fix tokenizer minor issues

src/generate.py: autoregressive generation with top-p sampling, grammar
masking (ROOT→QUAL→EXT→BASS; EOS only at bar boundary), key transposition,
and optional chord prefix.  Partial bars on context truncation are padded
with HOLDs rather than discarded.

scripts/generate.py: CLI wrapping generate_period — accepts mode, key,
time, subdivision, style, function, prefix, temperature, top-p, seed,
tempo; writes .chord and optional MIDI.

src/tokenizer.py: fix docstring vocab size (81→84); normalize redundant
BASS_<note>==root to no slash in _tokens_to_symbol.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 14:28:44 +03:00
parent 8a73394df9
commit e657d9edb5
3 changed files with 446 additions and 2 deletions
+283
View File
@@ -0,0 +1,283 @@
"""Autoregressive generation of harmonic periods.
Public API:
generate_period(model, mode, time, subdivision, style, function, key,
prefix=None, temperature=1.0, top_p=0.9,
max_tokens=300, seed=None) -> ChordPeriod
"""
from __future__ import annotations
import logging
import random
from dataclasses import replace
from typing import Optional
import torch
import torch.nn.functional as F
from src.chord_parser import parse_chord_symbol
from src.model import ChordTransformer
from src.tokenizer import (
VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period,
_qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol,
_expected_positions,
)
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py)
# ---------------------------------------------------------------------------
_VOCAB_SIZE = len(VOCAB)
_EOS = TOKEN_TO_ID["<EOS>"]
_HOLD = TOKEN_TO_ID["HOLD"]
_NC = TOKEN_TO_ID["NC"]
_ROOT_START = TOKEN_TO_ID["ROOT_C"]
_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive
_QUAL_START = TOKEN_TO_ID["QUAL_maj"]
_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive
_EXT_START = TOKEN_TO_ID["EXT_none"]
_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive
_BASS_START = TOKEN_TO_ID["BASS_root"]
_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive
# ---------------------------------------------------------------------------
# Grammar masks
# Additive logit bias: 0.0 = allowed, -inf = blocked.
# ---------------------------------------------------------------------------
def _make_bias(allowed: list[int]) -> torch.Tensor:
b = torch.full((_VOCAB_SIZE,), float("-inf"))
for i in allowed:
b[i] = 0.0
return b
_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC]
_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS
_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed
_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1)))
_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1)))
_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities
_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
"""Additive logit bias enforcing token grammar after *last_id*.
EOS is allowed only at bar boundaries (pos_in_bar == 0).
"""
if _ROOT_START <= last_id <= _ROOT_END:
return _BIAS_QUAL
if _QUAL_START <= last_id <= _QUAL_END:
return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT
if _EXT_START <= last_id <= _EXT_END:
return _BIAS_BASS
# FREE position: after BASS, HOLD, NC, or the last metadata token
if pos_in_bar == 0:
return _BIAS_FREE_EOS # at bar boundary — EOS is valid
return _BIAS_FREE # mid-bar — EOS not yet valid
def _is_mid_chord(last_id: int) -> bool:
"""True when the next required token is QUAL, EXT, or BASS."""
return (
(_ROOT_START <= last_id <= _ROOT_END) or
(_QUAL_START <= last_id <= _QUAL_END) or
(_EXT_START <= last_id <= _EXT_END)
)
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int:
"""Sample one token index via nucleus (top-p) sampling."""
logits = logits / max(temperature, 1e-8)
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
cut = (cumulative - sorted_probs) >= top_p
sorted_probs[cut] = 0.0
sorted_probs /= sorted_probs.sum()
chosen = torch.multinomial(sorted_probs, 1).item()
return int(sorted_idx[chosen].item())
# ---------------------------------------------------------------------------
# Transposition helpers
# ---------------------------------------------------------------------------
def _transpose_note(note: str, shift: int) -> str:
return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12]
def _to_canonical_shift(mode: str, tonic: str) -> int:
"""Semitone shift to transpose from *tonic* to canonical (C=0 / A=9)."""
canonical = 0 if mode == "major" else 9
return (canonical - _NOTE_INDEX[tonic]) % 12
def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
"""Transpose a canonical (C/Am) ChordPeriod to *target_key*."""
parts = target_key.split("_")
tonic = _parse_note_from_key(parts[0])
mode = parts[-1]
canonical = 0 if mode == "major" else 9
shift = (_NOTE_INDEX[tonic] - canonical) % 12
if shift == 0:
return replace(period, key=target_key)
fname = "<generate>"
new_bars = [
[_transpose_symbol(sym, shift, fname, bar_no) for sym in bar]
for bar_no, bar in enumerate(period.bars, start=1)
]
return replace(period, key=target_key, bars=new_bars)
# ---------------------------------------------------------------------------
# Prefix encoding
# ---------------------------------------------------------------------------
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
"""Encode chord symbols (in user key) as canonical body token IDs.
*shift* is the semitone offset from user key to canonical (C/Am).
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
"""
ids: list[int] = []
for sym in chord_symbols:
t = parse_chord_symbol(sym)
root = _transpose_note(t.root, shift)
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
return ids
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def generate_period(
model: ChordTransformer,
mode: str,
time: str,
subdivision: int,
style: str,
function: str,
key: str,
prefix: Optional[list[str]] = None,
temperature: float = 1.0,
top_p: float = 0.9,
max_tokens: int = 300,
seed: Optional[int] = None,
) -> ChordPeriod:
"""Generate one harmonic period autoregressively.
The model generates in canonical key (C major / A minor); bar boundaries
are determined by position counting in the detokenizer (no BAR tokens).
The result is transposed to *key* before being returned.
Args:
model: Loaded ChordTransformer in eval mode.
mode: 'major' or 'minor'.
time: Time signature string, e.g. '4/4'.
subdivision: Positions per beat unit (4 or 8).
style: Style tag, e.g. 'H1K0'.
function: Section label, e.g. 'chorus'.
key: Target output key, e.g. 'F#_major' or 'B_minor'.
prefix: Chord symbols (in *key*) prepended as body context.
temperature: Sampling temperature (> 1 = more random).
top_p: Nucleus cutoff probability (0 < top_p <= 1).
max_tokens: Hard cap on generated tokens.
seed: RNG seed for reproducibility.
Returns:
ChordPeriod in *key*.
"""
if seed is not None:
torch.manual_seed(seed)
random.seed(seed)
device = next(model.parameters()).device
model.eval()
key_parts = key.split("_")
tonic = _parse_note_from_key(key_parts[0])
shift_to_canonical = _to_canonical_shift(mode, tonic)
style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other"
func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified"
if style_tok == "STYLE_other" and style != "other":
log.warning("unknown style %r — using STYLE_other", style)
if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""):
log.warning("unknown function %r — using FUNC_unspecified", function)
ids: list[int] = [
TOKEN_TO_ID["<BOS>"],
TOKEN_TO_ID[f"MODE_{mode}"],
TOKEN_TO_ID[f"TIME_{time}"],
TOKEN_TO_ID[f"SUB_{subdivision}"],
TOKEN_TO_ID[style_tok],
TOKEN_TO_ID[func_tok],
]
positions_per_bar = _expected_positions(time, subdivision)
encoded_prefix: list[int] = []
if prefix:
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
ids.extend(encoded_prefix)
# Compute starting position-in-bar after the prefix
# Each chord in prefix = 4 tokens = 1 position
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
last_id = ids[-1]
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
with torch.no_grad():
for _ in range(max_tokens):
if len(ids) >= context_limit:
break
inp = torch.tensor([ids], dtype=torch.long, device=device)
logits = model(inp)[0, -1] # [vocab_size]
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
logits = logits + bias.to(device)
token_id = _sample_top_p(logits, temperature, top_p)
ids.append(token_id)
last_id = token_id
# Advance position counter when a body position is completed
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
if token_id == _EOS:
break
# Ensure sequence ends with EOS at a bar boundary.
if ids[-1] != _EOS:
if _is_mid_chord(last_id):
# Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail.
# Mid-chord tokens never advance pos_in_bar, so the counter stays valid.
while ids and _is_mid_chord(ids[-1]):
ids.pop()
# Pad the partial bar with HOLDs so EOS lands on a bar boundary.
while pos_in_bar > 0:
ids.append(_HOLD)
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
ids.append(_EOS)
log.debug("generated %d tokens total (prompt + body)", len(ids))
period = detokenize_to_period(ids)
return _transpose_to_key(period, key)