Files
H1K0 555023532f scripts: update fine-tune defaults to lr=3e-5, epochs=30
Matches the configuration that produced finetuned.pt (val ppl 2.15,
best epoch 20, early stopped at 30).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 21:18:37 +03:00

228 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Fine-tune ChordTransformer on the personal (user) chord corpus.
Requires a pre-trained checkpoint produced by scripts/pretrain.py.
Usage:
# Full run (fine-tuning + plot + report)
python scripts/train.py
# Skip training; re-plot and report from existing CSV
python scripts/train.py --skip-training
Outputs written:
checkpoints/finetuned.pt best checkpoint
checkpoints/finetuned.log.csv per-epoch metrics
checkpoints/finetuned_curves.png train/val loss plot
"""
from __future__ import annotations
import argparse
import csv
import logging
import math
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.model import ChordTransformer
from src.train import TrainConfig, train
from src.tokenizer import TOKEN_TO_ID
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
DATA_DIR = Path("data/processed/user")
INIT_FROM = Path("checkpoints/pretrained.pt")
CHECKPOINT = Path("checkpoints/finetuned.pt")
LOG_CSV = Path("checkpoints/finetuned.log.csv")
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
REPORT_TXT = Path("checkpoints/finetuned.report.txt")
# ---------------------------------------------------------------------------
# Training config
# ---------------------------------------------------------------------------
TRAIN_CFG = TrainConfig(
data_dir=DATA_DIR,
output=CHECKPOINT,
init_from=INIT_FROM,
# Small corpus (~45 train files) → ~6 batches/epoch.
# 30 epochs × 6 = ~180 gradient steps; patience=10 gives a 60-step window.
epochs=30,
batch_size=8,
lr=3e-5,
warmup_steps=10,
patience=10,
seed=42,
device="auto",
# Must match pretrained checkpoint (max_seq_len=320).
max_seq_len=320,
)
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_curves(log_csv: Path, out_png: Path) -> None:
epochs, train_losses, val_losses, val_ppls = [], [], [], []
with open(log_csv, newline="") as fh:
for row in csv.DictReader(fh):
epochs.append(int(row["epoch"]))
train_losses.append(float(row["train_loss"]))
val_losses.append(float(row["val_loss"]))
val_ppls.append(float(row["val_ppl"]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
best_epoch = epochs[val_losses.index(min(val_losses))]
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
label=f"best epoch {best_epoch}")
ax1.set_xlabel("epoch")
ax1.set_ylabel("cross-entropy loss")
ax1.set_title("Fine-tuning loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
ax2.set_xlabel("epoch")
ax2.set_ylabel("perplexity")
ax2.set_title("Val perplexity")
ax2.grid(True, alpha=0.3)
fig.tight_layout()
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
print(f"[plot] saved -> {out_png}")
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
rows = []
with open(log_csv, newline="") as fh:
rows = list(csv.DictReader(fh))
if not rows:
print("[report] log CSV is empty — nothing to report")
return
val_losses = [float(r["val_loss"]) for r in rows]
best_idx = val_losses.index(min(val_losses))
best_row = rows[best_idx]
best_loss = float(best_row["val_loss"])
conv_epoch = next(
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
int(best_row["epoch"]),
)
n_params = None
if checkpoint.exists():
ckpt = torch.load(checkpoint, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
tied = model.token_emb.weight.numel()
n_params = sum(p.numel() for p in model.parameters()) - tied
lines = []
lines += [
"",
"=" * 52,
" FINE-TUNING REPORT",
"=" * 52,
f" Total epochs run : {len(rows)}",
f" Best epoch (val loss) : {best_row['epoch']}",
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
f" Best val loss : {best_loss:.4f}",
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
]
if n_params is not None:
lines.append(f" Unique parameters : {n_params:,}")
lines += [
f" Checkpoint : {checkpoint}",
f" Log CSV : {log_csv}",
"=" * 52,
"",
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
]
for r in rows:
marker = "" if int(r["epoch"]) == int(best_row["epoch"]) else ""
lines.append(
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
f" {float(r['lr']):>10.2e}{marker}"
)
lines.append("")
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text("\n".join(lines), encoding="utf-8")
print(f"[report] saved -> {report_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--skip-training", action="store_true",
help="Skip training; re-plot and report from existing CSV.")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
if not args.skip_training:
if not INIT_FROM.exists():
print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr)
print("Run python scripts/pretrain.py first.", file=sys.stderr)
sys.exit(1)
if not DATA_DIR.exists():
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr)
sys.exit(1)
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
est_epoch_s = n_batches * 1.5
device_label = "GPU" if torch.cuda.is_available() else "CPU"
print(
f"[train] {n_train} train files, {n_batches} batches/epoch\n"
f"[train] estimated time on {device_label}: "
f"~{est_epoch_s/60:.0f} min/epoch, "
f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
)
train(TRAIN_CFG)
else:
if not LOG_CSV.exists():
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
sys.exit(1)
print(f"[skip-training] using existing log: {LOG_CSV}")
plot_curves(LOG_CSV, CURVES_PNG)
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
if __name__ == "__main__":
main()