From 10229be0426c2c1d0e96a64a00a7dd90e964ee50 Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Wed, 20 May 2026 11:09:11 +0300 Subject: [PATCH] feat: implement ChordTransformer (pre-norm decoder-only transformer) Adds src/model.py with a weight-tied autoregressive transformer and tests/test_model.py with shape, weight-tying, and causal-masking checks. Co-Authored-By: Claude Sonnet 4.6 --- src/model.py | 156 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_model.py | 87 ++++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 src/model.py create mode 100644 tests/test_model.py diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..2102ad4 --- /dev/null +++ b/src/model.py @@ -0,0 +1,156 @@ +"""Small decoder-only transformer for harmonic period generation. + +Architecture: pre-norm, causal self-attention, weight-tied embeddings. + +Usage: + from src.model import ChordTransformer + model = ChordTransformer(vocab_size=85) + logits = model(input_ids) # [batch, seq_len, vocab_size] +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +class _TransformerBlock(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.attn = nn.MultiheadAttention( + d_model, n_heads, dropout=dropout, batch_first=True + ) + self.attn_drop = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + self.ff = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(d_ff, d_model), + nn.Dropout(dropout), + ) + + def forward( + self, + x: torch.Tensor, + causal_mask: torch.Tensor, + key_padding_mask: torch.Tensor | None, + ) -> torch.Tensor: + normed = self.norm1(x) + attn_out, _ = self.attn( + normed, + normed, + normed, + attn_mask=causal_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + ) + x = x + self.attn_drop(attn_out) + x = x + self.ff(self.norm2(x)) + return x + + +class ChordTransformer(nn.Module): + """Autoregressive transformer for chord sequence modelling. + + Args: + vocab_size: Number of tokens in the vocabulary. + d_model: Embedding / hidden dimension. + n_layers: Number of transformer blocks. + n_heads: Number of attention heads (must divide d_model evenly). + d_ff: Feed-forward inner dimension. + max_seq_len: Maximum sequence length (sets positional embedding size). + dropout: Dropout probability applied throughout. + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 192, + n_layers: int = 3, + n_heads: int = 6, + d_ff: int = 768, + max_seq_len: int = 512, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.d_model = d_model + self.max_seq_len = max_seq_len + + self.token_emb = nn.Embedding(vocab_size, d_model) + self.pos_emb = nn.Embedding(max_seq_len, d_model) + self.emb_drop = nn.Dropout(dropout) + + self.blocks = nn.ModuleList( + [_TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)] + ) + self.norm = nn.LayerNorm(d_model) + + # Output projection — weight tied with token embedding + self.lm_head = nn.Linear(d_model, vocab_size, bias=False) + self.lm_head.weight = self.token_emb.weight + + self._init_weights() + + # ------------------------------------------------------------------ + # Weight initialisation (GPT-style) + # ------------------------------------------------------------------ + + def _init_weights(self) -> None: + nn.init.normal_(self.token_emb.weight, std=0.02) + nn.init.normal_(self.pos_emb.weight, std=0.02) + for block in self.blocks: + nn.init.normal_(block.attn.in_proj_weight, std=0.02) + nn.init.zeros_(block.attn.in_proj_bias) + nn.init.normal_(block.attn.out_proj.weight, std=0.02) + nn.init.zeros_(block.attn.out_proj.bias) + for layer in block.ff: + if isinstance(layer, nn.Linear): + nn.init.normal_(layer.weight, std=0.02) + nn.init.zeros_(layer.bias) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute next-token logits for every position. + + Args: + input_ids: Long tensor of shape [batch, seq_len]. + attention_mask: Optional boolean/int tensor [batch, seq_len]; + 1 = attend, 0 = ignore (padding). When None, all positions + are attended. + + Returns: + Float tensor of shape [batch, seq_len, vocab_size]. + """ + B, T = input_ids.shape + if T > self.max_seq_len: + raise ValueError( + f"sequence length {T} exceeds max_seq_len={self.max_seq_len}" + ) + + positions = torch.arange(T, device=input_ids.device) + x = self.emb_drop(self.token_emb(input_ids) + self.pos_emb(positions)) + + # Upper-triangular True = blocked: token i cannot attend to token j > i + causal_mask = torch.triu( + torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 + ) + + key_padding_mask: torch.Tensor | None = None + if attention_mask is not None: + # nn.MultiheadAttention expects True = ignore + key_padding_mask = ~attention_mask.bool() + + for block in self.blocks: + x = block(x, causal_mask, key_padding_mask) + + x = self.norm(x) + return self.lm_head(x) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..da0f451 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,87 @@ +"""Tests for src/model.py — ChordTransformer.""" + +import torch +import pytest + +from src.model import ChordTransformer + + +VOCAB_SIZE = 85 +SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64) + + +# --------------------------------------------------------------------------- +# Smoke tests +# --------------------------------------------------------------------------- + + +def test_output_shape_no_mask(): + model = ChordTransformer(**SMALL) + model.eval() + ids = torch.randint(0, VOCAB_SIZE, (2, 16)) + with torch.no_grad(): + logits = model(ids) + assert logits.shape == (2, 16, VOCAB_SIZE) + + +def test_output_shape_with_padding_mask(): + model = ChordTransformer(**SMALL) + model.eval() + ids = torch.randint(0, VOCAB_SIZE, (3, 20)) + mask = torch.ones(3, 20, dtype=torch.long) + mask[0, 15:] = 0 # last 5 positions are padding for batch item 0 + with torch.no_grad(): + logits = model(ids, attention_mask=mask) + assert logits.shape == (3, 20, VOCAB_SIZE) + + +def test_single_token_sequence(): + model = ChordTransformer(**SMALL) + model.eval() + ids = torch.randint(0, VOCAB_SIZE, (1, 1)) + with torch.no_grad(): + logits = model(ids) + assert logits.shape == (1, 1, VOCAB_SIZE) + + +def test_max_seq_len_raises(): + model = ChordTransformer(**SMALL, max_seq_len=8) + ids = torch.randint(0, VOCAB_SIZE, (1, 9)) + with pytest.raises(ValueError, match="exceeds max_seq_len"): + model(ids) + + +def test_weight_tying(): + model = ChordTransformer(**SMALL) + assert model.lm_head.weight is model.token_emb.weight + + +# --------------------------------------------------------------------------- +# Causal masking +# --------------------------------------------------------------------------- + + +def test_causal_masking_future_tokens_do_not_affect_past_logits(): + """Changing tokens at positions [k:] must not alter logits at positions [:k].""" + torch.manual_seed(0) + model = ChordTransformer(**SMALL) + model.eval() + + seq_len = 12 + base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len)) + + with torch.no_grad(): + logits_base = model(base_ids) + + # Mutate the second half of the sequence + modified_ids = base_ids.clone() + modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2)) + + with torch.no_grad(): + logits_mod = model(modified_ids) + + # Logits for the first half must be bit-exact (no stochastic ops in eval) + assert torch.equal( + logits_base[:, : seq_len // 2, :], + logits_mod[:, : seq_len // 2, :], + ), "Causal masking violated: future tokens affected past logits"