84ba7b4743
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate - scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout split, logs token length stats and style/function distributions - src/external_converters/mcgill_to_chord.py: rewrite parser for real McGill v2 format (2-column annotation, each bar in its own pipe group, interval bass notation e.g. /5 and /b3) - .gitignore: exclude data/processed/train, val, holdout subdirectories - tests: 37 new tests for ChordDataset and converter (260 total, all pass) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
"""PyTorch Dataset for tokenized .chord period files.
|
|
|
|
Public API:
|
|
ChordDataset — Dataset that loads pre-tokenized .pt files from a directory.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from src.tokenizer import TOKEN_TO_ID
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
|
|
|
|
|
|
class ChordDataset(Dataset):
|
|
"""Dataset over a directory of tokenized .pt period files.
|
|
|
|
Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``.
|
|
``__getitem__`` returns a fixed-length LongTensor: the token sequence is
|
|
truncated to *max_length* if too long, or right-padded with <PAD> if short.
|
|
|
|
Args:
|
|
data_dir: Directory containing .pt files (non-recursive).
|
|
max_length: Fixed output sequence length. Default 512.
|
|
"""
|
|
|
|
def __init__(self, data_dir: Path, max_length: int = 512) -> None:
|
|
self._max_length = max_length
|
|
self._files: list[Path] = sorted(Path(data_dir).glob("*.pt"))
|
|
if not self._files:
|
|
log.warning("ChordDataset: no .pt files found in %s", data_dir)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._files)
|
|
|
|
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
data = torch.load(self._files[idx], weights_only=True)
|
|
tokens: torch.Tensor = data["tokens"]
|
|
|
|
length = tokens.shape[0]
|
|
if length >= self._max_length:
|
|
return tokens[: self._max_length]
|
|
|
|
pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype)
|
|
return torch.cat([tokens, pad])
|