feat: add dataset, prepare_data pipeline and fix McGill converter
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate - scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout split, logs token length stats and style/function distributions - src/external_converters/mcgill_to_chord.py: rewrite parser for real McGill v2 format (2-column annotation, each bar in its own pipe group, interval bass notation e.g. /5 and /b3) - .gitignore: exclude data/processed/train, val, holdout subdirectories - tests: 37 new tests for ChordDataset and converter (260 total, all pass) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
"""PyTorch Dataset for tokenized .chord period files.
|
||||
|
||||
Public API:
|
||||
ChordDataset — Dataset that loads pre-tokenized .pt files from a directory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from src.tokenizer import TOKEN_TO_ID
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
||||
|
||||
|
||||
class ChordDataset(Dataset):
|
||||
"""Dataset over a directory of tokenized .pt period files.
|
||||
|
||||
Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``.
|
||||
``__getitem__`` returns a fixed-length LongTensor: the token sequence is
|
||||
truncated to *max_length* if too long, or right-padded with <PAD> if short.
|
||||
|
||||
Args:
|
||||
data_dir: Directory containing .pt files (non-recursive).
|
||||
max_length: Fixed output sequence length. Default 512.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Path, max_length: int = 512) -> None:
|
||||
self._max_length = max_length
|
||||
self._files: list[Path] = sorted(Path(data_dir).glob("*.pt"))
|
||||
if not self._files:
|
||||
log.warning("ChordDataset: no .pt files found in %s", data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._files)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
data = torch.load(self._files[idx], weights_only=True)
|
||||
tokens: torch.Tensor = data["tokens"]
|
||||
|
||||
length = tokens.shape[0]
|
||||
if length >= self._max_length:
|
||||
return tokens[: self._max_length]
|
||||
|
||||
pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype)
|
||||
return torch.cat([tokens, pad])
|
||||
Reference in New Issue
Block a user