"""Tokenize .chord files into .pt tensors for model training. Usage: python scripts/prepare_data.py --input-dir data/raw_user \\ --output-dir data/processed [--split-ratios 0.9/0.1] [--seed 42] Arguments: --input-dir Root directory to search recursively for .chord files. --output-dir Output directory. Subdirs train/, val/, holdout/ are created. --split-ratios Train/val ratio as "TRAIN/VAL", e.g. "0.8/0.2". Default: 0.9/0.1. --seed Random seed for reproducible shuffling. Default: 42. --log-level Logging verbosity. Default: INFO. Files found under any "holdout" directory within --input-dir are written to /holdout/ and never participate in the train/val split. """ from __future__ import annotations import argparse import logging import random import sys from collections import Counter from pathlib import Path import torch # Allow running as a script from the project root without installing the package. sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.tokenizer import parse_chord_file, tokenize_period # noqa: E402 log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _is_holdout(path: Path, input_dir: Path) -> bool: """True when the path lives under a 'holdout' sub-directory of input_dir.""" try: rel = path.relative_to(input_dir) except ValueError: return False return "holdout" in rel.parts def _parse_ratios(s: str) -> tuple[float, float]: parts = s.split("/") if len(parts) != 2: raise argparse.ArgumentTypeError( f"split-ratios must be TRAIN/VAL format, got {s!r}" ) try: train_r, val_r = float(parts[0]), float(parts[1]) except ValueError: raise argparse.ArgumentTypeError( f"split-ratios values must be floats, got {s!r}" ) total = train_r + val_r if abs(total - 1.0) > 1e-6: raise argparse.ArgumentTypeError( f"split-ratios must sum to 1.0, got {train_r}+{val_r}={total:.6f}" ) return train_r, val_r def _process_file(path: Path) -> dict | None: """Parse and tokenize one .chord file. Returns None on any error.""" try: period = parse_chord_file(path) ids = tokenize_period(period) tokens = torch.tensor(ids, dtype=torch.long) meta = { "title": period.title, "key": period.key, "style": period.style, "function": period.function, "time": period.time, "source_file": str(path), "n_tokens": len(ids), } return {"tokens": tokens, "meta": meta} except Exception as exc: log.warning("Skipping %s: %s", path, exc) return None def _save(data: dict, out_dir: Path, stem: str) -> None: out_path = out_dir / f"{stem}.pt" if out_path.exists(): log.warning("Overwriting existing output file: %s", out_path) torch.save(data, out_path) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(argv: list[str] | None = None) -> None: parser = argparse.ArgumentParser( description="Tokenize .chord files into .pt tensors for model training.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( "--input-dir", required=True, type=Path, help="Root directory containing .chord files (searched recursively).", ) parser.add_argument( "--output-dir", required=True, type=Path, help="Output directory; train/, val/, holdout/ subdirs are created.", ) parser.add_argument( "--split-ratios", default="0.9/0.1", help="Train/val split, e.g. '0.8/0.2'. Must sum to 1.0. Default: 0.9/0.1.", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducible shuffling. Default: 42.", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging verbosity. Default: INFO.", ) args = parser.parse_args(argv) logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") train_ratio, _val_ratio = _parse_ratios(args.split_ratios) input_dir: Path = args.input_dir.resolve() output_dir: Path = args.output_dir.resolve() if not input_dir.exists(): log.error("Input directory does not exist: %s", input_dir) sys.exit(1) for subdir in ("train", "val", "holdout"): (output_dir / subdir).mkdir(parents=True, exist_ok=True) all_files = sorted(input_dir.rglob("*.chord")) if not all_files: log.warning("No .chord files found in %s", input_dir) return holdout_files = [f for f in all_files if _is_holdout(f, input_dir)] regular_files = [f for f in all_files if not _is_holdout(f, input_dir)] log.info( "Found %d .chord files total (%d holdout, %d regular)", len(all_files), len(holdout_files), len(regular_files), ) # --- Holdout --- holdout_records: list[dict] = [] for path in holdout_files: data = _process_file(path) if data is not None: holdout_records.append(data) _save(data, output_dir / "holdout", path.stem) # --- Train / val split --- random.seed(args.seed) shuffled = list(regular_files) random.shuffle(shuffled) n_train = round(len(shuffled) * train_ratio) train_paths = shuffled[:n_train] val_paths = shuffled[n_train:] train_records: list[dict] = [] for path in train_paths: data = _process_file(path) if data is not None: train_records.append(data) _save(data, output_dir / "train", path.stem) val_records: list[dict] = [] for path in val_paths: data = _process_file(path) if data is not None: val_records.append(data) _save(data, output_dir / "val", path.stem) # --- Stats --- all_records = train_records + val_records + holdout_records if not all_records: log.warning("No files were successfully processed.") return token_lengths = [r["meta"]["n_tokens"] for r in all_records] style_counts: Counter[str] = Counter(r["meta"]["style"] for r in all_records) function_counts: Counter[str] = Counter(r["meta"]["function"] for r in all_records) log.info("--- Processing summary ---") log.info("Total processed: %d (train=%d, val=%d, holdout=%d)", len(all_records), len(train_records), len(val_records), len(holdout_records)) skipped = len(all_files) - len(all_records) if skipped: log.warning("Skipped due to errors: %d", skipped) log.info("Token lengths: mean=%.1f, max=%d", sum(token_lengths) / len(token_lengths), max(token_lengths)) log.info("Style distribution:") for style, count in sorted(style_counts.items()): log.info(" %-16s %d", style, count) log.info("Function distribution:") for func, count in sorted(function_counts.items()): log.info(" %-16s %d", func, count) if __name__ == "__main__": main()