Files
hamori/scripts/generate.py
T
H1K0 9e73fa5d32 feat: add --bars arg to control output length
generate_period() now accepts n_bars=N to stop after exactly N complete
bars. bars_completed is seeded from the prefix length so --bars counts
the full output, not just the generated tail.

scripts/generate.py exposes this as --bars (default: None = model decides).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 20:29:44 +03:00

174 lines
6.4 KiB
Python

"""Generate a harmonic period using a trained ChordTransformer.
Usage:
python scripts/generate.py \\
--checkpoint checkpoints/finetuned.pt \\
--mode major --key F# \\
--style H1K0 --function chorus \\
--time 4/4 --subdivision 4 \\
--output out.chord \\
[--midi out.mid] \\
[--prefix "Cmaj7 Am7"] \\
[--temperature 1.0] [--top-p 0.9] \\
[--max-tokens 300] [--seed 42] \\
[--tempo 90]
Outputs:
<output> generated .chord file in the requested key
<output>.mid (if --midi) MIDI rendering of the period
"""
from __future__ import annotations
import argparse
import logging
import sys
from dataclasses import replace
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.generate import generate_period
from src.midi_export import chord_file_to_midi
from src.model import ChordTransformer
from src.tokenizer import write_chord_file
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_model(checkpoint: Path, device: str) -> ChordTransformer:
ckpt = torch.load(checkpoint, map_location=device, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()
return model
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
ap.add_argument("--checkpoint", type=Path, required=True,
help="Path to .pt checkpoint (pretrained or finetuned).")
ap.add_argument("--mode", choices=["major", "minor"], required=True,
help="Tonal mode.")
ap.add_argument("--key", required=True,
help="Root note of the output key, e.g. F#, Bb, C.")
ap.add_argument("--style", default="H1K0",
help="Style tag (default: H1K0).")
ap.add_argument("--function", default="unspecified",
help="Section label: verse, chorus, bridge, ... (default: unspecified).")
ap.add_argument("--time", default="4/4",
help="Time signature (default: 4/4).")
ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8],
help="Positions per beat unit (default: 4).")
ap.add_argument("--output", type=Path, required=True,
help="Output file path. Extension .chord is appended if missing.")
ap.add_argument("--midi", type=Path, default=None,
help="Optional output MIDI file path.")
ap.add_argument("--prefix", default=None,
help='Space-separated chord symbols in the requested key, '
'e.g. "Cmaj7 . Am7 .". Use "." for held positions '
'and "NC" for no-chord positions.')
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
help="Do not prepend the tonic chord when --prefix is not given.")
ap.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature (default: 1.0).")
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
help="Nucleus sampling cutoff (default: 0.9).")
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
help="Hard cap on generated tokens (default: 300).")
ap.add_argument("--bars", type=int, default=None,
help="Stop after this many complete bars (default: let the model decide).")
ap.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility.")
ap.add_argument("--tempo", type=int, default=90,
help="MIDI playback tempo in BPM (default: 90).")
ap.add_argument("--device", default="auto",
help="Compute device: cpu, cuda, or auto (default: auto).")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
if not args.checkpoint.exists():
print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr)
sys.exit(1)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model = _load_model(args.checkpoint, device)
target_key = f"{args.key}_{args.mode}"
if args.prefix:
prefix_chords = args.prefix.split()
elif not args.no_tonic_anchor:
# Default: anchor to tonic so generation stays in key.
prefix_chords = [args.key + ("m" if args.mode == "minor" else "")]
else:
prefix_chords = None
period = generate_period(
model=model,
mode=args.mode,
time=args.time,
subdivision=args.subdivision,
style=args.style,
function=args.function,
key=target_key,
prefix=prefix_chords,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
n_bars=args.bars,
seed=args.seed,
)
# Give generated periods a readable title
period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})")
# Ensure .chord extension
out_path = args.output
if out_path.suffix != ".chord":
out_path = out_path.with_suffix(".chord")
write_chord_file(period, out_path)
print(f"[generate] written -> {out_path}")
if args.midi:
midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid")
chord_file_to_midi(out_path, midi_path, tempo=args.tempo)
print(f"[generate] MIDI -> {midi_path}")
# Quick summary to stdout
print()
print(f" Key: {period.key}")
print(f" Time: {period.time} subdivision={period.subdivision}")
print(f" Style: {period.style} function={period.function}")
print(f" Bars: {len(period.bars)}")
print()
for i, bar in enumerate(period.bars, 1):
print(f" Bar {i:3d}: {' '.join(bar)}")
if __name__ == "__main__":
main()