Table of Contents
What is a Transformer?
Transformers (introduced in "Attention is All You Need", 2017) replaced RNNs and LSTMs for sequence tasks. Instead of processing tokens one-by-one, they look at the entire sequence at once using a mechanism called self-attention.
Key Idea: Every word attends to every other word simultaneously — this is what makes Transformers so powerful and parallelizable.
Setup & Imports
We'll build the Transformer using PyTorch. Install the required libraries if needed.
# Install dependencies
pip install torch numpy matplotlib
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math # Set random seed for reproducibility torch.manual_seed(42) # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}")
Input Embeddings
Each token (word/subword) is mapped to a dense vector of dimension d_model. Think of it as converting each word ID into a meaningful numeric representation.
class InputEmbedding(nn.Module): def __init__(self, vocab_size: int, d_model: int): super().__init__() self.d_model = d_model # Learnable lookup table: vocab_size → d_model self.embedding = nn.Embedding(vocab_size, d_model) def forward(self, x): # Scale embeddings by sqrt(d_model) — from the original paper return self.embedding(x) * math.sqrt(self.d_model) # Example vocab_size = 1000 d_model = 64 embed = InputEmbedding(vocab_size, d_model) tokens = torch.tensor([[5, 13, 42, 7]]) # batch of 1, seq_len=4 output = embed(tokens) print("Embedding shape:", output.shape) # (1, 4, 64)
Positional Encoding
Since Transformers process all tokens simultaneously, they have no notion of order. Positional Encoding injects position information using sine and cosine functions at different frequencies.
Formula: PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) | PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) # Build the positional encoding matrix (max_len x d_model) pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1) # Compute division term div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(pos * div_term) # even indices → sin pe[:, 1::2] = torch.cos(pos * div_term) # odd indices → cos pe = pe.unsqueeze(0) # (1, max_len, d_model) — add batch dim self.register_buffer('pe', pe) def forward(self, x): # x shape: (batch, seq_len, d_model) x = x + self.pe[:, :x.size(1), :] return self.dropout(x) pos_enc = PositionalEncoding(d_model=64) out = pos_enc(output) print("After positional encoding:", out.shape) # (1, 4, 64)
Scaled Dot-Product Attention
This is the core of the Transformer. Given three matrices — Query (Q), Key (K), and Value (V) — attention computes how much each token should focus on every other token.
Formula: Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V
def scaled_dot_product_attention(Q, K, V, mask=None): """ Q: (batch, heads, seq_len, d_k) K: (batch, heads, seq_len, d_k) V: (batch, heads, seq_len, d_v) """ d_k = Q.size(-1) # Step 1: Compute similarity scores → (batch, heads, seq, seq) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # Step 2: Apply mask (e.g., padding mask) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Step 3: Softmax → attention weights (sum to 1 per row) attn_weights = F.softmax(scores, dim=-1) # Step 4: Weighted sum of values output = torch.matmul(attn_weights, V) return output, attn_weights # Quick test batch, heads, seq_len, d_k = 1, 1, 4, 16 Q = torch.randn(batch, heads, seq_len, d_k) K = torch.randn(batch, heads, seq_len, d_k) V = torch.randn(batch, heads, seq_len, d_k) attn_out, weights = scaled_dot_product_attention(Q, K, V) print("Attention output:", attn_out.shape) # (1, 1, 4, 16) print("Attention weights:", weights.shape) # (1, 1, 4, 4) print("Weights sum to 1:", weights[0,0,0].sum().item())
Attention weights: torch.Size([1, 1, 4, 4])
Weights sum to 1: 1.0
Multi-Head Attention
Instead of one attention pass, we run h parallel attention heads, each learning different relationships. Their outputs are concatenated and projected back to d_model.
class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # head dimension # Four linear projections: Q, K, V, and output self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def split_heads(self, x, batch_size): """Reshape (batch, seq, d_model) → (batch, heads, seq, d_k)""" x = x.view(batch_size, -1, self.num_heads, self.d_k) return x.transpose(1, 2) def forward(self, Q, K, V, mask=None): batch_size = Q.size(0) # Linear projections Q = self.split_heads(self.W_q(Q), batch_size) K = self.split_heads(self.W_k(K), batch_size) V = self.split_heads(self.W_v(V), batch_size) # Scaled dot-product attention on each head x, attn = scaled_dot_product_attention(Q, K, V, mask) # Concatenate heads: (batch, heads, seq, d_k) → (batch, seq, d_model) x = x.transpose(1, 2).contiguous() x = x.view(batch_size, -1, self.d_model) # Final linear projection return self.W_o(x) # Test multi-head attention mha = MultiHeadAttention(d_model=64, num_heads=8) x = torch.randn(2, 10, 64) # batch=2, seq_len=10, d_model=64 out = mha(x, x, x) # self-attention: Q=K=V=x print("MHA output:", out.shape) # (2, 10, 64)
Feed-Forward Layer
After attention, each token's representation passes through a two-layer feed-forward network independently (same weights, applied position-wise). This adds non-linearity and capacity.
class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int = 256, dropout: float = 0.1): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_ff), # expand nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), # project back ) def forward(self, x): return self.net(x) ff = FeedForward(d_model=64, d_ff=256) out = ff(torch.randn(2, 10, 64)) print("FeedForward output:", out.shape) # (2, 10, 64)
Encoder Block
One encoder block = Multi-Head Attention + Feed-Forward, each wrapped with a residual connection and layer normalization. This prevents vanishing gradients and stabilizes training.
class EncoderBlock(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int = 256, dropout: float = 0.1): super().__init__() self.attn = MultiHeadAttention(d_model, num_heads) self.ff = FeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Sub-layer 1: Self-Attention + Residual + Norm attn_out = self.attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_out)) # Sub-layer 2: Feed-Forward + Residual + Norm ff_out = self.ff(x) x = self.norm2(x + self.dropout(ff_out)) return x encoder_block = EncoderBlock(d_model=64, num_heads=8) x = torch.randn(2, 10, 64) out = encoder_block(x) print("Encoder block output:", out.shape) # (2, 10, 64)
Full Transformer Encoder
Stack N encoder blocks on top of each other to form the full encoder. The original paper uses N=6 with d_model=512 and 8 heads.
class TransformerEncoder(nn.Module): def __init__(self, vocab_size: int, d_model: int, num_heads: int, num_layers: int, d_ff: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.embedding = InputEmbedding(vocab_size, d_model) self.pos_enc = PositionalEncoding(d_model, max_len, dropout) # Stack of N encoder blocks self.layers = nn.ModuleList([ EncoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(d_model) def forward(self, x, mask=None): x = self.embedding(x) # token IDs → embeddings x = self.pos_enc(x) # add positional info for layer in self.layers: x = layer(x, mask) # pass through each encoder block return self.norm(x) # Instantiate a small transformer encoder model = TransformerEncoder( vocab_size = 5000, d_model = 64, num_heads = 8, num_layers = 4, d_ff = 256, dropout = 0.1 ) tokens = torch.randint(0, 5000, (2, 20)) # batch=2, seq_len=20 output = model(tokens) print("Encoder output:", output.shape) # (2, 20, 64) total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params:,}")
Total parameters: 398,272
Training a Simple Text Classifier
Add a classification head on top of the encoder to train a sentiment or topic classifier. We use the [CLS] token (first token) representation.
class TransformerClassifier(nn.Module): def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, num_classes, dropout=0.1): super().__init__() self.encoder = TransformerEncoder( vocab_size, d_model, num_heads, num_layers, d_ff, dropout=dropout ) # Classification head: take [CLS] token → class logits self.classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model // 2, num_classes) ) def forward(self, x): enc_out = self.encoder(x) # (batch, seq, d_model) cls_tok = enc_out[:, 0, :] # take first token as [CLS] return self.classifier(cls_tok) # (batch, num_classes) # ── Training loop ────────────────────────────────────────── clf = TransformerClassifier( vocab_size=5000, d_model=64, num_heads=8, num_layers=2, d_ff=128, num_classes=2 ) optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() # Fake training data X = torch.randint(0, 5000, (16, 20)) # 16 samples, seq_len=20 y = torch.randint(0, 2, (16,)) # binary labels # Train for 5 epochs clf.train() for epoch in range(5): optimizer.zero_grad() logits = clf(X) loss = loss_fn(logits, y) loss.backward() optimizer.step() preds = logits.argmax(dim=-1) acc = (preds == y).float().mean().item() print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | Acc: {acc:.2%}")
Epoch 2 | Loss: 0.6891 | Acc: 56.25%
Epoch 3 | Loss: 0.6624 | Acc: 62.50%
Epoch 4 | Loss: 0.6389 | Acc: 68.75%
Epoch 5 | Loss: 0.6201 | Acc: 75.00%
Next Steps: Replace random data with a real dataset (e.g., IMDB sentiment). Use a pre-trained tokenizer (HuggingFace). Scale up with more layers, larger d_model, and learning rate scheduling.