feat: implement training loop and CLI (src/train.py, scripts/train.py)

AdamW + cosine-with-warmup schedule, PAD-ignoring cross-entropy, per-epoch
CSV logging, best-val-loss checkpointing, early stopping (patience=5).
Same script handles both pre-training and fine-tuning via --init-from.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 11:15:39 +03:00
parent 10229be042
commit 733e1fde1f
2 changed files with 460 additions and 0 deletions
+120
View File
@@ -0,0 +1,120 @@
"""CLI entry point for pre-training and fine-tuning ChordTransformer.
Usage (pre-training):
python scripts/train.py \\
--data-dir data/processed/pretrain \\
--output checkpoints/pretrained \\
--epochs 50 --batch-size 32 --lr 3e-4
Usage (fine-tuning):
python scripts/train.py \\
--data-dir data/processed/finetune \\
--init-from checkpoints/pretrained.pt \\
--output checkpoints/finetuned \\
--epochs 15 --lr 1e-5
The script saves:
<output>.pt best checkpoint (lowest val loss)
<output>.log.csv per-epoch metrics
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.train import TrainConfig, train # noqa: E402
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Train or fine-tune ChordTransformer on tokenized chord data.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# I/O
io = p.add_argument_group("I/O")
io.add_argument(
"--data-dir", required=True, type=Path,
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).",
)
io.add_argument(
"--output", required=True, type=Path,
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
)
io.add_argument(
"--init-from", type=Path, default=None,
help="Checkpoint to load weights from before training (fine-tuning mode).",
)
# Training
tr = p.add_argument_group("Training")
tr.add_argument("--epochs", type=int, default=30)
tr.add_argument("--batch-size", type=int, default=16)
tr.add_argument("--lr", type=float, default=3e-4)
tr.add_argument("--warmup-steps", type=int, default=200)
tr.add_argument("--weight-decay", type=float, default=0.1)
tr.add_argument("--patience", type=int, default=5,
help="Early-stopping patience (epochs without val improvement).")
tr.add_argument("--seed", type=int, default=42)
tr.add_argument(
"--device", default="auto", choices=["auto", "cpu", "cuda"],
help="Compute device. 'auto' selects cuda when available.",
)
# Architecture (ignored when --init-from is given)
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
arch.add_argument("--d-model", type=int, default=192)
arch.add_argument("--n-layers", type=int, default=3)
arch.add_argument("--n-heads", type=int, default=6)
arch.add_argument("--d-ff", type=int, default=768)
arch.add_argument("--dropout", type=float, default=0.1)
arch.add_argument("--max-seq-len", type=int, default=512)
# Logging
p.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
)
return p.parse_args()
def main() -> None:
args = _parse_args()
logging.basicConfig(
level=args.log_level,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
cfg = TrainConfig(
data_dir=args.data_dir,
output=args.output,
init_from=args.init_from,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
seed=args.seed,
device=args.device,
patience=args.patience,
max_seq_len=args.max_seq_len,
d_model=args.d_model,
n_layers=args.n_layers,
n_heads=args.n_heads,
d_ff=args.d_ff,
dropout=args.dropout,
)
checkpoint = train(cfg)
print(f"best checkpoint: {checkpoint}")
if __name__ == "__main__":
main()