Files
gh-tachyon-beep-skillpacks-…/skills/using-neural-architectures/transformer-architecture-deepdive.md
2025-11-30 09:00:00 +08:00

27 KiB
Raw Blame History

Transformer Architecture Deep Dive

When to Use This Skill

Use this skill when you need to:

  • Implement a Transformer from scratch
  • Understand HOW and WHY self-attention works
  • Choose between encoder, decoder, or encoder-decoder architectures
  • Decide if Vision Transformer (ViT) is appropriate for your vision task
  • Understand modern variants (RoPE, ALiBi, GQA, MQA)
  • Debug Transformer implementation issues
  • Optimize Transformer performance

Do NOT use this skill for:

  • High-level architecture selection (use using-neural-architectures)
  • Attention mechanism comparison (use attention-mechanisms-catalog)
  • LLM-specific topics like prompt engineering (use llm-specialist pack)

Core Principle

Transformers are NOT magic. They are:

  1. Self-attention mechanism (information retrieval)
    • Position encoding (break permutation invariance)
    • Residual connections + Layer norm (training stability)
    • Feed-forward networks (non-linearity)

Understanding the mechanism beats cargo-culting implementations.

Part 1: Self-Attention Mechanism Explained

The Information Retrieval Analogy

Self-attention = Querying a database:

  • Query (Q): "What am I looking for?"
  • Key (K): "What do I contain?"
  • Value (V): "What information do I have?"

Process:

  1. Compare your query with all keys (compute similarity)
  2. Weight values by similarity
  3. Return weighted sum of values

Example: Sentence: "The cat sat on the mat" Token "sat" (verb):

  • High attention to "cat" (subject) → Learns verb-subject relationship
  • High attention to "mat" (object) → Learns verb-object relationship
  • Low attention to "the", "on" (function words)

Mathematical Breakdown

Given input X: (batch, seq_len, d_model)

Step 1: Project to Q, K, V

Q = X @ W_Q  # (batch, seq_len, d_k)
K = X @ W_K  # (batch, seq_len, d_k)
V = X @ W_V  # (batch, seq_len, d_v)

# Typically: d_k = d_v = d_model / num_heads

Step 2: Compute attention scores (similarity)

scores = Q @ K.transpose(-2, -1)  # (batch, seq_len, seq_len)
# scores[i, j] = similarity between query_i and key_j

Geometric interpretation:

  • Dot product measures vector alignment
  • q · k = ||q|| ||k|| cos(θ)
  • Similar vectors → Large dot product → High attention
  • Orthogonal vectors → Zero dot product → No attention

Step 3: Scale by √d_k (CRITICAL!)

scores = scores / math.sqrt(d_k)

WHY scaling?

  • Dot products grow with dimension: Var(q · k) = d_k
  • Example: d_k=64 → Random dot products ~ ±64
  • Large scores → Softmax saturates → Gradients vanish
  • Scaling: Keep scores ~ O(1) regardless of dimension

Without scaling: Softmax([30, 25, 20]) ≈ [0.99, 0.01, 0.00] (saturated!) With scaling: Softmax([3, 2.5, 2]) ≈ [0.50, 0.30, 0.20] (healthy gradients)

Step 4: Softmax to get attention weights

attn_weights = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)
# Each row sums to 1 (probability distribution)
# attn_weights[i, j] = "how much token i attends to token j"

Step 5: Weight values

output = attn_weights @ V  # (batch, seq_len, d_v)
# Each token's output = weighted average of all values

Complete formula:

Attention(Q, K, V) = softmax(Q K^T / d_k) V

Why Three Matrices (Q, K, V)?

Could we use just one? Attention(X, X, X) Yes, but Q/K/V separation enables:

  1. Asymmetry: Query can differ from key (search ≠ database)
  2. Decoupling: What you search for (Q@K) ≠ what you retrieve (V)
  3. Cross-attention: Q from one source, K/V from another
    • Example: Decoder queries encoder (translation)

Modern optimization: Multi-Query Attention (MQA), Grouped-Query Attention (GQA)

  • Share K/V across heads (fewer parameters, faster inference)

Implementation Example

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k=None):
        super().__init__()
        self.d_k = d_k or d_model

        self.W_q = nn.Linear(d_model, self.d_k)
        self.W_k = nn.Linear(d_model, self.d_k)
        self.W_v = nn.Linear(d_model, self.d_k)

    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)
        Q = self.W_q(x)  # (batch, seq_len, d_k)
        K = self.W_k(x)  # (batch, seq_len, d_k)
        V = self.W_v(x)  # (batch, seq_len, d_k)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores: (batch, seq_len, seq_len)

        # Apply mask if provided (for causal attention)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Attention weights
        attn_weights = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)

        # Weighted sum of values
        output = torch.matmul(attn_weights, V)  # (batch, seq_len, d_k)

        return output, attn_weights

Complexity: O(n² · d) where n = seq_len, d = d_model

  • Quadratic in sequence length (bottleneck for long sequences)
  • For n=1000, d=512: 1000² × 512 = 512M operations

Part 2: Multi-Head Attention

Why Multiple Heads?

Single-head attention learns one attention pattern. Multi-head attention learns multiple parallel patterns:

  • Head 1: Syntactic relationships (subject-verb)
  • Head 2: Semantic similarity
  • Head 3: Positional proximity
  • Head 4: Long-range dependencies

Analogy: Ensemble of attention functions, each specializing in different patterns.

Head Dimension Calculation

CRITICAL CONSTRAINT: num_heads must divide d_model evenly!

d_model = 512
num_heads = 8
d_k = d_model // num_heads  # 512 / 8 = 64

# Each head operates in d_k dimensions
# Concatenate all heads → back to d_model dimensions

Common configurations:

  • BERT-base: d_model=768, heads=12, d_k=64
  • GPT-2: d_model=768, heads=12, d_k=64
  • GPT-3 175B: d_model=12288, heads=96, d_k=128
  • LLaMA-2 70B: d_model=8192, heads=64, d_k=128

Rule of thumb: d_k (head dimension) should be 64-128

  • Too small (d_k < 32): Limited representational capacity
  • Too large (d_k > 256): Redundant, wasteful

Multi-Head Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Single linear layers for all heads (more efficient)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)  # Output projection

    def split_heads(self, x):
        # x: (batch, seq_len, d_model)
        batch_size, seq_len, d_model = x.size()
        # Reshape to (batch, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        # Transpose to (batch, num_heads, seq_len, d_k)
        return x.transpose(1, 2)

    def forward(self, x, mask=None):
        batch_size = x.size(0)

        # Linear projections
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)

        # Weighted sum
        attn_output = torch.matmul(attn_weights, V)
        # attn_output: (batch, num_heads, seq_len, d_k)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        # (batch, seq_len, num_heads, d_k)
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        # (batch, seq_len, d_model)

        # Final linear projection
        output = self.W_o(attn_output)

        return output, attn_weights

Modern Variants: GQA and MQA

Problem: K/V caching during inference is memory-intensive

  • LLaMA-2 70B: 8192 × 64 heads × 2 (K + V) = 1M parameters per token cached!

Solution 1: Multi-Query Attention (MQA)

  • One K/V head shared across all Q heads
  • Benefit: Dramatically faster inference (smaller KV cache)
  • Trade-off: ~1-2% accuracy loss
# MQA: Single K/V projection
self.W_k = nn.Linear(d_model, d_k)  # Not d_model!
self.W_v = nn.Linear(d_model, d_k)
self.W_q = nn.Linear(d_model, d_model)  # Multiple Q heads

Solution 2: Grouped-Query Attention (GQA)

  • Middle ground: Group multiple Q heads per K/V head
  • Example: 32 Q heads → 8 K/V heads (4 Q per K/V)
  • Benefit: 4x smaller KV cache, minimal accuracy loss

Used in: LLaMA-2, Mistral, Mixtral

Part 3: Position Encoding

Why Position Encoding?

Problem: Self-attention is permutation-invariant

  • Attention("cat sat mat") = Attention("mat cat sat")
  • No inherent notion of position or order!

Solution: Add position information to embeddings

Strategy 1: Sinusoidal Position Encoding (Original)

Formula:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Implementation:

def sinusoidal_position_encoding(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                         (-math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

# Usage: Add to input embeddings
x = token_embeddings + positional_encoding

Properties:

  • Deterministic (no learned parameters)
  • Extrapolates to unseen lengths (geometric properties)
  • Relative positions: PE(pos+k) is linear function of PE(pos)

When to use: Variable-length sequences in NLP

Strategy 2: Learned Position Embeddings

self.pos_embedding = nn.Embedding(max_seq_len, d_model)

# Usage
positions = torch.arange(seq_len, device=x.device)
x = token_embeddings + self.pos_embedding(positions)

Properties:

  • Learnable (adapts to data)
  • Cannot extrapolate beyond max_seq_len

When to use:

  • Fixed-length sequences
  • Vision Transformers (image patches)
  • When training data covers all positions

Strategy 3: Rotary Position Embeddings (RoPE)

Modern approach (2021+): Rotate Q and K in complex plane

Key advantages:

  • Encodes relative positions naturally
  • Better long-range decay properties
  • No addition to embeddings (applied in attention)

Used in: GPT-NeoX, PaLM, LLaMA, LLaMA-2, Mistral

def apply_rotary_pos_emb(x, cos, sin):
    # x: (batch, num_heads, seq_len, d_k)
    # Split into even/odd
    x1, x2 = x[..., ::2], x[..., 1::2]
    # Rotate
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

Strategy 4: ALiBi (Attention with Linear Biases)

Simplest modern approach: Add bias to attention scores (no embeddings!)

# Bias matrix: -1 * distance
# [[0, -1, -2, -3],
#  [0,  0, -1, -2],
#  [0,  0,  0, -1],
#  [0,  0,  0,  0]]

scores = Q @ K^T / d_k + alibi_bias

Key advantages:

  • Best extrapolation to longer sequences
  • No positional embeddings (simpler)
  • Per-head slopes (different decay rates)

Used in: BLOOM

Position Encoding Selection Guide

Use Case Recommended Why
NLP (variable length) RoPE or ALiBi Better extrapolation
NLP (fixed length) Learned embeddings Adapts to data
Vision (ViT) 2D learned embeddings Spatial structure
Long sequences (>2k) ALiBi Best extrapolation
Legacy/compatibility Sinusoidal Original Transformer

Modern trend (2023+): RoPE and ALiBi dominate over sinusoidal

Part 4: Architecture Variants

Variant 1: Encoder-Only (Bidirectional)

Architecture:

  • Self-attention: Each token attends to ALL tokens (past + future)
  • No masking (bidirectional context)

Examples: BERT, RoBERTa, ELECTRA, DeBERTa

Use cases:

  • Text classification
  • Named entity recognition
  • Question answering (extract span from context)
  • Sentence embeddings

Key property: Sees full context → Good for understanding

Implementation:

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)  # No causal mask!
        return x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Self-attention + residual + norm
        attn_output, _ = self.self_attn(x, mask)
        x = self.norm1(x + attn_output)

        # Feed-forward + residual + norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)

        return x

Variant 2: Decoder-Only (Autoregressive)

Architecture:

  • Self-attention with causal masking
  • Each token attends ONLY to past tokens (not future)

Causal mask (lower triangular):

# mask[i, j] = 1 if j <= i else 0
[[1, 0, 0, 0],   # Token 0 sees only itself
 [1, 1, 0, 0],   # Token 1 sees tokens 0-1
 [1, 1, 1, 0],   # Token 2 sees tokens 0-2
 [1, 1, 1, 1]]   # Token 3 sees all

Examples: GPT, GPT-2, GPT-3, GPT-4, LLaMA, Mistral

Use cases:

  • Text generation
  • Language modeling
  • Code generation
  • Autoregressive prediction

Key property: Generates sequentially → Good for generation

Implementation:

def create_causal_mask(seq_len, device):
    # Lower triangular matrix
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask

class TransformerDecoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        seq_len = x.size(1)
        causal_mask = create_causal_mask(seq_len, x.device)

        for layer in self.layers:
            x = layer(x, causal_mask)  # Apply causal mask!
        return x

Modern trend (2023+): Decoder-only architectures dominate

  • Can do both generation AND understanding (via prompting)
  • Simpler than encoder-decoder (no cross-attention)
  • Scales better to massive sizes

Variant 3: Encoder-Decoder (Seq2Seq)

Architecture:

  • Encoder: Bidirectional self-attention (understands input)
  • Decoder: Causal self-attention (generates output)
  • Cross-attention: Decoder queries encoder outputs

Cross-attention mechanism:

# Q from decoder, K and V from encoder
Q = decoder_hidden @ W_q
K = encoder_output @ W_k
V = encoder_output @ W_v

cross_attn = softmax(Q K^T / d_k) V

Examples: T5, BART, mT5, original Transformer (2017)

Use cases:

  • Translation (input ≠ output language)
  • Summarization (long input → short output)
  • Question answering (generate answer, not extract)

When to use: Input and output are fundamentally different

Implementation:

class EncoderDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)  # NEW!
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, decoder_input, encoder_output, causal_mask=None):
        # 1. Self-attention (causal)
        self_attn_out, _ = self.self_attn(decoder_input, causal_mask)
        x = self.norm1(decoder_input + self_attn_out)

        # 2. Cross-attention (Q from decoder, K/V from encoder)
        cross_attn_out, _ = self.cross_attn.forward_cross(
            query=x,
            key=encoder_output,
            value=encoder_output
        )
        x = self.norm2(x + cross_attn_out)

        # 3. Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm3(x + ff_out)

        return x

Architecture Selection Guide

Task Architecture Why
Classification Encoder-only Need full bidirectional context
Text generation Decoder-only Autoregressive generation
Translation Encoder-decoder or Decoder-only Different languages, or use prompting
Summarization Encoder-decoder or Decoder-only Length mismatch, or use prompting
Q&A (extract) Encoder-only Find span in context
Q&A (generate) Decoder-only Generate freeform answer

2023+ trend: Decoder-only can do everything via prompting (but less parameter-efficient for some tasks)

Part 5: Vision Transformers (ViT)

From Images to Sequences

Key insight: Treat image as sequence of patches

Process:

  1. Split image into patches (e.g., 16×16 pixels)
  2. Flatten each patch → 1D vector
  3. Linear projection → token embeddings
  4. Add 2D positional embeddings
  5. Prepend [CLS] token (for classification)
  6. Feed to Transformer encoder

Example: 224×224 image, 16×16 patches

  • Number of patches: (224/16)² = 196
  • Each patch: 16 × 16 × 3 = 768 dimensions
  • Transformer input: 197 tokens (196 patches + 1 [CLS])

ViT Implementation

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 d_model=768, num_heads=12, num_layers=12, num_classes=1000):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size ** 2

        # Patch embedding (linear projection of flattened patches)
        self.patch_embed = nn.Linear(patch_dim, d_model)

        # [CLS] token (learnable)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # Position embeddings (learnable)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))

        # Transformer encoder
        self.encoder = TransformerEncoder(d_model, num_heads,
                                         d_ff=4*d_model, num_layers=num_layers)

        # Classification head
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # x: (batch, channels, height, width)
        batch_size = x.size(0)

        # Divide into patches and flatten
        x = x.unfold(2, self.patch_size, self.patch_size)
        x = x.unfold(3, self.patch_size, self.patch_size)
        # (batch, channels, num_patches_h, num_patches_w, patch_size, patch_size)

        x = x.contiguous().view(batch_size, -1, self.patch_size ** 2 * 3)
        # (batch, num_patches, patch_dim)

        # Linear projection
        x = self.patch_embed(x)  # (batch, num_patches, d_model)

        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch, num_patches+1, d_model)

        # Add positional embeddings
        x = x + self.pos_embed

        # Transformer encoder
        x = self.encoder(x)

        # Classification: Use [CLS] token
        cls_output = x[:, 0]  # (batch, d_model)
        logits = self.head(cls_output)

        return logits

ViT vs CNN: Critical Differences

1. Inductive Bias

Property CNN ViT
Locality Strong (conv kernel) Weak (global attention)
Translation invariance Strong (weight sharing) Weak (position embeddings)
Hierarchy Strong (pooling layers) None (flat patches)

Implication: CNN has strong priors, ViT learns from data

2. Data Requirements

Dataset Size CNN ViT (from scratch) ViT (pretrained)
Small (< 100k) Good Fails Good
Medium (100k-1M) Excellent ⚠️ Poor Good
Large (> 1M) Excellent ⚠️ OK Excellent
Huge (> 100M) Excellent SOTA N/A

Key finding: ViT needs 100M+ images to train from scratch

  • Original ViT: Trained on JFT-300M (300 million images)
  • Without massive data, ViT underperforms CNNs significantly

3. Computational Cost

Example: 224×224 images

Model Parameters GFLOPs Inference (GPU)
ResNet-50 25M 4.1 ~30ms
EfficientNet-B0 5M 0.4 ~10ms
ViT-B/16 86M 17.6 ~100ms

Implication: ViT is 40x more expensive than EfficientNet!

When to Use ViT

Use ViT when:

  • Large dataset (> 1M images) OR using pretrained weights
  • Computational cost acceptable (cloud, large GPU)
  • Best possible accuracy needed
  • Can fine-tune from ImageNet-21k checkpoint

Use CNN when:

  • Small/medium dataset (< 1M images) and training from scratch
  • Limited compute/memory
  • Edge deployment (mobile, embedded)
  • Need architectural inductive biases

Hybrid Approaches (2022-2023)

ConvNeXt: CNN with ViT design choices

  • Matches ViT accuracy with CNN efficiency
  • Works better on small datasets

Swin Transformer: Hierarchical ViT with local windows

  • Shifted windows for efficiency
  • O(n) complexity instead of O(n²)
  • Better for dense prediction (segmentation)

CoAtNet: Mix conv layers (early) + Transformer layers (late)

  • Gets both inductive bias and global attention

Part 6: Implementation Checklist

Critical Details

1. Layer Norm Placement

Post-norm (original):

x = x + self_attn(x)
x = layer_norm(x)

Pre-norm (modern, recommended):

x = x + self_attn(layer_norm(x))

Why pre-norm? More stable training, less sensitive to learning rate

2. Attention Dropout

Apply dropout to attention weights, not Q/K/V!

attn_weights = F.softmax(scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=0.1, training=self.training)  # HERE!
output = torch.matmul(attn_weights, V)

3. Feed-Forward Dimension

Typically: d_ff = 4 × d_model

  • BERT: d_model=768, d_ff=3072
  • GPT-2: d_model=768, d_ff=3072

4. Residual Connections

ALWAYS use residual connections (essential for training)!

x = x + self_attn(x)  # Residual
x = x + feed_forward(x)  # Residual

5. Initialization

Use Xavier/Glorot initialization for attention weights:

nn.init.xavier_uniform_(self.W_q.weight)
nn.init.xavier_uniform_(self.W_k.weight)
nn.init.xavier_uniform_(self.W_v.weight)

Part 7: When NOT to Use Transformers

Limitation 1: Small Datasets

Problem: Transformers have weak inductive bias (learn from data)

Impact:

  • ViT: Fails on < 100k images without pretraining
  • NLP: BERT needs 100M+ tokens for pretraining

Solution: Use models with stronger priors (CNN for vision, smaller models for text)

Limitation 2: Long Sequences

Problem: O(n²) memory complexity

Impact:

  • Standard Transformer: n=10k → 100M attention scores
  • GPU memory: 10k² × 4 bytes = 400MB per sample!

Solution:

  • Sparse attention (Longformer, BigBird)
  • Linear attention (Linformer, Performer)
  • Flash Attention (memory-efficient kernel)
  • State space models (S4, Mamba)

Limitation 3: Edge Deployment

Problem: Large model size, high latency

Impact:

  • ViT-B: 86M parameters, ~100ms inference
  • Mobile/embedded: Need < 10M parameters, < 50ms

Solution: Efficient CNNs (MobileNet, EfficientNet) or distilled models

Limitation 4: Real-Time Processing

Problem: Sequential generation in decoder (cannot parallelize at inference)

Impact: GPT-style models generate one token at a time

Solution: Non-autoregressive models, speculative decoding, or smaller models

Part 8: Common Mistakes

Mistake 1: Forgetting Causal Mask

Symptom: Decoder "cheats" by seeing future tokens

Fix: Always apply causal mask to decoder self-attention!

causal_mask = torch.tril(torch.ones(seq_len, seq_len))

Mistake 2: Wrong Dimension for Multi-Head

Symptom: Runtime error or dimension mismatch

Fix: Ensure d_model % num_heads == 0

assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

Mistake 3: Forgetting Position Encoding

Symptom: Model ignores word order

Fix: Always add position information!

x = token_embeddings + positional_encoding

Mistake 4: Wrong Softmax Dimension

Symptom: Attention weights don't sum to 1 per query

Fix: Softmax over last dimension (keys)

attn_weights = F.softmax(scores, dim=-1)  # Sum over keys for each query

Mistake 5: No Residual Connections

Symptom: Training diverges or converges very slowly

Fix: Always add residual connections!

x = x + self_attn(x)
x = x + feed_forward(x)

Summary: Quick Reference

Architecture Selection

Classification/Understanding → Encoder-only (BERT-style)
Generation/Autoregressive → Decoder-only (GPT-style)
Seq2Seq (input ≠ output) → Encoder-decoder (T5-style) or Decoder-only with prompting

Position Encoding Selection

NLP (variable length) → RoPE or ALiBi
NLP (fixed length) → Learned embeddings
Vision (ViT) → 2D learned embeddings
Long sequences (> 2k) → ALiBi (best extrapolation)

Multi-Head Configuration

Small models (d_model < 512): 4-8 heads
Medium models (d_model 512-1024): 8-12 heads
Large models (d_model > 1024): 12-32 heads
Rule: d_k (head dimension) should be 64-128

ViT vs CNN

ViT: Large dataset (> 1M) OR pretrained weights
CNN: Small dataset (< 1M) OR edge deployment

Implementation Essentials

✅ Pre-norm (more stable than post-norm)
✅ Residual connections (essential!)
✅ Causal mask for decoder
✅ Attention dropout (on weights, not Q/K/V)
✅ d_ff = 4 × d_model (feed-forward dimension)
✅ Check: d_model % num_heads == 0

Next Steps

After mastering this skill:

  • attention-mechanisms-catalog: Explore attention variants (sparse, linear, Flash)
  • llm-specialist/llm-finetuning-strategies: Apply to language models
  • architecture-design-principles: Understand design trade-offs

Remember: Transformers are NOT magic. Understanding the mechanism (information retrieval via Q/K/V) beats cargo-culting implementations.