refactor: split training scripts into pretrain.py and train.py

- scripts/run_pretrain.py -> scripts/pretrain.py: pre-trains on McGill
  corpus (data/processed/mcgill/), saves checkpoints/pretrained.pt.
- scripts/train.py: rewritten as high-level fine-tune wrapper; loads
  pretrained.pt, trains on data/processed/user/, saves finetuned.pt.
  Both scripts include timing estimate, loss-curve plot, per-epoch report,
  and --skip-training flag.
- README: updated section 7 to reflect new script names and separate
  data directories.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 12:35:23 +03:00
parent 65c3f6bf7c
commit 632407ebef
3 changed files with 209 additions and 116 deletions
+24 -24
View File
@@ -254,39 +254,39 @@ python scripts/prepare_data.py \
### 7.1 Предобучение ### 7.1 Предобучение
Обучение базовой модели на конвертированном корпусе McGill Billboard:
```bash ```bash
python scripts/train.py \ python scripts/pretrain.py
--data-dir data/processed/mcgill/ \
--output checkpoints/pretrained.pt \
--epochs 50 \
--batch-size 32 \
--lr 3e-4 \
--warmup-steps 200 \
--seed 42
``` ```
По окончании обучения в директории `checkpoints/` появятся: сам чекпоинт, Обучает на корпусе McGill (`data/processed/mcgill/`). Выводит оценку времени
лог обучения в формате CSV и график кривых train/val loss. выполнения и по окончании сохраняет:
| Файл | Описание |
| ----------------------------------- | ----------------------------- |
| `checkpoints/pretrained.pt` | лучший чекпоинт (по val loss) |
| `checkpoints/pretrained.log.csv` | метрики по эпохам |
| `checkpoints/pretrained_curves.png` | график кривых train/val loss |
Если обучение было прервано, повторно построить график и отчёт без
повторного обучения:
```bash
python scripts/pretrain.py --skip-training
```
### 7.2 Дообучение на собственном корпусе ### 7.2 Дообучение на собственном корпусе
```bash ```bash
python scripts/train.py \ python scripts/train.py
--init-from checkpoints/pretrained.pt \
--data-dir data/processed/user/ \
--output checkpoints/finetuned.pt \
--epochs 15 \
--batch-size 16 \
--lr 1e-5 \
--warmup-steps 20 \
--seed 42
``` ```
Существенно более низкая скорость обучения (на два порядка меньше, чем на Загружает `checkpoints/pretrained.pt` и дообучает на собственном корпусе
предобучении) и небольшое число эпох предотвращают катастрофическое забывание (`data/processed/user/`). Сохраняет `checkpoints/finetuned.pt` и аналогичный
закономерностей, выученных на этапе предобучения. набор артефактов (`finetuned.log.csv`, `finetuned_curves.png`).
Существенно более низкая скорость обучения (lr=1e-5 против 3e-4) и небольшое
число эпох (15) предотвращают катастрофическое забывание закономерностей,
выученных на этапе предобучения.
## 8. Оценка результатов ## 8. Оценка результатов
@@ -1,12 +1,11 @@
"""Run full pre-training on the McGill corpus, then plot loss curves and """Pre-train ChordTransformer on the McGill Billboard corpus.
print a short diagnostic report.
Usage: Usage:
# Full run (training + plot + report) # Full run (training + plot + report)
python scripts/run_pretrain.py python scripts/pretrain.py
# Skip training if a checkpoint already exists; only re-plot and report # Skip training if a checkpoint already exists; only re-plot and report
python scripts/run_pretrain.py --skip-training python scripts/pretrain.py --skip-training
Outputs written: Outputs written:
checkpoints/pretrained.pt best checkpoint checkpoints/pretrained.pt best checkpoint
@@ -38,7 +37,7 @@ from src.tokenizer import TOKEN_TO_ID
# Paths # Paths
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
DATA_DIR = Path("data/processed") DATA_DIR = Path("data/processed/mcgill")
CHECKPOINT = Path("checkpoints/pretrained.pt") CHECKPOINT = Path("checkpoints/pretrained.pt")
LOG_CSV = Path("checkpoints/pretrained.log.csv") LOG_CSV = Path("checkpoints/pretrained.log.csv")
CURVES_PNG = Path("checkpoints/pretrained_curves.png") CURVES_PNG = Path("checkpoints/pretrained_curves.png")
@@ -185,7 +184,7 @@ def main() -> None:
if not args.skip_training: if not args.skip_training:
if not DATA_DIR.exists(): if not DATA_DIR.exists():
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr) print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
print("Run prepare_data.py first.", file=sys.stderr) print("Run: python scripts/prepare_data.py --input-dir data/raw_external/mcgill_chord --output-dir data/processed/mcgill", file=sys.stderr)
sys.exit(1) sys.exit(1)
import pathlib import pathlib
n_train = len(list((DATA_DIR / "train").glob("*.pt"))) n_train = len(list((DATA_DIR / "train").glob("*.pt")))
+179 -85
View File
@@ -1,119 +1,213 @@
"""CLI entry point for pre-training and fine-tuning ChordTransformer. """Fine-tune ChordTransformer on the personal (user) chord corpus.
Usage (pre-training): Requires a pre-trained checkpoint produced by scripts/pretrain.py.
python scripts/train.py \\
--data-dir data/processed/pretrain \\
--output checkpoints/pretrained \\
--epochs 50 --batch-size 32 --lr 3e-4
Usage (fine-tuning): Usage:
python scripts/train.py \\ # Full run (fine-tuning + plot + report)
--data-dir data/processed/finetune \\ python scripts/train.py
--init-from checkpoints/pretrained.pt \\
--output checkpoints/finetuned \\
--epochs 15 --lr 1e-5
The script saves: # Skip training; re-plot and report from existing CSV
<output>.pt best checkpoint (lowest val loss) python scripts/train.py --skip-training
<output>.log.csv per-epoch metrics
Outputs written:
checkpoints/finetuned.pt best checkpoint
checkpoints/finetuned.log.csv per-epoch metrics
checkpoints/finetuned_curves.png train/val loss plot
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import csv
import logging import logging
import math
import sys import sys
from pathlib import Path from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.train import TrainConfig, train # noqa: E402 from src.model import ChordTransformer
from src.train import TrainConfig, train
from src.tokenizer import TOKEN_TO_ID
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
DATA_DIR = Path("data/processed/user")
INIT_FROM = Path("checkpoints/pretrained.pt")
CHECKPOINT = Path("checkpoints/finetuned.pt")
LOG_CSV = Path("checkpoints/finetuned.log.csv")
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
# ---------------------------------------------------------------------------
# Training config
# ---------------------------------------------------------------------------
TRAIN_CFG = TrainConfig(
data_dir=DATA_DIR,
output=CHECKPOINT,
init_from=INIT_FROM,
epochs=15,
batch_size=8,
lr=1e-5,
warmup_steps=20,
seed=42,
device="auto",
max_seq_len=256,
)
def _parse_args() -> argparse.Namespace: # ---------------------------------------------------------------------------
p = argparse.ArgumentParser( # Plotting
description="Train or fine-tune ChordTransformer on tokenized chord data.", # ---------------------------------------------------------------------------
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
def plot_curves(log_csv: Path, out_png: Path) -> None:
epochs, train_losses, val_losses, val_ppls = [], [], [], []
with open(log_csv, newline="") as fh:
for row in csv.DictReader(fh):
epochs.append(int(row["epoch"]))
train_losses.append(float(row["train_loss"]))
val_losses.append(float(row["val_loss"]))
val_ppls.append(float(row["val_ppl"]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
best_epoch = epochs[val_losses.index(min(val_losses))]
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
label=f"best epoch {best_epoch}")
ax1.set_xlabel("epoch")
ax1.set_ylabel("cross-entropy loss")
ax1.set_title("Fine-tuning loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
ax2.set_xlabel("epoch")
ax2.set_ylabel("perplexity")
ax2.set_title("Val perplexity")
ax2.grid(True, alpha=0.3)
fig.tight_layout()
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
print(f"[plot] saved → {out_png}")
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def print_report(log_csv: Path, checkpoint: Path) -> None:
rows = []
with open(log_csv, newline="") as fh:
rows = list(csv.DictReader(fh))
if not rows:
print("[report] log CSV is empty — nothing to report")
return
val_losses = [float(r["val_loss"]) for r in rows]
best_idx = val_losses.index(min(val_losses))
best_row = rows[best_idx]
best_loss = float(best_row["val_loss"])
conv_epoch = next(
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
int(best_row["epoch"]),
) )
# I/O n_params = None
io = p.add_argument_group("I/O") if checkpoint.exists():
io.add_argument( ckpt = torch.load(checkpoint, weights_only=True)
"--data-dir", required=True, type=Path, model = ChordTransformer(**ckpt["model_config"])
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).", tied = model.token_emb.weight.numel()
) n_params = sum(p.numel() for p in model.parameters()) - tied
io.add_argument(
"--output", required=True, type=Path,
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
)
io.add_argument(
"--init-from", type=Path, default=None,
help="Checkpoint to load weights from before training (fine-tuning mode).",
)
# Training print()
tr = p.add_argument_group("Training") print("=" * 52)
tr.add_argument("--epochs", type=int, default=30) print(" FINE-TUNING REPORT")
tr.add_argument("--batch-size", type=int, default=16) print("=" * 52)
tr.add_argument("--lr", type=float, default=3e-4) print(f" Total epochs run : {len(rows)}")
tr.add_argument("--warmup-steps", type=int, default=200) print(f" Best epoch (val loss) : {best_row['epoch']}")
tr.add_argument("--weight-decay", type=float, default=0.1) print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)")
tr.add_argument("--patience", type=int, default=5, print(f" Best val loss : {best_loss:.4f}")
help="Early-stopping patience (epochs without val improvement).") print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}")
tr.add_argument("--seed", type=int, default=42) print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}")
tr.add_argument( if n_params is not None:
"--device", default="auto", choices=["auto", "cpu", "cuda"], print(f" Unique parameters : {n_params:,}")
help="Compute device. 'auto' selects cuda when available.", print(f" Checkpoint : {checkpoint}")
print(f" Log CSV : {log_csv}")
print("=" * 52)
print()
print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}")
print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}")
for r in rows:
marker = "" if int(r["epoch"]) == int(best_row["epoch"]) else ""
print(
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
f" {float(r['lr']):>10.2e}{marker}"
) )
print()
# Architecture (ignored when --init-from is given)
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
arch.add_argument("--d-model", type=int, default=192)
arch.add_argument("--n-layers", type=int, default=3)
arch.add_argument("--n-heads", type=int, default=6)
arch.add_argument("--d-ff", type=int, default=768)
arch.add_argument("--dropout", type=float, default=0.1)
arch.add_argument("--max-seq-len", type=int, default=512)
# Logging
p.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
)
return p.parse_args()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None: def main() -> None:
args = _parse_args() ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--skip-training", action="store_true",
help="Skip training; re-plot and report from existing CSV.")
args = ap.parse_args()
logging.basicConfig( logging.basicConfig(
level=args.log_level, level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s", format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S", datefmt="%H:%M:%S",
) )
cfg = TrainConfig( if not args.skip_training:
data_dir=args.data_dir, if not INIT_FROM.exists():
output=args.output, print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr)
init_from=args.init_from, print("Run python scripts/pretrain.py first.", file=sys.stderr)
epochs=args.epochs, sys.exit(1)
batch_size=args.batch_size, if not DATA_DIR.exists():
lr=args.lr, print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
warmup_steps=args.warmup_steps, print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr)
weight_decay=args.weight_decay, sys.exit(1)
seed=args.seed, n_train = len(list((DATA_DIR / "train").glob("*.pt")))
device=args.device, n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
patience=args.patience, est_epoch_s = n_batches * 1.5
max_seq_len=args.max_seq_len, device_label = "GPU" if torch.cuda.is_available() else "CPU"
d_model=args.d_model, print(
n_layers=args.n_layers, f"[train] {n_train} train files, {n_batches} batches/epoch\n"
n_heads=args.n_heads, f"[train] estimated time on {device_label}: "
d_ff=args.d_ff, f"~{est_epoch_s/60:.0f} min/epoch, "
dropout=args.dropout, f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
) )
train(TRAIN_CFG)
else:
if not LOG_CSV.exists():
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
sys.exit(1)
print(f"[skip-training] using existing log: {LOG_CSV}")
checkpoint = train(cfg) plot_curves(LOG_CSV, CURVES_PNG)
print(f"best checkpoint: {checkpoint}") print_report(LOG_CSV, CHECKPOINT)
if __name__ == "__main__": if __name__ == "__main__":