feat: add generate module and CLI; fix tokenizer minor issues
src/generate.py: autoregressive generation with top-p sampling, grammar masking (ROOT→QUAL→EXT→BASS; EOS only at bar boundary), key transposition, and optional chord prefix. Partial bars on context truncation are padded with HOLDs rather than discarded. scripts/generate.py: CLI wrapping generate_period — accepts mode, key, time, subdivision, style, function, prefix, temperature, top-p, seed, tempo; writes .chord and optional MIDI. src/tokenizer.py: fix docstring vocab size (81→84); normalize redundant BASS_<note>==root to no slash in _tokens_to_symbol. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,160 @@
|
|||||||
|
"""Generate a harmonic period using a trained ChordTransformer.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/generate.py \\
|
||||||
|
--checkpoint checkpoints/finetuned.pt \\
|
||||||
|
--mode major --key F# \\
|
||||||
|
--style H1K0 --function chorus \\
|
||||||
|
--time 4/4 --subdivision 4 \\
|
||||||
|
--output out.chord \\
|
||||||
|
[--midi out.mid] \\
|
||||||
|
[--prefix "Cmaj7 Am7"] \\
|
||||||
|
[--temperature 1.0] [--top-p 0.9] \\
|
||||||
|
[--max-tokens 300] [--seed 42] \\
|
||||||
|
[--tempo 90]
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
<output> generated .chord file in the requested key
|
||||||
|
<output>.mid (if --midi) MIDI rendering of the period
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from dataclasses import replace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from src.generate import generate_period
|
||||||
|
from src.midi_export import chord_file_to_midi
|
||||||
|
from src.model import ChordTransformer
|
||||||
|
from src.tokenizer import write_chord_file
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model(checkpoint: Path, device: str) -> ChordTransformer:
|
||||||
|
ckpt = torch.load(checkpoint, map_location=device, weights_only=True)
|
||||||
|
model = ChordTransformer(**ckpt["model_config"])
|
||||||
|
model.load_state_dict(ckpt["model_state"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
ap.add_argument("--checkpoint", type=Path, required=True,
|
||||||
|
help="Path to .pt checkpoint (pretrained or finetuned).")
|
||||||
|
ap.add_argument("--mode", choices=["major", "minor"], required=True,
|
||||||
|
help="Tonal mode.")
|
||||||
|
ap.add_argument("--key", required=True,
|
||||||
|
help="Root note of the output key, e.g. F#, Bb, C.")
|
||||||
|
ap.add_argument("--style", default="H1K0",
|
||||||
|
help="Style tag (default: H1K0).")
|
||||||
|
ap.add_argument("--function", default="unspecified",
|
||||||
|
help="Section label: verse, chorus, bridge, ... (default: unspecified).")
|
||||||
|
ap.add_argument("--time", default="4/4",
|
||||||
|
help="Time signature (default: 4/4).")
|
||||||
|
ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8],
|
||||||
|
help="Positions per beat unit (default: 4).")
|
||||||
|
ap.add_argument("--output", type=Path, required=True,
|
||||||
|
help="Output file path. Extension .chord is appended if missing.")
|
||||||
|
ap.add_argument("--midi", type=Path, default=None,
|
||||||
|
help="Optional output MIDI file path.")
|
||||||
|
ap.add_argument("--prefix", default=None,
|
||||||
|
help='Space-separated chord symbols in the requested key, '
|
||||||
|
'e.g. "Cmaj7 Am7". Used as generation context.')
|
||||||
|
ap.add_argument("--temperature", type=float, default=1.0,
|
||||||
|
help="Sampling temperature (default: 1.0).")
|
||||||
|
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
|
||||||
|
help="Nucleus sampling cutoff (default: 0.9).")
|
||||||
|
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
|
||||||
|
help="Hard cap on generated tokens (default: 300).")
|
||||||
|
ap.add_argument("--seed", type=int, default=None,
|
||||||
|
help="Random seed for reproducibility.")
|
||||||
|
ap.add_argument("--tempo", type=int, default=90,
|
||||||
|
help="MIDI playback tempo in BPM (default: 90).")
|
||||||
|
ap.add_argument("--device", default="auto",
|
||||||
|
help="Compute device: cpu, cuda, or auto (default: auto).")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.checkpoint.exists():
|
||||||
|
print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.device == "auto":
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
device = args.device
|
||||||
|
|
||||||
|
model = _load_model(args.checkpoint, device)
|
||||||
|
|
||||||
|
target_key = f"{args.key}_{args.mode}"
|
||||||
|
prefix_chords = args.prefix.split() if args.prefix else None
|
||||||
|
|
||||||
|
period = generate_period(
|
||||||
|
model=model,
|
||||||
|
mode=args.mode,
|
||||||
|
time=args.time,
|
||||||
|
subdivision=args.subdivision,
|
||||||
|
style=args.style,
|
||||||
|
function=args.function,
|
||||||
|
key=target_key,
|
||||||
|
prefix=prefix_chords,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Give generated periods a readable title
|
||||||
|
period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})")
|
||||||
|
|
||||||
|
# Ensure .chord extension
|
||||||
|
out_path = args.output
|
||||||
|
if out_path.suffix != ".chord":
|
||||||
|
out_path = out_path.with_suffix(".chord")
|
||||||
|
|
||||||
|
write_chord_file(period, out_path)
|
||||||
|
print(f"[generate] written -> {out_path}")
|
||||||
|
|
||||||
|
if args.midi:
|
||||||
|
midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid")
|
||||||
|
chord_file_to_midi(out_path, midi_path, tempo=args.tempo)
|
||||||
|
print(f"[generate] MIDI -> {midi_path}")
|
||||||
|
|
||||||
|
# Quick summary to stdout
|
||||||
|
print()
|
||||||
|
print(f" Key: {period.key}")
|
||||||
|
print(f" Time: {period.time} subdivision={period.subdivision}")
|
||||||
|
print(f" Style: {period.style} function={period.function}")
|
||||||
|
print(f" Bars: {len(period.bars)}")
|
||||||
|
print()
|
||||||
|
for i, bar in enumerate(period.bars, 1):
|
||||||
|
print(f" Bar {i:3d}: {' '.join(bar)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+283
@@ -0,0 +1,283 @@
|
|||||||
|
"""Autoregressive generation of harmonic periods.
|
||||||
|
|
||||||
|
Public API:
|
||||||
|
generate_period(model, mode, time, subdivision, style, function, key,
|
||||||
|
prefix=None, temperature=1.0, top_p=0.9,
|
||||||
|
max_tokens=300, seed=None) -> ChordPeriod
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from dataclasses import replace
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from src.chord_parser import parse_chord_symbol
|
||||||
|
from src.model import ChordTransformer
|
||||||
|
from src.tokenizer import (
|
||||||
|
VOCAB, TOKEN_TO_ID, ChordPeriod, detokenize_to_period,
|
||||||
|
_qual_token, _parse_note_from_key, _NOTE_INDEX, _CHROMATIC, _transpose_symbol,
|
||||||
|
_expected_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Token ID ranges (derived from the 84-token vocabulary in src/tokenizer.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_VOCAB_SIZE = len(VOCAB)
|
||||||
|
_EOS = TOKEN_TO_ID["<EOS>"]
|
||||||
|
_HOLD = TOKEN_TO_ID["HOLD"]
|
||||||
|
_NC = TOKEN_TO_ID["NC"]
|
||||||
|
_ROOT_START = TOKEN_TO_ID["ROOT_C"]
|
||||||
|
_ROOT_END = TOKEN_TO_ID["ROOT_B"] # inclusive
|
||||||
|
_QUAL_START = TOKEN_TO_ID["QUAL_maj"]
|
||||||
|
_QUAL_END = TOKEN_TO_ID["QUAL_m_add9"] # inclusive
|
||||||
|
_EXT_START = TOKEN_TO_ID["EXT_none"]
|
||||||
|
_EXT_END = TOKEN_TO_ID["EXT_b13"] # inclusive
|
||||||
|
_BASS_START = TOKEN_TO_ID["BASS_root"]
|
||||||
|
_BASS_END = TOKEN_TO_ID["BASS_B"] # inclusive
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Grammar masks
|
||||||
|
# Additive logit bias: 0.0 = allowed, -inf = blocked.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_bias(allowed: list[int]) -> torch.Tensor:
|
||||||
|
b = torch.full((_VOCAB_SIZE,), float("-inf"))
|
||||||
|
for i in allowed:
|
||||||
|
b[i] = 0.0
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
_free_ids = list(range(_ROOT_START, _ROOT_END + 1)) + [_HOLD, _NC]
|
||||||
|
_BIAS_FREE = _make_bias(_free_ids) # mid-bar: no EOS
|
||||||
|
_BIAS_FREE_EOS= _make_bias(_free_ids + [_EOS]) # bar boundary: EOS allowed
|
||||||
|
_BIAS_QUAL = _make_bias(list(range(_QUAL_START, _QUAL_END + 1)))
|
||||||
|
_BIAS_EXT = _make_bias(list(range(_EXT_START, _EXT_END + 1)))
|
||||||
|
_BIAS_EXT_NONE= _make_bias([TOKEN_TO_ID["EXT_none"]]) # only for add9 qualities
|
||||||
|
_BIAS_BASS = _make_bias(list(range(_BASS_START, _BASS_END + 1)))
|
||||||
|
|
||||||
|
# QUAL_add9 and QUAL_m_add9 already encode the 9th; further extension is invalid
|
||||||
|
_ADD9_QUAL_IDS = frozenset({TOKEN_TO_ID["QUAL_add9"], TOKEN_TO_ID["QUAL_m_add9"]})
|
||||||
|
|
||||||
|
|
||||||
|
def _grammar_bias(last_id: int, pos_in_bar: int, positions_per_bar: int) -> torch.Tensor:
|
||||||
|
"""Additive logit bias enforcing token grammar after *last_id*.
|
||||||
|
|
||||||
|
EOS is allowed only at bar boundaries (pos_in_bar == 0).
|
||||||
|
"""
|
||||||
|
if _ROOT_START <= last_id <= _ROOT_END:
|
||||||
|
return _BIAS_QUAL
|
||||||
|
if _QUAL_START <= last_id <= _QUAL_END:
|
||||||
|
return _BIAS_EXT_NONE if last_id in _ADD9_QUAL_IDS else _BIAS_EXT
|
||||||
|
if _EXT_START <= last_id <= _EXT_END:
|
||||||
|
return _BIAS_BASS
|
||||||
|
# FREE position: after BASS, HOLD, NC, or the last metadata token
|
||||||
|
if pos_in_bar == 0:
|
||||||
|
return _BIAS_FREE_EOS # at bar boundary — EOS is valid
|
||||||
|
return _BIAS_FREE # mid-bar — EOS not yet valid
|
||||||
|
|
||||||
|
|
||||||
|
def _is_mid_chord(last_id: int) -> bool:
|
||||||
|
"""True when the next required token is QUAL, EXT, or BASS."""
|
||||||
|
return (
|
||||||
|
(_ROOT_START <= last_id <= _ROOT_END) or
|
||||||
|
(_QUAL_START <= last_id <= _QUAL_END) or
|
||||||
|
(_EXT_START <= last_id <= _EXT_END)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sampling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _sample_top_p(logits: torch.Tensor, temperature: float, top_p: float) -> int:
|
||||||
|
"""Sample one token index via nucleus (top-p) sampling."""
|
||||||
|
logits = logits / max(temperature, 1e-8)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
||||||
|
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
||||||
|
cut = (cumulative - sorted_probs) >= top_p
|
||||||
|
sorted_probs[cut] = 0.0
|
||||||
|
sorted_probs /= sorted_probs.sum()
|
||||||
|
chosen = torch.multinomial(sorted_probs, 1).item()
|
||||||
|
return int(sorted_idx[chosen].item())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Transposition helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _transpose_note(note: str, shift: int) -> str:
|
||||||
|
return _CHROMATIC[(_NOTE_INDEX[note] + shift) % 12]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_canonical_shift(mode: str, tonic: str) -> int:
|
||||||
|
"""Semitone shift to transpose from *tonic* to canonical (C=0 / A=9)."""
|
||||||
|
canonical = 0 if mode == "major" else 9
|
||||||
|
return (canonical - _NOTE_INDEX[tonic]) % 12
|
||||||
|
|
||||||
|
|
||||||
|
def _transpose_to_key(period: ChordPeriod, target_key: str) -> ChordPeriod:
|
||||||
|
"""Transpose a canonical (C/Am) ChordPeriod to *target_key*."""
|
||||||
|
parts = target_key.split("_")
|
||||||
|
tonic = _parse_note_from_key(parts[0])
|
||||||
|
mode = parts[-1]
|
||||||
|
canonical = 0 if mode == "major" else 9
|
||||||
|
shift = (_NOTE_INDEX[tonic] - canonical) % 12
|
||||||
|
if shift == 0:
|
||||||
|
return replace(period, key=target_key)
|
||||||
|
fname = "<generate>"
|
||||||
|
new_bars = [
|
||||||
|
[_transpose_symbol(sym, shift, fname, bar_no) for sym in bar]
|
||||||
|
for bar_no, bar in enumerate(period.bars, start=1)
|
||||||
|
]
|
||||||
|
return replace(period, key=target_key, bars=new_bars)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prefix encoding
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _encode_prefix(chord_symbols: list[str], shift: int) -> list[int]:
|
||||||
|
"""Encode chord symbols (in user key) as canonical body token IDs.
|
||||||
|
|
||||||
|
*shift* is the semitone offset from user key to canonical (C/Am).
|
||||||
|
Returns a flat list: ROOT QUAL EXT BASS [ROOT QUAL EXT BASS ...].
|
||||||
|
"""
|
||||||
|
ids: list[int] = []
|
||||||
|
for sym in chord_symbols:
|
||||||
|
t = parse_chord_symbol(sym)
|
||||||
|
root = _transpose_note(t.root, shift)
|
||||||
|
bass = "root" if t.bass == "root" else _transpose_note(t.bass, shift)
|
||||||
|
ids.append(TOKEN_TO_ID[f"ROOT_{root}"])
|
||||||
|
ids.append(TOKEN_TO_ID[_qual_token(t.quality)])
|
||||||
|
ids.append(TOKEN_TO_ID[f"EXT_{t.extension}"])
|
||||||
|
ids.append(TOKEN_TO_ID[f"BASS_{bass}"])
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_period(
|
||||||
|
model: ChordTransformer,
|
||||||
|
mode: str,
|
||||||
|
time: str,
|
||||||
|
subdivision: int,
|
||||||
|
style: str,
|
||||||
|
function: str,
|
||||||
|
key: str,
|
||||||
|
prefix: Optional[list[str]] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_tokens: int = 300,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> ChordPeriod:
|
||||||
|
"""Generate one harmonic period autoregressively.
|
||||||
|
|
||||||
|
The model generates in canonical key (C major / A minor); bar boundaries
|
||||||
|
are determined by position counting in the detokenizer (no BAR tokens).
|
||||||
|
The result is transposed to *key* before being returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Loaded ChordTransformer in eval mode.
|
||||||
|
mode: 'major' or 'minor'.
|
||||||
|
time: Time signature string, e.g. '4/4'.
|
||||||
|
subdivision: Positions per beat unit (4 or 8).
|
||||||
|
style: Style tag, e.g. 'H1K0'.
|
||||||
|
function: Section label, e.g. 'chorus'.
|
||||||
|
key: Target output key, e.g. 'F#_major' or 'B_minor'.
|
||||||
|
prefix: Chord symbols (in *key*) prepended as body context.
|
||||||
|
temperature: Sampling temperature (> 1 = more random).
|
||||||
|
top_p: Nucleus cutoff probability (0 < top_p <= 1).
|
||||||
|
max_tokens: Hard cap on generated tokens.
|
||||||
|
seed: RNG seed for reproducibility.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChordPeriod in *key*.
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
key_parts = key.split("_")
|
||||||
|
tonic = _parse_note_from_key(key_parts[0])
|
||||||
|
shift_to_canonical = _to_canonical_shift(mode, tonic)
|
||||||
|
|
||||||
|
style_tok = f"STYLE_{style}" if f"STYLE_{style}" in TOKEN_TO_ID else "STYLE_other"
|
||||||
|
func_tok = f"FUNC_{function}" if f"FUNC_{function}" in TOKEN_TO_ID else "FUNC_unspecified"
|
||||||
|
if style_tok == "STYLE_other" and style != "other":
|
||||||
|
log.warning("unknown style %r — using STYLE_other", style)
|
||||||
|
if func_tok == "FUNC_unspecified" and function not in ("unspecified", ""):
|
||||||
|
log.warning("unknown function %r — using FUNC_unspecified", function)
|
||||||
|
|
||||||
|
ids: list[int] = [
|
||||||
|
TOKEN_TO_ID["<BOS>"],
|
||||||
|
TOKEN_TO_ID[f"MODE_{mode}"],
|
||||||
|
TOKEN_TO_ID[f"TIME_{time}"],
|
||||||
|
TOKEN_TO_ID[f"SUB_{subdivision}"],
|
||||||
|
TOKEN_TO_ID[style_tok],
|
||||||
|
TOKEN_TO_ID[func_tok],
|
||||||
|
]
|
||||||
|
|
||||||
|
positions_per_bar = _expected_positions(time, subdivision)
|
||||||
|
|
||||||
|
encoded_prefix: list[int] = []
|
||||||
|
if prefix:
|
||||||
|
encoded_prefix = _encode_prefix(prefix, shift_to_canonical)
|
||||||
|
ids.extend(encoded_prefix)
|
||||||
|
|
||||||
|
# Compute starting position-in-bar after the prefix
|
||||||
|
# Each chord in prefix = 4 tokens = 1 position
|
||||||
|
pos_in_bar = (len(encoded_prefix) // 4) % positions_per_bar
|
||||||
|
|
||||||
|
last_id = ids[-1]
|
||||||
|
context_limit = model.max_seq_len - 1 # leave one slot so seq_len never hits max
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
if len(ids) >= context_limit:
|
||||||
|
break
|
||||||
|
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||||
|
logits = model(inp)[0, -1] # [vocab_size]
|
||||||
|
bias = _grammar_bias(last_id, pos_in_bar, positions_per_bar)
|
||||||
|
logits = logits + bias.to(device)
|
||||||
|
token_id = _sample_top_p(logits, temperature, top_p)
|
||||||
|
ids.append(token_id)
|
||||||
|
last_id = token_id
|
||||||
|
|
||||||
|
# Advance position counter when a body position is completed
|
||||||
|
if (_BASS_START <= token_id <= _BASS_END) or token_id in (_HOLD, _NC):
|
||||||
|
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||||
|
|
||||||
|
if token_id == _EOS:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Ensure sequence ends with EOS at a bar boundary.
|
||||||
|
if ids[-1] != _EOS:
|
||||||
|
if _is_mid_chord(last_id):
|
||||||
|
# Drop any incomplete ROOT-QUAL-EXT-BASS group from the tail.
|
||||||
|
# Mid-chord tokens never advance pos_in_bar, so the counter stays valid.
|
||||||
|
while ids and _is_mid_chord(ids[-1]):
|
||||||
|
ids.pop()
|
||||||
|
# Pad the partial bar with HOLDs so EOS lands on a bar boundary.
|
||||||
|
while pos_in_bar > 0:
|
||||||
|
ids.append(_HOLD)
|
||||||
|
pos_in_bar = (pos_in_bar + 1) % positions_per_bar
|
||||||
|
ids.append(_EOS)
|
||||||
|
|
||||||
|
log.debug("generated %d tokens total (prompt + body)", len(ids))
|
||||||
|
|
||||||
|
period = detokenize_to_period(ids)
|
||||||
|
return _transpose_to_key(period, key)
|
||||||
+3
-2
@@ -8,7 +8,7 @@ Public API:
|
|||||||
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
|
detokenize_to_period(token_ids: list[int]) -> ChordPeriod
|
||||||
|
|
||||||
Vocabulary constants:
|
Vocabulary constants:
|
||||||
VOCAB -- 81-token ordered list; index == token ID
|
VOCAB -- 84-token ordered list; index == token ID
|
||||||
TOKEN_TO_ID -- {token_string: id}
|
TOKEN_TO_ID -- {token_string: id}
|
||||||
ID_TO_TOKEN -- alias for VOCAB
|
ID_TO_TOKEN -- alias for VOCAB
|
||||||
|
|
||||||
@@ -153,7 +153,8 @@ def _tokens_to_symbol(t: ChordTokens) -> str:
|
|||||||
quality_ext = t.quality
|
quality_ext = t.quality
|
||||||
else:
|
else:
|
||||||
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
|
quality_ext = t.quality + ("" if t.extension == "none" else t.extension)
|
||||||
bass_part = "" if t.bass == "root" else f"/{t.bass}"
|
# Normalize BASS_<note> == root note to no slash (same as BASS_root sentinel).
|
||||||
|
bass_part = "" if t.bass in ("root", t.root) else f"/{t.bass}"
|
||||||
return t.root + quality_ext + bass_part
|
return t.root + quality_ext + bass_part
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user