fix: fine-tune config and generator improvements
scripts/train.py: fix max_seq_len 256→320 (must match pretrained checkpoint); increase epochs 15→50 and patience 5→10 to give the small corpus enough gradient steps; reduce warmup 20→10 (was 22% of total steps). scripts/generate.py: default to prepending the tonic chord when --prefix is not given; add --no-tonic-anchor to opt out. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+10
-1
@@ -80,6 +80,8 @@ def main() -> None:
|
|||||||
ap.add_argument("--prefix", default=None,
|
ap.add_argument("--prefix", default=None,
|
||||||
help='Space-separated chord symbols in the requested key, '
|
help='Space-separated chord symbols in the requested key, '
|
||||||
'e.g. "Cmaj7 Am7". Used as generation context.')
|
'e.g. "Cmaj7 Am7". Used as generation context.')
|
||||||
|
ap.add_argument("--no-tonic-anchor", action="store_true", dest="no_tonic_anchor",
|
||||||
|
help="Do not prepend the tonic chord when --prefix is not given.")
|
||||||
ap.add_argument("--temperature", type=float, default=1.0,
|
ap.add_argument("--temperature", type=float, default=1.0,
|
||||||
help="Sampling temperature (default: 1.0).")
|
help="Sampling temperature (default: 1.0).")
|
||||||
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
|
ap.add_argument("--top-p", type=float, default=0.9, dest="top_p",
|
||||||
@@ -112,7 +114,14 @@ def main() -> None:
|
|||||||
model = _load_model(args.checkpoint, device)
|
model = _load_model(args.checkpoint, device)
|
||||||
|
|
||||||
target_key = f"{args.key}_{args.mode}"
|
target_key = f"{args.key}_{args.mode}"
|
||||||
prefix_chords = args.prefix.split() if args.prefix else None
|
|
||||||
|
if args.prefix:
|
||||||
|
prefix_chords = args.prefix.split()
|
||||||
|
elif not args.no_tonic_anchor:
|
||||||
|
# Default: anchor to tonic so generation stays in key.
|
||||||
|
prefix_chords = [args.key + ("m" if args.mode == "minor" else "")]
|
||||||
|
else:
|
||||||
|
prefix_chords = None
|
||||||
|
|
||||||
period = generate_period(
|
period = generate_period(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
+7
-3
@@ -54,13 +54,17 @@ TRAIN_CFG = TrainConfig(
|
|||||||
data_dir=DATA_DIR,
|
data_dir=DATA_DIR,
|
||||||
output=CHECKPOINT,
|
output=CHECKPOINT,
|
||||||
init_from=INIT_FROM,
|
init_from=INIT_FROM,
|
||||||
epochs=15,
|
# Small corpus (~45 train files) → ~6 batches/epoch.
|
||||||
|
# 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window.
|
||||||
|
epochs=50,
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
warmup_steps=20,
|
warmup_steps=10,
|
||||||
|
patience=10,
|
||||||
seed=42,
|
seed=42,
|
||||||
device="auto",
|
device="auto",
|
||||||
max_seq_len=256,
|
# Must match pretrained checkpoint (max_seq_len=320).
|
||||||
|
max_seq_len=320,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user