From 03b464973a3891e190bcdbb50c238809d5814cfd Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Wed, 20 May 2026 12:40:44 +0300 Subject: [PATCH] feat: write training report to file instead of stdout pretrain.py -> checkpoints/pretrained.report.txt train.py -> checkpoints/finetuned.report.txt Single-line [report] saved -> printed to stdout instead. Also fix arrow character incompatible with Windows cp1251 console. Co-Authored-By: Claude Sonnet 4.6 --- scripts/pretrain.py | 61 ++++++++++++++++++++++++--------------------- scripts/train.py | 55 +++++++++++++++++++++++----------------- 2 files changed, 65 insertions(+), 51 deletions(-) diff --git a/scripts/pretrain.py b/scripts/pretrain.py index 1399420..ab07cea 100644 --- a/scripts/pretrain.py +++ b/scripts/pretrain.py @@ -41,6 +41,7 @@ DATA_DIR = Path("data/processed/mcgill") CHECKPOINT = Path("checkpoints/pretrained.pt") LOG_CSV = Path("checkpoints/pretrained.log.csv") CURVES_PNG = Path("checkpoints/pretrained_curves.png") +REPORT_TXT = Path("checkpoints/pretrained.report.txt") # --------------------------------------------------------------------------- # Training config (mirrors the requested CLI invocation) @@ -98,14 +99,14 @@ def plot_curves(log_csv: Path, out_png: Path) -> None: out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) - print(f"[plot] saved → {out_png}") + print(f"[plot] saved -> {out_png}") # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- -def print_report(log_csv: Path, checkpoint: Path) -> None: +def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None: rows = [] with open(log_csv, newline="") as fh: rows = list(csv.DictReader(fh)) @@ -118,50 +119,54 @@ def print_report(log_csv: Path, checkpoint: Path) -> None: best_idx = val_losses.index(min(val_losses)) best_row = rows[best_idx] - # Convergence heuristic: first epoch where val loss is within 1 % of best best_loss = float(best_row["val_loss"]) conv_epoch = next( (int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01), int(best_row["epoch"]), ) - # Parameter count from checkpoint n_params = None if checkpoint.exists(): ckpt = torch.load(checkpoint, weights_only=True) - mcfg = ckpt["model_config"] - model = ChordTransformer(**mcfg) + model = ChordTransformer(**ckpt["model_config"]) tied = model.token_emb.weight.numel() n_params = sum(p.numel() for p in model.parameters()) - tied - print() - print("=" * 52) - print(" PRE-TRAINING REPORT") - print("=" * 52) - print(f" Total epochs run : {len(rows)}") - print(f" Best epoch (val loss) : {best_row['epoch']}") - print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)") - print(f" Best val loss : {best_loss:.4f}") - print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}") - print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}") + lines = [] + lines += [ + "", + "=" * 52, + " PRE-TRAINING REPORT", + "=" * 52, + f" Total epochs run : {len(rows)}", + f" Best epoch (val loss) : {best_row['epoch']}", + f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)", + f" Best val loss : {best_loss:.4f}", + f" Best val perplexity : {float(best_row['val_ppl']):.2f}", + f" Final train loss : {float(rows[-1]['train_loss']):.4f}", + ] if n_params is not None: - print(f" Unique parameters : {n_params:,}") - print(f" Checkpoint : {checkpoint}") - print(f" Log CSV : {log_csv}") - print("=" * 52) - print() - - # Full epoch table for copy-paste into the report - print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}") - print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") + lines.append(f" Unique parameters : {n_params:,}") + lines += [ + f" Checkpoint : {checkpoint}", + f" Log CSV : {log_csv}", + "=" * 52, + "", + f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}", + f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}", + ] for r in rows: marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" - print( + lines.append( f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" f" {float(r['lr']):>10.2e}{marker}" ) - print() + lines.append("") + + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text("\n".join(lines), encoding="utf-8") + print(f"[report] saved -> {report_path}") # --------------------------------------------------------------------------- @@ -207,7 +212,7 @@ def main() -> None: print(f"[skip-training] using existing log: {LOG_CSV}") plot_curves(LOG_CSV, CURVES_PNG) - print_report(LOG_CSV, CHECKPOINT) + write_report(LOG_CSV, CHECKPOINT, REPORT_TXT) if __name__ == "__main__": diff --git a/scripts/train.py b/scripts/train.py index 5e48fab..ed729a0 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -44,6 +44,7 @@ INIT_FROM = Path("checkpoints/pretrained.pt") CHECKPOINT = Path("checkpoints/finetuned.pt") LOG_CSV = Path("checkpoints/finetuned.log.csv") CURVES_PNG = Path("checkpoints/finetuned_curves.png") +REPORT_TXT = Path("checkpoints/finetuned.report.txt") # --------------------------------------------------------------------------- # Training config @@ -100,14 +101,14 @@ def plot_curves(log_csv: Path, out_png: Path) -> None: out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) - print(f"[plot] saved → {out_png}") + print(f"[plot] saved -> {out_png}") # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- -def print_report(log_csv: Path, checkpoint: Path) -> None: +def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None: rows = [] with open(log_csv, newline="") as fh: rows = list(csv.DictReader(fh)) @@ -133,33 +134,41 @@ def print_report(log_csv: Path, checkpoint: Path) -> None: tied = model.token_emb.weight.numel() n_params = sum(p.numel() for p in model.parameters()) - tied - print() - print("=" * 52) - print(" FINE-TUNING REPORT") - print("=" * 52) - print(f" Total epochs run : {len(rows)}") - print(f" Best epoch (val loss) : {best_row['epoch']}") - print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)") - print(f" Best val loss : {best_loss:.4f}") - print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}") - print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}") + lines = [] + lines += [ + "", + "=" * 52, + " FINE-TUNING REPORT", + "=" * 52, + f" Total epochs run : {len(rows)}", + f" Best epoch (val loss) : {best_row['epoch']}", + f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)", + f" Best val loss : {best_loss:.4f}", + f" Best val perplexity : {float(best_row['val_ppl']):.2f}", + f" Final train loss : {float(rows[-1]['train_loss']):.4f}", + ] if n_params is not None: - print(f" Unique parameters : {n_params:,}") - print(f" Checkpoint : {checkpoint}") - print(f" Log CSV : {log_csv}") - print("=" * 52) - print() - - print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}") - print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") + lines.append(f" Unique parameters : {n_params:,}") + lines += [ + f" Checkpoint : {checkpoint}", + f" Log CSV : {log_csv}", + "=" * 52, + "", + f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}", + f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}", + ] for r in rows: marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else "" - print( + lines.append( f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}" f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}" f" {float(r['lr']):>10.2e}{marker}" ) - print() + lines.append("") + + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text("\n".join(lines), encoding="utf-8") + print(f"[report] saved -> {report_path}") # --------------------------------------------------------------------------- @@ -207,7 +216,7 @@ def main() -> None: print(f"[skip-training] using existing log: {LOG_CSV}") plot_curves(LOG_CSV, CURVES_PNG) - print_report(LOG_CSV, CHECKPOINT) + write_report(LOG_CSV, CHECKPOINT, REPORT_TXT) if __name__ == "__main__":