733e1fde1f
AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch CSV logging, best-val-loss checkpointing, early stopping (patience=5). Same script handles both pre-training and fine-tuning via --init-from. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
"""CLI entry point for pre-training and fine-tuning ChordTransformer.
|
|
|
|
Usage (pre-training):
|
|
python scripts/train.py \\
|
|
--data-dir data/processed/pretrain \\
|
|
--output checkpoints/pretrained \\
|
|
--epochs 50 --batch-size 32 --lr 3e-4
|
|
|
|
Usage (fine-tuning):
|
|
python scripts/train.py \\
|
|
--data-dir data/processed/finetune \\
|
|
--init-from checkpoints/pretrained.pt \\
|
|
--output checkpoints/finetuned \\
|
|
--epochs 15 --lr 1e-5
|
|
|
|
The script saves:
|
|
<output>.pt best checkpoint (lowest val loss)
|
|
<output>.log.csv per-epoch metrics
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from src.train import TrainConfig, train # noqa: E402
|
|
|
|
|
|
def _parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(
|
|
description="Train or fine-tune ChordTransformer on tokenized chord data.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# I/O
|
|
io = p.add_argument_group("I/O")
|
|
io.add_argument(
|
|
"--data-dir", required=True, type=Path,
|
|
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).",
|
|
)
|
|
io.add_argument(
|
|
"--output", required=True, type=Path,
|
|
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
|
|
)
|
|
io.add_argument(
|
|
"--init-from", type=Path, default=None,
|
|
help="Checkpoint to load weights from before training (fine-tuning mode).",
|
|
)
|
|
|
|
# Training
|
|
tr = p.add_argument_group("Training")
|
|
tr.add_argument("--epochs", type=int, default=30)
|
|
tr.add_argument("--batch-size", type=int, default=16)
|
|
tr.add_argument("--lr", type=float, default=3e-4)
|
|
tr.add_argument("--warmup-steps", type=int, default=200)
|
|
tr.add_argument("--weight-decay", type=float, default=0.1)
|
|
tr.add_argument("--patience", type=int, default=5,
|
|
help="Early-stopping patience (epochs without val improvement).")
|
|
tr.add_argument("--seed", type=int, default=42)
|
|
tr.add_argument(
|
|
"--device", default="auto", choices=["auto", "cpu", "cuda"],
|
|
help="Compute device. 'auto' selects cuda when available.",
|
|
)
|
|
|
|
# Architecture (ignored when --init-from is given)
|
|
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
|
|
arch.add_argument("--d-model", type=int, default=192)
|
|
arch.add_argument("--n-layers", type=int, default=3)
|
|
arch.add_argument("--n-heads", type=int, default=6)
|
|
arch.add_argument("--d-ff", type=int, default=768)
|
|
arch.add_argument("--dropout", type=float, default=0.1)
|
|
arch.add_argument("--max-seq-len", type=int, default=512)
|
|
|
|
# Logging
|
|
p.add_argument(
|
|
"--log-level", default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
)
|
|
|
|
return p.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = _parse_args()
|
|
logging.basicConfig(
|
|
level=args.log_level,
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
|
|
cfg = TrainConfig(
|
|
data_dir=args.data_dir,
|
|
output=args.output,
|
|
init_from=args.init_from,
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
lr=args.lr,
|
|
warmup_steps=args.warmup_steps,
|
|
weight_decay=args.weight_decay,
|
|
seed=args.seed,
|
|
device=args.device,
|
|
patience=args.patience,
|
|
max_seq_len=args.max_seq_len,
|
|
d_model=args.d_model,
|
|
n_layers=args.n_layers,
|
|
n_heads=args.n_heads,
|
|
d_ff=args.d_ff,
|
|
dropout=args.dropout,
|
|
)
|
|
|
|
checkpoint = train(cfg)
|
|
print(f"best checkpoint: {checkpoint}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|