Files
hamori/scripts/train.py
T
H1K0 2a3eb1783a fix: fine-tune config and generator improvements
scripts/train.py: fix max_seq_len 256→320 (must match pretrained checkpoint);
increase epochs 15→50 and patience 5→10 to give the small corpus enough
gradient steps; reduce warmup 20→10 (was 22% of total steps).

scripts/generate.py: default to prepending the tonic chord when --prefix is
not given; add --no-tonic-anchor to opt out.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 10:15:48 +03:00

228 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Fine-tune ChordTransformer on the personal (user) chord corpus.
Requires a pre-trained checkpoint produced by scripts/pretrain.py.
Usage:
# Full run (fine-tuning + plot + report)
python scripts/train.py
# Skip training; re-plot and report from existing CSV
python scripts/train.py --skip-training
Outputs written:
checkpoints/finetuned.pt best checkpoint
checkpoints/finetuned.log.csv per-epoch metrics
checkpoints/finetuned_curves.png train/val loss plot
"""
from __future__ import annotations
import argparse
import csv
import logging
import math
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.model import ChordTransformer
from src.train import TrainConfig, train
from src.tokenizer import TOKEN_TO_ID
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
DATA_DIR = Path("data/processed/user")
INIT_FROM = Path("checkpoints/pretrained.pt")
CHECKPOINT = Path("checkpoints/finetuned.pt")
LOG_CSV = Path("checkpoints/finetuned.log.csv")
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
REPORT_TXT = Path("checkpoints/finetuned.report.txt")
# ---------------------------------------------------------------------------
# Training config
# ---------------------------------------------------------------------------
TRAIN_CFG = TrainConfig(
data_dir=DATA_DIR,
output=CHECKPOINT,
init_from=INIT_FROM,
# Small corpus (~45 train files) → ~6 batches/epoch.
# 50 epochs × 6 = ~300 gradient steps; patience=10 gives a 60-step window.
epochs=50,
batch_size=8,
lr=1e-5,
warmup_steps=10,
patience=10,
seed=42,
device="auto",
# Must match pretrained checkpoint (max_seq_len=320).
max_seq_len=320,
)
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_curves(log_csv: Path, out_png: Path) -> None:
epochs, train_losses, val_losses, val_ppls = [], [], [], []
with open(log_csv, newline="") as fh:
for row in csv.DictReader(fh):
epochs.append(int(row["epoch"]))
train_losses.append(float(row["train_loss"]))
val_losses.append(float(row["val_loss"]))
val_ppls.append(float(row["val_ppl"]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
best_epoch = epochs[val_losses.index(min(val_losses))]
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
label=f"best epoch {best_epoch}")
ax1.set_xlabel("epoch")
ax1.set_ylabel("cross-entropy loss")
ax1.set_title("Fine-tuning loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
ax2.set_xlabel("epoch")
ax2.set_ylabel("perplexity")
ax2.set_title("Val perplexity")
ax2.grid(True, alpha=0.3)
fig.tight_layout()
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
print(f"[plot] saved -> {out_png}")
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
rows = []
with open(log_csv, newline="") as fh:
rows = list(csv.DictReader(fh))
if not rows:
print("[report] log CSV is empty — nothing to report")
return
val_losses = [float(r["val_loss"]) for r in rows]
best_idx = val_losses.index(min(val_losses))
best_row = rows[best_idx]
best_loss = float(best_row["val_loss"])
conv_epoch = next(
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
int(best_row["epoch"]),
)
n_params = None
if checkpoint.exists():
ckpt = torch.load(checkpoint, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
tied = model.token_emb.weight.numel()
n_params = sum(p.numel() for p in model.parameters()) - tied
lines = []
lines += [
"",
"=" * 52,
" FINE-TUNING REPORT",
"=" * 52,
f" Total epochs run : {len(rows)}",
f" Best epoch (val loss) : {best_row['epoch']}",
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
f" Best val loss : {best_loss:.4f}",
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
]
if n_params is not None:
lines.append(f" Unique parameters : {n_params:,}")
lines += [
f" Checkpoint : {checkpoint}",
f" Log CSV : {log_csv}",
"=" * 52,
"",
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
]
for r in rows:
marker = "" if int(r["epoch"]) == int(best_row["epoch"]) else ""
lines.append(
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
f" {float(r['lr']):>10.2e}{marker}"
)
lines.append("")
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text("\n".join(lines), encoding="utf-8")
print(f"[report] saved -> {report_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--skip-training", action="store_true",
help="Skip training; re-plot and report from existing CSV.")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
if not args.skip_training:
if not INIT_FROM.exists():
print(f"ERROR: pre-trained checkpoint not found: {INIT_FROM}", file=sys.stderr)
print("Run python scripts/pretrain.py first.", file=sys.stderr)
sys.exit(1)
if not DATA_DIR.exists():
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
print("Run: python scripts/prepare_data.py --input-dir data/raw_user --output-dir data/processed/user", file=sys.stderr)
sys.exit(1)
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
est_epoch_s = n_batches * 1.5
device_label = "GPU" if torch.cuda.is_available() else "CPU"
print(
f"[train] {n_train} train files, {n_batches} batches/epoch\n"
f"[train] estimated time on {device_label}: "
f"~{est_epoch_s/60:.0f} min/epoch, "
f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
f"[train] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
)
train(TRAIN_CFG)
else:
if not LOG_CSV.exists():
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
sys.exit(1)
print(f"[skip-training] using existing log: {LOG_CSV}")
plot_curves(LOG_CSV, CURVES_PNG)
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
if __name__ == "__main__":
main()