feat: implement ChordTransformer (pre-norm decoder-only transformer)

Adds src/model.py with a weight-tied autoregressive transformer and
tests/test_model.py with shape, weight-tying, and causal-masking checks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 11:09:11 +03:00
parent 0712eec578
commit 10229be042
2 changed files with 243 additions and 0 deletions
+87
View File
@@ -0,0 +1,87 @@
"""Tests for src/model.py — ChordTransformer."""
import torch
import pytest
from src.model import ChordTransformer
VOCAB_SIZE = 85
SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64)
# ---------------------------------------------------------------------------
# Smoke tests
# ---------------------------------------------------------------------------
def test_output_shape_no_mask():
model = ChordTransformer(**SMALL)
model.eval()
ids = torch.randint(0, VOCAB_SIZE, (2, 16))
with torch.no_grad():
logits = model(ids)
assert logits.shape == (2, 16, VOCAB_SIZE)
def test_output_shape_with_padding_mask():
model = ChordTransformer(**SMALL)
model.eval()
ids = torch.randint(0, VOCAB_SIZE, (3, 20))
mask = torch.ones(3, 20, dtype=torch.long)
mask[0, 15:] = 0 # last 5 positions are padding for batch item 0
with torch.no_grad():
logits = model(ids, attention_mask=mask)
assert logits.shape == (3, 20, VOCAB_SIZE)
def test_single_token_sequence():
model = ChordTransformer(**SMALL)
model.eval()
ids = torch.randint(0, VOCAB_SIZE, (1, 1))
with torch.no_grad():
logits = model(ids)
assert logits.shape == (1, 1, VOCAB_SIZE)
def test_max_seq_len_raises():
model = ChordTransformer(**SMALL, max_seq_len=8)
ids = torch.randint(0, VOCAB_SIZE, (1, 9))
with pytest.raises(ValueError, match="exceeds max_seq_len"):
model(ids)
def test_weight_tying():
model = ChordTransformer(**SMALL)
assert model.lm_head.weight is model.token_emb.weight
# ---------------------------------------------------------------------------
# Causal masking
# ---------------------------------------------------------------------------
def test_causal_masking_future_tokens_do_not_affect_past_logits():
"""Changing tokens at positions [k:] must not alter logits at positions [:k]."""
torch.manual_seed(0)
model = ChordTransformer(**SMALL)
model.eval()
seq_len = 12
base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len))
with torch.no_grad():
logits_base = model(base_ids)
# Mutate the second half of the sequence
modified_ids = base_ids.clone()
modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2))
with torch.no_grad():
logits_mod = model(modified_ids)
# Logits for the first half must be bit-exact (no stochastic ops in eval)
assert torch.equal(
logits_base[:, : seq_len // 2, :],
logits_mod[:, : seq_len // 2, :],
), "Causal masking violated: future tokens affected past logits"