10229be042
Adds src/model.py with a weight-tied autoregressive transformer and tests/test_model.py with shape, weight-tying, and causal-masking checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
"""Tests for src/model.py — ChordTransformer."""
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
from src.model import ChordTransformer
|
|
|
|
|
|
VOCAB_SIZE = 85
|
|
SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Smoke tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_output_shape_no_mask():
|
|
model = ChordTransformer(**SMALL)
|
|
model.eval()
|
|
ids = torch.randint(0, VOCAB_SIZE, (2, 16))
|
|
with torch.no_grad():
|
|
logits = model(ids)
|
|
assert logits.shape == (2, 16, VOCAB_SIZE)
|
|
|
|
|
|
def test_output_shape_with_padding_mask():
|
|
model = ChordTransformer(**SMALL)
|
|
model.eval()
|
|
ids = torch.randint(0, VOCAB_SIZE, (3, 20))
|
|
mask = torch.ones(3, 20, dtype=torch.long)
|
|
mask[0, 15:] = 0 # last 5 positions are padding for batch item 0
|
|
with torch.no_grad():
|
|
logits = model(ids, attention_mask=mask)
|
|
assert logits.shape == (3, 20, VOCAB_SIZE)
|
|
|
|
|
|
def test_single_token_sequence():
|
|
model = ChordTransformer(**SMALL)
|
|
model.eval()
|
|
ids = torch.randint(0, VOCAB_SIZE, (1, 1))
|
|
with torch.no_grad():
|
|
logits = model(ids)
|
|
assert logits.shape == (1, 1, VOCAB_SIZE)
|
|
|
|
|
|
def test_max_seq_len_raises():
|
|
model = ChordTransformer(**SMALL, max_seq_len=8)
|
|
ids = torch.randint(0, VOCAB_SIZE, (1, 9))
|
|
with pytest.raises(ValueError, match="exceeds max_seq_len"):
|
|
model(ids)
|
|
|
|
|
|
def test_weight_tying():
|
|
model = ChordTransformer(**SMALL)
|
|
assert model.lm_head.weight is model.token_emb.weight
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Causal masking
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_causal_masking_future_tokens_do_not_affect_past_logits():
|
|
"""Changing tokens at positions [k:] must not alter logits at positions [:k]."""
|
|
torch.manual_seed(0)
|
|
model = ChordTransformer(**SMALL)
|
|
model.eval()
|
|
|
|
seq_len = 12
|
|
base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len))
|
|
|
|
with torch.no_grad():
|
|
logits_base = model(base_ids)
|
|
|
|
# Mutate the second half of the sequence
|
|
modified_ids = base_ids.clone()
|
|
modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2))
|
|
|
|
with torch.no_grad():
|
|
logits_mod = model(modified_ids)
|
|
|
|
# Logits for the first half must be bit-exact (no stochastic ops in eval)
|
|
assert torch.equal(
|
|
logits_base[:, : seq_len // 2, :],
|
|
logits_mod[:, : seq_len // 2, :],
|
|
), "Causal masking violated: future tokens affected past logits"
|