diff --git a/scripts/generate.py b/scripts/generate.py index adee989..22e39aa 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -80,6 +80,8 @@ def main() -> None: ap.add_argument("--prefix", default=None, help='Space-separated chord symbols in the requested key, ' 'e.g. "Cmaj7 Am7". Used as generation context.') + ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor", + help="Do not prepend the tonic chord when --prefix is not given.") ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (default: 1.0).") ap.add_argument("--top-p", type=float, default=0.9, dest="top_p", @@ -112,7 +114,14 @@ def main() -> None: model = _load_model(args.checkpoint, device) target_key = f"{args.key}_{args.mode}" - prefix_chords = args.prefix.split() if args.prefix else None + + if args.prefix: + prefix_chords = args.prefix.split() + elif not args.no_tonic_anchor: + # Default: anchor to tonic so generation stays in key. + prefix_chords = [args.key + ("m" if args.mode == "minor" else "")] + else: + prefix_chords = None period = generate_period( model=model, diff --git a/scripts/train.py b/scripts/train.py index ed729a0..e267927 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -54,13 +54,17 @@ TRAIN_CFG = TrainConfig( data_dir=DATA_DIR, output=CHECKPOINT, init_from=INIT_FROM, - epochs=15, + # Small corpus (~45 train files) → ~6 batches/epoch. + # 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window. + epochs=50, batch_size=8, lr=1e-5, - warmup_steps=20, + warmup_steps=10, + patience=10, seed=42, device="auto", - max_seq_len=256, + # Must match pretrained checkpoint (max_seq_len=320). + max_seq_len=320, )