diff --git a/README.md b/README.md index c59d575..3cfdcbb 100644 --- a/README.md +++ b/README.md @@ -254,39 +254,39 @@ python scripts/prepare_data.py \ ### 7.1 Предобучение -Обучение базовой модели на конвертированном корпусе McGill Billboard: - ```bash -python scripts/train.py \ - --data-dir data/processed/mcgill/ \ - --output checkpoints/pretrained.pt \ - --epochs 50 \ - --batch-size 32 \ - --lr 3e-4 \ - --warmup-steps 200 \ - --seed 42 +python scripts/pretrain.py ``` -По окончании обучения в директории `checkpoints/` появятся: сам чекпоинт, -лог обучения в формате CSV и график кривых train/val loss. +Обучает на корпусе McGill (`data/processed/mcgill/`). Выводит оценку времени +выполнения и по окончании сохраняет: + +| Файл | Описание | +| ----------------------------------- | ----------------------------- | +| `checkpoints/pretrained.pt` | лучший чекпоинт (по val loss) | +| `checkpoints/pretrained.log.csv` | метрики по эпохам | +| `checkpoints/pretrained_curves.png` | график кривых train/val loss | + +Если обучение было прервано, повторно построить график и отчёт без +повторного обучения: + +```bash +python scripts/pretrain.py --skip-training +``` ### 7.2 Дообучение на собственном корпусе ```bash -python scripts/train.py \ - --init-from checkpoints/pretrained.pt \ - --data-dir data/processed/user/ \ - --output checkpoints/finetuned.pt \ - --epochs 15 \ - --batch-size 16 \ - --lr 1e-5 \ - --warmup-steps 20 \ - --seed 42 +python scripts/train.py ``` -Существенно более низкая скорость обучения (на два порядка меньше, чем на -предобучении) и небольшое число эпох предотвращают катастрофическое забывание -закономерностей, выученных на этапе предобучения. +Загружает `checkpoints/pretrained.pt` и дообучает на собственном корпусе +(`data/processed/user/`). Сохраняет `checkpoints/finetuned.pt` и аналогичный +набор артефактов (`finetuned.log.csv`, `finetuned_curves.png`). + +Существенно более низкая скорость обучения (lr=1e-5 против 3e-4) и небольшое +число эпох (15) предотвращают катастрофическое забывание закономерностей, +выученных на этапе предобучения. ## 8. Оценка результатов diff --git a/scripts/run_pretrain.py b/scripts/pretrain.py similarity index 95% rename from scripts/run_pretrain.py rename to scripts/pretrain.py index 44b1777..1399420 100644 --- a/scripts/run_pretrain.py +++ b/scripts/pretrain.py @@ -1,12 +1,11 @@ -"""Run full pre-training on the McGill corpus, then plot loss curves and -print a short diagnostic report. +"""Pre-train ChordTransformer on the McGill Billboard corpus. Usage: # Full run (training + plot + report) - python scripts/run_pretrain.py + python scripts/pretrain.py # Skip training if a checkpoint already exists; only re-plot and report - python scripts/run_pretrain.py --skip-training + python scripts/pretrain.py --skip-training Outputs written: checkpoints/pretrained.pt best checkpoint @@ -38,7 +37,7 @@ from src.tokenizer import TOKEN_TO_ID # Paths # --------------------------------------------------------------------------- -DATA_DIR = Path("data/processed") +DATA_DIR = Path("data/processed/mcgill") CHECKPOINT = Path("checkpoints/pretrained.pt") LOG_CSV = Path("checkpoints/pretrained.log.csv") CURVES_PNG = Path("checkpoints/pretrained_curves.png") @@ -185,7 +184,7 @@ def main() -> None: if not args.skip_training: if not DATA_DIR.exists(): print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) - print("Run prepare_data.py first.", file=sys.stderr) + print("Run: python scripts/prepare_data.py --input-dir data/raw_external/mcgill_chord --output-dir data/processed/mcgill", file=sys.stderr) sys.exit(1) import pathlib n_train = len(list((DATA_DIR / "train").glob("*.pt"))) diff --git a/scripts/train.py b/scripts/train.py index 1b9f76d..5e48fab 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,119 +1,213 @@ -"""CLI entry point for pre-training and fine-tuning ChordTransformer. +"""Fine-tune ChordTransformer on the personal (user) chord corpus. -Usage (pre-training): - python scripts/train.py \\ - --data-dir data/processed/pretrain \\ - --output checkpoints/pretrained \\ - --epochs 50 --batch-size 32 --lr 3e-4 +Requires a pre-trained checkpoint produced by scripts/pretrain.py. -Usage (fine-tuning): - python scripts/train.py \\ - --data-dir data/processed/finetune \\ - --init-from checkpoints/pretrained.pt \\ - --output checkpoints/finetuned \\ - --epochs 15 --lr 1e-5 +Usage: + # Full run (fine-tuning + plot + report) + python scripts/train.py -The script saves: - .pt best checkpoint (lowest val loss) - .log.csv per-epoch metrics + # Skip training; re-plot and report from existing CSV + python scripts/train.py --skip-training + +Outputs written: + checkpoints/finetuned.pt best checkpoint + checkpoints/finetuned.log.csv per-epoch metrics + checkpoints/finetuned_curves.png train/val loss plot """ from __future__ import annotations import argparse +import csv import logging +import math import sys from pathlib import Path +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch + sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) -from src.train import TrainConfig, train # noqa: E402 +from src.model import ChordTransformer +from src.train import TrainConfig, train +from src.tokenizer import TOKEN_TO_ID + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +DATA_DIR = Path("data/processed/user") +INIT_FROM = Path("checkpoints/pretrained.pt") +CHECKPOINT = Path("checkpoints/finetuned.pt") +LOG_CSV = Path("checkpoints/finetuned.log.csv") +CURVES_PNG = Path("checkpoints/finetuned_curves.png") + +# --------------------------------------------------------------------------- +# Training config +# --------------------------------------------------------------------------- + +TRAIN_CFG = TrainConfig( + data_dir=DATA_DIR, + output=CHECKPOINT, + init_from=INIT_FROM, + epochs=15, + batch_size=8, + lr=1e-5, + warmup_steps=20, + seed=42, + device="auto", + max_seq_len=256, +) -def _parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser( - description="Train or fine-tune ChordTransformer on tokenized chord data.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_curves(log_csv: Path, out_png: Path) -> None: + epochs, train_losses, val_losses, val_ppls = [], [], [], [] + with open(log_csv, newline="") as fh: + for row in csv.DictReader(fh): + epochs.append(int(row["epoch"])) + train_losses.append(float(row["train_loss"])) + val_losses.append(float(row["val_loss"])) + val_ppls.append(float(row["val_ppl"])) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4)) + + ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5) + ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5) + best_epoch = epochs[val_losses.index(min(val_losses))] + ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8, + label=f"best epoch {best_epoch}") + ax1.set_xlabel("epoch") + ax1.set_ylabel("cross-entropy loss") + ax1.set_title("Fine-tuning loss") + ax1.legend() + ax1.grid(True, alpha=0.3) + + ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5) + ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8) + ax2.set_xlabel("epoch") + ax2.set_ylabel("perplexity") + ax2.set_title("Val perplexity") + ax2.grid(True, alpha=0.3) + + fig.tight_layout() + out_png.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_png, dpi=150) + plt.close(fig) + print(f"[plot] saved → {out_png}") + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + +def print_report(log_csv: Path, checkpoint: Path) -> None: + rows = [] + with open(log_csv, newline="") as fh: + rows = list(csv.DictReader(fh)) + + if not rows: + print("[report] log CSV is empty — nothing to report") + return + + val_losses = [float(r["val_loss"]) for r in rows] + best_idx = val_losses.index(min(val_losses)) + best_row = rows[best_idx] + + best_loss = float(best_row["val_loss"]) + conv_epoch = next( + (int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01), + int(best_row["epoch"]), ) - # I/O - io = p.add_argument_group("I/O") - io.add_argument( - "--data-dir", required=True, type=Path, - help="Directory with train/ and val/ sub-directories (output of prepare_data.py).", - ) - io.add_argument( - "--output", required=True, type=Path, - help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.", - ) - io.add_argument( - "--init-from", type=Path, default=None, - help="Checkpoint to load weights from before training (fine-tuning mode).", - ) + n_params = None + if checkpoint.exists(): + ckpt = torch.load(checkpoint, weights_only=True) + model = ChordTransformer(**ckpt["model_config"]) + tied = model.token_emb.weight.numel() + n_params = sum(p.numel() for p in model.parameters()) - tied - # Training - tr = p.add_argument_group("Training") - tr.add_argument("--epochs", type=int, default=30) - tr.add_argument("--batch-size", type=int, default=16) - tr.add_argument("--lr", type=float, default=3e-4) - tr.add_argument("--warmup-steps", type=int, default=200) - tr.add_argument("--weight-decay", type=float, default=0.1) - tr.add_argument("--patience", type=int, default=5, - help="Early-stopping patience (epochs without val improvement).") - tr.add_argument("--seed", type=int, default=42) - tr.add_argument( - "--device", default="auto", choices=["auto", "cpu", "cuda"], - help="Compute device. 'auto' selects cuda when available.", - ) + print() + print("=" * 52) + print(" FINE-TUNING REPORT") + print("=" * 52) + print(f" Total epochs run : {len(rows)}") + print(f" Best epoch (val loss) : {best_row['epoch']}") + print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)") + print(f" Best val loss : {best_loss:.4f}") + print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}") + print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}") + if n_params is not None: + print(f" Unique parameters : {n_params:,}") + print(f" Checkpoint : {checkpoint}") + print(f" Log CSV : {log_csv}") + print("=" * 52) + print() - # Architecture (ignored when --init-from is given) - arch = p.add_argument_group("Architecture (ignored when --init-from is set)") - arch.add_argument("--d-model", type=int, default=192) - arch.add_argument("--n-layers", type=int, default=3) - arch.add_argument("--n-heads", type=int, default=6) - arch.add_argument("--d-ff", type=int, default=768) - arch.add_argument("--dropout", type=float, default=0.1) - arch.add_argument("--max-seq-len", type=int, default=512) + print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}") + print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") + for r in rows: + marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" + print( + f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" + f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" + f" {float(r['lr']):>10.2e}{marker}" + ) + print() - # Logging - p.add_argument( - "--log-level", default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR"], - ) - - return p.parse_args() +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- def main() -> None: - args = _parse_args() + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--skip-training", action="store_true", + help="Skip training; re-plot and report from existing CSV.") + args = ap.parse_args() + logging.basicConfig( - level=args.log_level, + level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", ) - cfg = TrainConfig( - data_dir=args.data_dir, - output=args.output, - init_from=args.init_from, - epochs=args.epochs, - batch_size=args.batch_size, - lr=args.lr, - warmup_steps=args.warmup_steps, - weight_decay=args.weight_decay, - seed=args.seed, - device=args.device, - patience=args.patience, - max_seq_len=args.max_seq_len, - d_model=args.d_model, - n_layers=args.n_layers, - n_heads=args.n_heads, - d_ff=args.d_ff, - dropout=args.dropout, - ) + if not args.skip_training: + if not INIT_FROM.exists(): + print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr) + print("Run python scripts/pretrain.py first.", file=sys.stderr) + sys.exit(1) + if not DATA_DIR.exists(): + print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) + print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr) + sys.exit(1) + n_train = len(list((DATA_DIR / "train").glob("*.pt"))) + n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size + est_epoch_s = n_batches * 1.5 + device_label = "GPU" if torch.cuda.is_available() else "CPU" + print( + f"[train] {n_train} train files, {n_batches} batches/epoch\n" + f"[train] estimated time on {device_label}: " + f"~{est_epoch_s/60:.0f} min/epoch, " + f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n" + f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n" + ) + train(TRAIN_CFG) + else: + if not LOG_CSV.exists(): + print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr) + sys.exit(1) + print(f"[skip-training] using existing log: {LOG_CSV}") - checkpoint = train(cfg) - print(f"best checkpoint: {checkpoint}") + plot_curves(LOG_CSV, CURVES_PNG) + print_report(LOG_CSV, CHECKPOINT) if __name__ == "__main__":