feat: implement ChordTransformer (pre-norm decoder-only transformer)
Adds src/model.py with a weight-tied autoregressive transformer and tests/test_model.py with shape, weight-tying, and causal-masking checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+156
@@ -0,0 +1,156 @@
|
||||
"""Small decoder-only transformer for harmonic period generation.
|
||||
|
||||
Architecture: pre-norm, causal self-attention, weight-tied embeddings.
|
||||
|
||||
Usage:
|
||||
from src.model import ChordTransformer
|
||||
model = ChordTransformer(vocab_size=85)
|
||||
logits = model(input_ids) # [batch, seq_len, vocab_size]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.attn = nn.MultiheadAttention(
|
||||
d_model, n_heads, dropout=dropout, batch_first=True
|
||||
)
|
||||
self.attn_drop = nn.Dropout(dropout)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(d_model, d_ff),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_ff, d_model),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
causal_mask: torch.Tensor,
|
||||
key_padding_mask: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
normed = self.norm1(x)
|
||||
attn_out, _ = self.attn(
|
||||
normed,
|
||||
normed,
|
||||
normed,
|
||||
attn_mask=causal_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
x = x + self.attn_drop(attn_out)
|
||||
x = x + self.ff(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class ChordTransformer(nn.Module):
|
||||
"""Autoregressive transformer for chord sequence modelling.
|
||||
|
||||
Args:
|
||||
vocab_size: Number of tokens in the vocabulary.
|
||||
d_model: Embedding / hidden dimension.
|
||||
n_layers: Number of transformer blocks.
|
||||
n_heads: Number of attention heads (must divide d_model evenly).
|
||||
d_ff: Feed-forward inner dimension.
|
||||
max_seq_len: Maximum sequence length (sets positional embedding size).
|
||||
dropout: Dropout probability applied throughout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int = 192,
|
||||
n_layers: int = 3,
|
||||
n_heads: int = 6,
|
||||
d_ff: int = 768,
|
||||
max_seq_len: int = 512,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.token_emb = nn.Embedding(vocab_size, d_model)
|
||||
self.pos_emb = nn.Embedding(max_seq_len, d_model)
|
||||
self.emb_drop = nn.Dropout(dropout)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[_TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
# Output projection — weight tied with token embedding
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
||||
self.lm_head.weight = self.token_emb.weight
|
||||
|
||||
self._init_weights()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight initialisation (GPT-style)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_weights(self) -> None:
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
nn.init.normal_(self.pos_emb.weight, std=0.02)
|
||||
for block in self.blocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=0.02)
|
||||
nn.init.zeros_(block.attn.in_proj_bias)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=0.02)
|
||||
nn.init.zeros_(block.attn.out_proj.bias)
|
||||
for layer in block.ff:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, std=0.02)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward pass
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute next-token logits for every position.
|
||||
|
||||
Args:
|
||||
input_ids: Long tensor of shape [batch, seq_len].
|
||||
attention_mask: Optional boolean/int tensor [batch, seq_len];
|
||||
1 = attend, 0 = ignore (padding). When None, all positions
|
||||
are attended.
|
||||
|
||||
Returns:
|
||||
Float tensor of shape [batch, seq_len, vocab_size].
|
||||
"""
|
||||
B, T = input_ids.shape
|
||||
if T > self.max_seq_len:
|
||||
raise ValueError(
|
||||
f"sequence length {T} exceeds max_seq_len={self.max_seq_len}"
|
||||
)
|
||||
|
||||
positions = torch.arange(T, device=input_ids.device)
|
||||
x = self.emb_drop(self.token_emb(input_ids) + self.pos_emb(positions))
|
||||
|
||||
# Upper-triangular True = blocked: token i cannot attend to token j > i
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
|
||||
key_padding_mask: torch.Tensor | None = None
|
||||
if attention_mask is not None:
|
||||
# nn.MultiheadAttention expects True = ignore
|
||||
key_padding_mask = ~attention_mask.bool()
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, causal_mask, key_padding_mask)
|
||||
|
||||
x = self.norm(x)
|
||||
return self.lm_head(x)
|
||||
Reference in New Issue
Block a user