Files
hamori/scripts/generate.py
T
H1K0 f00a6c1b3a feat: add bigram repetition penalty to generate_period
Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events.
At each FREE position, subtracts penalty * count(prev→root) from ROOT logits,
capped at 3.0 to prevent NC/HOLD flooding at extreme values.

Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0
(aggressive). Default 0.0 keeps backward compatibility.

Added --repetition-penalty flag to scripts/generate.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 15:19:24 +03:00

180 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Generate a harmonic period using a trained ChordTransformer.
Usage:
python scripts/generate.py \\
--checkpoint checkpoints/finetuned.pt \\
--mode major --key F# \\
--style H1K0 --function chorus \\
--time 4/4 --subdivision 4 \\
--output out.chord \\
[--midi out.mid] \\
[--prefix "Cmaj7 Am7"] \\
[--temperature 1.0] [--top-p 0.9] \\
[--max-tokens 300] [--seed 42] \\
[--tempo 90]
Outputs:
<output> generated .chord file in the requested key
<output>.mid (if --midi) MIDI rendering of the period
"""
from __future__ import annotations
import argparse
import logging
import sys
from dataclasses import replace
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.generate import generate_period
from src.midi_export import chord_file_to_midi
from src.model import ChordTransformer
from src.tokenizer import write_chord_file
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_model(checkpoint: Path, device: str) -> ChordTransformer:
ckpt = torch.load(checkpoint, map_location=device, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()
return model
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
ap.add_argument("--checkpoint", type=Path, required=True,
help="Path to .pt checkpoint (pretrained or finetuned).")
ap.add_argument("--mode", choices=["major", "minor"], required=True,
help="Tonal mode.")
ap.add_argument("--key", required=True,
help="Root note of the output key, e.g. F#, Bb, C.")
ap.add_argument("--style", default="H1K0",
help="Style tag (default: H1K0).")
ap.add_argument("--function", default="unspecified",
help="Section label: verse, chorus, bridge, ... (default: unspecified).")
ap.add_argument("--time", default="4/4",
help="Time signature (default: 4/4).")
ap.add_argument("--subdivision", type=int, default=4, choices=[4, 8],
help="Positions per beat unit (default: 4).")
ap.add_argument("--output", type=Path, required=True,
help="Output file path. Extension .chord is appended if missing.")
ap.add_argument("--midi", type=Path, default=None,
help="Optional output MIDI file path.")
ap.add_argument("--prefix", default=None,
help='Space-separated chord symbols in the requested key, '
'e.g. "Cmaj7 . Am7 .". Use "." for held positions '
'and "NC" for no-chord positions.')
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
help="Do not prepend the tonic chord when --prefix is not given.")
ap.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature (default: 1.0).")
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
help="Nucleus sampling cutoff (default: 0.9).")
ap.add_argument("--max-tokens", type=int, default=300, dest="max_tokens",
help="Hard cap on generated tokens (default: 300).")
ap.add_argument("--bars", type=int, default=None,
help="Stop after this many complete bars (default: let the model decide).")
ap.add_argument("--repetition-penalty", type=float, default=0.0,
dest="repetition_penalty",
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
"from ROOT logits at each chord-change position. "
"0.0 = disabled (default). Suggested: 0.51.5.")
ap.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility.")
ap.add_argument("--tempo", type=int, default=90,
help="MIDI playback tempo in BPM (default: 90).")
ap.add_argument("--device", default="auto",
help="Compute device: cpu, cuda, or auto (default: auto).")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
if not args.checkpoint.exists():
print(f"ERROR: checkpoint not found: {args.checkpoint}", file=sys.stderr)
sys.exit(1)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model = _load_model(args.checkpoint, device)
target_key = f"{args.key}_{args.mode}"
if args.prefix:
prefix_chords = args.prefix.split()
elif not args.no_tonic_anchor:
# Default: anchor to tonic so generation stays in key.
prefix_chords = [args.key + ("m" if args.mode == "minor" else "")]
else:
prefix_chords = None
period = generate_period(
model=model,
mode=args.mode,
time=args.time,
subdivision=args.subdivision,
style=args.style,
function=args.function,
key=target_key,
prefix=prefix_chords,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
n_bars=args.bars,
seed=args.seed,
repetition_penalty=args.repetition_penalty,
)
# Give generated periods a readable title
period = replace(period, title=f"Generated ({args.key} {args.mode}, {args.function})")
# Ensure .chord extension
out_path = args.output
if out_path.suffix != ".chord":
out_path = out_path.with_suffix(".chord")
write_chord_file(period, out_path)
print(f"[generate] written -> {out_path}")
if args.midi:
midi_path = args.midi if args.midi.suffix == ".mid" else args.midi.with_suffix(".mid")
chord_file_to_midi(out_path, midi_path, tempo=args.tempo)
print(f"[generate] MIDI -> {midi_path}")
# Quick summary to stdout
print()
print(f" Key: {period.key}")
print(f" Time: {period.time} subdivision={period.subdivision}")
print(f" Style: {period.style} function={period.function}")
print(f" Bars: {len(period.bars)}")
print()
for i, bar in enumerate(period.bars, 1):
print(f" Bar {i:3d}: {' '.join(bar)}")
if __name__ == "__main__":
main()