Files
hamori/scripts/prepare_data.py
T
H1K0 84ba7b4743 feat: add dataset, prepare_data pipeline and fix McGill converter
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate
- scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout
  split, logs token length stats and style/function distributions
- src/external_converters/mcgill_to_chord.py: rewrite parser for real
  McGill v2 format (2-column annotation, each bar in its own pipe group,
  interval bass notation e.g. /5 and /b3)
- .gitignore: exclude data/processed/train, val, holdout subdirectories
- tests: 37 new tests for ChordDataset and converter (260 total, all pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 18:09:46 +03:00

223 lines
7.3 KiB
Python

"""Tokenize .chord files into .pt tensors for model training.
Usage:
python scripts/prepare_data.py --input-dir data/raw_user \\
--output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42]
Arguments:
--input-dir Root directory to search recursively for .chord files.
--output-dir Output directory. Subdirs train/, val/, holdout/ are created.
--split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1.
--seed Random seed for reproducible shuffling. Default: 42.
--log-level Logging verbosity. Default: INFO.
Files found under any "holdout" directory within --input-dir are written to
<output-dir>/holdout/ and never participate in the train/val split.
"""
from __future__ import annotations
import argparse
import logging
import random
import sys
from collections import Counter
from pathlib import Path
import torch
# Allow running as a script from the project root without installing the package.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_holdout(path: Path, input_dir: Path) -> bool:
"""True when the path lives under a 'holdout' sub-directory of input_dir."""
try:
rel = path.relative_to(input_dir)
except ValueError:
return False
return "holdout" in rel.parts
def _parse_ratios(s: str) -> tuple[float, float]:
parts = s.split("/")
if len(parts) != 2:
raise argparse.ArgumentTypeError(
f"split-ratios must be TRAIN/VAL format, got {s!r}"
)
try:
train_r, val_r = float(parts[0]), float(parts[1])
except ValueError:
raise argparse.ArgumentTypeError(
f"split-ratios values must be floats, got {s!r}"
)
total = train_r + val_r
if abs(total - 1.0) > 1e-6:
raise argparse.ArgumentTypeError(
f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}"
)
return train_r, val_r
def _process_file(path: Path) -> dict | None:
"""Parse and tokenize one .chord file. Returns None on any error."""
try:
period = parse_chord_file(path)
ids = tokenize_period(period)
tokens = torch.tensor(ids, dtype=torch.long)
meta = {
"title": period.title,
"key": period.key,
"style": period.style,
"function": period.function,
"time": period.time,
"source_file": str(path),
"n_tokens": len(ids),
}
return {"tokens": tokens, "meta": meta}
except Exception as exc:
log.warning("Skipping %s: %s", path, exc)
return None
def _save(data: dict, out_dir: Path, stem: str) -> None:
out_path = out_dir / f"{stem}.pt"
if out_path.exists():
log.warning("Overwriting existing output file: %s", out_path)
torch.save(data, out_path)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(
description="Tokenize .chord files into .pt tensors for model training.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--input-dir", required=True, type=Path,
help="Root directory containing .chord files (searched recursively).",
)
parser.add_argument(
"--output-dir", required=True, type=Path,
help="Output directory; train/, val/, holdout/ subdirs are created.",
)
parser.add_argument(
"--split-ratios", default="0.9/0.1",
help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.",
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed for reproducible shuffling. Default: 42.",
)
parser.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity. Default: INFO.",
)
args = parser.parse_args(argv)
logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s")
train_ratio, _val_ratio = _parse_ratios(args.split_ratios)
input_dir: Path = args.input_dir.resolve()
output_dir: Path = args.output_dir.resolve()
if not input_dir.exists():
log.error("Input directory does not exist: %s", input_dir)
sys.exit(1)
for subdir in ("train", "val", "holdout"):
(output_dir / subdir).mkdir(parents=True, exist_ok=True)
all_files = sorted(input_dir.rglob("*.chord"))
if not all_files:
log.warning("No .chord files found in %s", input_dir)
return
holdout_files = [f for f in all_files if _is_holdout(f, input_dir)]
regular_files = [f for f in all_files if not _is_holdout(f, input_dir)]
log.info(
"Found %d .chord files total (%d holdout, %d regular)",
len(all_files), len(holdout_files), len(regular_files),
)
# --- Holdout ---
holdout_records: list[dict] = []
for path in holdout_files:
data = _process_file(path)
if data is not None:
holdout_records.append(data)
_save(data, output_dir / "holdout", path.stem)
# --- Train / val split ---
random.seed(args.seed)
shuffled = list(regular_files)
random.shuffle(shuffled)
n_train = round(len(shuffled) * train_ratio)
train_paths = shuffled[:n_train]
val_paths = shuffled[n_train:]
train_records: list[dict] = []
for path in train_paths:
data = _process_file(path)
if data is not None:
train_records.append(data)
_save(data, output_dir / "train", path.stem)
val_records: list[dict] = []
for path in val_paths:
data = _process_file(path)
if data is not None:
val_records.append(data)
_save(data, output_dir / "val", path.stem)
# --- Stats ---
all_records = train_records + val_records + holdout_records
if not all_records:
log.warning("No files were successfully processed.")
return
token_lengths = [r["meta"]["n_tokens"] for r in all_records]
style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records)
function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records)
log.info("--- Processing summary ---")
log.info("Total processed: %d (train=%d, val=%d, holdout=%d)",
len(all_records), len(train_records), len(val_records), len(holdout_records))
skipped = len(all_files) - len(all_records)
if skipped:
log.warning("Skipped due to errors: %d", skipped)
log.info("Token lengths: mean=%.1f, max=%d",
sum(token_lengths) / len(token_lengths), max(token_lengths))
log.info("Style distribution:")
for style, count in sorted(style_counts.items()):
log.info(" %-16s %d", style, count)
log.info("Function distribution:")
for func, count in sorted(function_counts.items()):
log.info(" %-16s %d", func, count)
if __name__ == "__main__":
main()