feat: implement ChordTransformer (pre-norm decoder-only transformer)
Adds src/model.py with a weight-tied autoregressive transformer and tests/test_model.py with shape, weight-tying, and causal-masking checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
"""Tests for src/model.py — ChordTransformer."""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from src.model import ChordTransformer
|
||||
|
||||
|
||||
VOCAB_SIZE = 85
|
||||
SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_output_shape_no_mask():
|
||||
model = ChordTransformer(**SMALL)
|
||||
model.eval()
|
||||
ids = torch.randint(0, VOCAB_SIZE, (2, 16))
|
||||
with torch.no_grad():
|
||||
logits = model(ids)
|
||||
assert logits.shape == (2, 16, VOCAB_SIZE)
|
||||
|
||||
|
||||
def test_output_shape_with_padding_mask():
|
||||
model = ChordTransformer(**SMALL)
|
||||
model.eval()
|
||||
ids = torch.randint(0, VOCAB_SIZE, (3, 20))
|
||||
mask = torch.ones(3, 20, dtype=torch.long)
|
||||
mask[0, 15:] = 0 # last 5 positions are padding for batch item 0
|
||||
with torch.no_grad():
|
||||
logits = model(ids, attention_mask=mask)
|
||||
assert logits.shape == (3, 20, VOCAB_SIZE)
|
||||
|
||||
|
||||
def test_single_token_sequence():
|
||||
model = ChordTransformer(**SMALL)
|
||||
model.eval()
|
||||
ids = torch.randint(0, VOCAB_SIZE, (1, 1))
|
||||
with torch.no_grad():
|
||||
logits = model(ids)
|
||||
assert logits.shape == (1, 1, VOCAB_SIZE)
|
||||
|
||||
|
||||
def test_max_seq_len_raises():
|
||||
model = ChordTransformer(**SMALL, max_seq_len=8)
|
||||
ids = torch.randint(0, VOCAB_SIZE, (1, 9))
|
||||
with pytest.raises(ValueError, match="exceeds max_seq_len"):
|
||||
model(ids)
|
||||
|
||||
|
||||
def test_weight_tying():
|
||||
model = ChordTransformer(**SMALL)
|
||||
assert model.lm_head.weight is model.token_emb.weight
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Causal masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_causal_masking_future_tokens_do_not_affect_past_logits():
|
||||
"""Changing tokens at positions [k:] must not alter logits at positions [:k]."""
|
||||
torch.manual_seed(0)
|
||||
model = ChordTransformer(**SMALL)
|
||||
model.eval()
|
||||
|
||||
seq_len = 12
|
||||
base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len))
|
||||
|
||||
with torch.no_grad():
|
||||
logits_base = model(base_ids)
|
||||
|
||||
# Mutate the second half of the sequence
|
||||
modified_ids = base_ids.clone()
|
||||
modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2))
|
||||
|
||||
with torch.no_grad():
|
||||
logits_mod = model(modified_ids)
|
||||
|
||||
# Logits for the first half must be bit-exact (no stochastic ops in eval)
|
||||
assert torch.equal(
|
||||
logits_base[:, : seq_len // 2, :],
|
||||
logits_mod[:, : seq_len // 2, :],
|
||||
), "Causal masking violated: future tokens affected past logits"
|
||||
Reference in New Issue
Block a user