From 2a3eb1783a0abd86eb72edf8ddd56a87fbf570f6 Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Thu, 21 May 2026 10:15:48 +0300 Subject: [PATCH] fix: fine-tune config and generator improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit scripts/train.py: fix max_seq_len 256→320 (must match pretrained checkpoint); increase epochs 15→50 and patience 5→10 to give the small corpus enough gradient steps; reduce warmup 20→10 (was 22% of total steps). scripts/generate.py: default to prepending the tonic chord when --prefix is not given; add --no-tonic-anchor to opt out. Co-Authored-By: Claude Sonnet 4.6 --- scripts/generate.py | 11 ++++++++++- scripts/train.py | 10 +++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/scripts/generate.py b/scripts/generate.py index adee989..22e39aa 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -80,6 +80,8 @@ def main() -> None: ap.add_argument("--prefix", default=None, help='Space-separated chord symbols in the requested key, ' 'e.g. "Cmaj7 Am7". Used as generation context.') + ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor", + help="Do not prepend the tonic chord when --prefix is not given.") ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (default: 1.0).") ap.add_argument("--top-p", type=float, default=0.9, dest="top_p", @@ -112,7 +114,14 @@ def main() -> None: model = _load_model(args.checkpoint, device) target_key = f"{args.key}_{args.mode}" - prefix_chords = args.prefix.split() if args.prefix else None + + if args.prefix: + prefix_chords = args.prefix.split() + elif not args.no_tonic_anchor: + # Default: anchor to tonic so generation stays in key. + prefix_chords = [args.key + ("m" if args.mode == "minor" else "")] + else: + prefix_chords = None period = generate_period( model=model, diff --git a/scripts/train.py b/scripts/train.py index ed729a0..e267927 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -54,13 +54,17 @@ TRAIN_CFG = TrainConfig( data_dir=DATA_DIR, output=CHECKPOINT, init_from=INIT_FROM, - epochs=15, + # Small corpus (~45 train files) → ~6 batches/epoch. + # 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window. + epochs=50, batch_size=8, lr=1e-5, - warmup_steps=20, + warmup_steps=10, + patience=10, seed=42, device="auto", - max_seq_len=256, + # Must match pretrained checkpoint (max_seq_len=320). + max_seq_len=320, )