feat: implement ChordTransformer (pre-norm decoder-only transformer)

Adds src/model.py with a weight-tied autoregressive transformer and
tests/test_model.py with shape, weight-tying, and causal-masking checks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 11:09:11 +03:00
parent 0712eec578
commit 10229be042
2 changed files with 243 additions and 0 deletions
+156
View File
@@ -0,0 +1,156 @@
"""Small decoder-only transformer for harmonic period generation.
Architecture: pre-norm, causal self-attention, weight-tied embeddings.
Usage:
from src.model import ChordTransformer
model = ChordTransformer(vocab_size=85)
logits = model(input_ids) # [batch, seq_len, vocab_size]
"""
from __future__ import annotations
import torch
import torch.nn as nn
class _TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.attn_drop = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(
self,
x: torch.Tensor,
causal_mask: torch.Tensor,
key_padding_mask: torch.Tensor | None,
) -> torch.Tensor:
normed = self.norm1(x)
attn_out, _ = self.attn(
normed,
normed,
normed,
attn_mask=causal_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
x = x + self.attn_drop(attn_out)
x = x + self.ff(self.norm2(x))
return x
class ChordTransformer(nn.Module):
"""Autoregressive transformer for chord sequence modelling.
Args:
vocab_size: Number of tokens in the vocabulary.
d_model: Embedding / hidden dimension.
n_layers: Number of transformer blocks.
n_heads: Number of attention heads (must divide d_model evenly).
d_ff: Feed-forward inner dimension.
max_seq_len: Maximum sequence length (sets positional embedding size).
dropout: Dropout probability applied throughout.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 192,
n_layers: int = 3,
n_heads: int = 6,
d_ff: int = 768,
max_seq_len: int = 512,
dropout: float = 0.1,
) -> None:
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.emb_drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[_TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
)
self.norm = nn.LayerNorm(d_model)
# Output projection — weight tied with token embedding
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight
self._init_weights()
# ------------------------------------------------------------------
# Weight initialisation (GPT-style)
# ------------------------------------------------------------------
def _init_weights(self) -> None:
nn.init.normal_(self.token_emb.weight, std=0.02)
nn.init.normal_(self.pos_emb.weight, std=0.02)
for block in self.blocks:
nn.init.normal_(block.attn.in_proj_weight, std=0.02)
nn.init.zeros_(block.attn.in_proj_bias)
nn.init.normal_(block.attn.out_proj.weight, std=0.02)
nn.init.zeros_(block.attn.out_proj.bias)
for layer in block.ff:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, std=0.02)
nn.init.zeros_(layer.bias)
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute next-token logits for every position.
Args:
input_ids: Long tensor of shape [batch, seq_len].
attention_mask: Optional boolean/int tensor [batch, seq_len];
1 = attend, 0 = ignore (padding). When None, all positions
are attended.
Returns:
Float tensor of shape [batch, seq_len, vocab_size].
"""
B, T = input_ids.shape
if T > self.max_seq_len:
raise ValueError(
f"sequence length {T} exceeds max_seq_len={self.max_seq_len}"
)
positions = torch.arange(T, device=input_ids.device)
x = self.emb_drop(self.token_emb(input_ids) + self.pos_emb(positions))
# Upper-triangular True = blocked: token i cannot attend to token j > i
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
)
key_padding_mask: torch.Tensor | None = None
if attention_mask is not None:
# nn.MultiheadAttention expects True = ignore
key_padding_mask = ~attention_mask.bool()
for block in self.blocks:
x = block(x, causal_mask, key_padding_mask)
x = self.norm(x)
return self.lm_head(x)