feat: implement ChordTransformer (pre-norm decoder-only transformer)
Adds src/model.py with a weight-tied autoregressive transformer and tests/test_model.py with shape, weight-tying, and causal-masking checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+156
@@ -0,0 +1,156 @@
|
|||||||
|
"""Small decoder-only transformer for harmonic period generation.
|
||||||
|
|
||||||
|
Architecture: pre-norm, causal self-attention, weight-tied embeddings.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from src.model import ChordTransformer
|
||||||
|
model = ChordTransformer(vocab_size=85)
|
||||||
|
logits = model(input_ids) # [batch, seq_len, vocab_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
d_model, n_heads, dropout=dropout, batch_first=True
|
||||||
|
)
|
||||||
|
self.attn_drop = nn.Dropout(dropout)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_ff),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(d_ff, d_model),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
causal_mask: torch.Tensor,
|
||||||
|
key_padding_mask: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
normed = self.norm1(x)
|
||||||
|
attn_out, _ = self.attn(
|
||||||
|
normed,
|
||||||
|
normed,
|
||||||
|
normed,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
)
|
||||||
|
x = x + self.attn_drop(attn_out)
|
||||||
|
x = x + self.ff(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ChordTransformer(nn.Module):
|
||||||
|
"""Autoregressive transformer for chord sequence modelling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: Number of tokens in the vocabulary.
|
||||||
|
d_model: Embedding / hidden dimension.
|
||||||
|
n_layers: Number of transformer blocks.
|
||||||
|
n_heads: Number of attention heads (must divide d_model evenly).
|
||||||
|
d_ff: Feed-forward inner dimension.
|
||||||
|
max_seq_len: Maximum sequence length (sets positional embedding size).
|
||||||
|
dropout: Dropout probability applied throughout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
d_model: int = 192,
|
||||||
|
n_layers: int = 3,
|
||||||
|
n_heads: int = 6,
|
||||||
|
d_ff: int = 768,
|
||||||
|
max_seq_len: int = 512,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
self.token_emb = nn.Embedding(vocab_size, d_model)
|
||||||
|
self.pos_emb = nn.Embedding(max_seq_len, d_model)
|
||||||
|
self.emb_drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[_TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# Output projection — weight tied with token embedding
|
||||||
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
||||||
|
self.lm_head.weight = self.token_emb.weight
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Weight initialisation (GPT-style)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _init_weights(self) -> None:
|
||||||
|
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.pos_emb.weight, std=0.02)
|
||||||
|
for block in self.blocks:
|
||||||
|
nn.init.normal_(block.attn.in_proj_weight, std=0.02)
|
||||||
|
nn.init.zeros_(block.attn.in_proj_bias)
|
||||||
|
nn.init.normal_(block.attn.out_proj.weight, std=0.02)
|
||||||
|
nn.init.zeros_(block.attn.out_proj.bias)
|
||||||
|
for layer in block.ff:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.normal_(layer.weight, std=0.02)
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Forward pass
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute next-token logits for every position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Long tensor of shape [batch, seq_len].
|
||||||
|
attention_mask: Optional boolean/int tensor [batch, seq_len];
|
||||||
|
1 = attend, 0 = ignore (padding). When None, all positions
|
||||||
|
are attended.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Float tensor of shape [batch, seq_len, vocab_size].
|
||||||
|
"""
|
||||||
|
B, T = input_ids.shape
|
||||||
|
if T > self.max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"sequence length {T} exceeds max_seq_len={self.max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = torch.arange(T, device=input_ids.device)
|
||||||
|
x = self.emb_drop(self.token_emb(input_ids) + self.pos_emb(positions))
|
||||||
|
|
||||||
|
# Upper-triangular True = blocked: token i cannot attend to token j > i
|
||||||
|
causal_mask = torch.triu(
|
||||||
|
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_padding_mask: torch.Tensor | None = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
# nn.MultiheadAttention expects True = ignore
|
||||||
|
key_padding_mask = ~attention_mask.bool()
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, causal_mask, key_padding_mask)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return self.lm_head(x)
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
"""Tests for src/model.py — ChordTransformer."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.model import ChordTransformer
|
||||||
|
|
||||||
|
|
||||||
|
VOCAB_SIZE = 85
|
||||||
|
SMALL = dict(vocab_size=VOCAB_SIZE, d_model=32, n_layers=2, n_heads=4, d_ff=64)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Smoke tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_shape_no_mask():
|
||||||
|
model = ChordTransformer(**SMALL)
|
||||||
|
model.eval()
|
||||||
|
ids = torch.randint(0, VOCAB_SIZE, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(ids)
|
||||||
|
assert logits.shape == (2, 16, VOCAB_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_shape_with_padding_mask():
|
||||||
|
model = ChordTransformer(**SMALL)
|
||||||
|
model.eval()
|
||||||
|
ids = torch.randint(0, VOCAB_SIZE, (3, 20))
|
||||||
|
mask = torch.ones(3, 20, dtype=torch.long)
|
||||||
|
mask[0, 15:] = 0 # last 5 positions are padding for batch item 0
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(ids, attention_mask=mask)
|
||||||
|
assert logits.shape == (3, 20, VOCAB_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_token_sequence():
|
||||||
|
model = ChordTransformer(**SMALL)
|
||||||
|
model.eval()
|
||||||
|
ids = torch.randint(0, VOCAB_SIZE, (1, 1))
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(ids)
|
||||||
|
assert logits.shape == (1, 1, VOCAB_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_seq_len_raises():
|
||||||
|
model = ChordTransformer(**SMALL, max_seq_len=8)
|
||||||
|
ids = torch.randint(0, VOCAB_SIZE, (1, 9))
|
||||||
|
with pytest.raises(ValueError, match="exceeds max_seq_len"):
|
||||||
|
model(ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_tying():
|
||||||
|
model = ChordTransformer(**SMALL)
|
||||||
|
assert model.lm_head.weight is model.token_emb.weight
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Causal masking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_causal_masking_future_tokens_do_not_affect_past_logits():
|
||||||
|
"""Changing tokens at positions [k:] must not alter logits at positions [:k]."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = ChordTransformer(**SMALL)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
seq_len = 12
|
||||||
|
base_ids = torch.randint(0, VOCAB_SIZE, (1, seq_len))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_base = model(base_ids)
|
||||||
|
|
||||||
|
# Mutate the second half of the sequence
|
||||||
|
modified_ids = base_ids.clone()
|
||||||
|
modified_ids[:, seq_len // 2 :] = torch.randint(0, VOCAB_SIZE, (1, seq_len // 2))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_mod = model(modified_ids)
|
||||||
|
|
||||||
|
# Logits for the first half must be bit-exact (no stochastic ops in eval)
|
||||||
|
assert torch.equal(
|
||||||
|
logits_base[:, : seq_len // 2, :],
|
||||||
|
logits_mod[:, : seq_len // 2, :],
|
||||||
|
), "Causal masking violated: future tokens affected past logits"
|
||||||
Reference in New Issue
Block a user