feat: add dataset, prepare_data pipeline and fix McGill converter
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate - scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout split, logs token length stats and style/function distributions - src/external_converters/mcgill_to_chord.py: rewrite parser for real McGill v2 format (2-column annotation, each bar in its own pipe group, interval bass notation e.g. /5 and /b3) - .gitignore: exclude data/processed/train, val, holdout subdirectories - tests: 37 new tests for ChordDataset and converter (260 total, all pass) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,222 @@
|
||||
"""Tokenize .chord files into .pt tensors for model training.
|
||||
|
||||
Usage:
|
||||
python scripts/prepare_data.py --input-dir data/raw_user \\
|
||||
--output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42]
|
||||
|
||||
Arguments:
|
||||
--input-dir Root directory to search recursively for .chord files.
|
||||
--output-dir Output directory. Subdirs train/, val/, holdout/ are created.
|
||||
--split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1.
|
||||
--seed Random seed for reproducible shuffling. Default: 42.
|
||||
--log-level Logging verbosity. Default: INFO.
|
||||
|
||||
Files found under any "holdout" directory within --input-dir are written to
|
||||
<output-dir>/holdout/ and never participate in the train/val split.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# Allow running as a script from the project root without installing the package.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_holdout(path: Path, input_dir: Path) -> bool:
|
||||
"""True when the path lives under a 'holdout' sub-directory of input_dir."""
|
||||
try:
|
||||
rel = path.relative_to(input_dir)
|
||||
except ValueError:
|
||||
return False
|
||||
return "holdout" in rel.parts
|
||||
|
||||
|
||||
def _parse_ratios(s: str) -> tuple[float, float]:
|
||||
parts = s.split("/")
|
||||
if len(parts) != 2:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"split-ratios must be TRAIN/VAL format, got {s!r}"
|
||||
)
|
||||
try:
|
||||
train_r, val_r = float(parts[0]), float(parts[1])
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"split-ratios values must be floats, got {s!r}"
|
||||
)
|
||||
total = train_r + val_r
|
||||
if abs(total - 1.0) > 1e-6:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}"
|
||||
)
|
||||
return train_r, val_r
|
||||
|
||||
|
||||
def _process_file(path: Path) -> dict | None:
|
||||
"""Parse and tokenize one .chord file. Returns None on any error."""
|
||||
try:
|
||||
period = parse_chord_file(path)
|
||||
ids = tokenize_period(period)
|
||||
tokens = torch.tensor(ids, dtype=torch.long)
|
||||
meta = {
|
||||
"title": period.title,
|
||||
"key": period.key,
|
||||
"style": period.style,
|
||||
"function": period.function,
|
||||
"time": period.time,
|
||||
"source_file": str(path),
|
||||
"n_tokens": len(ids),
|
||||
}
|
||||
return {"tokens": tokens, "meta": meta}
|
||||
except Exception as exc:
|
||||
log.warning("Skipping %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _save(data: dict, out_dir: Path, stem: str) -> None:
|
||||
out_path = out_dir / f"{stem}.pt"
|
||||
if out_path.exists():
|
||||
log.warning("Overwriting existing output file: %s", out_path)
|
||||
torch.save(data, out_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tokenize .chord files into .pt tensors for model training.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir", required=True, type=Path,
|
||||
help="Root directory containing .chord files (searched recursively).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=Path,
|
||||
help="Output directory; train/, val/, holdout/ subdirs are created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-ratios", default="0.9/0.1",
|
||||
help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42,
|
||||
help="Random seed for reproducible shuffling. Default: 42.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level", default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging verbosity. Default: INFO.",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s")
|
||||
|
||||
train_ratio, _val_ratio = _parse_ratios(args.split_ratios)
|
||||
|
||||
input_dir: Path = args.input_dir.resolve()
|
||||
output_dir: Path = args.output_dir.resolve()
|
||||
|
||||
if not input_dir.exists():
|
||||
log.error("Input directory does not exist: %s", input_dir)
|
||||
sys.exit(1)
|
||||
|
||||
for subdir in ("train", "val", "holdout"):
|
||||
(output_dir / subdir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_files = sorted(input_dir.rglob("*.chord"))
|
||||
if not all_files:
|
||||
log.warning("No .chord files found in %s", input_dir)
|
||||
return
|
||||
|
||||
holdout_files = [f for f in all_files if _is_holdout(f, input_dir)]
|
||||
regular_files = [f for f in all_files if not _is_holdout(f, input_dir)]
|
||||
|
||||
log.info(
|
||||
"Found %d .chord files total (%d holdout, %d regular)",
|
||||
len(all_files), len(holdout_files), len(regular_files),
|
||||
)
|
||||
|
||||
# --- Holdout ---
|
||||
holdout_records: list[dict] = []
|
||||
for path in holdout_files:
|
||||
data = _process_file(path)
|
||||
if data is not None:
|
||||
holdout_records.append(data)
|
||||
_save(data, output_dir / "holdout", path.stem)
|
||||
|
||||
# --- Train / val split ---
|
||||
random.seed(args.seed)
|
||||
shuffled = list(regular_files)
|
||||
random.shuffle(shuffled)
|
||||
|
||||
n_train = round(len(shuffled) * train_ratio)
|
||||
train_paths = shuffled[:n_train]
|
||||
val_paths = shuffled[n_train:]
|
||||
|
||||
train_records: list[dict] = []
|
||||
for path in train_paths:
|
||||
data = _process_file(path)
|
||||
if data is not None:
|
||||
train_records.append(data)
|
||||
_save(data, output_dir / "train", path.stem)
|
||||
|
||||
val_records: list[dict] = []
|
||||
for path in val_paths:
|
||||
data = _process_file(path)
|
||||
if data is not None:
|
||||
val_records.append(data)
|
||||
_save(data, output_dir / "val", path.stem)
|
||||
|
||||
# --- Stats ---
|
||||
all_records = train_records + val_records + holdout_records
|
||||
if not all_records:
|
||||
log.warning("No files were successfully processed.")
|
||||
return
|
||||
|
||||
token_lengths = [r["meta"]["n_tokens"] for r in all_records]
|
||||
style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records)
|
||||
function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records)
|
||||
|
||||
log.info("--- Processing summary ---")
|
||||
log.info("Total processed: %d (train=%d, val=%d, holdout=%d)",
|
||||
len(all_records), len(train_records), len(val_records), len(holdout_records))
|
||||
|
||||
skipped = len(all_files) - len(all_records)
|
||||
if skipped:
|
||||
log.warning("Skipped due to errors: %d", skipped)
|
||||
|
||||
log.info("Token lengths: mean=%.1f, max=%d",
|
||||
sum(token_lengths) / len(token_lengths), max(token_lengths))
|
||||
|
||||
log.info("Style distribution:")
|
||||
for style, count in sorted(style_counts.items()):
|
||||
log.info(" %-16s %d", style, count)
|
||||
|
||||
log.info("Function distribution:")
|
||||
for func, count in sorted(function_counts.items()):
|
||||
log.info(" %-16s %d", func, count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user