feat: add generate module and CLI; fix tokenizer minor issues
src/generate.py: autoregressive generation with top-p sampling, grammar masking (ROOT→QUAL→EXT→BASS; EOS only at bar boundary), key transposition, and optional chord prefix. Partial bars on context truncation are padded with HOLDs rather than discarded. scripts/generate.py: CLI wrapping generate_period — accepts mode, key, time, subdivision, style, function, prefix, temperature, top-p, seed, tempo; writes .chord and optional MIDI. src/tokenizer.py: fix docstring vocab size (81→84); normalize redundant BASS_<note>==root to no slash in _tokens_to_symbol. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+283
@@ -0,0 +1,283 @@
|
||||
"""Autoregressive generation of harmonic periods.
|
||||
|
||||
Public API:
|
||||
generate_period(model, mode, time, subdivision, style, function, key,
|
||||
prefix=None, temperature=1.0, top_p=0.9,
|
||||
max_tokens=300, seed=None) -> ChordPeriod
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import replace
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from src.chord_parser import parse_chord_symbol
|
||||
from src.model import ChordTransformer
|
||||
from src.tokenizer import (
|
||||
VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period,
|
||||
_qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol,
|
||||
_expected_positions,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VOCAB_SIZE = len(VOCAB)
|
||||
_EOS = TOKEN_TO_ID["<EOS>"]
|
||||
_HOLD = TOKEN_TO_ID["HOLD"]
|
||||
_NC = TOKEN_TO_ID["NC"]
|
||||
_ROOT_START = TOKEN_TO_ID["ROOT_C"]
|
||||
_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive
|
||||
_QUAL_START = TOKEN_TO_ID["QUAL_maj"]
|
||||
_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive
|
||||
_EXT_START = TOKEN_TO_ID["EXT_none"]
|
||||
_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive
|
||||
_BASS_START = TOKEN_TO_ID["BASS_root"]
|
||||
_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grammar masks
|
||||
# Additive logit bias: 0.0 = allowed, -inf = blocked.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_bias(allowed: list[int]) -> torch.Tensor:
|
||||
b = torch.full((_VOCAB_SIZE,), float("-inf"))
|
||||
for i in allowed:
|
||||
b[i] = 0.0
|
||||
return b
|
||||
|
||||
|
||||
_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC]
|
||||
_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS
|
||||
_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed
|
||||
_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1)))
|
||||
_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1)))
|
||||
_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities
|
||||
_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
|
||||
|
||||
# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid
|
||||
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
||||
|
||||
|
||||
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
||||
"""Additive logit bias enforcing token grammar after *last_id*.
|
||||
|
||||
EOS is allowed only at bar boundaries (pos_in_bar == 0).
|
||||
"""
|
||||
if _ROOT_START <= last_id <= _ROOT_END:
|
||||
return _BIAS_QUAL
|
||||
if _QUAL_START <= last_id <= _QUAL_END:
|
||||
return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT
|
||||
if _EXT_START <= last_id <= _EXT_END:
|
||||
return _BIAS_BASS
|
||||
# FREE position: after BASS, HOLD, NC, or the last metadata token
|
||||
if pos_in_bar == 0:
|
||||
return _BIAS_FREE_EOS # at bar boundary — EOS is valid
|
||||
return _BIAS_FREE # mid-bar — EOS not yet valid
|
||||
|
||||
|
||||
def _is_mid_chord(last_id: int) -> bool:
|
||||
"""True when the next required token is QUAL, EXT, or BASS."""
|
||||
return (
|
||||
(_ROOT_START <= last_id <= _ROOT_END) or
|
||||
(_QUAL_START <= last_id <= _QUAL_END) or
|
||||
(_EXT_START <= last_id <= _EXT_END)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int:
|
||||
"""Sample one token index via nucleus (top-p) sampling."""
|
||||
logits = logits / max(temperature, 1e-8)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
||||
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
||||
cut = (cumulative - sorted_probs) >= top_p
|
||||
sorted_probs[cut] = 0.0
|
||||
sorted_probs /= sorted_probs.sum()
|
||||
chosen = torch.multinomial(sorted_probs, 1).item()
|
||||
return int(sorted_idx[chosen].item())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transposition helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _transpose_note(note: str, shift: int) -> str:
|
||||
return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12]
|
||||
|
||||
|
||||
def _to_canonical_shift(mode: str, tonic: str) -> int:
|
||||
"""Semitone shift to transpose from *tonic* to canonical (C=0 / A=9)."""
|
||||
canonical = 0 if mode == "major" else 9
|
||||
return (canonical - _NOTE_INDEX[tonic]) % 12
|
||||
|
||||
|
||||
def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
|
||||
"""Transpose a canonical (C/Am) ChordPeriod to *target_key*."""
|
||||
parts = target_key.split("_")
|
||||
tonic = _parse_note_from_key(parts[0])
|
||||
mode = parts[-1]
|
||||
canonical = 0 if mode == "major" else 9
|
||||
shift = (_NOTE_INDEX[tonic] - canonical) % 12
|
||||
if shift == 0:
|
||||
return replace(period, key=target_key)
|
||||
fname = "<generate>"
|
||||
new_bars = [
|
||||
[_transpose_symbol(sym, shift, fname, bar_no) for sym in bar]
|
||||
for bar_no, bar in enumerate(period.bars, start=1)
|
||||
]
|
||||
return replace(period, key=target_key, bars=new_bars)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefix encoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
|
||||
"""Encode chord symbols (in user key) as canonical body token IDs.
|
||||
|
||||
*shift* is the semitone offset from user key to canonical (C/Am).
|
||||
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
|
||||
"""
|
||||
ids: list[int] = []
|
||||
for sym in chord_symbols:
|
||||
t = parse_chord_symbol(sym)
|
||||
root = _transpose_note(t.root, shift)
|
||||
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||||
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
|
||||
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||||
return ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_period(
|
||||
model: ChordTransformer,
|
||||
mode: str,
|
||||
time: str,
|
||||
subdivision: int,
|
||||
style: str,
|
||||
function: str,
|
||||
key: str,
|
||||
prefix: Optional[list[str]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.9,
|
||||
max_tokens: int = 300,
|
||||
seed: Optional[int] = None,
|
||||
) -> ChordPeriod:
|
||||
"""Generate one harmonic period autoregressively.
|
||||
|
||||
The model generates in canonical key (C major / A minor); bar boundaries
|
||||
are determined by position counting in the detokenizer (no BAR tokens).
|
||||
The result is transposed to *key* before being returned.
|
||||
|
||||
Args:
|
||||
model: Loaded ChordTransformer in eval mode.
|
||||
mode: 'major' or 'minor'.
|
||||
time: Time signature string, e.g. '4/4'.
|
||||
subdivision: Positions per beat unit (4 or 8).
|
||||
style: Style tag, e.g. 'H1K0'.
|
||||
function: Section label, e.g. 'chorus'.
|
||||
key: Target output key, e.g. 'F#_major' or 'B_minor'.
|
||||
prefix: Chord symbols (in *key*) prepended as body context.
|
||||
temperature: Sampling temperature (> 1 = more random).
|
||||
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||
max_tokens: Hard cap on generated tokens.
|
||||
seed: RNG seed for reproducibility.
|
||||
|
||||
Returns:
|
||||
ChordPeriod in *key*.
|
||||
"""
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
key_parts = key.split("_")
|
||||
tonic = _parse_note_from_key(key_parts[0])
|
||||
shift_to_canonical = _to_canonical_shift(mode, tonic)
|
||||
|
||||
style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other"
|
||||
func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified"
|
||||
if style_tok == "STYLE_other" and style != "other":
|
||||
log.warning("unknown style %r — using STYLE_other", style)
|
||||
if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""):
|
||||
log.warning("unknown function %r — using FUNC_unspecified", function)
|
||||
|
||||
ids: list[int] = [
|
||||
TOKEN_TO_ID["<BOS>"],
|
||||
TOKEN_TO_ID[f"MODE_{mode}"],
|
||||
TOKEN_TO_ID[f"TIME_{time}"],
|
||||
TOKEN_TO_ID[f"SUB_{subdivision}"],
|
||||
TOKEN_TO_ID[style_tok],
|
||||
TOKEN_TO_ID[func_tok],
|
||||
]
|
||||
|
||||
positions_per_bar = _expected_positions(time, subdivision)
|
||||
|
||||
encoded_prefix: list[int] = []
|
||||
if prefix:
|
||||
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
|
||||
ids.extend(encoded_prefix)
|
||||
|
||||
# Compute starting position-in-bar after the prefix
|
||||
# Each chord in prefix = 4 tokens = 1 position
|
||||
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
|
||||
|
||||
last_id = ids[-1]
|
||||
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_tokens):
|
||||
if len(ids) >= context_limit:
|
||||
break
|
||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||
logits = model(inp)[0, -1] # [vocab_size]
|
||||
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||
logits = logits + bias.to(device)
|
||||
token_id = _sample_top_p(logits, temperature, top_p)
|
||||
ids.append(token_id)
|
||||
last_id = token_id
|
||||
|
||||
# Advance position counter when a body position is completed
|
||||
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||
|
||||
if token_id == _EOS:
|
||||
break
|
||||
|
||||
# Ensure sequence ends with EOS at a bar boundary.
|
||||
if ids[-1] != _EOS:
|
||||
if _is_mid_chord(last_id):
|
||||
# Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail.
|
||||
# Mid-chord tokens never advance pos_in_bar, so the counter stays valid.
|
||||
while ids and _is_mid_chord(ids[-1]):
|
||||
ids.pop()
|
||||
# Pad the partial bar with HOLDs so EOS lands on a bar boundary.
|
||||
while pos_in_bar > 0:
|
||||
ids.append(_HOLD)
|
||||
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||
ids.append(_EOS)
|
||||
|
||||
log.debug("generated %d tokens total (prompt + body)", len(ids))
|
||||
|
||||
period = detokenize_to_period(ids)
|
||||
return _transpose_to_key(period, key)
|
||||
+3
-2
@@ -8,7 +8,7 @@ Public API:
|
||||
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
|
||||
|
||||
Vocabulary constants:
|
||||
VOCAB -- 81-token ordered list; index == token ID
|
||||
VOCAB -- 84-token ordered list; index == token ID
|
||||
TOKEN_TO_ID -- {token_string: id}
|
||||
ID_TO_TOKEN -- alias for VOCAB
|
||||
|
||||
@@ -153,7 +153,8 @@ def _tokens_to_symbol(t: ChordTokens) -> str:
|
||||
quality_ext = t.quality
|
||||
else:
|
||||
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
|
||||
bass_part = "" if t.bass == "root" else f"/{t.bass}"
|
||||
# Normalize BASS_<note> == root note to no slash (same as BASS_root sentinel).
|
||||
bass_part = "" if t.bass in ("root", t.root) else f"/{t.bass}"
|
||||
return t.root + quality_ext + bass_part
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user