"""Generate a harmonic period using a trained ChordTransformer. Usage: python scripts/generate.py \\ --checkpoint checkpoints/finetuned.pt \\ --mode major --key F# \\ --style H1K0 --function chorus \\ --time 4/4 --subdivision 4 \\ --output out.chord \\ [--midi out.mid] \\ [--prefix "Cmaj7 Am7"] \\ [--temperature 1.0] [--top-p 0.9] \\ [--max-tokens 300] [--seed 42] \\ [--tempo 90] Outputs: generated .chord file in the requested key .mid (if --midi) MIDI rendering of the period """ from __future__ import annotations import argparse import logging import sys from dataclasses import replace from pathlib import Path import torch sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.generate import generate_period from src.midi_export import chord_file_to_midi from src.model import ChordTransformer from src.tokenizer import write_chord_file # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _load_model(checkpoint: Path, device: str) -> ChordTransformer: ckpt = torch.load(checkpoint, map_location=device, weights_only=True) model = ChordTransformer(**ckpt["model_config"]) model.load_state_dict(ckpt["model_state"]) model.to(device) model.eval() return model # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) ap.add_argument("--checkpoint", type=Path, required=True, help="Path to .pt checkpoint (pretrained or finetuned).") ap.add_argument("--mode", choices=["major", "minor"], required=True, help="Tonal mode.") ap.add_argument("--key", required=True, help="Root note of the output key, e.g. F#, Bb, C.") ap.add_argument("--style", default="H1K0", help="Style tag (default: H1K0).") ap.add_argument("--function", default="unspecified", help="Section label: verse, chorus, bridge, ... (default: unspecified).") ap.add_argument("--time", default="4/4", help="Time signature (default: 4/4).") ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8], help="Positions per beat unit (default: 4).") ap.add_argument("--output", type=Path, required=True, help="Output file path. Extension .chord is appended if missing.") ap.add_argument("--midi", type=Path, default=None, help="Optional output MIDI file path.") ap.add_argument("--prefix", default=None, help='Space-separated chord symbols in the requested key, ' 'e.g. "Cmaj7 Am7". Used as generation context.') ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (default: 1.0).") ap.add_argument("--top-p", type=float, default=0.9, dest="top_p", help="Nucleus sampling cutoff (default: 0.9).") ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens", help="Hard cap on generated tokens (default: 300).") ap.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") ap.add_argument("--tempo", type=int, default=90, help="MIDI playback tempo in BPM (default: 90).") ap.add_argument("--device", default="auto", help="Compute device: cpu, cuda, or auto (default: auto).") args = ap.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) if not args.checkpoint.exists(): print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr) sys.exit(1) if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device model = _load_model(args.checkpoint, device) target_key = f"{args.key}_{args.mode}" prefix_chords = args.prefix.split() if args.prefix else None period = generate_period( model=model, mode=args.mode, time=args.time, subdivision=args.subdivision, style=args.style, function=args.function, key=target_key, prefix=prefix_chords, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, seed=args.seed, ) # Give generated periods a readable title period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})") # Ensure .chord extension out_path = args.output if out_path.suffix != ".chord": out_path = out_path.with_suffix(".chord") write_chord_file(period, out_path) print(f"[generate] written -> {out_path}") if args.midi: midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid") chord_file_to_midi(out_path, midi_path, tempo=args.tempo) print(f"[generate] MIDI -> {midi_path}") # Quick summary to stdout print() print(f" Key: {period.key}") print(f" Time: {period.time} subdivision={period.subdivision}") print(f" Style: {period.style} function={period.function}") print(f" Bars: {len(period.bars)}") print() for i, bar in enumerate(period.bars, 1): print(f" Bar {i:3d}: {' '.join(bar)}") if __name__ == "__main__": main()