feat: write training report to file instead of stdout
pretrain.py -> checkpoints/pretrained.report.txt train.py -> checkpoints/finetuned.report.txt Single-line [report] saved -> <path> printed to stdout instead. Also fix arrow character incompatible with Windows cp1251 console. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+33
-28
@@ -41,6 +41,7 @@ DATA_DIR = Path("data/processed/mcgill")
|
|||||||
CHECKPOINT = Path("checkpoints/pretrained.pt")
|
CHECKPOINT = Path("checkpoints/pretrained.pt")
|
||||||
LOG_CSV = Path("checkpoints/pretrained.log.csv")
|
LOG_CSV = Path("checkpoints/pretrained.log.csv")
|
||||||
CURVES_PNG = Path("checkpoints/pretrained_curves.png")
|
CURVES_PNG = Path("checkpoints/pretrained_curves.png")
|
||||||
|
REPORT_TXT = Path("checkpoints/pretrained.report.txt")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Training config (mirrors the requested CLI invocation)
|
# Training config (mirrors the requested CLI invocation)
|
||||||
@@ -98,14 +99,14 @@ def plot_curves(log_csv: Path, out_png: Path) -> None:
|
|||||||
out_png.parent.mkdir(parents=True, exist_ok=True)
|
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||||||
fig.savefig(out_png, dpi=150)
|
fig.savefig(out_png, dpi=150)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
print(f"[plot] saved → {out_png}")
|
print(f"[plot] saved -> {out_png}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Report
|
# Report
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def print_report(log_csv: Path, checkpoint: Path) -> None:
|
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
|
||||||
rows = []
|
rows = []
|
||||||
with open(log_csv, newline="") as fh:
|
with open(log_csv, newline="") as fh:
|
||||||
rows = list(csv.DictReader(fh))
|
rows = list(csv.DictReader(fh))
|
||||||
@@ -118,50 +119,54 @@ def print_report(log_csv: Path, checkpoint: Path) -> None:
|
|||||||
best_idx = val_losses.index(min(val_losses))
|
best_idx = val_losses.index(min(val_losses))
|
||||||
best_row = rows[best_idx]
|
best_row = rows[best_idx]
|
||||||
|
|
||||||
# Convergence heuristic: first epoch where val loss is within 1 % of best
|
|
||||||
best_loss = float(best_row["val_loss"])
|
best_loss = float(best_row["val_loss"])
|
||||||
conv_epoch = next(
|
conv_epoch = next(
|
||||||
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
|
(int(r["epoch"]) for r in rows if float(r["val_loss"]) <= best_loss * 1.01),
|
||||||
int(best_row["epoch"]),
|
int(best_row["epoch"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parameter count from checkpoint
|
|
||||||
n_params = None
|
n_params = None
|
||||||
if checkpoint.exists():
|
if checkpoint.exists():
|
||||||
ckpt = torch.load(checkpoint, weights_only=True)
|
ckpt = torch.load(checkpoint, weights_only=True)
|
||||||
mcfg = ckpt["model_config"]
|
model = ChordTransformer(**ckpt["model_config"])
|
||||||
model = ChordTransformer(**mcfg)
|
|
||||||
tied = model.token_emb.weight.numel()
|
tied = model.token_emb.weight.numel()
|
||||||
n_params = sum(p.numel() for p in model.parameters()) - tied
|
n_params = sum(p.numel() for p in model.parameters()) - tied
|
||||||
|
|
||||||
print()
|
lines = []
|
||||||
print("=" * 52)
|
lines += [
|
||||||
print(" PRE-TRAINING REPORT")
|
"",
|
||||||
print("=" * 52)
|
"=" * 52,
|
||||||
print(f" Total epochs run : {len(rows)}")
|
" PRE-TRAINING REPORT",
|
||||||
print(f" Best epoch (val loss) : {best_row['epoch']}")
|
"=" * 52,
|
||||||
print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)")
|
f" Total epochs run : {len(rows)}",
|
||||||
print(f" Best val loss : {best_loss:.4f}")
|
f" Best epoch (val loss) : {best_row['epoch']}",
|
||||||
print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}")
|
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
|
||||||
print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}")
|
f" Best val loss : {best_loss:.4f}",
|
||||||
|
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
|
||||||
|
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
|
||||||
|
]
|
||||||
if n_params is not None:
|
if n_params is not None:
|
||||||
print(f" Unique parameters : {n_params:,}")
|
lines.append(f" Unique parameters : {n_params:,}")
|
||||||
print(f" Checkpoint : {checkpoint}")
|
lines += [
|
||||||
print(f" Log CSV : {log_csv}")
|
f" Checkpoint : {checkpoint}",
|
||||||
print("=" * 52)
|
f" Log CSV : {log_csv}",
|
||||||
print()
|
"=" * 52,
|
||||||
|
"",
|
||||||
# Full epoch table for copy-paste into the report
|
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
|
||||||
print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}")
|
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
|
||||||
print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}")
|
]
|
||||||
for r in rows:
|
for r in rows:
|
||||||
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
||||||
print(
|
lines.append(
|
||||||
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
||||||
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
||||||
f" {float(r['lr']):>10.2e}{marker}"
|
f" {float(r['lr']):>10.2e}{marker}"
|
||||||
)
|
)
|
||||||
print()
|
lines.append("")
|
||||||
|
|
||||||
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
report_path.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
print(f"[report] saved -> {report_path}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -207,7 +212,7 @@ def main() -> None:
|
|||||||
print(f"[skip-training] using existing log: {LOG_CSV}")
|
print(f"[skip-training] using existing log: {LOG_CSV}")
|
||||||
|
|
||||||
plot_curves(LOG_CSV, CURVES_PNG)
|
plot_curves(LOG_CSV, CURVES_PNG)
|
||||||
print_report(LOG_CSV, CHECKPOINT)
|
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+32
-23
@@ -44,6 +44,7 @@ INIT_FROM = Path("checkpoints/pretrained.pt")
|
|||||||
CHECKPOINT = Path("checkpoints/finetuned.pt")
|
CHECKPOINT = Path("checkpoints/finetuned.pt")
|
||||||
LOG_CSV = Path("checkpoints/finetuned.log.csv")
|
LOG_CSV = Path("checkpoints/finetuned.log.csv")
|
||||||
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
|
CURVES_PNG = Path("checkpoints/finetuned_curves.png")
|
||||||
|
REPORT_TXT = Path("checkpoints/finetuned.report.txt")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Training config
|
# Training config
|
||||||
@@ -100,14 +101,14 @@ def plot_curves(log_csv: Path, out_png: Path) -> None:
|
|||||||
out_png.parent.mkdir(parents=True, exist_ok=True)
|
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||||||
fig.savefig(out_png, dpi=150)
|
fig.savefig(out_png, dpi=150)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
print(f"[plot] saved → {out_png}")
|
print(f"[plot] saved -> {out_png}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Report
|
# Report
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def print_report(log_csv: Path, checkpoint: Path) -> None:
|
def write_report(log_csv: Path, checkpoint: Path, report_path: Path) -> None:
|
||||||
rows = []
|
rows = []
|
||||||
with open(log_csv, newline="") as fh:
|
with open(log_csv, newline="") as fh:
|
||||||
rows = list(csv.DictReader(fh))
|
rows = list(csv.DictReader(fh))
|
||||||
@@ -133,33 +134,41 @@ def print_report(log_csv: Path, checkpoint: Path) -> None:
|
|||||||
tied = model.token_emb.weight.numel()
|
tied = model.token_emb.weight.numel()
|
||||||
n_params = sum(p.numel() for p in model.parameters()) - tied
|
n_params = sum(p.numel() for p in model.parameters()) - tied
|
||||||
|
|
||||||
print()
|
lines = []
|
||||||
print("=" * 52)
|
lines += [
|
||||||
print(" FINE-TUNING REPORT")
|
"",
|
||||||
print("=" * 52)
|
"=" * 52,
|
||||||
print(f" Total epochs run : {len(rows)}")
|
" FINE-TUNING REPORT",
|
||||||
print(f" Best epoch (val loss) : {best_row['epoch']}")
|
"=" * 52,
|
||||||
print(f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)")
|
f" Total epochs run : {len(rows)}",
|
||||||
print(f" Best val loss : {best_loss:.4f}")
|
f" Best epoch (val loss) : {best_row['epoch']}",
|
||||||
print(f" Best val perplexity : {float(best_row['val_ppl']):.2f}")
|
f" Convergence epoch : {conv_epoch} (val ≤ best+1 %)",
|
||||||
print(f" Final train loss : {float(rows[-1]['train_loss']):.4f}")
|
f" Best val loss : {best_loss:.4f}",
|
||||||
|
f" Best val perplexity : {float(best_row['val_ppl']):.2f}",
|
||||||
|
f" Final train loss : {float(rows[-1]['train_loss']):.4f}",
|
||||||
|
]
|
||||||
if n_params is not None:
|
if n_params is not None:
|
||||||
print(f" Unique parameters : {n_params:,}")
|
lines.append(f" Unique parameters : {n_params:,}")
|
||||||
print(f" Checkpoint : {checkpoint}")
|
lines += [
|
||||||
print(f" Log CSV : {log_csv}")
|
f" Checkpoint : {checkpoint}",
|
||||||
print("=" * 52)
|
f" Log CSV : {log_csv}",
|
||||||
print()
|
"=" * 52,
|
||||||
|
"",
|
||||||
print(f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}")
|
f" {'epoch':>5} {'train':>8} {'val':>8} {'ppl':>7} {'lr':>10}",
|
||||||
print(f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}")
|
f" {'-'*5} {'-'*8} {'-'*8} {'-'*7} {'-'*10}",
|
||||||
|
]
|
||||||
for r in rows:
|
for r in rows:
|
||||||
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
marker = " ←" if int(r["epoch"]) == int(best_row["epoch"]) else ""
|
||||||
print(
|
lines.append(
|
||||||
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
f" {int(r['epoch']):>5} {float(r['train_loss']):>8.4f}"
|
||||||
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
f" {float(r['val_loss']):>8.4f} {float(r['val_ppl']):>7.2f}"
|
||||||
f" {float(r['lr']):>10.2e}{marker}"
|
f" {float(r['lr']):>10.2e}{marker}"
|
||||||
)
|
)
|
||||||
print()
|
lines.append("")
|
||||||
|
|
||||||
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
report_path.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
print(f"[report] saved -> {report_path}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -207,7 +216,7 @@ def main() -> None:
|
|||||||
print(f"[skip-training] using existing log: {LOG_CSV}")
|
print(f"[skip-training] using existing log: {LOG_CSV}")
|
||||||
|
|
||||||
plot_curves(LOG_CSV, CURVES_PNG)
|
plot_curves(LOG_CSV, CURVES_PNG)
|
||||||
print_report(LOG_CSV, CHECKPOINT)
|
write_report(LOG_CSV, CHECKPOINT, REPORT_TXT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user