refactor: split training scripts into pretrain.py and train.py
- scripts/run_pretrain.py -> scripts/pretrain.py: pre-trains on McGill corpus (data/processed/mcgill/), saves checkpoints/pretrained.pt. - scripts/train.py: rewritten as high-level fine-tune wrapper; loads pretrained.pt, trains on data/processed/user/, saves finetuned.pt. Both scripts include timing estimate, loss-curve plot, per-epoch report, and --skip-training flag. - README: updated section 7 to reflect new script names and separate data directories. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -254,39 +254,39 @@ python scripts/prepare_data.py \
|
|||||||
|
|
||||||
### 7.1 Предобучение
|
### 7.1 Предобучение
|
||||||
|
|
||||||
Обучение базовой модели на конвертированном корпусе McGill Billboard:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/train.py \
|
python scripts/pretrain.py
|
||||||
--data-dir data/processed/mcgill/ \
|
|
||||||
--output checkpoints/pretrained.pt \
|
|
||||||
--epochs 50 \
|
|
||||||
--batch-size 32 \
|
|
||||||
--lr 3e-4 \
|
|
||||||
--warmup-steps 200 \
|
|
||||||
--seed 42
|
|
||||||
```
|
```
|
||||||
|
|
||||||
По окончании обучения в директории `checkpoints/` появятся: сам чекпоинт,
|
Обучает на корпусе McGill (`data/processed/mcgill/`). Выводит оценку времени
|
||||||
лог обучения в формате CSV и график кривых train/val loss.
|
выполнения и по окончании сохраняет:
|
||||||
|
|
||||||
|
| Файл | Описание |
|
||||||
|
| ----------------------------------- | ----------------------------- |
|
||||||
|
| `checkpoints/pretrained.pt` | лучший чекпоинт (по val loss) |
|
||||||
|
| `checkpoints/pretrained.log.csv` | метрики по эпохам |
|
||||||
|
| `checkpoints/pretrained_curves.png` | график кривых train/val loss |
|
||||||
|
|
||||||
|
Если обучение было прервано, повторно построить график и отчёт без
|
||||||
|
повторного обучения:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/pretrain.py --skip-training
|
||||||
|
```
|
||||||
|
|
||||||
### 7.2 Дообучение на собственном корпусе
|
### 7.2 Дообучение на собственном корпусе
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/train.py \
|
python scripts/train.py
|
||||||
--init-from checkpoints/pretrained.pt \
|
|
||||||
--data-dir data/processed/user/ \
|
|
||||||
--output checkpoints/finetuned.pt \
|
|
||||||
--epochs 15 \
|
|
||||||
--batch-size 16 \
|
|
||||||
--lr 1e-5 \
|
|
||||||
--warmup-steps 20 \
|
|
||||||
--seed 42
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Существенно более низкая скорость обучения (на два порядка меньше, чем на
|
Загружает `checkpoints/pretrained.pt` и дообучает на собственном корпусе
|
||||||
предобучении) и небольшое число эпох предотвращают катастрофическое забывание
|
(`data/processed/user/`). Сохраняет `checkpoints/finetuned.pt` и аналогичный
|
||||||
закономерностей, выученных на этапе предобучения.
|
набор артефактов (`finetuned.log.csv`, `finetuned_curves.png`).
|
||||||
|
|
||||||
|
Существенно более низкая скорость обучения (lr=1e-5 против 3e-4) и небольшое
|
||||||
|
число эпох (15) предотвращают катастрофическое забывание закономерностей,
|
||||||
|
выученных на этапе предобучения.
|
||||||
|
|
||||||
## 8. Оценка результатов
|
## 8. Оценка результатов
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
"""Run full pre-training on the McGill corpus, then plot loss curves and
|
"""Pre-train ChordTransformer on the McGill Billboard corpus.
|
||||||
print a short diagnostic report.
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Full run (training + plot + report)
|
# Full run (training + plot + report)
|
||||||
python scripts/run_pretrain.py
|
python scripts/pretrain.py
|
||||||
|
|
||||||
# Skip training if a checkpoint already exists; only re-plot and report
|
# Skip training if a checkpoint already exists; only re-plot and report
|
||||||
python scripts/run_pretrain.py --skip-training
|
python scripts/pretrain.py --skip-training
|
||||||
|
|
||||||
Outputs written:
|
Outputs written:
|
||||||
checkpoints/pretrained.pt best checkpoint
|
checkpoints/pretrained.pt best checkpoint
|
||||||
@@ -38,7 +37,7 @@ from src.tokenizer import TOKEN_TO_ID
|
|||||||
# Paths
|
# Paths
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
DATA_DIR = Path("data/processed")
|
DATA_DIR = Path("data/processed/mcgill")
|
||||||
CHECKPOINT = Path("checkpoints/pretrained.pt")
|
CHECKPOINT = Path("checkpoints/pretrained.pt")
|
||||||
LOG_CSV = Path("checkpoints/pretrained.log.csv")
|
LOG_CSV = Path("checkpoints/pretrained.log.csv")
|
||||||
CURVES_PNG = Path("checkpoints/pretrained_curves.png")
|
CURVES_PNG = Path("checkpoints/pretrained_curves.png")
|
||||||
@@ -185,7 +184,7 @@ def main() -> None:
|
|||||||
if not args.skip_training:
|
if not args.skip_training:
|
||||||
if not DATA_DIR.exists():
|
if not DATA_DIR.exists():
|
||||||
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
|
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
|
||||||
print("Run prepare_data.py first.", file=sys.stderr)
|
print("Run: python scripts/prepare_data.py --input-dir data/raw_external/mcgill_chord --output-dir data/processed/mcgill", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
import pathlib
|
import pathlib
|
||||||
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
|
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
|
||||||
+180
-86
@@ -1,119 +1,213 @@
|
|||||||
"""CLI entry point for pre-training and fine-tuning ChordTransformer.
|
"""Fine-tune ChordTransformer on the personal (user) chord corpus.
|
||||||
|
|
||||||
Usage (pre-training):
|
Requires a pre-trained checkpoint produced by scripts/pretrain.py.
|
||||||
python scripts/train.py \\
|
|
||||||
--data-dir data/processed/pretrain \\
|
|
||||||
--output checkpoints/pretrained \\
|
|
||||||
--epochs 50 --batch-size 32 --lr 3e-4
|
|
||||||
|
|
||||||
Usage (fine-tuning):
|
Usage:
|
||||||
python scripts/train.py \\
|
# Full run (fine-tuning + plot + report)
|
||||||
--data-dir data/processed/finetune \\
|
python scripts/train.py
|
||||||
--init-from checkpoints/pretrained.pt \\
|
|
||||||
--output checkpoints/finetuned \\
|
|
||||||
--epochs 15 --lr 1e-5
|
|
||||||
|
|
||||||
The script saves:
|
# Skip training; re-plot and report from existing CSV
|
||||||
<output>.pt best checkpoint (lowest val loss)
|
python scripts/train.py --skip-training
|
||||||
<output>.log.csv per-epoch metrics
|
|
||||||
|
Outputs written:
|
||||||
|
checkpoints/finetuned.pt best checkpoint
|
||||||
|
checkpoints/finetuned.log.csv per-epoch metrics
|
||||||
|
checkpoints/finetuned_curves.png train/val loss plot
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
from src.train import TrainConfig, train # noqa: E402
|
from src.model import ChordTransformer
|
||||||
|
from src.train import TrainConfig, train
|
||||||
|
from src.tokenizer import TOKEN_TO_ID
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
DATA_DIR = Path("data/processed/user")
|
||||||
|
INIT_FROM = Path("checkpoints/pretrained.pt")
|
||||||
|
CHECKPOINT = Path("checkpoints/finetuned.pt")
|
||||||
|
LOG_CSV = Path("checkpoints/finetuned.log.csv")
|
||||||
|
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Training config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
TRAIN_CFG = TrainConfig(
|
||||||
|
data_dir=DATA_DIR,
|
||||||
|
output=CHECKPOINT,
|
||||||
|
init_from=INIT_FROM,
|
||||||
|
epochs=15,
|
||||||
|
batch_size=8,
|
||||||
|
lr=1e-5,
|
||||||
|
warmup_steps=20,
|
||||||
|
seed=42,
|
||||||
|
device="auto",
|
||||||
|
max_seq_len=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> argparse.Namespace:
|
# ---------------------------------------------------------------------------
|
||||||
p = argparse.ArgumentParser(
|
# Plotting
|
||||||
description="Train or fine-tune ChordTransformer on tokenized chord data.",
|
# ---------------------------------------------------------------------------
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
||||||
|
def plot_curves(log_csv: Path, out_png: Path) -> None:
|
||||||
|
epochs, train_losses, val_losses, val_ppls = [], [], [], []
|
||||||
|
with open(log_csv, newline="") as fh:
|
||||||
|
for row in csv.DictReader(fh):
|
||||||
|
epochs.append(int(row["epoch"]))
|
||||||
|
train_losses.append(float(row["train_loss"]))
|
||||||
|
val_losses.append(float(row["val_loss"]))
|
||||||
|
val_ppls.append(float(row["val_ppl"]))
|
||||||
|
|
||||||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
|
||||||
|
|
||||||
|
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
|
||||||
|
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
|
||||||
|
best_epoch = epochs[val_losses.index(min(val_losses))]
|
||||||
|
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
|
||||||
|
label=f"best epoch {best_epoch}")
|
||||||
|
ax1.set_xlabel("epoch")
|
||||||
|
ax1.set_ylabel("cross-entropy loss")
|
||||||
|
ax1.set_title("Fine-tuning loss")
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
|
||||||
|
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
|
||||||
|
ax2.set_xlabel("epoch")
|
||||||
|
ax2.set_ylabel("perplexity")
|
||||||
|
ax2.set_title("Val perplexity")
|
||||||
|
ax2.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_png, dpi=150)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[plot] saved → {out_png}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def print_report(log_csv: Path, checkpoint: Path) -> None:
|
||||||
|
rows = []
|
||||||
|
with open(log_csv, newline="") as fh:
|
||||||
|
rows = list(csv.DictReader(fh))
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
print("[report] log CSV is empty — nothing to report")
|
||||||
|
return
|
||||||
|
|
||||||
|
val_losses = [float(r["val_loss"]) for r in rows]
|
||||||
|
best_idx = val_losses.index(min(val_losses))
|
||||||
|
best_row = rows[best_idx]
|
||||||
|
|
||||||
|
best_loss = float(best_row["val_loss"])
|
||||||
|
conv_epoch = next(
|
||||||
|
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
|
||||||
|
int(best_row["epoch"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# I/O
|
n_params = None
|
||||||
io = p.add_argument_group("I/O")
|
if checkpoint.exists():
|
||||||
io.add_argument(
|
ckpt = torch.load(checkpoint, weights_only=True)
|
||||||
"--data-dir", required=True, type=Path,
|
model = ChordTransformer(**ckpt["model_config"])
|
||||||
help="Directory with train/ and val/ sub-directories (output of prepare_data.py).",
|
tied = model.token_emb.weight.numel()
|
||||||
)
|
n_params = sum(p.numel() for p in model.parameters()) - tied
|
||||||
io.add_argument(
|
|
||||||
"--output", required=True, type=Path,
|
|
||||||
help="Output path prefix; .pt checkpoint and .log.csv are appended automatically.",
|
|
||||||
)
|
|
||||||
io.add_argument(
|
|
||||||
"--init-from", type=Path, default=None,
|
|
||||||
help="Checkpoint to load weights from before training (fine-tuning mode).",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training
|
print()
|
||||||
tr = p.add_argument_group("Training")
|
print("=" * 52)
|
||||||
tr.add_argument("--epochs", type=int, default=30)
|
print(" FINE-TUNING REPORT")
|
||||||
tr.add_argument("--batch-size", type=int, default=16)
|
print("=" * 52)
|
||||||
tr.add_argument("--lr", type=float, default=3e-4)
|
print(f" Total epochs run : {len(rows)}")
|
||||||
tr.add_argument("--warmup-steps", type=int, default=200)
|
print(f" Best epoch (val loss) : {best_row['epoch']}")
|
||||||
tr.add_argument("--weight-decay", type=float, default=0.1)
|
print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)")
|
||||||
tr.add_argument("--patience", type=int, default=5,
|
print(f" Best val loss : {best_loss:.4f}")
|
||||||
help="Early-stopping patience (epochs without val improvement).")
|
print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}")
|
||||||
tr.add_argument("--seed", type=int, default=42)
|
print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}")
|
||||||
tr.add_argument(
|
if n_params is not None:
|
||||||
"--device", default="auto", choices=["auto", "cpu", "cuda"],
|
print(f" Unique parameters : {n_params:,}")
|
||||||
help="Compute device. 'auto' selects cuda when available.",
|
print(f" Checkpoint : {checkpoint}")
|
||||||
)
|
print(f" Log CSV : {log_csv}")
|
||||||
|
print("=" * 52)
|
||||||
|
print()
|
||||||
|
|
||||||
# Architecture (ignored when --init-from is given)
|
print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}")
|
||||||
arch = p.add_argument_group("Architecture (ignored when --init-from is set)")
|
print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}")
|
||||||
arch.add_argument("--d-model", type=int, default=192)
|
for r in rows:
|
||||||
arch.add_argument("--n-layers", type=int, default=3)
|
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
||||||
arch.add_argument("--n-heads", type=int, default=6)
|
print(
|
||||||
arch.add_argument("--d-ff", type=int, default=768)
|
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
||||||
arch.add_argument("--dropout", type=float, default=0.1)
|
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
||||||
arch.add_argument("--max-seq-len", type=int, default=512)
|
f" {float(r['lr']):>10.2e}{marker}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
# Logging
|
|
||||||
p.add_argument(
|
|
||||||
"--log-level", default="INFO",
|
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return p.parse_args()
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = _parse_args()
|
ap = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
ap.add_argument("--skip-training", action="store_true",
|
||||||
|
help="Skip training; re-plot and report from existing CSV.")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=args.log_level,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(message)s",
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
datefmt="%H:%M:%S",
|
datefmt="%H:%M:%S",
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = TrainConfig(
|
if not args.skip_training:
|
||||||
data_dir=args.data_dir,
|
if not INIT_FROM.exists():
|
||||||
output=args.output,
|
print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr)
|
||||||
init_from=args.init_from,
|
print("Run python scripts/pretrain.py first.", file=sys.stderr)
|
||||||
epochs=args.epochs,
|
sys.exit(1)
|
||||||
batch_size=args.batch_size,
|
if not DATA_DIR.exists():
|
||||||
lr=args.lr,
|
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
|
||||||
warmup_steps=args.warmup_steps,
|
print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr)
|
||||||
weight_decay=args.weight_decay,
|
sys.exit(1)
|
||||||
seed=args.seed,
|
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
|
||||||
device=args.device,
|
n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
|
||||||
patience=args.patience,
|
est_epoch_s = n_batches * 1.5
|
||||||
max_seq_len=args.max_seq_len,
|
device_label = "GPU" if torch.cuda.is_available() else "CPU"
|
||||||
d_model=args.d_model,
|
print(
|
||||||
n_layers=args.n_layers,
|
f"[train] {n_train} train files, {n_batches} batches/epoch\n"
|
||||||
n_heads=args.n_heads,
|
f"[train] estimated time on {device_label}: "
|
||||||
d_ff=args.d_ff,
|
f"~{est_epoch_s/60:.0f} min/epoch, "
|
||||||
dropout=args.dropout,
|
f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
|
||||||
)
|
f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
|
||||||
|
)
|
||||||
|
train(TRAIN_CFG)
|
||||||
|
else:
|
||||||
|
if not LOG_CSV.exists():
|
||||||
|
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"[skip-training] using existing log: {LOG_CSV}")
|
||||||
|
|
||||||
checkpoint = train(cfg)
|
plot_curves(LOG_CSV, CURVES_PNG)
|
||||||
print(f"best checkpoint: {checkpoint}")
|
print_report(LOG_CSV, CHECKPOINT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user