"""CLI entry point for pre-training and fine-tuning ChordTransformer. Usage (pre-training): python scripts/train.py \\ --data-dir data/processed/pretrain \\ --output checkpoints/pretrained \\ --epochs 50 --batch-size 32 --lr 3e-4 Usage (fine-tuning): python scripts/train.py \\ --data-dir data/processed/finetune \\ --init-from checkpoints/pretrained.pt \\ --output checkpoints/finetuned \\ --epochs 15 --lr 1e-5 The script saves: .pt best checkpoint (lowest val loss) .log.csv per-epoch metrics """ from __future__ import annotations import argparse import logging import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.train import TrainConfig, train # noqa: E402 def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Train or fine-tune ChordTransformer on tokenized chord data.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # I/O io = p.add_argument_group("I/O") io.add_argument( "--data-dir", required=True, type=Path, help="Directory with train/ and val/ sub-directories (output of prepare_data.py).", ) io.add_argument( "--output", required=True, type=Path, help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.", ) io.add_argument( "--init-from", type=Path, default=None, help="Checkpoint to load weights from before training (fine-tuning mode).", ) # Training tr = p.add_argument_group("Training") tr.add_argument("--epochs", type=int, default=30) tr.add_argument("--batch-size", type=int, default=16) tr.add_argument("--lr", type=float, default=3e-4) tr.add_argument("--warmup-steps", type=int, default=200) tr.add_argument("--weight-decay", type=float, default=0.1) tr.add_argument("--patience", type=int, default=5, help="Early-stopping patience (epochs without val improvement).") tr.add_argument("--seed", type=int, default=42) tr.add_argument( "--device", default="auto", choices=["auto", "cpu", "cuda"], help="Compute device. 'auto' selects cuda when available.", ) # Architecture (ignored when --init-from is given) arch = p.add_argument_group("Architecture (ignored when --init-from is set)") arch.add_argument("--d-model", type=int, default=192) arch.add_argument("--n-layers", type=int, default=3) arch.add_argument("--n-heads", type=int, default=6) arch.add_argument("--d-ff", type=int, default=768) arch.add_argument("--dropout", type=float, default=0.1) arch.add_argument("--max-seq-len", type=int, default=512) # Logging p.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], ) return p.parse_args() def main() -> None: args = _parse_args() logging.basicConfig( level=args.log_level, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) cfg = TrainConfig( data_dir=args.data_dir, output=args.output, init_from=args.init_from, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, warmup_steps=args.warmup_steps, weight_decay=args.weight_decay, seed=args.seed, device=args.device, patience=args.patience, max_seq_len=args.max_seq_len, d_model=args.d_model, n_layers=args.n_layers, n_heads=args.n_heads, d_ff=args.d_ff, dropout=args.dropout, ) checkpoint = train(cfg) print(f"best checkpoint: {checkpoint}") if __name__ == "__main__": main()