84ba7b4743
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate - scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout split, logs token length stats and style/function distributions - src/external_converters/mcgill_to_chord.py: rewrite parser for real McGill v2 format (2-column annotation, each bar in its own pipe group, interval bass notation e.g. /5 and /b3) - .gitignore: exclude data/processed/train, val, holdout subdirectories - tests: 37 new tests for ChordDataset and converter (260 total, all pass) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
223 lines
7.3 KiB
Python
223 lines
7.3 KiB
Python
"""Tokenize .chord files into .pt tensors for model training.
|
|
|
|
Usage:
|
|
python scripts/prepare_data.py --input-dir data/raw_user \\
|
|
--output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42]
|
|
|
|
Arguments:
|
|
--input-dir Root directory to search recursively for .chord files.
|
|
--output-dir Output directory. Subdirs train/, val/, holdout/ are created.
|
|
--split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1.
|
|
--seed Random seed for reproducible shuffling. Default: 42.
|
|
--log-level Logging verbosity. Default: INFO.
|
|
|
|
Files found under any "holdout" directory within --input-dir are written to
|
|
<output-dir>/holdout/ and never participate in the train/val split.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
import random
|
|
import sys
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
# Allow running as a script from the project root without installing the package.
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _is_holdout(path: Path, input_dir: Path) -> bool:
|
|
"""True when the path lives under a 'holdout' sub-directory of input_dir."""
|
|
try:
|
|
rel = path.relative_to(input_dir)
|
|
except ValueError:
|
|
return False
|
|
return "holdout" in rel.parts
|
|
|
|
|
|
def _parse_ratios(s: str) -> tuple[float, float]:
|
|
parts = s.split("/")
|
|
if len(parts) != 2:
|
|
raise argparse.ArgumentTypeError(
|
|
f"split-ratios must be TRAIN/VAL format, got {s!r}"
|
|
)
|
|
try:
|
|
train_r, val_r = float(parts[0]), float(parts[1])
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
f"split-ratios values must be floats, got {s!r}"
|
|
)
|
|
total = train_r + val_r
|
|
if abs(total - 1.0) > 1e-6:
|
|
raise argparse.ArgumentTypeError(
|
|
f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}"
|
|
)
|
|
return train_r, val_r
|
|
|
|
|
|
def _process_file(path: Path) -> dict | None:
|
|
"""Parse and tokenize one .chord file. Returns None on any error."""
|
|
try:
|
|
period = parse_chord_file(path)
|
|
ids = tokenize_period(period)
|
|
tokens = torch.tensor(ids, dtype=torch.long)
|
|
meta = {
|
|
"title": period.title,
|
|
"key": period.key,
|
|
"style": period.style,
|
|
"function": period.function,
|
|
"time": period.time,
|
|
"source_file": str(path),
|
|
"n_tokens": len(ids),
|
|
}
|
|
return {"tokens": tokens, "meta": meta}
|
|
except Exception as exc:
|
|
log.warning("Skipping %s: %s", path, exc)
|
|
return None
|
|
|
|
|
|
def _save(data: dict, out_dir: Path, stem: str) -> None:
|
|
out_path = out_dir / f"{stem}.pt"
|
|
if out_path.exists():
|
|
log.warning("Overwriting existing output file: %s", out_path)
|
|
torch.save(data, out_path)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Tokenize .chord files into .pt tensors for model training.",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__,
|
|
)
|
|
parser.add_argument(
|
|
"--input-dir", required=True, type=Path,
|
|
help="Root directory containing .chord files (searched recursively).",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir", required=True, type=Path,
|
|
help="Output directory; train/, val/, holdout/ subdirs are created.",
|
|
)
|
|
parser.add_argument(
|
|
"--split-ratios", default="0.9/0.1",
|
|
help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.",
|
|
)
|
|
parser.add_argument(
|
|
"--seed", type=int, default=42,
|
|
help="Random seed for reproducible shuffling. Default: 42.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-level", default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Logging verbosity. Default: INFO.",
|
|
)
|
|
args = parser.parse_args(argv)
|
|
|
|
logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s")
|
|
|
|
train_ratio, _val_ratio = _parse_ratios(args.split_ratios)
|
|
|
|
input_dir: Path = args.input_dir.resolve()
|
|
output_dir: Path = args.output_dir.resolve()
|
|
|
|
if not input_dir.exists():
|
|
log.error("Input directory does not exist: %s", input_dir)
|
|
sys.exit(1)
|
|
|
|
for subdir in ("train", "val", "holdout"):
|
|
(output_dir / subdir).mkdir(parents=True, exist_ok=True)
|
|
|
|
all_files = sorted(input_dir.rglob("*.chord"))
|
|
if not all_files:
|
|
log.warning("No .chord files found in %s", input_dir)
|
|
return
|
|
|
|
holdout_files = [f for f in all_files if _is_holdout(f, input_dir)]
|
|
regular_files = [f for f in all_files if not _is_holdout(f, input_dir)]
|
|
|
|
log.info(
|
|
"Found %d .chord files total (%d holdout, %d regular)",
|
|
len(all_files), len(holdout_files), len(regular_files),
|
|
)
|
|
|
|
# --- Holdout ---
|
|
holdout_records: list[dict] = []
|
|
for path in holdout_files:
|
|
data = _process_file(path)
|
|
if data is not None:
|
|
holdout_records.append(data)
|
|
_save(data, output_dir / "holdout", path.stem)
|
|
|
|
# --- Train / val split ---
|
|
random.seed(args.seed)
|
|
shuffled = list(regular_files)
|
|
random.shuffle(shuffled)
|
|
|
|
n_train = round(len(shuffled) * train_ratio)
|
|
train_paths = shuffled[:n_train]
|
|
val_paths = shuffled[n_train:]
|
|
|
|
train_records: list[dict] = []
|
|
for path in train_paths:
|
|
data = _process_file(path)
|
|
if data is not None:
|
|
train_records.append(data)
|
|
_save(data, output_dir / "train", path.stem)
|
|
|
|
val_records: list[dict] = []
|
|
for path in val_paths:
|
|
data = _process_file(path)
|
|
if data is not None:
|
|
val_records.append(data)
|
|
_save(data, output_dir / "val", path.stem)
|
|
|
|
# --- Stats ---
|
|
all_records = train_records + val_records + holdout_records
|
|
if not all_records:
|
|
log.warning("No files were successfully processed.")
|
|
return
|
|
|
|
token_lengths = [r["meta"]["n_tokens"] for r in all_records]
|
|
style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records)
|
|
function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records)
|
|
|
|
log.info("--- Processing summary ---")
|
|
log.info("Total processed: %d (train=%d, val=%d, holdout=%d)",
|
|
len(all_records), len(train_records), len(val_records), len(holdout_records))
|
|
|
|
skipped = len(all_files) - len(all_records)
|
|
if skipped:
|
|
log.warning("Skipped due to errors: %d", skipped)
|
|
|
|
log.info("Token lengths: mean=%.1f, max=%d",
|
|
sum(token_lengths) / len(token_lengths), max(token_lengths))
|
|
|
|
log.info("Style distribution:")
|
|
for style, count in sorted(style_counts.items()):
|
|
log.info(" %-16s %d", style, count)
|
|
|
|
log.info("Function distribution:")
|
|
for func, count in sorted(function_counts.items()):
|
|
log.info(" %-16s %d", func, count)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|