Files
hamori/src/dataset.py
T
H1K0 84ba7b4743 feat: add dataset, prepare_data pipeline and fix McGill converter
- src/dataset.py: ChordDataset wrapping .pt files with pad/truncate
- scripts/prepare_data.py: tokenize .chord to .pt with train/val/holdout
  split, logs token length stats and style/function distributions
- src/external_converters/mcgill_to_chord.py: rewrite parser for real
  McGill v2 format (2-column annotation, each bar in its own pipe group,
  interval bass notation e.g. /5 and /b3)
- .gitignore: exclude data/processed/train, val, holdout subdirectories
- tests: 37 new tests for ChordDataset and converter (260 total, all pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 18:09:46 +03:00

53 lines
1.6 KiB
Python

"""PyTorch Dataset for tokenized .chord period files.
Public API:
ChordDataset — Dataset that loads pre-tokenized .pt files from a directory.
"""
from __future__ import annotations
import logging
from pathlib import Path
import torch
from torch.utils.data import Dataset
from src.tokenizer import TOKEN_TO_ID
log = logging.getLogger(__name__)
_PAD_ID: int = TOKEN_TO_ID["<PAD>"]
class ChordDataset(Dataset):
"""Dataset over a directory of tokenized .pt period files.
Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``.
``__getitem__`` returns a fixed-length LongTensor: the token sequence is
truncated to *max_length* if too long, or right-padded with <PAD> if short.
Args:
data_dir: Directory containing .pt files (non-recursive).
max_length: Fixed output sequence length. Default 512.
"""
def __init__(self, data_dir: Path, max_length: int = 512) -> None:
self._max_length = max_length
self._files: list[Path] = sorted(Path(data_dir).glob("*.pt"))
if not self._files:
log.warning("ChordDataset: no .pt files found in %s", data_dir)
def __len__(self) -> int:
return len(self._files)
def __getitem__(self, idx: int) -> torch.Tensor:
data = torch.load(self._files[idx], weights_only=True)
tokens: torch.Tensor = data["tokens"]
length = tokens.shape[0]
if length >= self._max_length:
return tokens[: self._max_length]
pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype)
return torch.cat([tokens, pad])