f00a6c1b3a
Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events. At each FREE position, subtracts penalty * count(prev→root) from ROOT logits, capped at 3.0 to prevent NC/HOLD flooding at extreme values. Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0 (aggressive). Default 0.0 keeps backward compatibility. Added --repetition-penalty flag to scripts/generate.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
"""Generate a harmonic period using a trained ChordTransformer.
|
||
|
||
Usage:
|
||
python scripts/generate.py \\
|
||
--checkpoint checkpoints/finetuned.pt \\
|
||
--mode major --key F# \\
|
||
--style H1K0 --function chorus \\
|
||
--time 4/4 --subdivision 4 \\
|
||
--output out.chord \\
|
||
[--midi out.mid] \\
|
||
[--prefix "Cmaj7 Am7"] \\
|
||
[--temperature 1.0] [--top-p 0.9] \\
|
||
[--max-tokens 300] [--seed 42] \\
|
||
[--tempo 90]
|
||
|
||
Outputs:
|
||
<output> generated .chord file in the requested key
|
||
<output>.mid (if --midi) MIDI rendering of the period
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import logging
|
||
import sys
|
||
from dataclasses import replace
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
|
||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||
|
||
from src.generate import generate_period
|
||
from src.midi_export import chord_file_to_midi
|
||
from src.model import ChordTransformer
|
||
from src.tokenizer import write_chord_file
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _load_model(checkpoint: Path, device: str) -> ChordTransformer:
|
||
ckpt = torch.load(checkpoint, map_location=device, weights_only=True)
|
||
model = ChordTransformer(**ckpt["model_config"])
|
||
model.load_state_dict(ckpt["model_state"])
|
||
model.to(device)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(
|
||
description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
)
|
||
ap.add_argument("--checkpoint", type=Path, required=True,
|
||
help="Path to .pt checkpoint (pretrained or finetuned).")
|
||
ap.add_argument("--mode", choices=["major", "minor"], required=True,
|
||
help="Tonal mode.")
|
||
ap.add_argument("--key", required=True,
|
||
help="Root note of the output key, e.g. F#, Bb, C.")
|
||
ap.add_argument("--style", default="H1K0",
|
||
help="Style tag (default: H1K0).")
|
||
ap.add_argument("--function", default="unspecified",
|
||
help="Section label: verse, chorus, bridge, ... (default: unspecified).")
|
||
ap.add_argument("--time", default="4/4",
|
||
help="Time signature (default: 4/4).")
|
||
ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8],
|
||
help="Positions per beat unit (default: 4).")
|
||
ap.add_argument("--output", type=Path, required=True,
|
||
help="Output file path. Extension .chord is appended if missing.")
|
||
ap.add_argument("--midi", type=Path, default=None,
|
||
help="Optional output MIDI file path.")
|
||
ap.add_argument("--prefix", default=None,
|
||
help='Space-separated chord symbols in the requested key, '
|
||
'e.g. "Cmaj7 . Am7 .". Use "." for held positions '
|
||
'and "NC" for no-chord positions.')
|
||
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
|
||
help="Do not prepend the tonic chord when --prefix is not given.")
|
||
ap.add_argument("--temperature", type=float, default=1.0,
|
||
help="Sampling temperature (default: 1.0).")
|
||
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
|
||
help="Nucleus sampling cutoff (default: 0.9).")
|
||
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
|
||
help="Hard cap on generated tokens (default: 300).")
|
||
ap.add_argument("--bars", type=int, default=None,
|
||
help="Stop after this many complete bars (default: let the model decide).")
|
||
ap.add_argument("--repetition-penalty", type=float, default=0.0,
|
||
dest="repetition_penalty",
|
||
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
|
||
"from ROOT logits at each chord-change position. "
|
||
"0.0 = disabled (default). Suggested: 0.5–1.5.")
|
||
ap.add_argument("--seed", type=int, default=None,
|
||
help="Random seed for reproducibility.")
|
||
ap.add_argument("--tempo", type=int, default=90,
|
||
help="MIDI playback tempo in BPM (default: 90).")
|
||
ap.add_argument("--device", default="auto",
|
||
help="Compute device: cpu, cuda, or auto (default: auto).")
|
||
args = ap.parse_args()
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
|
||
if not args.checkpoint.exists():
|
||
print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
if args.device == "auto":
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
else:
|
||
device = args.device
|
||
|
||
model = _load_model(args.checkpoint, device)
|
||
|
||
target_key = f"{args.key}_{args.mode}"
|
||
|
||
if args.prefix:
|
||
prefix_chords = args.prefix.split()
|
||
elif not args.no_tonic_anchor:
|
||
# Default: anchor to tonic so generation stays in key.
|
||
prefix_chords = [args.key + ("m" if args.mode == "minor" else "")]
|
||
else:
|
||
prefix_chords = None
|
||
|
||
period = generate_period(
|
||
model=model,
|
||
mode=args.mode,
|
||
time=args.time,
|
||
subdivision=args.subdivision,
|
||
style=args.style,
|
||
function=args.function,
|
||
key=target_key,
|
||
prefix=prefix_chords,
|
||
temperature=args.temperature,
|
||
top_p=args.top_p,
|
||
max_tokens=args.max_tokens,
|
||
n_bars=args.bars,
|
||
seed=args.seed,
|
||
repetition_penalty=args.repetition_penalty,
|
||
)
|
||
|
||
# Give generated periods a readable title
|
||
period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})")
|
||
|
||
# Ensure .chord extension
|
||
out_path = args.output
|
||
if out_path.suffix != ".chord":
|
||
out_path = out_path.with_suffix(".chord")
|
||
|
||
write_chord_file(period, out_path)
|
||
print(f"[generate] written -> {out_path}")
|
||
|
||
if args.midi:
|
||
midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid")
|
||
chord_file_to_midi(out_path, midi_path, tempo=args.tempo)
|
||
print(f"[generate] MIDI -> {midi_path}")
|
||
|
||
# Quick summary to stdout
|
||
print()
|
||
print(f" Key: {period.key}")
|
||
print(f" Time: {period.time} subdivision={period.subdivision}")
|
||
print(f" Style: {period.style} function={period.function}")
|
||
print(f" Bars: {len(period.bars)}")
|
||
print()
|
||
for i, bar in enumerate(period.bars, 1):
|
||
print(f" Bar {i:3d}: {' '.join(bar)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|