From e657d9edb57b033a2c5e07a7e1f7d85715648702 Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Wed, 20 May 2026 14:28:44 +0300 Subject: [PATCH] feat: add generate module and CLI; fix tokenizer minor issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit src/generate.py: autoregressive generation with top-p sampling, grammar masking (ROOT→QUAL→EXT→BASS; EOS only at bar boundary), key transposition, and optional chord prefix. Partial bars on context truncation are padded with HOLDs rather than discarded. scripts/generate.py: CLI wrapping generate_period — accepts mode, key, time, subdivision, style, function, prefix, temperature, top-p, seed, tempo; writes .chord and optional MIDI. src/tokenizer.py: fix docstring vocab size (81→84); normalize redundant BASS_==root to no slash in _tokens_to_symbol. Co-Authored-By: Claude Sonnet 4.6 --- scripts/generate.py | 160 +++++++++++++++++++++++++ src/generate.py | 283 ++++++++++++++++++++++++++++++++++++++++++++ src/tokenizer.py | 5 +- 3 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 scripts/generate.py create mode 100644 src/generate.py diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000..adee989 --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,160 @@ +"""Generate a harmonic period using a trained ChordTransformer. + +Usage: + python scripts/generate.py \\ + --checkpoint checkpoints/finetuned.pt \\ + --mode major --key F# \\ + --style H1K0 --function chorus \\ + --time 4/4 --subdivision 4 \\ + --output out.chord \\ + [--midi out.mid] \\ + [--prefix "Cmaj7 Am7"] \\ + [--temperature 1.0] [--top-p 0.9] \\ + [--max-tokens 300] [--seed 42] \\ + [--tempo 90] + +Outputs: + generated .chord file in the requested key + .mid (if --midi) MIDI rendering of the period +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from dataclasses import replace +from pathlib import Path + +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.generate import generate_period +from src.midi_export import chord_file_to_midi +from src.model import ChordTransformer +from src.tokenizer import write_chord_file + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _load_model(checkpoint: Path, device: str) -> ChordTransformer: + ckpt = torch.load(checkpoint, map_location=device, weights_only=True) + model = ChordTransformer(**ckpt["model_config"]) + model.load_state_dict(ckpt["model_state"]) + model.to(device) + model.eval() + return model + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + ap = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + ap.add_argument("--checkpoint", type=Path, required=True, + help="Path to .pt checkpoint (pretrained or finetuned).") + ap.add_argument("--mode", choices=["major", "minor"], required=True, + help="Tonal mode.") + ap.add_argument("--key", required=True, + help="Root note of the output key, e.g. F#, Bb, C.") + ap.add_argument("--style", default="H1K0", + help="Style tag (default: H1K0).") + ap.add_argument("--function", default="unspecified", + help="Section label: verse, chorus, bridge, ... (default: unspecified).") + ap.add_argument("--time", default="4/4", + help="Time signature (default: 4/4).") + ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8], + help="Positions per beat unit (default: 4).") + ap.add_argument("--output", type=Path, required=True, + help="Output file path. Extension .chord is appended if missing.") + ap.add_argument("--midi", type=Path, default=None, + help="Optional output MIDI file path.") + ap.add_argument("--prefix", default=None, + help='Space-separated chord symbols in the requested key, ' + 'e.g. "Cmaj7 Am7". Used as generation context.') + ap.add_argument("--temperature", type=float, default=1.0, + help="Sampling temperature (default: 1.0).") + ap.add_argument("--top-p", type=float, default=0.9, dest="top_p", + help="Nucleus sampling cutoff (default: 0.9).") + ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens", + help="Hard cap on generated tokens (default: 300).") + ap.add_argument("--seed", type=int, default=None, + help="Random seed for reproducibility.") + ap.add_argument("--tempo", type=int, default=90, + help="MIDI playback tempo in BPM (default: 90).") + ap.add_argument("--device", default="auto", + help="Compute device: cpu, cuda, or auto (default: auto).") + args = ap.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + + if not args.checkpoint.exists(): + print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr) + sys.exit(1) + + if args.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + + model = _load_model(args.checkpoint, device) + + target_key = f"{args.key}_{args.mode}" + prefix_chords = args.prefix.split() if args.prefix else None + + period = generate_period( + model=model, + mode=args.mode, + time=args.time, + subdivision=args.subdivision, + style=args.style, + function=args.function, + key=target_key, + prefix=prefix_chords, + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + seed=args.seed, + ) + + # Give generated periods a readable title + period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})") + + # Ensure .chord extension + out_path = args.output + if out_path.suffix != ".chord": + out_path = out_path.with_suffix(".chord") + + write_chord_file(period, out_path) + print(f"[generate] written -> {out_path}") + + if args.midi: + midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid") + chord_file_to_midi(out_path, midi_path, tempo=args.tempo) + print(f"[generate] MIDI -> {midi_path}") + + # Quick summary to stdout + print() + print(f" Key: {period.key}") + print(f" Time: {period.time} subdivision={period.subdivision}") + print(f" Style: {period.style} function={period.function}") + print(f" Bars: {len(period.bars)}") + print() + for i, bar in enumerate(period.bars, 1): + print(f" Bar {i:3d}: {' '.join(bar)}") + + +if __name__ == "__main__": + main() diff --git a/src/generate.py b/src/generate.py new file mode 100644 index 0000000..2e068de --- /dev/null +++ b/src/generate.py @@ -0,0 +1,283 @@ +"""Autoregressive generation of harmonic periods. + +Public API: + generate_period(model, mode, time, subdivision, style, function, key, + prefix=None, temperature=1.0, top_p=0.9, + max_tokens=300, seed=None) -> ChordPeriod +""" + +from __future__ import annotations + +import logging +import random +from dataclasses import replace +from typing import Optional + +import torch +import torch.nn.functional as F + +from src.chord_parser import parse_chord_symbol +from src.model import ChordTransformer +from src.tokenizer import ( + VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period, + _qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol, + _expected_positions, +) + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py) +# --------------------------------------------------------------------------- + +_VOCAB_SIZE = len(VOCAB) +_EOS = TOKEN_TO_ID[""] +_HOLD = TOKEN_TO_ID["HOLD"] +_NC = TOKEN_TO_ID["NC"] +_ROOT_START = TOKEN_TO_ID["ROOT_C"] +_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive +_QUAL_START = TOKEN_TO_ID["QUAL_maj"] +_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive +_EXT_START = TOKEN_TO_ID["EXT_none"] +_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive +_BASS_START = TOKEN_TO_ID["BASS_root"] +_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive + +# --------------------------------------------------------------------------- +# Grammar masks +# Additive logit bias: 0.0 = allowed, -inf = blocked. +# --------------------------------------------------------------------------- + +def _make_bias(allowed: list[int]) -> torch.Tensor: + b = torch.full((_VOCAB_SIZE,), float("-inf")) + for i in allowed: + b[i] = 0.0 + return b + + +_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC] +_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS +_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed +_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1))) +_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1))) +_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities +_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1))) + +# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid +_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]}) + + +def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor: + """Additive logit bias enforcing token grammar after *last_id*. + + EOS is allowed only at bar boundaries (pos_in_bar == 0). + """ + if _ROOT_START <= last_id <= _ROOT_END: + return _BIAS_QUAL + if _QUAL_START <= last_id <= _QUAL_END: + return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT + if _EXT_START <= last_id <= _EXT_END: + return _BIAS_BASS + # FREE position: after BASS, HOLD, NC, or the last metadata token + if pos_in_bar == 0: + return _BIAS_FREE_EOS # at bar boundary — EOS is valid + return _BIAS_FREE # mid-bar — EOS not yet valid + + +def _is_mid_chord(last_id: int) -> bool: + """True when the next required token is QUAL, EXT, or BASS.""" + return ( + (_ROOT_START <= last_id <= _ROOT_END) or + (_QUAL_START <= last_id <= _QUAL_END) or + (_EXT_START <= last_id <= _EXT_END) + ) + + +# --------------------------------------------------------------------------- +# Sampling +# --------------------------------------------------------------------------- + +def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int: + """Sample one token index via nucleus (top-p) sampling.""" + logits = logits / max(temperature, 1e-8) + probs = F.softmax(logits, dim=-1) + sorted_probs, sorted_idx = torch.sort(probs, descending=True) + cumulative = torch.cumsum(sorted_probs, dim=-1) + cut = (cumulative - sorted_probs) >= top_p + sorted_probs[cut] = 0.0 + sorted_probs /= sorted_probs.sum() + chosen = torch.multinomial(sorted_probs, 1).item() + return int(sorted_idx[chosen].item()) + + +# --------------------------------------------------------------------------- +# Transposition helpers +# --------------------------------------------------------------------------- + +def _transpose_note(note: str, shift: int) -> str: + return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12] + + +def _to_canonical_shift(mode: str, tonic: str) -> int: + """Semitone shift to transpose from *tonic* to canonical (C=0 / A=9).""" + canonical = 0 if mode == "major" else 9 + return (canonical - _NOTE_INDEX[tonic]) % 12 + + +def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod: + """Transpose a canonical (C/Am) ChordPeriod to *target_key*.""" + parts = target_key.split("_") + tonic = _parse_note_from_key(parts[0]) + mode = parts[-1] + canonical = 0 if mode == "major" else 9 + shift = (_NOTE_INDEX[tonic] - canonical) % 12 + if shift == 0: + return replace(period, key=target_key) + fname = "" + new_bars = [ + [_transpose_symbol(sym, shift, fname, bar_no) for sym in bar] + for bar_no, bar in enumerate(period.bars, start=1) + ] + return replace(period, key=target_key, bars=new_bars) + + +# --------------------------------------------------------------------------- +# Prefix encoding +# --------------------------------------------------------------------------- + +def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]: + """Encode chord symbols (in user key) as canonical body token IDs. + + *shift* is the semitone offset from user key to canonical (C/Am). + Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...]. + """ + ids: list[int] = [] + for sym in chord_symbols: + t = parse_chord_symbol(sym) + root = _transpose_note(t.root, shift) + bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift) + ids.append(TOKEN_TO_ID[f"ROOT_{root}"]) + ids.append(TOKEN_TO_ID[_qual_token(t.quality)]) + ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"]) + ids.append(TOKEN_TO_ID[f"BASS_{bass}"]) + return ids + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def generate_period( + model: ChordTransformer, + mode: str, + time: str, + subdivision: int, + style: str, + function: str, + key: str, + prefix: Optional[list[str]] = None, + temperature: float = 1.0, + top_p: float = 0.9, + max_tokens: int = 300, + seed: Optional[int] = None, +) -> ChordPeriod: + """Generate one harmonic period autoregressively. + + The model generates in canonical key (C major / A minor); bar boundaries + are determined by position counting in the detokenizer (no BAR tokens). + The result is transposed to *key* before being returned. + + Args: + model: Loaded ChordTransformer in eval mode. + mode: 'major' or 'minor'. + time: Time signature string, e.g. '4/4'. + subdivision: Positions per beat unit (4 or 8). + style: Style tag, e.g. 'H1K0'. + function: Section label, e.g. 'chorus'. + key: Target output key, e.g. 'F#_major' or 'B_minor'. + prefix: Chord symbols (in *key*) prepended as body context. + temperature: Sampling temperature (> 1 = more random). + top_p: Nucleus cutoff probability (0 < top_p <= 1). + max_tokens: Hard cap on generated tokens. + seed: RNG seed for reproducibility. + + Returns: + ChordPeriod in *key*. + """ + if seed is not None: + torch.manual_seed(seed) + random.seed(seed) + + device = next(model.parameters()).device + model.eval() + + key_parts = key.split("_") + tonic = _parse_note_from_key(key_parts[0]) + shift_to_canonical = _to_canonical_shift(mode, tonic) + + style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other" + func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified" + if style_tok == "STYLE_other" and style != "other": + log.warning("unknown style %r — using STYLE_other", style) + if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""): + log.warning("unknown function %r — using FUNC_unspecified", function) + + ids: list[int] = [ + TOKEN_TO_ID[""], + TOKEN_TO_ID[f"MODE_{mode}"], + TOKEN_TO_ID[f"TIME_{time}"], + TOKEN_TO_ID[f"SUB_{subdivision}"], + TOKEN_TO_ID[style_tok], + TOKEN_TO_ID[func_tok], + ] + + positions_per_bar = _expected_positions(time, subdivision) + + encoded_prefix: list[int] = [] + if prefix: + encoded_prefix = _encode_prefix(prefix, shift_to_canonical) + ids.extend(encoded_prefix) + + # Compute starting position-in-bar after the prefix + # Each chord in prefix = 4 tokens = 1 position + pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar + + last_id = ids[-1] + context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max + + with torch.no_grad(): + for _ in range(max_tokens): + if len(ids) >= context_limit: + break + inp = torch.tensor([ids], dtype=torch.long, device=device) + logits = model(inp)[0, -1] # [vocab_size] + bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar) + logits = logits + bias.to(device) + token_id = _sample_top_p(logits, temperature, top_p) + ids.append(token_id) + last_id = token_id + + # Advance position counter when a body position is completed + if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC): + pos_in_bar = (pos_in_bar + 1) % positions_per_bar + + if token_id == _EOS: + break + + # Ensure sequence ends with EOS at a bar boundary. + if ids[-1] != _EOS: + if _is_mid_chord(last_id): + # Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail. + # Mid-chord tokens never advance pos_in_bar, so the counter stays valid. + while ids and _is_mid_chord(ids[-1]): + ids.pop() + # Pad the partial bar with HOLDs so EOS lands on a bar boundary. + while pos_in_bar > 0: + ids.append(_HOLD) + pos_in_bar = (pos_in_bar + 1) % positions_per_bar + ids.append(_EOS) + + log.debug("generated %d tokens total (prompt + body)", len(ids)) + + period = detokenize_to_period(ids) + return _transpose_to_key(period, key) diff --git a/src/tokenizer.py b/src/tokenizer.py index a412494..63ea7ca 100644 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -8,7 +8,7 @@ Public API: detokenize_to_period(token_ids: list[int]) -> ChordPeriod Vocabulary constants: - VOCAB -- 81-token ordered list; index == token ID + VOCAB -- 84-token ordered list; index == token ID TOKEN_TO_ID -- {token_string: id} ID_TO_TOKEN -- alias for VOCAB @@ -153,7 +153,8 @@ def _tokens_to_symbol(t: ChordTokens) -> str: quality_ext = t.quality else: quality_ext = t.quality + ("" if t.extension == "none" else t.extension) - bass_part = "" if t.bass == "root" else f"/{t.bass}" + # Normalize BASS_ == root note to no slash (same as BASS_root sentinel). + bass_part = "" if t.bass in ("root", t.root) else f"/{t.bass}" return t.root + quality_ext + bass_part