2a3eb1783a
scripts/train.py: fix max_seq_len 256→320 (must match pretrained checkpoint); increase epochs 15→50 and patience 5→10 to give the small corpus enough gradient steps; reduce warmup 20→10 (was 22% of total steps). scripts/generate.py: default to prepending the tonic chord when --prefix is not given; add --no-tonic-anchor to opt out. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
228 lines
7.9 KiB
Python
228 lines
7.9 KiB
Python
"""Fine-tune ChordTransformer on the personal (user) chord corpus.
|
||
|
||
Requires a pre-trained checkpoint produced by scripts/pretrain.py.
|
||
|
||
Usage:
|
||
# Full run (fine-tuning + plot + report)
|
||
python scripts/train.py
|
||
|
||
# Skip training; re-plot and report from existing CSV
|
||
python scripts/train.py --skip-training
|
||
|
||
Outputs written:
|
||
checkpoints/finetuned.pt best checkpoint
|
||
checkpoints/finetuned.log.csv per-epoch metrics
|
||
checkpoints/finetuned_curves.png train/val loss plot
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import logging
|
||
import math
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import torch
|
||
|
||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||
|
||
from src.model import ChordTransformer
|
||
from src.train import TrainConfig, train
|
||
from src.tokenizer import TOKEN_TO_ID
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Paths
|
||
# ---------------------------------------------------------------------------
|
||
|
||
DATA_DIR = Path("data/processed/user")
|
||
INIT_FROM = Path("checkpoints/pretrained.pt")
|
||
CHECKPOINT = Path("checkpoints/finetuned.pt")
|
||
LOG_CSV = Path("checkpoints/finetuned.log.csv")
|
||
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
|
||
REPORT_TXT = Path("checkpoints/finetuned.report.txt")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Training config
|
||
# ---------------------------------------------------------------------------
|
||
|
||
TRAIN_CFG = TrainConfig(
|
||
data_dir=DATA_DIR,
|
||
output=CHECKPOINT,
|
||
init_from=INIT_FROM,
|
||
# Small corpus (~45 train files) → ~6 batches/epoch.
|
||
# 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window.
|
||
epochs=50,
|
||
batch_size=8,
|
||
lr=1e-5,
|
||
warmup_steps=10,
|
||
patience=10,
|
||
seed=42,
|
||
device="auto",
|
||
# Must match pretrained checkpoint (max_seq_len=320).
|
||
max_seq_len=320,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Plotting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def plot_curves(log_csv: Path, out_png: Path) -> None:
|
||
epochs, train_losses, val_losses, val_ppls = [], [], [], []
|
||
with open(log_csv, newline="") as fh:
|
||
for row in csv.DictReader(fh):
|
||
epochs.append(int(row["epoch"]))
|
||
train_losses.append(float(row["train_loss"]))
|
||
val_losses.append(float(row["val_loss"]))
|
||
val_ppls.append(float(row["val_ppl"]))
|
||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
|
||
|
||
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
|
||
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
|
||
best_epoch = epochs[val_losses.index(min(val_losses))]
|
||
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
|
||
label=f"best epoch {best_epoch}")
|
||
ax1.set_xlabel("epoch")
|
||
ax1.set_ylabel("cross-entropy loss")
|
||
ax1.set_title("Fine-tuning loss")
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
|
||
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
|
||
ax2.set_xlabel("epoch")
|
||
ax2.set_ylabel("perplexity")
|
||
ax2.set_title("Val perplexity")
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
fig.tight_layout()
|
||
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||
fig.savefig(out_png, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[plot] saved -> {out_png}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Report
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
|
||
rows = []
|
||
with open(log_csv, newline="") as fh:
|
||
rows = list(csv.DictReader(fh))
|
||
|
||
if not rows:
|
||
print("[report] log CSV is empty — nothing to report")
|
||
return
|
||
|
||
val_losses = [float(r["val_loss"]) for r in rows]
|
||
best_idx = val_losses.index(min(val_losses))
|
||
best_row = rows[best_idx]
|
||
|
||
best_loss = float(best_row["val_loss"])
|
||
conv_epoch = next(
|
||
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
|
||
int(best_row["epoch"]),
|
||
)
|
||
|
||
n_params = None
|
||
if checkpoint.exists():
|
||
ckpt = torch.load(checkpoint, weights_only=True)
|
||
model = ChordTransformer(**ckpt["model_config"])
|
||
tied = model.token_emb.weight.numel()
|
||
n_params = sum(p.numel() for p in model.parameters()) - tied
|
||
|
||
lines = []
|
||
lines += [
|
||
"",
|
||
"=" * 52,
|
||
" FINE-TUNING REPORT",
|
||
"=" * 52,
|
||
f" Total epochs run : {len(rows)}",
|
||
f" Best epoch (val loss) : {best_row['epoch']}",
|
||
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
|
||
f" Best val loss : {best_loss:.4f}",
|
||
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
|
||
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
|
||
]
|
||
if n_params is not None:
|
||
lines.append(f" Unique parameters : {n_params:,}")
|
||
lines += [
|
||
f" Checkpoint : {checkpoint}",
|
||
f" Log CSV : {log_csv}",
|
||
"=" * 52,
|
||
"",
|
||
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
|
||
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
|
||
]
|
||
for r in rows:
|
||
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
||
lines.append(
|
||
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
||
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
||
f" {float(r['lr']):>10.2e}{marker}"
|
||
)
|
||
lines.append("")
|
||
|
||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||
report_path.write_text("\n".join(lines), encoding="utf-8")
|
||
print(f"[report] saved -> {report_path}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
ap.add_argument("--skip-training", action="store_true",
|
||
help="Skip training; re-plot and report from existing CSV.")
|
||
args = ap.parse_args()
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
|
||
if not args.skip_training:
|
||
if not INIT_FROM.exists():
|
||
print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr)
|
||
print("Run python scripts/pretrain.py first.", file=sys.stderr)
|
||
sys.exit(1)
|
||
if not DATA_DIR.exists():
|
||
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
|
||
print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr)
|
||
sys.exit(1)
|
||
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
|
||
n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
|
||
est_epoch_s = n_batches * 1.5
|
||
device_label = "GPU" if torch.cuda.is_available() else "CPU"
|
||
print(
|
||
f"[train] {n_train} train files, {n_batches} batches/epoch\n"
|
||
f"[train] estimated time on {device_label}: "
|
||
f"~{est_epoch_s/60:.0f} min/epoch, "
|
||
f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
|
||
f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
|
||
)
|
||
train(TRAIN_CFG)
|
||
else:
|
||
if not LOG_CSV.exists():
|
||
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
|
||
sys.exit(1)
|
||
print(f"[skip-training] using existing log: {LOG_CSV}")
|
||
|
||
plot_curves(LOG_CSV, CURVES_PNG)
|
||
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|