Files
H1K0 4aead2ea20 feat: remove BAR token; bump spec to v2.3; fix max_seq_len
Bar boundaries are now implicit — the detokenizer counts positions per bar
using TIME × SUB, and the generator gates EOS to bar boundaries only.
Removing the deterministic BAR token reduces vocab size from 85 to 84 and
lets the model focus on meaningful predictions.

- src/tokenizer.py: drop BAR from VOCAB (85→84); replace BAR-based
  detokenize_to_period with position-counting logic; add write_chord_file;
  fix _tokens_to_symbol for add9/m(add9) qualities
- tests/test_tokenizer.py: update vocab-size assertions to 84, structural
  token test, remove bar-count test, add test_no_bar_token_in_vocab
- docs/chord_format_spec.md: bump to v2.3; document BAR removal in §5.2,
  §5.3, §5.4, §5.5, §5.6, §6.2, and changelog
- CLAUDE.md: remove stale BAR reference, update vocab size to 84
- scripts/pretrain.py: raise max_seq_len 256→320 to cover regenerated
  McGill data (mean=83, max=283 tokens with BAR-free tokenizer)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 13:56:34 +03:00

220 lines
7.8 KiB
Python

"""Pre-train ChordTransformer on the McGill Billboard corpus.
Usage:
# Full run (training + plot + report)
python scripts/pretrain.py
# Skip training if a checkpoint already exists; only re-plot and report
python scripts/pretrain.py --skip-training
Outputs written:
checkpoints/pretrained.pt best checkpoint
checkpoints/pretrained.log.csv per-epoch metrics
checkpoints/pretrained_curves.png train/val loss plot
"""
from __future__ import annotations
import argparse
import csv
import logging
import math
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg") # headless — no display required
import matplotlib.pyplot as plt
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.model import ChordTransformer
from src.train import TrainConfig, train
from src.tokenizer import TOKEN_TO_ID
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
DATA_DIR = Path("data/processed/mcgill")
CHECKPOINT = Path("checkpoints/pretrained.pt")
LOG_CSV = Path("checkpoints/pretrained.log.csv")
CURVES_PNG = Path("checkpoints/pretrained_curves.png")
REPORT_TXT = Path("checkpoints/pretrained.report.txt")
# ---------------------------------------------------------------------------
# Training config (mirrors the requested CLI invocation)
# ---------------------------------------------------------------------------
TRAIN_CFG = TrainConfig(
data_dir=DATA_DIR,
output=CHECKPOINT,
epochs=50,
batch_size=32,
lr=3e-4,
warmup_steps=200,
seed=42,
device="auto",
# Regenerated McGill sequences: mean=83, max=283 (BAR-free tokenizer).
# 320 covers the full distribution with headroom; still ~2.5x cheaper than 512.
max_seq_len=320,
)
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_curves(log_csv: Path, out_png: Path) -> None:
epochs, train_losses, val_losses, val_ppls = [], [], [], []
with open(log_csv, newline="") as fh:
for row in csv.DictReader(fh):
epochs.append(int(row["epoch"]))
train_losses.append(float(row["train_loss"]))
val_losses.append(float(row["val_loss"]))
val_ppls.append(float(row["val_ppl"]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, train_losses, label="train loss", linewidth=1.5)
ax1.plot(epochs, val_losses, label="val loss", linewidth=1.5)
best_epoch = epochs[val_losses.index(min(val_losses))]
ax1.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8,
label=f"best epoch {best_epoch}")
ax1.set_xlabel("epoch")
ax1.set_ylabel("cross-entropy loss")
ax1.set_title("Pre-training loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2.plot(epochs, val_ppls, color="tab:orange", linewidth=1.5)
ax2.axvline(best_epoch, color="grey", linestyle="--", linewidth=0.8)
ax2.set_xlabel("epoch")
ax2.set_ylabel("perplexity")
ax2.set_title("Val perplexity")
ax2.grid(True, alpha=0.3)
fig.tight_layout()
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
print(f"[plot] saved -> {out_png}")
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
rows = []
with open(log_csv, newline="") as fh:
rows = list(csv.DictReader(fh))
if not rows:
print("[report] log CSV is empty — nothing to report")
return
val_losses = [float(r["val_loss"]) for r in rows]
best_idx = val_losses.index(min(val_losses))
best_row = rows[best_idx]
best_loss = float(best_row["val_loss"])
conv_epoch = next(
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
int(best_row["epoch"]),
)
n_params = None
if checkpoint.exists():
ckpt = torch.load(checkpoint, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
tied = model.token_emb.weight.numel()
n_params = sum(p.numel() for p in model.parameters()) - tied
lines = []
lines += [
"",
"=" * 52,
" PRE-TRAINING REPORT",
"=" * 52,
f" Total epochs run : {len(rows)}",
f" Best epoch (val loss) : {best_row['epoch']}",
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
f" Best val loss : {best_loss:.4f}",
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
]
if n_params is not None:
lines.append(f" Unique parameters : {n_params:,}")
lines += [
f" Checkpoint : {checkpoint}",
f" Log CSV : {log_csv}",
"=" * 52,
"",
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
]
for r in rows:
marker = "" if int(r["epoch"]) == int(best_row["epoch"]) else ""
lines.append(
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
f" {float(r['lr']):>10.2e}{marker}"
)
lines.append("")
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text("\n".join(lines), encoding="utf-8")
print(f"[report] saved -> {report_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--skip-training", action="store_true",
help="Skip training; re-plot and report from existing CSV.")
args = ap.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
if not args.skip_training:
if not DATA_DIR.exists():
print(f"ERROR: data directory not found: {DATA_DIR}", file=sys.stderr)
print("Run: python scripts/prepare_data.py --input-dir data/raw_external/mcgill_chord --output-dir data/processed/mcgill", file=sys.stderr)
sys.exit(1)
import pathlib
n_train = len(list((DATA_DIR / "train").glob("*.pt")))
n_batches = (n_train + TRAIN_CFG.batch_size - 1) // TRAIN_CFG.batch_size
# Rough estimate: ~1.5 s/batch on CPU with seq_len≈196, faster on GPU.
est_epoch_s = n_batches * 1.5
device_label = "GPU" if __import__("torch").cuda.is_available() else "CPU"
print(
f"[run_pretrain] {n_train} train files, {n_batches} batches/epoch\n"
f"[run_pretrain] estimated time on {device_label}: "
f"~{est_epoch_s/60:.0f} min/epoch, "
f"~{TRAIN_CFG.epochs * est_epoch_s / 3600:.1f} h total\n"
f"[run_pretrain] (early stopping with patience={TRAIN_CFG.patience} may reduce this)\n"
)
train(TRAIN_CFG)
else:
if not LOG_CSV.exists():
print(f"ERROR: log CSV not found: {LOG_CSV}", file=sys.stderr)
sys.exit(1)
print(f"[skip-training] using existing log: {LOG_CSV}")
plot_curves(LOG_CSV, CURVES_PNG)
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
if __name__ == "__main__":
main()