"""Tests for src/model.py — ChordTransformer.""" import torch import pytest from src.model import ChordTransformer VOCAB_SIZE = 85 SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64) # --------------------------------------------------------------------------- # Smoke tests # --------------------------------------------------------------------------- def test_output_shape_no_mask(): model = ChordTransformer(**SMALL) model.eval() ids = torch.randint(0, VOCAB_SIZE, (2, 16)) with torch.no_grad(): logits = model(ids) assert logits.shape == (2, 16, VOCAB_SIZE) def test_output_shape_with_padding_mask(): model = ChordTransformer(**SMALL) model.eval() ids = torch.randint(0, VOCAB_SIZE, (3, 20)) mask = torch.ones(3, 20, dtype=torch.long) mask[0, 15:] = 0 # last 5 positions are padding for batch item 0 with torch.no_grad(): logits = model(ids, attention_mask=mask) assert logits.shape == (3, 20, VOCAB_SIZE) def test_single_token_sequence(): model = ChordTransformer(**SMALL) model.eval() ids = torch.randint(0, VOCAB_SIZE, (1, 1)) with torch.no_grad(): logits = model(ids) assert logits.shape == (1, 1, VOCAB_SIZE) def test_max_seq_len_raises(): model = ChordTransformer(**SMALL, max_seq_len=8) ids = torch.randint(0, VOCAB_SIZE, (1, 9)) with pytest.raises(ValueError, match="exceeds max_seq_len"): model(ids) def test_weight_tying(): model = ChordTransformer(**SMALL) assert model.lm_head.weight is model.token_emb.weight # --------------------------------------------------------------------------- # Causal masking # --------------------------------------------------------------------------- def test_causal_masking_future_tokens_do_not_affect_past_logits(): """Changing tokens at positions [k:] must not alter logits at positions [:k].""" torch.manual_seed(0) model = ChordTransformer(**SMALL) model.eval() seq_len = 12 base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len)) with torch.no_grad(): logits_base = model(base_ids) # Mutate the second half of the sequence modified_ids = base_ids.clone() modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2)) with torch.no_grad(): logits_mod = model(modified_ids) # Logits for the first half must be bit-exact (no stochastic ops in eval) assert torch.equal( logits_base[:, : seq_len // 2, :], logits_mod[:, : seq_len // 2, :], ), "Causal masking violated: future tokens affected past logits"