feat: add run_pretrain.py; fix output-path naming and max_seq_len
- scripts/run_pretrain.py: single-command pre-training runner with timing estimate, loss-curve plot (matplotlib), and per-epoch report. Sets max_seq_len=256 (McGill sequences max out at 195 tokens, ~4x faster attention than the 512 default). - src/train.py: normalise --output so pretrained.pt and pretrained both produce pretrained.pt + pretrained.log.csv (not pretrained.pt.log.csv). Serialize Path fields as strings in checkpoint to satisfy weights_only. - requirements.txt: drop unused pandas/music21, add mido (pretty_midi dep). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+4
-2
@@ -267,8 +267,10 @@ def train(cfg: TrainConfig) -> Path:
|
||||
# ------------------------------------------------------------------
|
||||
output_path = Path(cfg.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_path = output_path.with_suffix(".pt")
|
||||
log_csv_path = Path(str(output_path) + ".log.csv")
|
||||
# Normalise: --output may be given with or without the .pt extension.
|
||||
stem = output_path.with_suffix("") if output_path.suffix == ".pt" else output_path
|
||||
checkpoint_path = stem.with_suffix(".pt")
|
||||
log_csv_path = Path(str(stem) + ".log.csv")
|
||||
|
||||
with open(log_csv_path, "w", newline="") as fh:
|
||||
writer = csv.writer(fh)
|
||||
|
||||
Reference in New Issue
Block a user