"""PyTorch Dataset for tokenized .chord period files. Public API: ChordDataset — Dataset that loads pre-tokenized .pt files from a directory. """ from __future__ import annotations import logging from pathlib import Path import torch from torch.utils.data import Dataset from src.tokenizer import TOKEN_TO_ID log = logging.getLogger(__name__) _PAD_ID: int = TOKEN_TO_ID[""] class ChordDataset(Dataset): """Dataset over a directory of tokenized .pt period files. Each .pt file must be a dict ``{"tokens": LongTensor, "meta": dict}``. ``__getitem__`` returns a fixed-length LongTensor: the token sequence is truncated to *max_length* if too long, or right-padded with if short. Args: data_dir: Directory containing .pt files (non-recursive). max_length: Fixed output sequence length. Default 512. """ def __init__(self, data_dir: Path, max_length: int = 512) -> None: self._max_length = max_length self._files: list[Path] = sorted(Path(data_dir).glob("*.pt")) if not self._files: log.warning("ChordDataset: no .pt files found in %s", data_dir) def __len__(self) -> int: return len(self._files) def __getitem__(self, idx: int) -> torch.Tensor: data = torch.load(self._files[idx], weights_only=True) tokens: torch.Tensor = data["tokens"] length = tokens.shape[0] if length >= self._max_length: return tokens[: self._max_length] pad = torch.full((self._max_length - length,), _PAD_ID, dtype=tokens.dtype) return torch.cat([tokens, pad])