Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 09:00:00 +08:00
commit 955d5c6743
12 changed files with 6996 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
{
"name": "yzmir-neural-architectures",
"description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills",
"version": "1.0.1",
"author": {
"name": "tachyon-beep",
"url": "https://github.com/tachyon-beep"
},
"skills": [
"./skills"
]
}

3
README.md Normal file
View File

@@ -0,0 +1,3 @@
# yzmir-neural-architectures
Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills

77
plugin.lock.json Normal file
View File

@@ -0,0 +1,77 @@
{
"$schema": "internal://schemas/plugin.lock.v1.json",
"pluginId": "gh:tachyon-beep/skillpacks:plugins/yzmir-neural-architectures",
"normalized": {
"repo": null,
"ref": "refs/tags/v20251128.0",
"commit": "52681919f25de737a14d958d80b685171797dfdb",
"treeHash": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc",
"generatedAt": "2025-11-28T10:28:34.235937Z",
"toolVersion": "publish_plugins.py@0.2.0"
},
"origin": {
"remote": "git@github.com:zhongweili/42plugin-data.git",
"branch": "master",
"commit": "aa1497ed0949fd50e99e70d6324a29c5b34f9390",
"repoRoot": "/Users/zhongweili/projects/openmind/42plugin-data"
},
"manifest": {
"name": "yzmir-neural-architectures",
"description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills",
"version": "1.0.1"
},
"content": {
"files": [
{
"path": "README.md",
"sha256": "cf12418cc5cc226683c463dca6c54e98f97f2ddf57482d593839f430acc36b86"
},
{
"path": ".claude-plugin/plugin.json",
"sha256": "9893529f120b1582ad62d899531401988bdf89d6b221e057239e09399d47e255"
},
{
"path": "skills/using-neural-architectures/architecture-design-principles.md",
"sha256": "bd1cde3e2cf997b6e42dc52bc28638e644ee387c81b07464edc7fb4cf33d4329"
},
{
"path": "skills/using-neural-architectures/sequence-models-comparison.md",
"sha256": "5c041c08261cf04e47edd3ec092db04a8b4c725fb6bb91620df54005881cbb66"
},
{
"path": "skills/using-neural-architectures/normalization-techniques.md",
"sha256": "c87c82e926951a963d7909321c60a62b5cf6d09dbcbb576472eda46d72993766"
},
{
"path": "skills/using-neural-architectures/graph-neural-networks-basics.md",
"sha256": "85a968972f0c1c226d6bcb155b231560af0eab52326f419d308b1a6a22ebd7c5"
},
{
"path": "skills/using-neural-architectures/generative-model-families.md",
"sha256": "7c527707a0d3ecab99e83079880b1166c4310de62a743823ae747a03545180b3"
},
{
"path": "skills/using-neural-architectures/cnn-families-and-selection.md",
"sha256": "acce2a55a84012259119e70d3eb035aee08beb124d0df1aab009ae7fdd01ff51"
},
{
"path": "skills/using-neural-architectures/SKILL.md",
"sha256": "b8dc5fac750ca1740bd97591409701ca6162e0d5dc5faaf7a2cfcf7feef432b3"
},
{
"path": "skills/using-neural-architectures/attention-mechanisms-catalog.md",
"sha256": "3eab5267c7e5a1263d08be2dca4ef236ad66b554e08aec0e0cd85d5c1f8b5500"
},
{
"path": "skills/using-neural-architectures/transformer-architecture-deepdive.md",
"sha256": "1913626a82ff02f638d10aa00644fcc59fc9836329dc152a7becd9aa030e0937"
}
],
"dirSha256": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc"
},
"security": {
"scannedAt": null,
"scannerVersion": null,
"flags": []
}
}

View File

@@ -0,0 +1,496 @@
---
name: using-neural-architectures
description: Architecture selection router: CNNs, Transformers, RNNs, GANs, GNNs by data modality and constraints
mode: true
pack: neural-architectures
faction: yzmir
---
# Using Neural Architectures: Architecture Selection Router
<CRITICAL_CONTEXT>
Architecture selection comes BEFORE training optimization. Wrong architecture = no amount of training will fix it.
This meta-skill routes you to the right architecture guidance based on:
- Data modality (images, sequences, graphs, etc.)
- Problem type (classification, generation, regression)
- Constraints (data size, compute, latency, interpretability)
Load this skill when architecture decisions are needed.
</CRITICAL_CONTEXT>
## When to Use This Skill
Use this skill when:
- ✅ Selecting an architecture for a new problem
- ✅ Comparing architecture families (CNN vs Transformer, RNN vs Transformer, etc.)
- ✅ Designing custom network topology
- ✅ Troubleshooting architectural instability (deep networks, gradient issues)
- ✅ Understanding when to use specialized architectures (GNNs, generative models)
DO NOT use for:
- ❌ Training/optimization issues (use training-optimization pack)
- ❌ PyTorch implementation details (use pytorch-engineering pack)
- ❌ Production deployment (use ml-production pack)
**When in doubt:** If choosing WHAT architecture → this skill. If training/deploying architecture → different pack.
---
## Core Routing Logic
### Step 1: Identify Data Modality
**Question to ask:** "What type of data are you working with?"
| Data Type | Route To | Why |
|-----------|----------|-----|
| Images (photos, medical scans, etc.) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | CNNs excel at spatial hierarchies |
| Sequences (time series, text, audio) | [sequence-models-comparison.md](sequence-models-comparison.md) | Temporal dependencies need sequential models |
| Graphs (social networks, molecules) | [graph-neural-networks-basics.md](graph-neural-networks-basics.md) | Graph structure requires GNNs |
| Generation task (create images, text) | [generative-model-families.md](generative-model-families.md) | Generative models are specialized |
| Multiple modalities (text + images) | [architecture-design-principles.md](architecture-design-principles.md) | Need custom design |
| Unclear / Generic | [architecture-design-principles.md](architecture-design-principles.md) | Start with fundamentals |
### Step 2: Check for Special Requirements
**If any of these apply, address FIRST:**
| Requirement | Route To | Priority |
|-------------|----------|----------|
| Deep network (> 20 layers) unstable | [normalization-techniques.md](normalization-techniques.md) | CRITICAL - fix before continuing |
| Need attention mechanisms | [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) | Specialized component |
| Custom architecture design | [architecture-design-principles.md](architecture-design-principles.md) | Foundation before specifics |
| Transformer-specific question | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | Specialized architecture |
### Step 3: Consider Problem Characteristics
**Clarify BEFORE routing:**
Ask:
- "How large is your dataset?" (Small < 10k, Medium 10k-1M, Large > 1M)
- "What are your computational constraints?" (Edge device, cloud, GPU availability)
- "What are your latency requirements?" (Real-time, batch, offline)
- "Do you need interpretability?" (Clinical, research, production)
These answers determine architecture appropriateness.
---
## Routing by Data Modality
### Images → CNN Families
**Symptoms triggering this route:**
- "classify images"
- "object detection"
- "semantic segmentation"
- "medical imaging"
- "computer vision"
**Route to:** See [cnn-families-and-selection.md](cnn-families-and-selection.md) for CNN architecture selection and comparison.
**When to route here:**
- ANY vision task (CNNs are default for spatial data)
- Even if considering Transformers, check CNN families first (often better with less data)
**Clarifying questions:**
- "Dataset size?" (< 10k → Start with proven CNNs, > 100k → Consider ViT)
- "Deployment target?" (Edge → EfficientNet, Cloud → Anything)
- "Task type?" (Classification → ResNet/EfficientNet, Detection → YOLO/Faster-RCNN)
---
### Sequences → Sequence Models Comparison
**Symptoms triggering this route:**
- "time series"
- "forecasting"
- "natural language" (NLP)
- "sequential data"
- "temporal patterns"
- "RNN vs LSTM vs Transformer"
**Route to:** See [sequence-models-comparison.md](sequence-models-comparison.md) for sequential model selection (RNN, LSTM, Transformer, TCN).
**When to route here:**
- ANY sequential data
- When user asks "RNN vs LSTM" (skill will present modern alternatives)
- Time-dependent patterns
**Clarifying questions:**
- "Sequence length?" (< 100 → RNN/LSTM/TCN, 100-1000 → Transformer, > 1000 → Sparse Transformers)
- "Latency requirements?" (Real-time → TCN/LSTM, Offline → Transformer)
- "Data volume?" (Small → Simpler models, Large → Transformers)
**CRITICAL:** Challenge "RNN vs LSTM" premise if they ask. Modern alternatives (Transformers, TCN) often better.
---
### Graphs → Graph Neural Networks
**Symptoms triggering this route:**
- "social network"
- "molecular structure"
- "knowledge graph"
- "graph data"
- "node classification"
- "link prediction"
- "graph embeddings"
**Route to:** See [graph-neural-networks-basics.md](graph-neural-networks-basics.md) for GNN architectures and graph learning.
**When to route here:**
- Data has explicit graph structure (nodes + edges)
- Relational information is important
- Network topology matters
**Red flag:** If treating graph as tabular data (extracting features and ignoring edges) → WRONG. Route to GNN skill.
---
### Generation → Generative Model Families
**Symptoms triggering this route:**
- "generate images"
- "synthesize data"
- "GAN vs VAE vs Diffusion"
- "image-to-image translation"
- "style transfer"
- "generative modeling"
**Route to:** See [generative-model-families.md](generative-model-families.md) for GANs, VAEs, and Diffusion models.
**When to route here:**
- Goal is to CREATE data, not classify/predict
- Need to sample from distribution
- Data augmentation through generation
**Clarifying questions:**
- "Use case?" (Real-time game → GAN, Art/research → Diffusion, Fast training → VAE)
- "Quality vs speed?" (Quality → Diffusion, Speed → GAN)
- "Controllability?" (Fine control → StyleGAN/Conditional models)
**CRITICAL:** Different generative models have VERY different trade-offs. Must clarify requirements.
---
## Routing by Architecture Component
### Attention Mechanisms
**Symptoms triggering this route:**
- "when to use attention"
- "self-attention vs cross-attention"
- "attention in CNNs"
- "attention bottleneck"
- "multi-head attention"
**Route to:** See [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) for attention mechanism selection and design.
**When to route here:**
- Designing custom architecture that might benefit from attention
- Understanding where attention helps vs hinders
- Comparing attention variants
**NOT for:** General Transformer questions → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) instead
---
### Transformer Deep Dive
**Symptoms triggering this route:**
- "how do transformers work"
- "Vision Transformer (ViT)"
- "BERT architecture"
- "positional encoding"
- "transformer blocks"
- "scaling transformers"
**Route to:** See [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) for Transformer internals and implementation.
**When to route here:**
- Implementing/customizing transformers
- Understanding transformer internals
- Debugging transformer-specific issues
**Cross-reference:**
- For sequence models generally → [sequence-models-comparison.md](sequence-models-comparison.md) (includes transformers in context)
- For LLMs specifically → `yzmir/llm-specialist/transformer-for-llms` (LLM-specific transformers)
---
### Normalization Techniques
**Symptoms triggering this route:**
- "gradient explosion"
- "training instability in deep network"
- "BatchNorm vs LayerNorm"
- "normalization layers"
- "50+ layer network won't train"
**Route to:** See [normalization-techniques.md](normalization-techniques.md) for deep network stability and normalization methods.
**When to route here:**
- Deep networks (> 20 layers) with training instability
- Choosing between normalization methods
- Architectural stability issues
**CRITICAL:** This is often the ROOT CAUSE of "training won't work" - fix architecture before blaming hyperparameters.
---
### Architecture Design Principles
**Symptoms triggering this route:**
- "how to design architecture"
- "architecture best practices"
- "when to use skip connections"
- "how deep should network be"
- "custom architecture for [novel task]"
- Unclear problem modality
**Route to:** See [architecture-design-principles.md](architecture-design-principles.md) for custom architecture design fundamentals.
**When to route here:**
- Designing custom architectures
- Novel problems without established architecture
- Understanding WHY architectures work
- User is unsure what modality/problem type they have
**This is the foundational skill** - route here if other specific skills don't match.
---
## Multi-Modal / Cross-Pack Routing
### When Problem Spans Multiple Modalities
**Example:** "Text + image classification" (multimodal)
**Route to BOTH:**
1. [sequence-models-comparison.md](sequence-models-comparison.md) (for text)
2. [cnn-families-and-selection.md](cnn-families-and-selection.md) (for images)
3. [architecture-design-principles.md](architecture-design-principles.md) (for fusion strategy)
**Order matters:** Understand individual modalities BEFORE fusion.
### When Architecture + Other Concerns
**Example:** "Select architecture AND optimize training"
**Route order:**
1. Architecture skill FIRST (this pack)
2. Training-optimization SECOND (after architecture chosen)
**Why:** Wrong architecture can't be fixed by better training.
**Example:** "Select architecture AND deploy efficiently"
**Route order:**
1. Architecture skill FIRST
2. ML-production SECOND (quantization, serving)
**Deployment constraints might influence architecture choice** - if so, note constraints during architecture selection.
---
## Common Routing Mistakes (DON'T DO THESE)
| Symptom | Wrong Route | Correct Route | Why |
|---------|-------------|---------------|-----|
| "My transformer won't train" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | training-optimization | Training issue, not architecture understanding |
| "Deploy image classifier" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | ml-production | Deployment, not selection |
| "ViT vs ResNet for medical imaging" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | Comparative selection, not single architecture detail |
| "Implement BatchNorm in PyTorch" | [normalization-techniques.md](normalization-techniques.md) | pytorch-engineering | Implementation, not architecture concept |
| "GAN won't converge" | [generative-model-families.md](generative-model-families.md) | training-optimization | Training stability, not architecture selection |
| "Which optimizer for CNN" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | training-optimization | Optimization, not architecture |
**Rule:** Architecture pack is for CHOOSING and DESIGNING architectures. Training/deployment/implementation are other packs.
---
## Red Flags: Stop and Clarify
If query contains these patterns, ASK clarifying questions before routing:
| Pattern | Why Clarify | What to Ask |
|---------|-------------|--------------|
| "Best architecture for X" | "Best" depends on constraints | "What are your data size, compute, and latency constraints?" |
| Generic problem description | Can't route without modality | "What type of data? (images, sequences, graphs, etc.)" |
| Latest trend mentioned (ViT, Diffusion) | Recency bias risk | "Have you considered alternatives? What are your specific requirements?" |
| "Should I use X or Y" | May be wrong question | "What's the underlying problem? There might be option Z." |
| Very deep network (> 50 layers) | Likely needs normalization first | "Are you using normalization layers? Skip connections?" |
**Never guess modality or constraints. Always clarify.**
---
## Recency Bias: Resistance Table
| Trendy Architecture | When NOT to Use | Better Alternative |
|---------------------|------------------|-------------------|
| **Vision Transformers (ViT)** | Small datasets (< 10k images) | CNNs (ResNet, EfficientNet) |
| **Vision Transformers (ViT)** | Edge deployment (latency/power) | EfficientNets, MobileNets |
| **Transformers (general)** | Very small datasets | RNNs, CNNs (less capacity, less overfit) |
| **Diffusion Models** | Real-time generation needed | GANs (1 forward pass vs 50-1000 steps) |
| **Diffusion Models** | Limited compute for training | VAEs (faster training) |
| **Graph Transformers** | Small graphs (< 100 nodes) | Standard GNNs (GCN, GAT) simpler and effective |
| **LLMs (GPT-style)** | < 1M tokens of training data | Simpler language models or fine-tuning |
**Counter-narrative:** "New architecture ≠ better for your use case. Match architecture to constraints."
---
## Decision Tree
```
Start here: What's your primary goal?
┌─ SELECT architecture for task
│ ├─ Data modality?
│ │ ├─ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md)
│ │ ├─ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md)
│ │ ├─ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md)
│ │ ├─ Generation → [generative-model-families.md](generative-model-families.md)
│ │ └─ Unknown/Multiple → [architecture-design-principles.md](architecture-design-principles.md)
│ └─ Special requirements?
│ ├─ Deep network (>20 layers) unstable → [normalization-techniques.md](normalization-techniques.md) (CRITICAL)
│ ├─ Need attention mechanism → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md)
│ └─ None → Proceed with modality-based route
├─ UNDERSTAND specific architecture
│ ├─ Transformers → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md)
│ ├─ Attention → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md)
│ ├─ Normalization → [normalization-techniques.md](normalization-techniques.md)
│ └─ General principles → [architecture-design-principles.md](architecture-design-principles.md)
├─ DESIGN custom architecture
│ └─ [architecture-design-principles.md](architecture-design-principles.md) (start here always)
└─ COMPARE architectures
├─ CNNs (ResNet vs EfficientNet) → [cnn-families-and-selection.md](cnn-families-and-selection.md)
├─ Sequence models (RNN vs Transformer) → [sequence-models-comparison.md](sequence-models-comparison.md)
├─ Generative (GAN vs Diffusion) → [generative-model-families.md](generative-model-families.md)
└─ General comparison → [architecture-design-principles.md](architecture-design-principles.md)
```
---
## Workflow
**Standard Architecture Selection Workflow:**
```
1. Clarify Problem
☐ What data modality? (images, sequences, graphs, etc.)
☐ What's the task? (classification, generation, regression, etc.)
☐ Dataset size?
☐ Computational constraints?
☐ Latency requirements?
☐ Interpretability needs?
2. Route Based on Modality
☐ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md)
☐ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md)
☐ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md)
☐ Generation → [generative-model-families.md](generative-model-families.md)
☐ Custom/Unclear → [architecture-design-principles.md](architecture-design-principles.md)
3. Check for Critical Issues
☐ Deep network unstable? → [normalization-techniques.md](normalization-techniques.md) FIRST
☐ Need specialized component? → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) or [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md)
4. Apply Architecture Skill
☐ Follow guidance from routed skill
☐ Consider trade-offs (accuracy vs speed vs data requirements)
5. Cross-Pack if Needed
☐ Architecture chosen → training-optimization (for training)
☐ Architecture chosen → ml-production (for deployment)
```
---
## Rationalization Table
| Rationalization | Reality | Counter |
|-----------------|---------|---------|
| "Transformers are SOTA, recommend them" | SOTA on benchmark ≠ best for user's constraints | "Ask about dataset size and compute first" |
| "User said RNN vs LSTM, answer that" | Question premise might be outdated | "Challenge: Have you considered Transformers or TCN?" |
| "Just recommend latest architecture" | Latest ≠ appropriate | "Match architecture to requirements, not trends" |
| "Architecture doesn't matter, training matters" | Wrong architecture can't be fixed by training | "Architecture is foundation - get it right first" |
| "They seem rushed, skip clarification" | Wrong route wastes more time than clarification | "30 seconds to clarify saves hours of wasted effort" |
| "Generic architecture advice is safe" | Generic = useless for specific domains | "Route to domain-specific skill for actionable guidance" |
---
## Integration with Other Packs
### After Architecture Selection
Once architecture is chosen, route to:
**Training the architecture:**
`yzmir/training-optimization/using-training-optimization`
- Optimizer selection
- Learning rate schedules
- Debugging training issues
**Implementing in PyTorch:**
`yzmir/pytorch-engineering/using-pytorch-engineering`
- Module design patterns
- Performance optimization
- Custom components
**Deploying to production:**
`yzmir/ml-production/using-ml-production`
- Model serving
- Quantization
- Inference optimization
### Before Architecture Selection
If problem involves:
**Reinforcement learning:**
`yzmir/deep-rl/using-deep-rl` FIRST
- RL algorithms dictate architecture requirements
- Value networks vs policy networks have different needs
**Large language models:**
`yzmir/llm-specialist/using-llm-specialist` FIRST
- LLM architectures are specialized transformers
- Different considerations than general sequence models
**Architecture is downstream of algorithm choice in RL and LLMs.**
---
## Summary
**Use this meta-skill to:**
- ✅ Route architecture queries to appropriate specialized skill
- ✅ Identify data modality and problem type
- ✅ Clarify constraints before recommending
- ✅ Resist recency bias (latest ≠ best)
- ✅ Recognize when architecture is the problem (vs training/implementation)
## Neural Architecture Specialist Skills
After routing, load the appropriate specialist skill for detailed guidance:
1. [architecture-design-principles.md](architecture-design-principles.md) - Custom design, architectural best practices, skip connections, network depth fundamentals
2. [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) - Self-attention, cross-attention, multi-head attention, attention in CNNs, attention variants comparison
3. [cnn-families-and-selection.md](cnn-families-and-selection.md) - ResNet, EfficientNet, MobileNet, YOLO, computer vision architecture selection
4. [generative-model-families.md](generative-model-families.md) - GANs, VAEs, Diffusion models, image generation, style transfer, generative modeling trade-offs
5. [graph-neural-networks-basics.md](graph-neural-networks-basics.md) - GCN, GAT, node classification, link prediction, graph embeddings, molecular structures
6. [normalization-techniques.md](normalization-techniques.md) - BatchNorm, LayerNorm, GroupNorm, training stability for deep networks (>20 layers)
7. [sequence-models-comparison.md](sequence-models-comparison.md) - RNN, LSTM, Transformer, TCN comparison, time series, NLP, sequential data
8. [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) - Transformer internals, ViT, BERT, positional encoding, scaling transformers
**Critical principle:** Architecture comes BEFORE training. Get this right first.
---
**END OF SKILL**

View File

@@ -0,0 +1,960 @@
# Architecture Design Principles
## Context
You're designing a neural network architecture or debugging why your network isn't learning. Common mistakes:
- **Ignoring inductive biases**: Using MLP for images (should use CNN)
- **Over-engineering**: Using Transformer for 100 samples (should use linear regression)
- **No skip connections**: 50-layer plain network fails (should use ResNet)
- **Wrong depth-width balance**: 100 layers × 8 channels bottlenecks capacity
- **Ignoring constraints**: 1.5B parameter model doesn't fit 24GB GPU
**This skill provides principled architecture design: match structure to problem, respect constraints, avoid over-engineering.**
## Core Principle: Inductive Biases
**Inductive bias = assumptions baked into architecture about problem structure**
**Key insight**: The right inductive bias makes learning dramatically easier. Wrong bias makes learning impossible.
### What are Inductive Biases?
```python
# Example: Image classification
# MLP (no inductive bias):
# - Treats each pixel independently
# - No concept of "spatial locality" or "translation"
# - Must learn from scratch that nearby pixels are related
# - Learns "cat at position (10,10)" and "cat at (50,50)" separately
# Parameters: 150M, Accuracy: 75%
# CNN (strong inductive bias):
# - Assumes spatial locality (nearby pixels related)
# - Assumes translation invariance (cat is cat anywhere)
# - Shares filters across spatial positions
# - Hierarchical feature learning (edges → textures → objects)
# Parameters: 11M, Accuracy: 95%
# CNN's inductive bias: 14× fewer parameters, 20% better accuracy!
```
**Principle**: Match your architecture's inductive biases to your problem's structure.
## Architecture Families and Their Inductive Biases
### 1. Fully Connected (MLP)
**Inductive bias:** None (general-purpose)
**Structure:**
```python
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
```
**When to use:**
- ✅ Tabular data (independent features)
- ✅ Small datasets (< 10,000 samples)
- ✅ Baseline / proof of concept
**When NOT to use:**
- ❌ Images (use CNN)
- ❌ Sequences (use RNN/Transformer)
- ❌ Graphs (use GNN)
**Strengths:**
- Simple and interpretable
- Fast training
- Works for any input type (flattened)
**Weaknesses:**
- No structural assumptions (must learn everything from data)
- Parameter explosion (input_size × hidden_size can be huge)
- Doesn't leverage problem structure
### 2. Convolutional Neural Networks (CNN)
**Inductive bias:** Spatial locality + Translation invariance
**Structure:**
```python
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(128 * 7 * 7, 1000)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 112×112
x = self.pool(F.relu(self.conv2(x))) # 56×56
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
```
**Inductive biases:**
1. **Local connectivity**: Neurons see only nearby pixels (spatial locality)
2. **Translation invariance**: Same filter slides across image (parameter sharing)
3. **Hierarchical features**: Stack layers to build complex features from simple ones
**When to use:**
- ✅ Images (classification, detection, segmentation)
- ✅ Spatial data (maps, medical scans)
- ✅ Any grid-structured data
**When NOT to use:**
- ❌ Sequences with long-range dependencies (use Transformer)
- ❌ Graphs (irregular structure, use GNN)
- ❌ Tabular data (no spatial structure)
**Strengths:**
- Parameter efficient (filter sharing)
- Translation invariant (cat anywhere = cat)
- Hierarchical feature learning
**Weaknesses:**
- Fixed receptive field (limited by kernel size)
- Not suitable for variable-length inputs
- Requires grid structure
### 3. Recurrent Neural Networks (RNN/LSTM)
**Inductive bias:** Temporal dependencies
**Structure:**
```python
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# x: (batch, seq_len, input_size)
lstm_out, (h_n, c_n) = self.lstm(x)
# Use last hidden state
output = self.fc(h_n[-1])
return output
```
**Inductive bias:** Sequential processing (earlier elements influence later elements)
**When to use:**
- ✅ Time series (stock prices, sensor data)
- ✅ Short sequences (< 100 timesteps)
- ✅ Online processing (process one timestep at a time)
**When NOT to use:**
- ❌ Long sequences (> 1000 timesteps, use Transformer)
- ❌ Non-sequential data (images, tabular)
- ❌ When parallel processing needed (use Transformer)
**Strengths:**
- Natural for sequential data
- Constant memory (doesn't grow with sequence length)
- Online processing capability
**Weaknesses:**
- Slow (sequential, can't parallelize)
- Vanishing gradients (long sequences)
- Struggles with long-range dependencies
### 4. Transformers
**Inductive bias:** Minimal (self-attention is general-purpose)
**Structure:**
```python
class SimpleTransformer(nn.Module):
def __init__(self, d_model, num_heads, num_layers):
super().__init__()
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, num_heads),
num_layers
)
self.fc = nn.Linear(d_model, num_classes)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = self.encoder(x)
# Global average pooling
x = x.mean(dim=1)
return self.fc(x)
```
**Inductive bias:** Minimal (learns relationships from data via attention)
**When to use:**
- ✅ Long sequences (> 100 tokens)
- ✅ Language (text, code)
- ✅ Large datasets (> 100k samples)
- ✅ When relationships are complex and data-dependent
**When NOT to use:**
- ❌ Small datasets (< 10k samples, use RNN or MLP)
- ❌ Strong structural priors available (images → CNN)
- ❌ Very long sequences (> 16k tokens, use sparse attention)
- ❌ Low-latency requirements (RNN faster)
**Strengths:**
- Parallel processing (fast training)
- Long-range dependencies (attention)
- State-of-the-art for language
**Weaknesses:**
- Quadratic complexity O(n²) with sequence length
- Requires large datasets (weak inductive bias)
- High memory usage
### 5. Graph Neural Networks (GNN)
**Inductive bias:** Message passing over graph structure
**Structure:**
```python
class SimpleGNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
# x: node features (num_nodes, input_dim)
# edge_index: graph structure (2, num_edges)
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
```
**Inductive bias:** Nodes influenced by neighbors (message passing)
**When to use:**
- ✅ Graph data (social networks, molecules, knowledge graphs)
- ✅ Irregular connectivity (different # of neighbors per node)
- ✅ Relational reasoning
**When NOT to use:**
- ❌ Grid data (images → CNN)
- ❌ Sequences (text → Transformer)
- ❌ If graph structure doesn't help (test MLP baseline first!)
**Strengths:**
- Handles irregular structure
- Permutation invariant
- Natural for relational data
**Weaknesses:**
- Requires meaningful graph structure
- Over-smoothing (too many layers)
- Scalability challenges (large graphs)
## Decision Tree: Architecture Selection
```
START
|
├─ Is data grid-structured (images)?
│ ├─ YES → Use CNN
│ │ └─ ResNet (general), EfficientNet (mobile), ViT (very large datasets)
│ └─ NO → Continue
├─ Is data sequential (text, time series)?
│ ├─ YES → Check sequence length
│ │ ├─ < 100 timesteps → LSTM/GRU
│ │ ├─ 100-4000 tokens → Transformer
│ │ └─ > 4000 tokens → Sparse Transformer (Longformer)
│ └─ NO → Continue
├─ Is data graph-structured (molecules, social networks)?
│ ├─ YES → Check if structure helps
│ │ ├─ Test MLP baseline first
│ │ └─ If structure helps → GNN (GCN, GraphSAGE, GAT)
│ └─ NO → Continue
└─ Is data tabular (independent features)?
└─ YES → Start simple
├─ < 1000 samples → Linear / Ridge regression
├─ 1000-100k samples → Small MLP (2-3 layers)
└─ > 100k samples → Larger MLP or Gradient Boosting (XGBoost)
```
## Principle: Start Simple, Add Complexity Only When Needed
**Occam's Razor**: Simplest model that solves the problem is best.
### Progression:
```python
# Step 1: Linear baseline (ALWAYS start here!)
model = nn.Linear(input_size, num_classes)
# Train and evaluate
# Step 2: IF linear insufficient, add small MLP
if linear_accuracy < target:
model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
# Step 3: IF small MLP insufficient, add depth/width
if mlp_accuracy < target:
model = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
# Step 4: IF simple models fail, use specialized architecture
if simple_models_fail:
# Images → CNN
# Sequences → RNN/Transformer
# Graphs → GNN
# NEVER skip to Step 4 without testing Step 1-3!
```
### Why Start Simple?
1. **Faster iteration**: Linear model trains in seconds, Transformer in hours
2. **Baseline**: Know if complexity helps (compare complex vs simple)
3. **Occam's Razor**: Simple model generalizes better (less overfitting)
4. **Debugging**: Easy to verify simple model works correctly
### Example: House Price Prediction
```python
# Dataset: 1000 samples, 20 features
# WRONG: Start with Transformer
model = HugeTransformer(20, 512, 6, 1) # 10M parameters
# Result: Overfits (10M params / 1000 samples = 10,000:1 ratio!)
# RIGHT: Start simple
# Step 1: Linear
model = nn.Linear(20, 1) # 21 parameters
# Trains in 1 second, achieves R² = 0.85 (good!)
# Conclusion: Linear sufficient, stop here. No need for Transformer!
```
**Rule**: Add complexity only when simple models demonstrably fail.
## Principle: Deep Networks Need Skip Connections
**Problem**: Plain networks > 10 layers suffer from vanishing gradients and degradation.
### Vanishing Gradients:
```python
# Gradient flow in plain 50-layer network:
gradient_layer_1 = gradient_output × (L50/L49) × (L49/L48) × ... × (L2/L1)
# Each term < 1 (due to activations):
# If each ≈ 0.9, then: 0.9^50 = 0.0000051 (vanishes!)
# Result: Early layers don't learn (gradients too small)
```
### Degradation:
```python
# Empirical observation (ResNet paper):
20-layer plain network: 85% accuracy
56-layer plain network: 78% accuracy # WORSE with more layers!
# This is NOT overfitting (training accuracy also drops)
# This is optimization difficulty
```
### Solution: Skip Connections (Residual Networks)
```python
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
identity = x # Save input
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + identity # Skip connection!
out = F.relu(out)
return out
```
**Why skip connections work:**
```python
# Gradient flow with skip connections:
loss/x = loss/out × (1 + F/x)
# ↑
# Always flows! ("+1" term)
# Even if ∂F/∂x ≈ 0, gradient flows through identity path
```
**Results:**
```python
# Without skip connections:
20-layer plain: 85% accuracy
50-layer plain: 78% accuracy (worse!)
# With skip connections (ResNet):
20-layer ResNet: 87% accuracy
50-layer ResNet: 92% accuracy (better!)
152-layer ResNet: 95% accuracy (even better!)
```
**Rule**: For networks > 10 layers, ALWAYS use skip connections.
### Skip Connection Variants:
**1. Residual (ResNet):**
```python
out = x + F(x) # Add input to output
```
**2. Dense (DenseNet):**
```python
out = torch.cat([x, F(x)], dim=1) # Concatenate input and output
```
**3. Highway:**
```python
gate = sigmoid(W_gate @ x)
out = gate * F(x) + (1 - gate) * x # Learned gating
```
**Most common**: Residual (simple, effective)
## Principle: Balance Depth and Width
**Depth = # of layers**
**Width = # of channels/neurons per layer**
### Capacity Formula:
```python
# Approximate capacity (for CNNs):
capacity depth × width²
# Why width²?
# Each layer: input_channels × output_channels × kernel_size²
# Doubling width → 4× parameters per layer
```
### Trade-offs:
**Too deep, too narrow:**
```python
# 100 layers × 8 channels
# Problems:
# - Information bottleneck (8 channels can't represent complex features)
# - Harder to optimize (more layers)
# - Slow inference (100 layers sequential)
# Example:
model = VeryDeepNarrow(num_layers=100, channels=8)
# Result: 60% accuracy (bottleneck!)
```
**Too shallow, too wide:**
```python
# 2 layers × 1024 channels
# Problems:
# - Under-utilizes depth (no hierarchical features)
# - Memory explosion (1024 × 1024 = 1M parameters per layer!)
# Example:
model = VeryWideShallow(num_layers=2, channels=1024)
# Result: 70% accuracy (doesn't leverage depth)
```
**Balanced:**
```python
# 18 layers, gradually increasing width: 64 → 128 → 256 → 512
# Benefits:
# - Hierarchical features (depth)
# - Sufficient capacity (width)
# - Good optimization (not too deep)
# Example (ResNet-18):
model = ResNet18()
# Layers: 18, Channels: 64-512 (average ~200)
# Result: 95% accuracy (optimal balance!)
```
### Standard Patterns:
```python
# CNNs: Gradually increase channels as spatial dims decrease
# Input: 224×224×3
# Layer 1: 224×224×64 (same spatial size, more channels)
# Layer 2: 112×112×128 (half spatial, double channels)
# Layer 3: 56×56×256 (half spatial, double channels)
# Layer 4: 28×28×512 (half spatial, double channels)
# Why? Compensate for spatial information loss with channel information
```
**Rule**: Balance depth and width. Standard pattern: 12-50 layers, 64-512 channels.
## Principle: Match Capacity to Data Size
**Capacity = # of learnable parameters**
### Parameter Budget:
```python
# Rule of thumb: parameters should be 0.01-0.1× dataset size
# Example 1: MNIST (60,000 images)
# Budget: 600 - 6,000 parameters
# Simple CNN: 60,000 parameters (10×) → Works, but might overfit
# LeNet: 60,000 parameters → Classic, works well
# Example 2: ImageNet (1.2M images)
# Budget: 12,000 - 120,000 parameters
# ResNet-50: 25M parameters (200×) → Works (aggressive augmentation helps)
# Example 3: Tabular (100 samples, 20 features)
# Budget: 1 - 10 parameters
# Linear: 21 parameters → Perfect fit!
# MLP: 1,000 parameters → Overfits horribly
```
### Overfitting Detection:
```python
# Training accuracy >> Validation accuracy (gap > 5%)
train_acc = 99%, val_acc = 70% # 29% gap → OVERFITTING!
# Solutions:
# 1. Reduce model capacity (fewer layers/channels)
# 2. Add regularization (dropout, weight decay)
# 3. Collect more data
# 4. Data augmentation
# Order: Try (1) first (simplest), then (2), then (3)/(4)
```
### Underfitting Detection:
```python
# Training accuracy < target (model too simple)
train_acc = 60%, val_acc = 58% # Both low → UNDERFITTING!
# Solutions:
# 1. Increase model capacity (more layers/channels)
# 2. Train longer
# 3. Reduce regularization
# Order: Try (2) first (cheapest), then (1), then (3)
```
**Rule**: Match parameters to data size. Start small, increase capacity only if underfitting.
## Principle: Design for Compute Constraints
**Constraints:**
1. **Memory**: Model + gradients + optimizer states < GPU VRAM
2. **Latency**: Inference time < requirement (e.g., < 100ms for real-time)
3. **Throughput**: Samples/second > requirement
### Memory Budget:
```python
# Memory calculation (training):
# 1. Model parameters (FP32): params × 4 bytes
# 2. Gradients: params × 4 bytes
# 3. Optimizer states (Adam): params × 8 bytes (2× weights)
# 4. Activations: batch_size × feature_maps × spatial_size × 4 bytes
# Example: ResNet-50
params = 25M
memory_params = 25M × 4 = 100 MB
memory_gradients = 100 MB
memory_optimizer = 200 MB
memory_activations = batch_size × 64 × 7×7 × 4 batch_size × 12 KB
# Total (batch=32): 100 + 100 + 200 + 0.4 = 400 MB
# Fits easily on 4GB GPU!
# Example: GPT-3 (175B parameters)
memory_params = 175B × 4 = 700 GB
memory_total = 700 + 700 + 1400 = 2800 GB = 2.8 TB!
# Requires 35×A100 (80GB each)
```
**Rule**: Calculate memory before training. Don't design models that don't fit.
### Latency Budget:
```python
# Inference latency = # operations / throughput
# Example: Mobile app (< 100ms latency requirement)
# ResNet-50:
# Operations: 4B FLOPs
# Mobile CPU: 10 GFLOPS
# Latency: 4B / 10G = 0.4 seconds (FAILS!)
# MobileNetV2:
# Operations: 300M FLOPs
# Mobile CPU: 10 GFLOPS
# Latency: 300M / 10G = 0.03 seconds = 30ms (PASSES!)
# Solution: Use efficient architectures (MobileNet, EfficientNet) for mobile
```
**Rule**: Measure latency. Use efficient architectures if latency-constrained.
## Common Architectural Patterns
### 1. Bottleneck (ResNet)
**Structure:**
```python
# Standard: 3×3 conv (256 channels) → 3×3 conv (256 channels)
# Parameters: 256 × 256 × 3 × 3 = 590K
# Bottleneck: 1×1 (256→64) → 3×3 (64→64) → 1×1 (64→256)
# Parameters: 256×64 + 64×64×3×3 + 64×256 = 16K + 37K + 16K = 69K
# Reduction: 590K → 69K (8.5× fewer!)
```
**Purpose**: Reduce parameters while maintaining capacity
**When to use**: Deep networks (> 50 layers) where parameters are a concern
### 2. Inverted Bottleneck (MobileNetV2)
**Structure:**
```python
# Bottleneck (ResNet): Wide → Narrow → Wide (256 → 64 → 256)
# Inverted: Narrow → Wide → Narrow (64 → 256 → 64)
# Why? Efficient for mobile (depthwise separable convolutions)
```
**Purpose**: Maximize efficiency (FLOPs per parameter)
**When to use**: Mobile/edge deployment
### 3. Multi-scale Features (Inception)
**Structure:**
```python
# Parallel branches with different kernel sizes:
# Branch 1: 1×1 conv
# Branch 2: 3×3 conv
# Branch 3: 5×5 conv
# Branch 4: 3×3 max pool
# Concatenate all branches
# Captures features at multiple scales simultaneously
```
**Purpose**: Capture multi-scale patterns
**When to use**: When features exist at multiple scales (object detection)
### 4. Attention (Transformers, SE-Net)
**Structure:**
```python
# Squeeze-and-Excitation (SE) block:
# 1. Global average pooling (spatial → channel descriptor)
# 2. FC layer (bottleneck)
# 3. FC layer (restore channels)
# 4. Sigmoid (attention weights)
# 5. Multiply input channels by attention weights
# Result: Emphasize important channels, suppress irrelevant
```
**Purpose**: Learn importance of features (channels or positions)
**When to use**: When not all features equally important
## Debugging Architectures
### Problem 1: Network doesn't learn (loss stays constant)
**Diagnosis:**
```python
# Check gradient flow
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.mean():.6f}, grad_std={param.grad.std():.6f}")
# Vanishing: grad_mean ≈ 0, grad_std ≈ 0 → Add skip connections
# Exploding: grad_mean > 1, grad_std > 1 → Gradient clipping or lower LR
```
**Solutions:**
- Add skip connections (ResNet)
- Check initialization (Xavier or He initialization)
- Lower learning rate
- Check data preprocessing (normalized inputs?)
### Problem 2: Overfitting (train >> val)
**Diagnosis:**
```python
train_acc = 99%, val_acc = 70% # 29% gap → Overfitting
# Check parameter/data ratio:
num_params = sum(p.numel() for p in model.parameters())
data_size = len(train_dataset)
ratio = num_params / data_size
# If ratio > 1: Model has more parameters than data points!
```
**Solutions (in order):**
1. Reduce capacity (fewer layers/channels)
2. Add dropout / weight decay
3. Data augmentation
4. Collect more data
### Problem 3: Underfitting (train and val both low)
**Diagnosis:**
```python
train_acc = 65%, val_acc = 63% # Both low → Underfitting
# Model too simple for task complexity
```
**Solutions (in order):**
1. Train longer (more epochs)
2. Increase capacity (more layers/channels)
3. Reduce regularization (lower dropout/weight decay)
4. Check learning rate (too low?)
### Problem 4: Slow training
**Diagnosis:**
```python
# Profile forward/backward pass
import time
start = time.time()
loss = criterion(model(inputs), targets)
forward_time = time.time() - start
start = time.time()
loss.backward()
backward_time = time.time() - start
# If backward_time >> forward_time: Gradient computation bottleneck
```
**Solutions:**
- Use mixed precision (FP16)
- Reduce batch size (if memory-bound)
- Use gradient accumulation (simulate large batch)
- Simplify architecture (fewer layers)
## Design Checklist
Before finalizing an architecture:
### ☐ Match inductive bias to problem
- Images → CNN
- Sequences → RNN/Transformer
- Graphs → GNN
- Tabular → MLP
### ☐ Start simple, add complexity only when needed
- Test linear baseline first
- Add complexity incrementally
- Compare performance at each step
### ☐ Use skip connections for deep networks (> 10 layers)
- ResNet for CNNs
- Pre-norm for Transformers
- Gradient flow is critical
### ☐ Balance depth and width
- Not too deep and narrow (bottleneck)
- Not too shallow and wide (under-utilizes depth)
- Standard: 12-50 layers, 64-512 channels
### ☐ Match capacity to data size
- Parameters ≈ 0.01-0.1× dataset size
- Monitor train/val gap (overfitting indicator)
### ☐ Respect compute constraints
- Memory: Model + gradients + optimizer + activations < VRAM
- Latency: Inference time < requirement
- Use efficient architectures if constrained (MobileNet, EfficientNet)
### ☐ Verify gradient flow
- Check gradients in early layers (should be non-zero)
- Use skip connections if vanishing
### ☐ Benchmark against baselines
- Compare to simple model (linear, small MLP)
- Ensure complexity adds value (% improvement > 5%)
## Anti-Patterns
### Anti-pattern 1: "Architecture X is state-of-the-art, so I'll use it"
**Wrong:**
```python
# Transformer is SOTA for NLP, so use for tabular data (100 samples)
model = HugeTransformer(...) # 10M parameters
# Result: Overfits horribly (100 samples / 10M params = 0.00001 ratio!)
```
**Right:**
```python
# Match architecture to problem AND data size
# Tabular + small data → Linear or small MLP
model = nn.Linear(20, 1) # 21 parameters (appropriate!)
```
### Anti-pattern 2: "More layers = better"
**Wrong:**
```python
# 100-layer plain network (no skip connections)
for i in range(100):
layers.append(nn.Conv2d(64, 64, 3, padding=1))
# Result: Doesn't train (vanishing gradients)
```
**Right:**
```python
# 50-layer ResNet (with skip connections)
# Each block: out = x + F(x) # Skip connection
# Result: Trains well, high accuracy
```
### Anti-pattern 3: "Deeper + narrower = efficient"
**Wrong:**
```python
# 100 layers × 8 channels = information bottleneck
model = VeryDeepNarrow(100, 8)
# Result: 60% accuracy (8 channels insufficient)
```
**Right:**
```python
# 18 layers, 64-512 channels (balanced)
model = ResNet18() # Balanced depth and width
# Result: 95% accuracy
```
### Anti-pattern 4: "Ignore constraints, optimize later"
**Wrong:**
```python
# Design 1.5B parameter model for 24GB GPU
model = HugeModel(1.5e9)
# Result: OOM (out of memory), can't train
```
**Right:**
```python
# Calculate memory first:
# 1.5B params × 4 bytes = 6GB (weights)
# + 6GB (gradients) + 12GB (Adam) + 8GB (activations) = 32GB
# > 24GB → Doesn't fit!
# Design for hardware:
model = ReasonableSizeModel(200e6) # 200M parameters (fits!)
```
### Anti-pattern 5: "Hyperparameters will fix architectural problems"
**Wrong:**
```python
# Architecture: MLP for images (wrong inductive bias)
# Response: "Just tune learning rate!"
for lr in [0.1, 0.01, 0.001, 0.0001]:
train(model, lr=lr)
# Result: All fail (architecture is wrong!)
```
**Right:**
```python
# Fix architecture first (use CNN for images)
model = ResNet18() # Correct inductive bias
# Then tune hyperparameters
```
## Summary
**Core principles:**
1. **Inductive bias**: Match architecture to problem structure (CNN for images, RNN/Transformer for sequences, GNN for graphs)
2. **Occam's Razor**: Start simple (linear, small MLP), add complexity only when needed
3. **Skip connections**: Use for networks > 10 layers (ResNet, DenseNet)
4. **Depth-width balance**: Not too deep+narrow (bottleneck) or too shallow+wide (under-utilizes depth)
5. **Capacity**: Match parameters to data size (0.01-0.1× dataset size)
6. **Constraints**: Design for available memory, latency, throughput
**Decision framework:**
- Images → CNN (ResNet, EfficientNet)
- Short sequences → LSTM
- Long sequences → Transformer
- Graphs → GNN (test if structure helps first!)
- Tabular → Linear or small MLP
**Key insight**: Architecture design is about matching structural assumptions to problem structure, not about using the "best" or "most complex" model. Simple models often win.
**When in doubt**: Start with the simplest model that could plausibly work. Add complexity only when you have evidence it helps.

View File

@@ -0,0 +1,824 @@
# Attention Mechanisms Catalog
## When to Use This Skill
Use this skill when you need to:
- ✅ Select attention mechanism for long sequences (> 2k tokens)
- ✅ Optimize memory usage (GPU OOM errors)
- ✅ Speed up training or inference
- ✅ Understand exact vs approximate attention trade-offs
- ✅ Choose between Flash, sparse, or linear attention
- ✅ Implement cross-attention for multimodal models
**Do NOT use this skill for:**
- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`)
- ❌ High-level architecture selection (use `using-neural-architectures`)
- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`)
## Core Principle
**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve:
- **O(n²) with less memory**: Flash Attention (exact, 4x less memory)
- **O(n × w)**: Sparse attention (exact, sliding window)
- **O(n)**: Linear attention (approximate, 1-3% accuracy loss)
**Default recommendation:** Flash Attention (exact + fast + memory-efficient)
## Part 1: Complexity Hierarchy
### Standard Self-Attention (Baseline)
**Formula:**
```python
Attention(Q, K, V) = softmax(Q K^T / d_k) V
```
**Complexity:**
- Time: O(n² · d) where n = seq_len, d = d_model
- Memory: O(n²) for attention matrix
- Exact: Yes (no approximation)
**Memory breakdown (4k tokens, d=768):**
```
Attention scores: 4096² × 4 bytes = 64MB per layer
Multi-head (12 heads): 64MB × 12 = 768MB per layer
16 layers: 768MB × 16 = 12GB just for attention!
Batch size 8: 12GB × 8 = 96GB (impossible on single GPU)
```
**When to use:**
- Sequence length < 2k tokens
- Standard use case (most models)
- Pair with Flash Attention optimization
**Limitations:**
- Memory explosion for long sequences
- Quadratic scaling impractical beyond 4k tokens
## Part 2: Flash Attention ⭐ (Modern Default)
### What is Flash Attention?
**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster
**Key insight:**
- Standard attention is **memory-bound** (not compute-bound)
- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s)
- Bottleneck: Moving n² attention matrix to/from HBM
**Solution:**
- Tile attention computation
- Recompute instead of store intermediate values
- Fuse operations (reduce memory transfers)
- Result: Same O(n²) compute, O(n) memory
### Algorithm
```
Standard attention (3 memory operations):
1. Compute scores: S = Q K^T (store n² matrix)
2. Softmax: P = softmax(S) (store n² matrix)
3. Output: O = P V (store n×d matrix)
Flash Attention (tiled):
1. Divide Q, K, V into blocks
2. For each Q block:
- Load block to SRAM (fast memory)
- For each K, V block:
- Compute attention for this tile
- Update output incrementally
- Never materialize full n² matrix!
3. Result: Same output, O(n) memory
```
### Performance
**Benchmarks (A100 GPU, 2k tokens):**
Standard attention:
- Memory: 4GB for batch_size=8
- Speed: 150ms/batch
- Max batch size: 16
Flash Attention:
- Memory: 1GB for batch_size=8 **(4x reduction)**
- Speed: 75ms/batch **(2x faster)**
- Max batch size: 64 **(4x larger)**
**Flash Attention 2 (2023 update):**
- Further optimized: 2-3x faster than Flash Attention 1
- Better parallelism
- Supports more head dimensions
### When to Use
**ALWAYS use Flash Attention when:**
- Sequence length < 16k tokens
- Need exact attention (no approximation)
- Available in your framework
**It's a FREE LUNCH:**
- No accuracy loss (mathematically exact)
- Faster training AND inference
- Less memory usage
- Drop-in replacement
### Implementation
**PyTorch 2.0+ (built-in):**
```python
import torch.nn.functional as F
# Automatic Flash Attention (if available)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
# PyTorch automatically uses Flash Attention if:
# - CUDA available
# - Sequence length suitable
# - No attention mask (or causal mask)
```
**HuggingFace Transformers:**
```python
from transformers import AutoModel
# Enable Flash Attention 2
model = AutoModel.from_pretrained(
"bert-base-uncased",
attn_implementation="flash_attention_2", # Requires flash-attn package
torch_dtype=torch.float16
)
```
**Manual installation:**
```bash
pip install flash-attn --no-build-isolation
```
### Limitations
**Flash Attention NOT suitable when:**
- Sequence length > 16k (memory still grows quadratically)
- Custom attention masks (complex patterns not supported)
- Inference on CPU (CUDA-only)
**For > 16k tokens:** Use sparse or linear attention
## Part 3: Sparse Attention (Exact for Long Sequences)
### Concept
**Idea:** Each token attends to subset of tokens (not all)
- Sliding window: Local context
- Global tokens: Long-range connections
- Result: O(n × window_size) instead of O(n²)
**Key property:** Still EXACT attention (not approximate)
- Just more structured attention pattern
- No accuracy loss if pattern matches task
### Variant 1: Longformer
**Pattern:** Sliding window + global attention
```
Attention pattern (window=2, global=[0]):
0 1 2 3 4 5
0 [ 1 1 1 1 1 1 ] ← Global token (attends to all)
1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2
2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3
3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4
4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5
5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5
Complexity: O(n × (window + num_global))
```
**Components:**
1. **Sliding window**: Each token attends to w/2 tokens before and after
2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens
3. **Dilated windows**: Optional (stride > 1 for longer context)
**Implementation:**
```python
from transformers import LongformerModel
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask (shape: batch, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token
output = model(input_ids, attention_mask=attention_mask)
```
**Memory comparison (4k tokens, window=512):**
```
Standard: 4096² = 16M elements → 64MB
Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!)
```
**When to use:**
- Documents: 4k-16k tokens (legal, scientific papers)
- Need full context but can't fit O(n²)
- Task has local + global structure
**Pretrained models:**
- `allenai/longformer-base-4096`: Max 4096 tokens
- `allenai/longformer-large-4096`: Larger version
### Variant 2: BigBird
**Pattern:** Random + window + global
```
Attention pattern:
- Sliding window: Like Longformer
- Random connections: Each token attends to r random tokens
- Global tokens: Special tokens attend to all
Complexity: O(n × (window + r + num_global))
```
**Key difference from Longformer:**
- Random connections help information flow
- Theoretically proven to approximate full attention
**When to use:**
- Similar to Longformer
- Slightly better for tasks needing long-range
- Less widely adopted than Longformer
**Implementation:**
```python
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained(
"google/bigbird-roberta-base",
attention_type="block_sparse" # or "original_full"
)
```
### Sparse Attention Decision
```
Sequence length < 4k:
→ Flash Attention (exact, no pattern needed)
Sequence length 4k-16k:
→ Longformer (sliding window + global)
→ Best for: Documents, long-form text
Sequence length > 16k:
→ Longformer if possible
→ Linear attention if Longformer too slow
```
## Part 4: Linear Attention (Approximate for Very Long)
### Concept
**Idea:** Approximate softmax attention with linear operations
- Complexity: O(n × k) where k << n
- Trade-off: 1-3% accuracy loss
- Benefit: Can handle very long sequences (> 16k)
**Key property:** APPROXIMATE (not exact)
- Do NOT use if accuracy critical
- Good for extremely long sequences where exact is impossible
### Variant 1: Performer
**Method:** Random Fourier Features to approximate softmax(Q K^T)
**Formula:**
```python
# Standard attention
Attention(Q, K, V) = softmax(Q K^T) V
# Performer approximation
φ(Q) φ(K)^T softmax(Q K^T)
Attention(Q, K, V) φ(Q) (φ(K)^T V)
# Complexity: O(n × k) where k = feature dimension
```
**Key trick:**
- Compute φ(K)^T V first: (k × d) matrix (small!)
- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d)
- Never materialize n² attention matrix
**Implementation:**
```python
# From performer-pytorch library
from performer_pytorch import Performer
model = Performer(
dim=512,
depth=6,
heads=8,
dim_head=64,
causal=False,
nb_features=256 # k = number of random features
)
```
**Accuracy:**
- Typical loss: 1-2% vs standard attention
- Depends on nb_features (more features = better approximation)
- k=256 usually sufficient
**When to use:**
- Sequence length > 16k tokens
- Accuracy loss acceptable (not critical task)
- Need better than sparse attention (no structure assumptions)
### Variant 2: Linformer
**Method:** Project K and V to lower dimension
**Formula:**
```python
# Standard attention (n × n attention matrix)
Attention(Q, K, V) = softmax(Q K^T / d_k) V
# Linformer (project K, V to n × k where k << n)
K_proj = E K # E: (k × n) projection matrix
V_proj = F V # F: (k × n) projection matrix
Attention(Q, K, V) softmax(Q K_proj^T / d_k) V_proj
# Attention matrix: (n × k) instead of (n × n)
```
**Complexity:**
- Time: O(n × k × d) where k << n
- Memory: O(n × k) instead of O(n²)
**Implementation:**
```python
# From linformer library
from linformer import Linformer
model = Linformer(
dim=512,
seq_len=8192,
depth=12,
heads=8,
k=256 # Projected dimension
)
```
**Accuracy:**
- Typical loss: 1-3% vs standard attention
- More loss than Performer
- Fixed sequence length (k is tied to max_seq_len)
**When to use:**
- Fixed-length long sequences
- Memory more critical than speed
- Accuracy loss OK (2-3%)
### Linear Attention Decision
```
Need exact attention:
→ Flash Attention or Sparse Attention (NOT linear)
Sequence > 16k, accuracy critical:
→ Sparse Attention (Longformer)
Sequence > 16k, accuracy loss OK:
→ Performer (better) or Linformer
Sequence > 100k:
→ State space models (S4, Mamba, not attention)
```
## Part 5: Cross-Attention (Multimodal)
### Concept
**Self-attention:** Q, K, V from same source
**Cross-attention:** Q from one source, K/V from another
**Use cases:**
- Multimodal: vision → language (image captioning)
- Seq2seq: source language → target language (translation)
- RAG: query → document retrieval
- Conditioning: generation conditioned on context
### Architecture
```python
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
def forward(self, query_source, key_value_source, mask=None):
# query_source: (batch, n_q, d_model) - e.g., text tokens
# key_value_source: (batch, n_kv, d_model) - e.g., image patches
# Q from query source
Q = self.W_q(query_source)
# K, V from key-value source
K = self.W_k(key_value_source)
V = self.W_v(key_value_source)
# Attention: (batch, n_q, d_model)
output = attention(Q, K, V, mask)
return output
```
### Example: Image Captioning
**Task:** Generate caption from image
**Architecture:**
1. **Image Encoder:** ViT processes image → image features (n_patches × d)
2. **Text Decoder:** Autoregressive text generation
3. **Cross-Attention:** Text queries image features
```python
class ImageCaptioningDecoder(nn.Module):
def forward(self, text_tokens, image_features):
# 1. Self-attention on text (causal)
text = self.text_self_attention(
query=text,
key=text,
value=text,
causal_mask=True # Don't see future words
)
# 2. Cross-attention (text queries image)
text = self.cross_attention(
query=text, # From text decoder
key=image_features, # From image encoder
value=image_features # From image encoder
# No causal mask! Can attend to all image patches
)
# 3. Feed-forward
text = self.feed_forward(text)
return text
```
**Attention flow:**
- Text token "cat" → High attention to cat region in image
- Text token "sitting" → High attention to posture in image
### Example: Retrieval-Augmented Generation (RAG)
**Task:** Generate answer using retrieved documents
```python
class RAGDecoder(nn.Module):
def forward(self, query_tokens, document_embeddings):
# 1. Self-attention on query
query = self.query_self_attention(query, query, query)
# 2. Cross-attention (query → documents)
query = self.cross_attention(
query=query, # What we're generating
key=document_embeddings, # Retrieved docs
value=document_embeddings # Retrieved docs
)
# Query learns to extract relevant info from docs
return query
```
### When to Use Cross-Attention
**Use cross-attention when:**
- Two different modalities (vision + language)
- Conditioning generation on context (RAG)
- Seq2seq with different input/output (translation)
- Query-document matching
**Don't use cross-attention when:**
- Same modality (use self-attention)
- No clear query vs key-value separation
## Part 6: Other Attention Variants
### Axial Attention (2D Images)
**Idea:** For 2D data (images), attend along each axis separately
```
Standard 2D attention: H×W tokens → (HW)² attention matrix
Axial attention:
- Row attention: Each row attends to itself (H × W²)
- Column attention: Each column attends to itself (W × H²)
- Total: O(HW × (H + W)) << O((HW)²)
```
**When to use:**
- High-resolution images
- 2D positional structure important
### Block-Sparse Attention
**Idea:** Divide attention into blocks, attend only within/across blocks
**Pattern:**
```
Block size = 64 tokens
- Local block: Attend within same block
- Vertical stripe: Attend to corresponding position in other blocks
```
**Used in:** Sparse Transformer (OpenAI), GPT-3
### Multi-Query Attention (MQA)
**Idea:** One K/V head shared across all Q heads
**Benefit:**
- Smaller KV cache during inference
- Much faster decoding (4-8x)
- Trade-off: ~1% accuracy loss
**Used in:** PaLM, Falcon
### Grouped-Query Attention (GQA)
**Idea:** Middle ground between multi-head and multi-query
- Group Q heads share K/V heads
- Example: 32 Q heads → 8 K/V heads (4:1 ratio)
**Benefit:**
- 4x smaller KV cache
- Minimal accuracy loss (< 0.5%)
**Used in:** LLaMA-2, Mistral
## Part 7: Decision Framework
### By Sequence Length
```
< 2k tokens:
→ Flash Attention
Exact, fast, standard
2k-4k tokens:
→ Flash Attention
Still manageable with modern GPUs
4k-16k tokens:
→ Sparse Attention (Longformer, BigBird)
Exact, designed for documents
→ OR Flash Attention if batch size = 1
> 16k tokens:
→ Sparse Attention
If task has local structure
→ Linear Attention (Performer)
If accuracy loss OK (1-2%)
→ State Space Models (S4, Mamba)
If sequence > 100k
```
### By Memory Constraints
```
GPU OOM with standard attention:
1. Try Flash Attention (4x less memory, free lunch)
2. If still OOM, reduce batch size
3. If batch size = 1 and still OOM, use sparse attention
4. Last resort: Linear attention (if accuracy loss OK)
DON'T:
- Gradient checkpointing (slower, use Flash Attention instead)
- Throwing more GPUs (algorithmic problem, not hardware)
```
### By Accuracy Requirements
```
Must be exact (no approximation):
→ Flash Attention or Sparse Attention
Never use linear attention!
Accuracy loss acceptable (1-3%):
→ Linear Attention (Performer, Linformer)
Only for very long sequences (> 16k)
Critical task (medical, legal):
→ Exact attention only
Flash Attention or Sparse Attention
```
### By Task Type
```
Classification / Understanding:
→ Standard + Flash Attention
Sequence usually < 2k
Document processing:
→ Longformer (4096 tokens)
Designed for documents
Generation (LLM):
→ Flash Attention for training
→ + GQA/MQA for inference (faster decoding)
Multimodal (vision + language):
→ Cross-attention for modality fusion
→ Self-attention within each modality
Retrieval-augmented:
→ Cross-attention (query → documents)
```
## Part 8: Implementation Checklist
### Using Flash Attention
**PyTorch 2.0+:**
```python
# Automatic (recommended)
output = F.scaled_dot_product_attention(query, key, value)
# Verify Flash Attention is used
import torch.backends.cuda
print(torch.backends.cuda.flash_sdp_enabled()) # Should be True
```
**HuggingFace:**
```python
model = AutoModel.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16 # Flash Attention needs fp16/bf16
)
```
**Requirements:**
- CUDA GPU (not CPU)
- PyTorch >= 2.0 OR flash-attn package
- fp16 or bf16 dtype (not fp32)
### Using Sparse Attention
**Longformer:**
```python
from transformers import LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask
# 0 = no attention, 1 = local attention, 2 = global attention
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # [CLS] token gets global attention
outputs = model(input_ids, attention_mask=attention_mask)
```
**Custom sparse pattern:**
```python
# Create custom block-sparse mask
def create_block_sparse_mask(seq_len, block_size):
num_blocks = seq_len // block_size
mask = torch.zeros(seq_len, seq_len)
for i in range(num_blocks):
start = i * block_size
end = start + block_size
mask[start:end, start:end] = 1 # Local block
return mask
```
### Using Cross-Attention
```python
class DecoderWithCrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
def forward(self, decoder_input, encoder_output, causal_mask=None):
# Self-attention (causal)
x = self.self_attn(
query=decoder_input,
key=decoder_input,
value=decoder_input,
mask=causal_mask
)
# Cross-attention (Q from decoder, K/V from encoder)
x = self.cross_attn(
query=x, # From decoder
key=encoder_output, # From encoder
value=encoder_output, # From encoder
mask=None # No causal mask for cross-attention!
)
return x
```
## Part 9: Common Mistakes
### Mistake 1: Ignoring Flash Attention
**Symptom:** Training slow, high memory usage
**Fix:** Always use Flash Attention for < 16k tokens
### Mistake 2: Using Linear Attention Unnecessarily
**Symptom:** 1-3% accuracy loss for no reason
**Fix:** Use Flash Attention (exact) unless sequence > 16k
### Mistake 3: Gradient Checkpointing Instead of Flash Attention
**Symptom:** Training 20% slower
**Fix:** Flash Attention gives memory savings AND speed
### Mistake 4: Cross-Attention with Causal Mask
**Symptom:** Decoder can't attend to encoder properly
**Fix:** Causal mask only for self-attention, NOT cross-attention
### Mistake 5: Accepting O(n²) Memory
**Symptom:** GPU OOM for > 4k tokens
**Fix:** Use sparse or Flash Attention, don't just add GPUs
## Summary: Quick Reference
### Attention Selection
```
Sequence length:
< 2k → Flash Attention (default)
2-4k → Flash Attention
4-16k → Longformer (documents) or Flash Attention (batch=1)
> 16k → Sparse or Linear Attention
Memory constrained:
First: Try Flash Attention (4x less memory)
Still OOM: Use sparse attention (Longformer)
Last resort: Linear attention (accuracy loss)
Speed critical:
Training: Flash Attention (2x faster)
Inference: Flash Attention + GQA/MQA
Accuracy critical:
Use exact attention only (Flash or Sparse)
NEVER linear attention
Multimodal:
Cross-attention for modality fusion
```
### Implementation
```
PyTorch 2.0+:
F.scaled_dot_product_attention() # Auto Flash Attention
HuggingFace:
attn_implementation="flash_attention_2"
Longformer:
LongformerModel.from_pretrained("allenai/longformer-base-4096")
Custom:
Inherit from nn.Module, implement forward()
```
## Next Steps
After mastering this skill:
- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference
- `llm-specialist/context-window-management`: Manage long contexts in LLMs
- `architecture-design-principles`: Understand broader design trade-offs
**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns).

View File

@@ -0,0 +1,622 @@
# CNN Families and Selection: Choosing the Right Convolutional Network
<CRITICAL_CONTEXT>
CNNs are the foundation of computer vision. Different families have vastly different trade-offs:
- Accuracy vs Speed vs Size
- Dataset size requirements
- Deployment target (cloud vs edge vs mobile)
- Task type (classification vs detection vs segmentation)
This skill helps you choose the RIGHT CNN for YOUR constraints.
</CRITICAL_CONTEXT>
## When to Use This Skill
Use this skill when:
- ✅ Selecting CNN for vision task (classification, detection, segmentation)
- ✅ Comparing CNN families (ResNet vs EfficientNet vs MobileNet)
- ✅ Optimizing for specific constraints (latency, size, accuracy)
- ✅ Understanding CNN evolution (why newer architectures exist)
- ✅ Deployment-specific selection (cloud, edge, mobile)
DO NOT use for:
- ❌ Non-vision tasks (use sequence-models-comparison or other skills)
- ❌ Training optimization (use training-optimization pack)
- ❌ Implementation details (use pytorch-engineering pack)
**When in doubt:** If choosing WHICH CNN → this skill. If implementing/training CNN → other skills.
## Selection Framework
### Step 1: Identify Constraints
**Before recommending ANY architecture, ask:**
| Constraint | Question | Impact |
|------------|----------|--------|
| **Deployment** | Where will model run? | Cloud → Any, Edge → MobileNet/EfficientNet-Lite, Mobile → MobileNetV3 |
| **Latency** | Speed requirement? | Real-time (< 10ms) → MobileNet, Batch (> 100ms) → Any |
| **Model Size** | Parameter/memory budget? | < 10M params → MobileNet, < 50M → ResNet/EfficientNet, Any → Large models OK |
| **Dataset Size** | Training images? | < 10k → Small models, 10k-100k → Medium, > 100k → Large |
| **Accuracy** | Required accuracy? | Competitive → EfficientNet-B4+, Production → ResNet-50/EfficientNet-B2 |
| **Task Type** | Classification/detection/segmentation? | Detection → FPN-compatible, Segmentation → Multi-scale |
**Critical:** Get answers to these BEFORE recommending architecture.
### Step 2: Apply Decision Tree
```
START: What's your primary constraint?
┌─ DEPLOYMENT TARGET
│ ├─ Cloud / Server
│ │ └─ Dataset size?
│ │ ├─ Small (< 10k) → ResNet-18, EfficientNet-B0
│ │ ├─ Medium (10k-100k) → ResNet-50, EfficientNet-B2
│ │ └─ Large (> 100k) → ResNet-101, EfficientNet-B4, ViT
│ │
│ ├─ Edge Device (Jetson, Coral)
│ │ └─ Latency requirement?
│ │ ├─ Real-time (< 10ms) → MobileNetV3-Small, EfficientNet-Lite0
│ │ ├─ Medium (10-50ms) → MobileNetV3-Large, EfficientNet-Lite2
│ │ └─ Relaxed (> 50ms) → EfficientNet-B0, ResNet-18
│ │
│ └─ Mobile (iOS/Android)
│ └─ MobileNetV3-Small (fastest), MobileNetV3-Large (balanced)
│ + INT8 quantization (route to ml-production)
├─ ACCURACY PRIORITY (cloud deployment assumed)
│ ├─ Maximum accuracy → EfficientNet-B7, ResNet-152, ViT-Large
│ ├─ Balanced → EfficientNet-B2/B3, ResNet-50
│ └─ Fast training → ResNet-18, EfficientNet-B0
├─ EFFICIENCY PRIORITY
│ └─ Best accuracy per FLOP → EfficientNet family (B0-B7)
│ (EfficientNet dominates ResNet on Pareto frontier)
└─ TASK TYPE
├─ Classification → Any CNN (use constraint-based selection above)
├─ Object Detection → ResNet + FPN, EfficientDet, YOLOv8 (CSPDarknet)
└─ Segmentation → ResNet + U-Net, EfficientNet + DeepLabV3
```
## CNN Family Catalog
### 1. ResNet Family (2015) - The Standard Baseline
**Architecture:** Residual connections (skip connections) enable very deep networks
**Variants:**
- ResNet-18: 11M params, 1.8 GFLOPs, 69.8% ImageNet
- ResNet-34: 22M params, 3.7 GFLOPs, 73.3% ImageNet
- ResNet-50: 25M params, 4.1 GFLOPs, 76.1% ImageNet
- ResNet-101: 44M params, 7.8 GFLOPs, 77.4% ImageNet
- ResNet-152: 60M params, 11.6 GFLOPs, 78.3% ImageNet
**When to Use:**
-**Baseline choice**: Well-tested, widely supported
-**Transfer learning**: Excellent pre-trained weights available
-**Object detection**: Standard backbone for Faster R-CNN, Mask R-CNN
-**Interpretability**: Simple architecture, easy to understand
**When NOT to Use:**
-**Edge/mobile deployment**: Too large and slow
-**Efficiency priority**: EfficientNet beats ResNet on accuracy/FLOP
-**Small datasets (< 10k)**: Use ResNet-18, not ResNet-50+
**Key Insight:** Skip connections solve vanishing gradient, enable depth
**Code Example:**
```python
import torchvision.models as models
# For cloud/server (good dataset)
model = models.resnet50(pretrained=True)
# For small dataset or faster training
model = models.resnet18(pretrained=True)
# For maximum accuracy (cloud only)
model = models.resnet101(pretrained=True)
```
### 2. EfficientNet Family (2019) - Best Efficiency
**Architecture:** Compound scaling (depth + width + resolution) optimized via neural architecture search
**Variants:**
- EfficientNet-B0: 5M params, 0.4 GFLOPs, 77.3% ImageNet
- EfficientNet-B1: 8M params, 0.7 GFLOPs, 79.2% ImageNet
- EfficientNet-B2: 9M params, 1.0 GFLOPs, 80.3% ImageNet
- EfficientNet-B3: 12M params, 1.8 GFLOPs, 81.7% ImageNet
- EfficientNet-B4: 19M params, 4.2 GFLOPs, 82.9% ImageNet
- EfficientNet-B7: 66M params, 37 GFLOPs, 84.4% ImageNet
**When to Use:**
-**Efficiency matters**: Best accuracy per FLOP/parameter
-**Cloud deployment**: B2-B4 sweet spot for production
-**Limited compute**: B0 matches ResNet-50 accuracy at 10x fewer FLOPs
-**Scaling needs**: Want to scale model up/down systematically
**When NOT to Use:**
-**Real-time mobile**: Use MobileNet (EfficientNet has more layers)
-**Very small datasets**: Can overfit despite efficiency
-**Simplicity needed**: More complex than ResNet
**Key Insight:** Compound scaling balances depth, width, and resolution optimally
**Efficiency Comparison:**
```
Same accuracy as ResNet-50 (76%):
- ResNet-50: 25M params, 4.1 GFLOPs
- EfficientNet-B0: 5M params, 0.4 GFLOPs (10x more efficient!)
Better accuracy (82.9%):
- ResNet-152: 60M params, 11.6 GFLOPs → 78.3% ImageNet
- EfficientNet-B4: 19M params, 4.2 GFLOPs → 82.9% ImageNet
(Better accuracy with 3x fewer params and 3x less compute)
```
**Code Example:**
```python
import timm # PyTorch Image Models library
# Balanced choice (production)
model = timm.create_model('efficientnet_b2', pretrained=True)
# Efficiency priority (edge)
model = timm.create_model('efficientnet_b0', pretrained=True)
# Accuracy priority (research)
model = timm.create_model('efficientnet_b4', pretrained=True)
```
### 3. MobileNet Family (2017-2019) - Mobile Optimized
**Architecture:** Depthwise separable convolutions (drastically reduce compute)
**Variants:**
- MobileNetV1: 4.2M params, 0.6 GFLOPs, 70.6% ImageNet
- MobileNetV2: 3.5M params, 0.3 GFLOPs, 72.0% ImageNet
- MobileNetV3-Small: 2.5M params, 0.06 GFLOPs, 67.4% ImageNet
- MobileNetV3-Large: 5.4M params, 0.2 GFLOPs, 75.2% ImageNet
**When to Use:**
-**Mobile deployment**: iOS/Android apps
-**Edge devices**: Raspberry Pi, Jetson Nano
-**Real-time inference**: < 100ms latency
-**Extreme efficiency**: < 10M parameters budget
**When NOT to Use:**
-**Cloud deployment with no constraints**: EfficientNet or ResNet better accuracy
-**Accuracy priority**: Sacrifices accuracy for speed
-**Large datasets with compute**: Can afford better models
**Key Insight:** Depthwise separable convolutions = standard conv split into depthwise + pointwise (9x fewer operations)
**Deployment Performance:**
```
Raspberry Pi 4 inference (224×224 image):
- ResNet-50: ~2000ms (unusable)
- ResNet-18: ~600ms (slow)
- MobileNetV2: ~150ms (acceptable)
- MobileNetV3-Large: ~80ms (good)
- MobileNetV3-Small: ~40ms (fast)
With INT8 quantization:
- MobileNetV3-Large: ~30ms (production-ready)
- MobileNetV3-Small: ~15ms (real-time)
```
**Code Example:**
```python
import torchvision.models as models
# For mobile deployment
model = models.mobilenet_v3_large(pretrained=True)
# For ultra-low latency (sacrifice accuracy)
model = models.mobilenet_v3_small(pretrained=True)
# Quantization for mobile (route to ml-production skill for details)
# Achieves 2-4x speedup with minimal accuracy loss
```
### 4. Inception Family (2014-2016) - Multi-Scale Features
**Architecture:** Multi-scale convolutions in parallel (inception modules)
**Variants:**
- InceptionV3: 24M params, 5.7 GFLOPs, 77.5% ImageNet
- InceptionV4: 42M params, 12.3 GFLOPs, 80.0% ImageNet
- Inception-ResNet: Hybrid with residual connections
**When to Use:**
-**Multi-scale features**: Objects at different sizes
-**Object detection**: Good backbone for detection
-**Historical interest**: Understanding multi-scale approaches
**When NOT to Use:**
-**Simplicity needed**: Complex architecture, hard to modify
-**Efficiency priority**: EfficientNet better
-**Modern projects**: Largely superseded by ResNet/EfficientNet
**Key Insight:** Parallel multi-scale convolutions (1×1, 3×3, 5×5) capture different receptive fields
**Status:** Mostly historical - ResNet and EfficientNet have replaced Inception in practice
### 5. DenseNet Family (2017) - Dense Connections
**Architecture:** Every layer connects to every other layer (dense connections)
**Variants:**
- DenseNet-121: 8M params, 2.9 GFLOPs, 74.4% ImageNet
- DenseNet-169: 14M params, 3.4 GFLOPs, 75.6% ImageNet
- DenseNet-201: 20M params, 4.3 GFLOPs, 76.9% ImageNet
**When to Use:**
-**Parameter efficiency**: Good accuracy with few parameters
-**Feature reuse**: Dense connections enable feature reuse
-**Small datasets**: Better gradient flow helps with limited data
**When NOT to Use:**
-**Inference speed priority**: Dense connections slow (high memory bandwidth)
-**Training speed**: Slower to train than ResNet
-**Production deployment**: Less mature ecosystem than ResNet
**Key Insight:** Dense connections improve gradient flow, enable feature reuse, but slow inference
**Status:** Theoretically elegant, but ResNet/EfficientNet more practical
### 6. VGG Family (2014) - Historical Baseline
**Architecture:** Very deep (16-19 layers), small 3×3 convolutions, many parameters
**Variants:**
- VGG-16: 138M params, 15.5 GFLOPs, 71.5% ImageNet
- VGG-19: 144M params, 19.6 GFLOPs, 71.1% ImageNet
**When to Use:**
-**DON'T use VGG for new projects**
- Historical understanding only
**Why NOT to Use:**
- Massive parameter count (138M vs ResNet-50's 25M)
- Poor accuracy for size
- Superseded by ResNet (2015)
**Key Insight:** Proved that depth matters, but skip connections (ResNet) are better
**Status:** **Obsolete** - use ResNet or EfficientNet instead
## Practical Selection Guide
### Scenario 1: Cloud/Server Deployment
**Goal:** Best accuracy, no compute constraints
**Recommendation:**
```
Small dataset (< 10k images):
→ EfficientNet-B0 or ResNet-18
(Avoid overfitting with smaller model)
Medium dataset (10k-100k images):
→ EfficientNet-B2 or ResNet-50
(Balanced accuracy and efficiency)
Large dataset (> 100k images):
→ EfficientNet-B4 or ResNet-101
(Can afford larger model)
Maximum accuracy (research):
→ EfficientNet-B7 or Vision Transformer
(If dataset > 1M images and compute unlimited)
```
### Scenario 2: Edge Deployment (Jetson, Coral TPU)
**Goal:** Optimize for edge hardware latency
**Recommendation:**
```
Real-time requirement (< 10ms):
→ MobileNetV3-Small or EfficientNet-Lite0
+ INT8 quantization
Medium latency (10-50ms):
→ MobileNetV3-Large or EfficientNet-Lite2
Relaxed latency (> 50ms):
→ EfficientNet-B0 or ResNet-18
```
**Critical:** Profile on actual edge hardware. Quantization is mandatory (route to ml-production).
### Scenario 3: Mobile Deployment (iOS/Android)
**Goal:** On-device inference, minimal battery drain
**Recommendation:**
```
All mobile deployments:
→ MobileNetV3-Large (balanced)
→ MobileNetV3-Small (fastest, less accurate)
Always use:
- INT8 quantization (2-4x speedup)
- CoreML (iOS) or TFLite (Android) optimization
- Benchmark on target device before deploying
```
**Expected latency (iPhone 12, INT8 quantized):**
- MobileNetV3-Small: 5-10ms
- MobileNetV3-Large: 15-25ms
### Scenario 4: Object Detection
**Goal:** Select backbone for detection framework
**Recommendation:**
```
Faster R-CNN:
→ ResNet-50 + FPN (standard)
→ ResNet-101 + FPN (more accuracy)
YOLOv8:
→ CSPDarknet (built-in, optimized)
EfficientDet:
→ EfficientNet + BiFPN (best efficiency)
Custom detection:
→ ResNet or EfficientNet as backbone
→ Add Feature Pyramid Network (FPN) for multi-scale
```
**Note:** Detection adds significant compute on top of backbone. Choose efficient backbone.
### Scenario 5: Semantic Segmentation
**Goal:** Dense pixel-wise prediction
**Recommendation:**
```
U-Net style:
→ ResNet-18/34 as encoder (fast)
→ EfficientNet-B0 as encoder (efficient)
DeepLabV3:
→ ResNet-50 (standard)
→ MobileNetV3 (mobile deployment)
Key: Segmentation requires multi-scale features
→ Ensure backbone has skip connections or FPN
```
## Trade-Off Analysis
### Accuracy vs Efficiency (Pareto Frontier)
**ImageNet Top-1 Accuracy vs FLOPs:**
```
Efficiency Winners (best accuracy per FLOP):
1. EfficientNet-B0: 77.3% @ 0.4 GFLOPs (best efficiency)
2. EfficientNet-B2: 80.3% @ 1.0 GFLOPs
3. EfficientNet-B4: 82.9% @ 4.2 GFLOPs
Accuracy Winners (best absolute accuracy):
1. EfficientNet-B7: 84.4% @ 37 GFLOPs
2. ViT-Large: 85.2% @ 190 GFLOPs (requires huge dataset)
3. ResNet-152: 78.3% @ 11.6 GFLOPs (dominated by EfficientNet)
Speed Winners (lowest latency):
1. MobileNetV3-Small: 67.4% @ 0.06 GFLOPs (50ms on mobile)
2. MobileNetV3-Large: 75.2% @ 0.2 GFLOPs (100ms on mobile)
3. EfficientNet-Lite0: 75.0% @ 0.4 GFLOPs
```
**Key Takeaway:** EfficientNet dominates ResNet on Pareto frontier (better accuracy at same compute).
### Parameters vs Accuracy
**For same ~75% ImageNet accuracy:**
```
VGG-16: 138M params (❌ terrible efficiency)
ResNet-50: 25M params
EfficientNet-B0: 5M params (✅ 5x fewer parameters!)
MobileNetV3-Large: 5M params (fast inference)
```
**Conclusion:** Modern architectures (EfficientNet, MobileNet) achieve same accuracy with far fewer parameters.
## Common Pitfalls
### Pitfall 1: Defaulting to ResNet-50
**Symptom:** Using ResNet-50 without considering alternatives
**Why it's wrong:** EfficientNet-B0 matches ResNet-50 accuracy with 10x less compute
**Fix:** Consider EfficientNet family first (better efficiency)
### Pitfall 2: Choosing Large Model for Small Dataset
**Symptom:** Using ResNet-101 with < 10k images
**Why it's wrong:** Model will overfit (too many parameters for data)
**Fix:**
- < 10k images → ResNet-18 or EfficientNet-B0
- 10k-100k → ResNet-50 or EfficientNet-B2
- > 100k → Can use larger models
### Pitfall 3: Using Desktop Model on Mobile
**Symptom:** Trying to run ResNet-50 on mobile device
**Why it's wrong:** 2000ms inference time is unusable
**Fix:** Use MobileNetV3 + quantization for mobile (15-30ms)
### Pitfall 4: Ignoring Task Type
**Symptom:** Using standard CNN for object detection without FPN
**Why it's wrong:** Detection needs multi-scale features
**Fix:** Use detection-specific frameworks (YOLOv8, Faster R-CNN) with appropriate backbone
### Pitfall 5: Believing "Bigger = Better"
**Symptom:** Choosing ResNet-152 over ResNet-50 without justification
**Why it's wrong:** Diminishing returns - 3x compute for 1.3% accuracy, will overfit on small data
**Fix:** Match model capacity to dataset size, consider efficiency
## Evolution and Historical Context
**Why CNNs evolved the way they did:**
```
2012: AlexNet
→ Proved deep learning works for vision
→ 8 layers, 60M params
2014: VGG
→ Deeper is better (16-19 layers)
→ But: 138M params (too many)
2014: Inception/GoogLeNet
→ Multi-scale convolutions
→ More efficient than VGG
2015: ResNet ★
→ Skip connections enable very deep networks (152 layers)
→ Solved vanishing gradient problem
→ Became standard baseline
2017: MobileNet
→ Mobile deployment needs
→ Depthwise separable convolutions (9x fewer ops)
2017: DenseNet
→ Dense connections for feature reuse
→ Parameter efficient but slow inference
2019: EfficientNet ★
→ Compound scaling (depth + width + resolution)
→ Neural architecture search
→ Dominates Pareto frontier (best accuracy per FLOP)
→ New standard for efficiency
2020: Vision Transformer
→ Attention-based (no convolutions)
→ Requires very large datasets (> 1M images)
→ For research/large-scale applications
```
**Current Recommendations (2025):**
- Cloud: **EfficientNet** (best efficiency) or ResNet (simplicity)
- Edge: **EfficientNet-Lite** or MobileNetV3
- Mobile: **MobileNetV3** + quantization
- Detection: **EfficientDet** or YOLOv8
- Baseline: **ResNet** (simple, well-tested)
## Decision Checklist
Before choosing CNN, answer these:
```
☐ Deployment target? (cloud/edge/mobile)
☐ Latency requirement? (< 10ms / 10-100ms / > 100ms)
☐ Model size budget? (< 10M / 10-50M / unlimited params)
☐ Dataset size? (< 10k / 10k-100k / > 100k images)
☐ Accuracy priority? (maximum / production / fast iteration)
☐ Task type? (classification / detection / segmentation)
☐ Efficiency matters? (yes → EfficientNet, no → flexibility)
Based on answers:
→ Mobile → MobileNetV3
→ Edge → EfficientNet-Lite or MobileNetV3
→ Cloud + efficiency → EfficientNet
→ Cloud + simplicity → ResNet
→ Maximum accuracy → EfficientNet-B7 or ViT
→ Small dataset → Small models (ResNet-18, EfficientNet-B0)
```
## Integration with Other Skills
**After selecting CNN architecture:**
**Training the model:**
`yzmir/training-optimization/using-training-optimization`
- Optimizer selection (Adam, SGD, AdamW)
- Learning rate schedules
- Data augmentation
**Implementing in PyTorch:**
`yzmir/pytorch-engineering/using-pytorch-engineering`
- Custom modifications to pre-trained models
- Multi-GPU training
- Performance optimization
**Deploying to production:**
`yzmir/ml-production/using-ml-production`
- Quantization (INT8, FP16)
- Model serving (TorchServe, ONNX)
- Optimization for edge/mobile (TFLite, CoreML)
**If architecture is unstable (very deep):**
`yzmir/neural-architectures/normalization-techniques`
- Normalization layers (BatchNorm, LayerNorm)
- Skip connections
- Initialization strategies
## Summary
**CNN Selection in One Table:**
| Scenario | Recommendation | Why |
|----------|----------------|-----|
| Cloud, balanced | EfficientNet-B2 | Best efficiency, 80% accuracy |
| Cloud, max accuracy | EfficientNet-B4 | 83% accuracy, reasonable compute |
| Cloud, simple baseline | ResNet-50 | Well-tested, widely used |
| Edge device | MobileNetV3-Large | Optimized for edge, 75% accuracy |
| Mobile app | MobileNetV3-Small + quantization | < 20ms inference |
| Small dataset (< 10k) | ResNet-18 or EfficientNet-B0 | Avoid overfitting |
| Object detection | ResNet-50 + FPN, EfficientDet | Multi-scale features |
| Segmentation | ResNet + U-Net, DeepLabV3 | Dense prediction |
**Key Principles:**
1. **Match model capacity to dataset size** (small data → small model)
2. **EfficientNet dominates ResNet on efficiency** (same accuracy, less compute)
3. **Mobile needs mobile-specific architectures** (MobileNet, quantization)
4. **Task type matters** (detection/segmentation need multi-scale features)
5. **Bigger ≠ always better** (diminishing returns, overfitting risk)
**When in doubt:** Start with **EfficientNet-B2** (cloud) or **MobileNetV3-Large** (edge/mobile).
**END OF SKILL**

View File

@@ -0,0 +1,811 @@
# Generative Model Families
## When to Use This Skill
Use this skill when you need to:
- ✅ Select generative model for image/audio/video generation
- ✅ Understand VAE vs GAN vs Diffusion trade-offs
- ✅ Decide between training from scratch vs fine-tuning
- ✅ Address mode collapse in GANs
- ✅ Choose between quality, speed, and training stability
- ✅ Understand modern landscape (Stable Diffusion, StyleGAN, etc.)
**Do NOT use this skill for:**
- ❌ Text generation (use `llm-specialist` pack)
- ❌ Architecture implementation details (use model-specific docs)
- ❌ High-level architecture selection (use `using-neural-architectures`)
## Core Principle
**Generative models have fundamental trade-offs:**
- **Quality vs Stability**: GANs sharp but unstable, VAEs blurry but stable
- **Quality vs Speed**: Diffusion high-quality but slow, GANs fast
- **Explicitness vs Flexibility**: Autoregressive/Flow have likelihood, GANs don't
**Modern default (2025):** Diffusion models (best quality + stability)
## Part 1: Model Family Overview
### The Five Families
**1. VAE (Variational Autoencoder)**
- **Approach**: Learn latent space with encoder-decoder
- **Quality**: Blurry (6/10)
- **Training**: Very stable
- **Use**: Latent space exploration, NOT high-quality generation
**2. GAN (Generative Adversarial Network)**
- **Approach**: Adversarial game (generator vs discriminator)
- **Quality**: Sharp (9/10)
- **Training**: Unstable (adversarial dynamics)
- **Use**: High-quality generation, fast inference
**3. Diffusion Models**
- **Approach**: Iterative denoising
- **Quality**: Very sharp (9.5/10)
- **Training**: Stable
- **Use**: Modern default for high-quality generation
**4. Autoregressive Models**
- **Approach**: Sequential generation (pixel-by-pixel, token-by-token)
- **Quality**: Good (7-8/10)
- **Training**: Stable
- **Use**: Explicit likelihood, sequential data
**5. Flow Models**
- **Approach**: Invertible transformations
- **Quality**: Good (7-8/10)
- **Training**: Stable
- **Use**: Exact likelihood, invertibility needed
### Quick Comparison
| Model | Quality | Training Stability | Inference Speed | Mode Collapse | Likelihood |
|-------|---------|-------------------|----------------|---------------|------------|
| VAE | 6/10 (blurry) | 10/10 | Fast | No | Approximate |
| GAN | 9/10 | 3/10 | Fast | Yes | No |
| Diffusion | 9.5/10 | 9/10 | Slow | No | Approximate |
| Autoregressive | 7-8/10 | 9/10 | Very slow | No | Exact |
| Flow | 7-8/10 | 8/10 | Fast (both ways) | No | Exact |
## Part 2: VAE (Variational Autoencoder)
### Architecture
**Components:**
1. **Encoder**: x → z (image to latent)
2. **Latent space**: z ~ N(μ, σ²)
3. **Decoder**: z → x' (latent to reconstruction)
**Loss function:**
```python
# ELBO (Evidence Lower Bound)
loss = reconstruction_loss + KL_divergence
# Reconstruction: How well decoder reconstructs input
reconstruction_loss = MSE(x, x_reconstructed)
# KL: How close latent is to standard normal
KL_divergence = KL(q(z|x) || p(z))
```
### Why VAE is Blurry
**Problem**: MSE loss encourages pixel-wise averaging
**Example:**
- Dataset: Faces with both smiles and no smiles
- VAE learns: "Average face has half-smile blur"
- Result: Blurry, hedges between modes
**Mathematical reason:**
- MSE minimization = mean prediction
- Mean of sharp images = blurry image
### When to Use VAE
**Use VAE for:**
- Latent space exploration (interpolation, arithmetic)
- Anomaly detection (reconstruction error)
- Disentangled representations (β-VAE)
- Compression (lossy, with latent codes)
**DON'T use VAE for:**
- High-quality image generation (use GAN or Diffusion!)
- Sharp, realistic outputs
### Implementation
```python
import torch
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 64x64 -> 32x32
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)
self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 128 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1),
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
# Reparameterization trick: z = μ + σ * ε
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# Encode
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
# Sample latent
z = self.reparameterize(mu, logvar)
# Decode
h = self.fc_decode(z)
h = h.view(-1, 128, 8, 8)
x_recon = self.decoder(h)
return x_recon, mu, logvar
def loss_function(self, x, x_recon, mu, logvar):
# Reconstruction loss
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL divergence
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
```
## Part 3: GAN (Generative Adversarial Network)
### Architecture
**Components:**
1. **Generator**: z → x (noise to image)
2. **Discriminator**: x → [0, 1] (image to real/fake probability)
**Adversarial Training:**
```python
# Discriminator loss: Classify real as real, fake as fake
D_loss = -log(D(x_real)) - log(1 - D(G(z)))
# Generator loss: Fool discriminator
G_loss = -log(D(G(z)))
# Minimax game:
min_G max_D V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))]
```
### Training Instability
**Problem**: Adversarial dynamics are unstable
**Common issues:**
1. **Mode collapse**: Generator produces limited variety
2. **Non-convergence**: Oscillation, never settles
3. **Vanishing gradients**: Discriminator too strong, generator can't learn
4. **Hyperparameter sensitivity**: Learning rates critical
**Solutions:**
- Spectral normalization (StyleGAN2)
- Progressive growing (start low-res, increase)
- Minibatch discrimination (penalize lack of diversity)
- Wasserstein loss (WGAN, more stable)
### Mode Collapse
**What is it?**
- Generator produces subset of distribution
- Example: Face GAN only generates 10 face types
**Why it happens:**
- Generator exploits discriminator weaknesses
- Finds "easy" samples that fool discriminator
- Forgets other modes
**Detection:**
```python
# Check diversity: Generate many samples
samples = generator.generate(n=1000)
diversity = compute_pairwise_distance(samples)
if diversity < threshold:
print("Mode collapse detected!")
```
**Solutions:**
- Minibatch discrimination
- Unrolled GANs (slow but helps)
- Switch to diffusion (no mode collapse by design!)
### Modern GANs
**StyleGAN2 (2020):**
- State-of-the-art for faces
- Style-based generator
- Spectral normalization for stability
- Resolution: 1024×1024
**StyleGAN3 (2021):**
- Alias-free architecture
- Better animation/video
**When to use GAN:**
✅ Fast inference needed (50ms per image)
✅ Pretrained model available (StyleGAN2)
✅ Can tolerate training difficulty
❌ Training instability unacceptable
❌ Mode collapse problematic
❌ Starting from scratch (use diffusion instead)
### Implementation (Basic GAN)
```python
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_channels=3):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128 * 8 * 8),
nn.ReLU(),
nn.Unflatten(1, (128, 8, 8)),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(32, img_channels, 4, 2, 1), # 32x32 -> 64x64
nn.Tanh()
)
def forward(self, z):
return self.model(z)
class Discriminator(nn.Module):
def __init__(self, img_channels=3):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(img_channels, 32, 4, 2, 1), # 64x64 -> 32x32
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Training loop
for real_images in dataloader:
# Train discriminator
fake_images = generator(noise)
D_real = discriminator(real_images)
D_fake = discriminator(fake_images.detach())
D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
D_loss.backward()
optimizer_D.step()
# Train generator
D_fake = discriminator(fake_images)
G_loss = -torch.mean(torch.log(D_fake))
G_loss.backward()
optimizer_G.step()
```
## Part 4: Diffusion Models (Modern Default)
### Architecture
**Concept**: Learn to reverse a diffusion (noising) process
**Forward process** (fixed):
```python
# Gradually add noise to image
x_0 (original) x_1 x_2 ... x_T (pure noise)
# At each step:
x_t = (1 - β_t) * x_{t-1} + β_t * ε
where ε ~ N(0, I), β_t = noise schedule
```
**Reverse process** (learned):
```python
# Model learns to denoise
x_T (noise) x_{T-1} ... x_1 x_0 (image)
# Model predicts: ε_θ(x_t, t)
# Then: x_{t-1} = (x_t - √β_t * ε_θ(x_t, t)) / √(1 - β_t)
```
**Training:**
```python
# Simple loss: Predict the noise
loss = MSE(ε, ε_θ(x_t, t))
# x_t = noisy image at step t
# ε = actual noise added
# ε_θ(x_t, t) = model's noise prediction
```
### Why Diffusion is Excellent
**Advantages:**
1. **High quality**: State-of-the-art (better than GAN)
2. **Stable training**: Standard MSE loss (no adversarial dynamics)
3. **No mode collapse**: By design, covers full distribution
4. **Controllable**: Easy to add conditioning (text, class, etc.)
**Disadvantages:**
1. **Slow inference**: 50-1000 denoising steps (vs GAN's 1 step)
2. **Compute intensive**: T forward passes (T = 50-1000)
**Speed comparison:**
```
GAN: 1 forward pass = 50ms
Diffusion (T=50): 50 forward passes = 2.5 seconds
Diffusion (T=1000): 1000 forward passes = 50 seconds
```
**Speedup techniques:**
- DDIM (fewer steps, 10-50 instead of 1000)
- DPM-Solver (fast sampler)
- Latent diffusion (Stable Diffusion, denoise in latent space)
### Modern Diffusion Models
**Stable Diffusion (2022+):**
- Latent diffusion (denoise in VAE latent space)
- Text conditioning (CLIP text encoder)
- Pretrained on billions of images
- Fine-tunable
**DALL-E 2 (2022):**
- Prior network (text → image embedding)
- Diffusion decoder (embedding → image)
**Imagen (2022, Google):**
- Text conditioning with T5 encoder
- Cascaded diffusion (64×64 → 256×256 → 1024×1024)
**When to use Diffusion:**
✅ High-quality generation (best quality)
✅ Stable training (standard loss)
✅ Diversity needed (no mode collapse)
✅ Conditioning (text-to-image, class-conditional)
❌ Need fast inference (< 1 second)
❌ Real-time generation
### Implementation (DDPM)
```python
class DiffusionModel(nn.Module):
def __init__(self, img_channels=3):
super().__init__()
# U-Net architecture
self.model = UNet(
in_channels=img_channels,
out_channels=img_channels,
time_embedding_dim=256
)
def forward(self, x_t, t):
# Predict noise ε at timestep t
return self.model(x_t, t)
# Training
def train_step(model, x_0):
# Sample random timestep
t = torch.randint(0, T, (batch_size,))
# Sample noise
ε = torch.randn_like(x_0)
# Create noisy image x_t
α_t = alpha_schedule[t]
x_t = torch.sqrt(α_t) * x_0 + torch.sqrt(1 - α_t) * ε
# Predict noise
ε_pred = model(x_t, t)
# Loss: MSE between actual and predicted noise
loss = F.mse_loss(ε_pred, ε)
return loss
# Sampling (generation)
@torch.no_grad()
def sample(model, shape):
# Start from pure noise
x_t = torch.randn(shape)
# Iteratively denoise
for t in reversed(range(T)):
# Predict noise
ε_pred = model(x_t, t)
# Denoise one step
α_t = alpha_schedule[t]
x_t = (x_t - (1 - α_t) / torch.sqrt(1 - ᾱ_t) * ε_pred) / torch.sqrt(α_t)
# Add noise (except last step)
if t > 0:
x_t += torch.sqrt(β_t) * torch.randn_like(x_t)
return x_t # x_0 (generated image)
```
## Part 5: Autoregressive Models
### Concept
**Idea**: Model probability as product of conditionals
```
p(x) = p(x_1) * p(x_2|x_1) * p(x_3|x_1,x_2) * ... * p(x_n|x_1,...,x_{n-1})
```
**For images**: Generate pixel-by-pixel (or patch-by-patch)
**Architectures:**
- **PixelCNN**: Convolutional with masked kernels
- **PixelCNN++**: Improved with mixture of logistics
- **VQ-VAE + PixelCNN**: Two-stage (learn discrete codes, model codes)
- **ImageGPT**: GPT-style Transformer for images
### Advantages
**Explicit likelihood**: Can compute p(x) exactly
**Stable training**: Standard cross-entropy loss
**Theoretical guarantees**: Proper probability model
### Disadvantages
**Very slow generation**: Sequential (can't parallelize)
**Limited quality**: Worse than GAN/Diffusion for high-res
**Resolution scaling**: Impractical for 1024×1024 (1M pixels!)
**Speed comparison:**
```
GAN: Generate 1024×1024 in 50ms (parallel)
PixelCNN: Generate 32×32 in 5 seconds (sequential!)
ImageGPT: Generate 256×256 in 30 seconds
For 1024×1024: 1M pixels × 5ms/pixel = 83 minutes!
```
### When to Use
**Use autoregressive for:**
- Explicit likelihood needed (compression, evaluation)
- Small images (32×32, 64×64)
- Two-stage models (VQ-VAE + Transformer)
**Don't use for:**
- High-resolution images (too slow)
- Real-time generation
- Quality-critical applications (use diffusion)
### Modern Usage
**Two-stage approach (DALL-E, VQ-GAN):**
1. **Stage 1**: VQ-VAE learns discrete codes
- Image → 32×32 grid of codes (instead of 1M pixels)
2. **Stage 2**: Autoregressive model (Transformer) on codes
- Much faster (32×32 = 1024 codes, not 1M pixels)
## Part 6: Flow Models
### Concept
**Idea**: Invertible transformations
```
z ~ N(0, I) ←→ x ~ p_data
f: z → x (forward)
f⁻¹: x → z (inverse)
```
**Requirement**: f must be invertible and differentiable
**Advantage**: Exact likelihood via change-of-variables
```
log p(x) = log p(z) + log |det(∂f⁻¹/∂x)|
```
### Architectures
**RealNVP (2017):**
- Coupling layers (affine transformations)
- Invertible by design
**Glow (2018, OpenAI):**
- Actnorm, invertible 1×1 convolutions
- Multi-scale architecture
**When to use Flow:**
✅ Exact likelihood needed (better than VAE)
✅ Invertibility needed (both z→x and x→z)
✅ Stable training (standard loss)
❌ Architecture constraints (must be invertible)
❌ Quality not as good as GAN/Diffusion
### Modern Status
**Mostly superseded by Diffusion:**
- Diffusion has better quality
- Diffusion more flexible (no invertibility constraint)
- Flow models still used in specialized applications
## Part 7: Decision Framework
### By Primary Goal
```
Goal: High-quality images
→ Diffusion (modern default)
→ OR GAN if pretrained available
Goal: Fast inference
→ GAN (50ms per image)
→ Avoid Diffusion (too slow for real-time)
Goal: Training stability
→ Diffusion or VAE (standard loss)
→ Avoid GAN (adversarial training hard)
Goal: Latent space exploration
→ VAE (smooth interpolation)
→ Avoid GAN (no encoder)
Goal: Explicit likelihood
→ Autoregressive or Flow
→ For evaluation, compression
Goal: Diversity (no mode collapse)
→ Diffusion (by design)
→ OR VAE (stable)
→ Avoid GAN (mode collapse common)
```
### By Data Type
```
Images (high-quality):
→ Diffusion (Stable Diffusion)
→ OR GAN (StyleGAN2)
Images (small, 32×32):
→ Any model works
→ Try VAE first (simplest)
Audio waveforms:
→ WaveGAN
→ OR Diffusion (WaveGrad)
Video:
→ Video Diffusion (limited)
→ OR GAN (StyleGAN-V)
Text:
→ Autoregressive (GPT)
→ NOT VAE/GAN/Diffusion (discrete tokens)
```
### By Training Budget
```
Large budget (millions $, pretrain from scratch):
→ Diffusion (Stable Diffusion scale)
→ Billions of images, weeks on cluster
Medium budget (thousands $, train from scratch):
→ GAN or Diffusion
→ 10k-1M images, days on GPU
Small budget (hundreds $, fine-tune):
→ Fine-tune Stable Diffusion (LoRA)
→ 1k-10k images, hours on consumer GPU
Tiny budget (research, small scale):
→ VAE (simplest, most stable)
→ Few thousand images, CPU possible
```
### Modern Recommendations (2025)
**For new projects:**
1. **Default: Diffusion**
- Fine-tune Stable Diffusion or train from scratch
- Best quality + stability
2. **If need speed: GAN**
- Use pretrained StyleGAN2 if available
- Or train GAN (if can tolerate instability)
3. **If need latent space: VAE**
- For interpolation, not generation quality
**AVOID:**
- Training GAN from scratch (unless necessary)
- Using VAE for high-quality generation
- Autoregressive for high-res images
## Part 8: Training from Scratch vs Fine-Tuning
### Stable Diffusion Example
**Pretraining (what Stability AI did):**
- Dataset: LAION-5B (5 billion images)
- Compute: 150,000 A100 GPU hours
- Cost: ~$600,000
- Time: Weeks on massive cluster
- **DON'T DO THIS!**
**Fine-tuning (what users do):**
- Dataset: 10k-100k domain images
- Compute: 100-1000 GPU hours
- Cost: $100-1,000
- Time: Days on single A100
- **DO THIS!**
**LoRA (Low-Rank Adaptation):**
- Efficient fine-tuning (fewer parameters)
- Dataset: 1k-5k images
- Compute: 10-100 GPU hours
- Cost: $10-100
- Time: Hours on consumer GPU (RTX 3090)
- **Best for small budgets!**
### Decision
```
Have pretrained model in your domain:
→ Fine-tune (don't retrain!)
No pretrained model:
→ Train from scratch (small model)
→ OR find closest pretrained and fine-tune
Budget < $1000:
→ LoRA fine-tuning
→ OR train small model (64×64)
Budget < $100:
→ LoRA with free Colab
→ OR VAE from scratch (cheap)
```
## Part 9: Common Mistakes
### Mistake 1: VAE for High-Quality Generation
**Symptom:** Blurry outputs
**Fix:** Use GAN or Diffusion for quality
**VAE is for:** Latent space, not generation
### Mistake 2: Ignoring Mode Collapse
**Symptom:** GAN generates same images
**Fix:** Spectral norm, minibatch discrimination
**Better:** Switch to Diffusion (no mode collapse)
### Mistake 3: Training Stable Diffusion from Scratch
**Symptom:** Burning money, poor results
**Fix:** Fine-tune pretrained model
**Reality:** Pretraining costs $600k+
### Mistake 4: Slow Inference with Diffusion
**Symptom:** 50 seconds per image
**Fix:** Use DDIM (fewer steps, 10-50)
**OR:** Use GAN if speed critical
### Mistake 5: Wrong Loss for GAN
**Symptom:** Training diverges
**Fix:** Use Wasserstein loss (WGAN)
**OR:** Spectral normalization
**Better:** Switch to Diffusion (standard loss)
## Summary: Quick Reference
### Model Selection
```
High quality + stable training:
→ Diffusion (modern default)
Fast inference required:
→ GAN (if pretrained) or trained GAN
Latent space exploration:
→ VAE
Explicit likelihood:
→ Autoregressive or Flow
Small images (< 64×64):
→ Any model (start with VAE)
Large images (> 256×256):
→ Diffusion or GAN (avoid autoregressive)
```
### Quality Ranking
```
1. Diffusion (9.5/10)
2. GAN (9/10)
3. Autoregressive (7-8/10)
4. Flow (7-8/10)
5. VAE (6/10 - blurry)
```
### Training Stability Ranking
```
1. VAE (10/10)
2. Diffusion (9/10)
3. Autoregressive (9/10)
4. Flow (8/10)
5. GAN (3/10 - very unstable)
```
### Modern Stack (2025)
```
Image generation: Stable Diffusion (fine-tuned)
Fast inference: StyleGAN2 (if available)
Latent space: VAE
Research: Diffusion (easiest to train)
```
## Next Steps
After mastering this skill:
- `llm-specialist/llm-finetuning-strategies`: Apply to text generation
- `architecture-design-principles`: Understand design trade-offs
- `training-optimization`: Optimize training for your chosen model
**Remember:** Diffusion models dominate in 2025. Use them unless you have specific reason not to (speed, latent space, likelihood).

View File

@@ -0,0 +1,625 @@
# Graph Neural Networks Basics
## When to Use This Skill
Use this skill when you need to:
- ✅ Work with graph-structured data (molecules, social networks, citations)
- ✅ Understand why CNN/RNN don't work on graphs
- ✅ Learn message passing framework
- ✅ Choose between GCN, GraphSAGE, GAT
- ✅ Decide if GNN is appropriate (vs simple model)
- ✅ Implement permutation-invariant aggregations
**Do NOT use this skill for:**
- ❌ Sequential data (use RNN/Transformer)
- ❌ Grid data (use CNN)
- ❌ High-level architecture selection (use `using-neural-architectures`)
## Core Principle
**Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work.
**GNN solution:** Message passing
- Nodes aggregate information from neighbors
- Multiple layers = multi-hop neighborhoods
- Permutation invariant (order doesn't matter)
**Critical question:** Does graph structure actually help? (Test: Compare with/without edges)
## Part 1: Why GNN (Not CNN/RNN)
### Problem: Graph Structure
**Graph components:**
- **Nodes**: Entities (atoms, users, papers)
- **Edges**: Relationships (bonds, friendships, citations)
- **Features**: Node/edge attributes
**Key property:** Irregular structure
- Each node has variable number of neighbors
- No fixed spatial arrangement
- Permutation invariant (node order doesn't matter)
### Why CNN Doesn't Work
**CNN assumption:** Regular grid structure
**Example:** Image (2D grid)
```
Every pixel has exactly 8 neighbors:
[■][■][■]
[■][X][■] ← Center pixel has 8 neighbors (fixed!)
[■][■][■]
CNN kernel: 3×3 (fixed size, fixed positions)
```
**Graph reality:** Irregular neighborhoods
```
Node A: 2 neighbors (H, C)
Node B: 4 neighbors (C, C, C, H)
Node C: 1 neighbor (H)
No fixed kernel size or position!
```
**CNN limitations:**
- Requires fixed-size neighborhoods → Graphs have variable-size
- Assumes spatial locality → Graphs have arbitrary connectivity
- Depends on node ordering → Should be permutation invariant
### Why RNN Doesn't Work
**RNN assumption:** Sequential structure
**Example:** Text (1D sequence)
```
"The cat sat" → [The] → [cat] → [sat]
Clear sequential order, temporal dependencies
```
**Graph reality:** No inherent sequence
```
Social network:
A — B — C
| |
D ——————E
What's the "sequence"? A→B→C? A→D→E? No natural ordering!
```
**RNN limitations:**
- Requires sequential order → Graphs have no natural order
- Processes one element at a time → Graphs have parallel connections
- Order-dependent → Should be permutation invariant
### GNN Solution
**Key innovation:** Message passing on graph structure
- Operate directly on nodes and edges
- Variable-size neighborhoods (handled naturally)
- Permutation invariant aggregations
## Part 2: Message Passing Framework
### Core Mechanism
**Message passing in 3 steps:**
**1. Aggregate neighbor messages**
```python
# Node i aggregates from neighbors N(i)
messages = [h_j for j in neighbors(i)]
aggregated = aggregate(messages) # e.g., mean, sum, max
```
**2. Update node representation**
```python
# Combine own features with aggregated messages
h_i_new = update(h_i_old, aggregated) # e.g., neural network
```
**3. Repeat for L layers**
- Layer 1: Node sees 1-hop neighbors
- Layer 2: Node sees 2-hop neighbors
- Layer L: Node sees L-hop neighborhood
### Concrete Example: Social Network
**Task:** Predict user interests
**Graph:**
```
B (sports)
|
A ---+--- C (cooking)
|
D (music)
```
**Layer 1: 1-hop neighbors**
```python
# Node A aggregates from direct friends
h_A_layer1 = update(
h_A,
aggregate([h_B, h_C, h_D])
)
# Now h_A includes friend interests
```
**Layer 2: 2-hop neighbors (friends of friends)**
```python
# B's friends: E, F
# C's friends: G, H
# D's friends: I
h_A_layer2 = update(
h_A_layer1,
aggregate([h_B', h_C', h_D']) # h_B' includes E, F
)
# Now h_A includes friends-of-friends!
```
**Key insight:** More layers = larger receptive field (L-hop neighborhood)
### Permutation Invariance
**Critical property:** Same graph → same output (regardless of node ordering)
**Example:**
```python
Graph: A-B, B-C
Node list 1: [A, B, C]
Node list 2: [C, B, A]
Output MUST be identical! (Same graph, different ordering)
```
**Invariant aggregations:**
- ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])`
- ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])`
- ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])`
**NOT invariant:**
- ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])`
- ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]`
**Implementation:**
```python
# CORRECT: Permutation invariant
def aggregate(neighbor_features):
return torch.mean(neighbor_features, dim=0)
# WRONG: Order-dependent!
def aggregate(neighbor_features):
return LSTM(neighbor_features) # Output depends on order
```
## Part 3: GNN Architectures
### Architecture 1: GCN (Graph Convolutional Network)
**Key idea:** Spectral convolution on graphs (simplified)
**Formula:**
```python
h_i^(l+1) = σ(_{jN(i)} W^(l) h_j^(l) / (|N(i)| |N(j)|))
# Normalize by degree (√(deg(i) * deg(j)))
```
**Aggregation:** Weighted mean (degree-normalized)
**Properties:**
- Transductive (needs full graph at training)
- Computationally efficient
- Good baseline
**When to use:**
- Full graph available at training time
- Starting point (simplest GNN)
- Small to medium graphs
**Implementation:**
```python
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# x: Node features (N, in_channels)
# edge_index: Graph connectivity (2, E)
# Layer 1
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Layer 2
x = self.conv2(x, edge_index)
return x
```
### Architecture 2: GraphSAGE
**Key idea:** Sample and aggregate (inductive learning)
**Formula:**
```python
# Sample fixed-size neighborhood
neighbors_sampled = sample(neighbors(i), k=10)
# Aggregate
h_N = aggregate({h_j for j in neighbors_sampled})
# Concatenate and transform
h_i^(l+1) = σ(W^(l) [h_i^(l); h_N])
```
**Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance)
**Key innovation:** Sampling
- Sample fixed number of neighbors (e.g., 10)
- Makes computation tractable for large graphs
- Enables inductive learning (generalizes to unseen nodes)
**When to use:**
- Large graphs (millions of nodes)
- Need inductive capability (new nodes appear)
- Training on subset, testing on full graph
**Implementation:**
```python
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture 3: GAT (Graph Attention Network)
**Key idea:** Learn attention weights for neighbors
**Formula:**
```python
# Attention scores
α_ij = attention(h_i, h_j) # How important is neighbor j to node i?
# Normalize (softmax)
α_ij = softmax_j(α_ij)
# Weighted aggregation
h_i^(l+1) = σ(_{jN(i)} α_ij W h_j^(l))
```
**Key innovation:** Learned neighbor importance
- Not all neighbors equally important
- Attention mechanism decides weights
- Multi-head attention (like Transformer)
**When to use:**
- Neighbors have varying importance
- Need interpretability (attention weights)
- Have sufficient data (attention needs more data to learn)
**Implementation:**
```python
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
```
### Architecture Comparison
| Feature | GCN | GraphSAGE | GAT |
|---------|-----|-----------|-----|
| Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted |
| Neighbor weighting | Fixed (by degree) | Equal | Learned |
| Inductive | No | Yes | Yes |
| Scalability | Medium | High (sampling) | Medium |
| Interpretability | Low | Low | High (attention) |
| Complexity | Low | Medium | High |
### Decision Tree
```
Starting out / Small graph:
→ GCN (simplest baseline)
Large graph (millions of nodes):
→ GraphSAGE (sampling enables scalability)
Need inductive learning (new nodes):
→ GraphSAGE or GAT
Neighbors have different importance:
→ GAT (attention learns importance)
Need interpretability:
→ GAT (attention weights explain predictions)
Production deployment:
→ GraphSAGE (most robust and scalable)
```
## Part 4: When NOT to Use GNN
### Critical Question
**Does graph structure actually help?**
**Test:** Compare model with and without edges
```python
# Baseline: MLP on node features only
mlp_accuracy = train_mlp(node_features, labels)
# GNN: Use node features + graph structure
gnn_accuracy = train_gnn(node_features, edges, labels)
# Decision:
if gnn_accuracy - mlp_accuracy < 2%:
print("Graph structure doesn't help much")
print("Use simpler model (MLP or XGBoost)")
else:
print("Graph structure adds value")
print("Use GNN")
```
### Scenarios Where GNN Doesn't Help
**1. Node features dominate**
```
User churn prediction:
- Node features: Usage hours, demographics, subscription → Highly predictive
- Graph edges: Sparse user interactions → Weak signal
- Result: MLP 85%, GNN 86% (not worth complexity!)
```
**2. Sparse graphs**
```
Graph with 1000 nodes, 100 edges (0.01% density):
- Most nodes have 0-1 neighbors
- No information to aggregate
- GNN reduces to MLP
```
**3. Random graph structure**
```
If edges are random (no homophily):
- Neighbor labels uncorrelated
- Aggregation adds noise
- Simple model better
```
### When GNN DOES Help
**Molecular property prediction**
- Structure is PRIMARY signal
- Atom types + bonds determine properties
- GNN: Huge improvement over fingerprints
**Citation networks**
- Paper quality correlated with neighbors
- "You are what you cite"
- Clear homophily
**Social recommendation**
- Friends have similar preferences
- Graph structure informative
- GNN: Moderate to large improvement
**Knowledge graphs**
- Entities connected by relations
- Multi-hop reasoning valuable
- GNN captures complex patterns
### Decision Framework
```
1. Start simple:
- Try MLP or XGBoost on node features
- Establish baseline performance
2. Check graph structure value:
- Does edge information correlate with target?
- Is there homophily (similar nodes connected)?
- Test: Remove edges, compare performance
3. Use GNN if:
- Graph structure adds >2-5% accuracy
- Structure is interpretable (not random)
- Have enough nodes for GNN to learn
4. Stick with simple if:
- Node features alone sufficient
- Graph structure weak/random
- Small dataset (< 1000 nodes)
```
## Part 5: Practical Implementation
### Using PyTorch Geometric
**Installation:**
```bash
pip install torch-geometric
```
**Basic workflow:**
```python
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
# 1. Create graph data
x = torch.tensor([[feature1], [feature2], ...]) # Node features
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format)
y = torch.tensor([label1, label2, ...]) # Node labels
data = Data(x=x, edge_index=edge_index, y=y)
# 2. Define model
class GNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(in_features, 64)
self.conv2 = GCNConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 3. Train
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
```
### Edge Index Format
**COO (Coordinate) format:**
```python
# Edge list: (0→1), (1→2), (2→0)
edge_index = torch.tensor([
[0, 1, 2], # Source nodes
[1, 2, 0] # Target nodes
])
# For undirected graph, include both directions:
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 0], # Source
[1, 0, 2, 1, 0, 2] # Target
])
```
### Mini-batching Graphs
**Problem:** Graphs have different sizes
**Solution:** Batch graphs as one large disconnected graph
```python
from torch_geometric.data import DataLoader
# Create dataset
dataset = [Data(...), Data(...), ...]
# DataLoader handles batching
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch contains multiple graphs as one large graph
# batch.batch: Indicator which nodes belong to which graph
out = model(batch.x, batch.edge_index)
```
## Part 6: Common Mistakes
### Mistake 1: LSTM Aggregation
**Symptom:** Different outputs for same graph with reordered nodes
**Fix:** Use mean/sum/max aggregation (permutation invariant)
### Mistake 2: Forgetting Edge Direction
**Symptom:** Information flows wrong way
**Fix:** For undirected graphs, add edges in both directions
### Mistake 3: Too Many Layers
**Symptom:** Performance degrades, over-smoothing
**Fix:** Use 2-3 layers (most graphs have small diameter)
**Explanation:** Too many layers → all nodes converge to same representation
### Mistake 4: Not Testing Simple Baseline
**Symptom:** Complex GNN with minimal improvement
**Fix:** Always test MLP on node features first
### Mistake 5: Using GNN on Euclidean Data
**Symptom:** CNN/RNN would work better
**Fix:** Use GNN only for irregular graph structure (not grids/sequences)
## Part 7: Summary
### Quick Reference
**When to use GNN:**
- Graph-structured data (molecules, social networks, citations)
- Irregular neighborhoods (not grid/sequence)
- Graph structure informative (test this!)
**Architecture selection:**
```
Start: GCN (simplest)
Large graph: GraphSAGE (scalable)
Inductive learning: GraphSAGE or GAT
Neighbor importance: GAT (attention)
```
**Key principles:**
- Message passing: Aggregate neighbors + Update node
- Permutation invariance: Use mean/sum/max (not LSTM)
- Test baseline: MLP first, GNN if structure helps
- Layers: 2-3 sufficient (more = over-smoothing)
**Implementation:**
- PyTorch Geometric: Standard library
- COO format: Edge index as 2×E tensor
- Batching: Merge graphs into one large graph
## Next Steps
After mastering this skill:
- `transformer-architecture-deepdive`: Understand attention (used in GAT)
- `architecture-design-principles`: Design principles for graph architectures
- Advanced GNNs: Graph Transformers, Equivariant GNNs
**Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)

View File

@@ -0,0 +1,915 @@
# Normalization Techniques
## Context
You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering:
- **Batch size dependency**: BatchNorm fails with small batches (< 8)
- **Architecture mismatch**: BatchNorm breaks RNNs/Transformers (use LayerNorm)
- **Task-specific needs**: Style transfer needs InstanceNorm, not BatchNorm
- **Modern alternatives**: RMSNorm simpler and faster than LayerNorm for LLMs
**This skill prevents normalization cargo-culting and provides architecture-specific selection.**
## Why Normalization Matters
**Problem: Internal Covariate Shift**
During training, layer input distributions shift as previous layers update. This causes:
- Vanishing/exploding gradients (deep networks)
- Slow convergence (small learning rates required)
- Training instability (loss spikes)
**Solution: Normalization**
Normalize activations to have stable statistics (mean=0, std=1). Benefits:
- **10x faster convergence**: Can use larger learning rates
- **Better generalization**: Regularization effect
- **Enables deep networks**: 50+ layers without gradient issues
- **Less sensitive to initialization**: Weights can start further from optimal
**Key insight**: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize.
## Normalization Families
### 1. Batch Normalization (BatchNorm)
**What it does:**
Normalizes across the batch dimension for each channel/feature.
**Formula:**
```
Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width
For each channel c:
μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims
σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims
x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε)
# Learnable scale and shift
y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c
```
**When to use:**
- ✅ CNNs for classification (ResNet, EfficientNet)
- ✅ Large batch sizes (≥ 16)
- ✅ IID data (image classification, object detection)
**When NOT to use:**
- ❌ Small batch sizes (< 8): Noisy statistics cause training failure
- ❌ RNNs/LSTMs: Breaks temporal dependencies
- ❌ Transformers: Batch dependency problematic for variable-length sequences
- ❌ Style transfer: Batch statistics erase style information
**Batch size dependency:**
```python
batch_size = 32: # ✓ Works well (stable statistics)
batch_size = 16: # ✓ Acceptable
batch_size = 8: # ✓ Marginal (consider GroupNorm)
batch_size = 4: # ✗ Unstable (use GroupNorm)
batch_size = 2: # ✗ FAILS! (noisy statistics)
batch_size = 1: # ✗ Undefined (no batch to normalize over!)
```
**PyTorch example:**
```python
import torch.nn as nn
# For Conv2d
bn = nn.BatchNorm2d(num_features=64) # 64 channels
x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64
y = bn(x)
# For Linear
bn = nn.BatchNorm1d(num_features=128) # 128 features
x = torch.randn(32, 128) # Batch=32, Features=128
y = bn(x)
```
**Inference mode:**
```python
# Training: Uses batch statistics
model.train()
y = bn(x) # Normalizes using current batch mean/std
# Inference: Uses running statistics (accumulated during training)
model.eval()
y = bn(x) # Normalizes using running_mean/running_std
```
### 2. Layer Normalization (LayerNorm)
**What it does:**
Normalizes across the feature dimension for each sample independently.
**Formula:**
```
Given input x with shape (B, C): # Batch, Features
For each sample b:
μ_b = mean(x[b, :]) # Mean over features
σ_b = std(x[b, :]) # Std over features
x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε)
# Learnable scale and shift
y[b, :] = γ * x_norm[b, :] + β
```
**When to use:**
- ✅ Transformers (BERT, GPT, T5)
- ✅ RNNs/LSTMs (maintains temporal independence)
- ✅ Small batch sizes (batch-independent!)
- ✅ Variable-length sequences
- ✅ Reinforcement learning (batch_size=1 common)
**Advantages over BatchNorm:**
-**Batch-independent**: Works with batch_size=1
-**No running statistics**: Inference = training (no mode switching)
-**Sequence-friendly**: Doesn't mix information across timesteps
**PyTorch example:**
```python
import torch.nn as nn
# For Transformer
ln = nn.LayerNorm(normalized_shape=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512
y = ln(x) # Normalizes last dimension independently per (batch, position)
# For RNN hidden states
ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256
h = torch.randn(32, 256) # Batch=32, Hidden=256
h_norm = ln(h)
```
**Key difference from BatchNorm:**
```python
# BatchNorm: Normalizes across batch dimension
# Given (B=32, C=64, H=28, W=28)
# Computes 64 means/stds (one per channel, across batch + spatial)
# LayerNorm: Normalizes across feature dimension
# Given (B=32, L=128, D=512)
# Computes 32×128 means/stds (one per (batch, position), across features)
```
### 3. Group Normalization (GroupNorm)
**What it does:**
Normalizes channels in groups, batch-independent.
**Formula:**
```
Given input x with shape (B, C, H, W):
Divide C channels into G groups (C must be divisible by G)
For each sample b and group g:
channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g
μ_{b,g} = mean(channels) # Mean over channels in group + spatial
σ_{b,g} = std(channels) # Std over channels in group + spatial
x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε)
```
**When to use:**
- ✅ Small batch sizes (< 8)
- ✅ CNNs with batch_size=1 (style transfer, RL)
- ✅ Object detection/segmentation (often use small batches)
- ✅ When BatchNorm unstable but want spatial normalization
**Group size selection:**
```python
# num_groups trade-off:
num_groups = 1: # = LayerNorm (all channels together)
num_groups = C: # = InstanceNorm (each channel separate)
num_groups = 32: # Standard choice (good balance)
# Rule: C must be divisible by num_groups
channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group
channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group
channels = 64, num_groups = 30: # ✗ 64/30 not integer!
```
**PyTorch example:**
```python
import torch.nn as nn
# For small batch CNN
gn = nn.GroupNorm(num_groups=32, num_channels=64)
x = torch.randn(2, 64, 28, 28) # Batch=2 (small!)
y = gn(x) # Works well even with batch=2
# Compare performance:
batch_sizes = [1, 2, 4, 8, 16, 32]
bn = nn.BatchNorm2d(64)
gn = nn.GroupNorm(32, 64)
for bs in batch_sizes:
x = torch.randn(bs, 64, 28, 28)
# BatchNorm gets more stable with larger batch
# GroupNorm consistent across all batch sizes
```
**Empirical results (He et al. 2018):**
```
ImageNet classification with ResNet-50:
batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie)
batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!)
batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!)
```
### 4. Instance Normalization (InstanceNorm)
**What it does:**
Normalizes each sample and channel independently (no batch mixing).
**Formula:**
```
Given input x with shape (B, C, H, W):
For each sample b and channel c:
μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only
σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only
x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε)
```
**When to use:**
- ✅ Style transfer (neural style, CycleGAN, pix2pix)
- ✅ Image-to-image translation
- ✅ When batch/channel mixing destroys information
**Why for style transfer:**
```python
# Style transfer goal: Transfer style while preserving content
# BatchNorm: Mixes statistics across batch (erases individual style!)
# InstanceNorm: Per-image statistics (preserves each image's style)
# Example: Neural style transfer
content_image = load_image("photo.jpg")
style_image = load_image("starry_night.jpg")
# With BatchNorm: Output loses content image's unique characteristics
# With InstanceNorm: Content characteristics preserved, style applied
```
**PyTorch example:**
```python
import torch.nn as nn
# For style transfer generator
in_norm = nn.InstanceNorm2d(num_features=64)
x = torch.randn(1, 64, 256, 256) # Single image
y = in_norm(x) # Normalizes each channel independently
# CycleGAN generator architecture
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm!
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Per-image normalization
x = self.relu(x)
return x
```
**Relation to GroupNorm:**
```python
# InstanceNorm is GroupNorm with num_groups = num_channels
InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C)
```
### 5. RMS Normalization (RMSNorm)
**What it does:**
Simplified LayerNorm that only rescales (no recentering), faster and simpler.
**Formula:**
```
Given input x:
# LayerNorm (2 steps):
x_centered = x - mean(x) # 1. Center
x_norm = x_centered / std(x) # 2. Scale
# RMSNorm (1 step):
rms = sqrt(mean(x²)) # Root Mean Square
x_norm = x / rms # Only scale, no centering
```
**When to use:**
- ✅ Modern LLMs (LLaMA, Mistral, Gemma)
- ✅ When speed matters (15-20% faster than LayerNorm)
- ✅ Large Transformer models (billions of parameters)
**Advantages:**
-**Simpler**: One operation instead of two
-**Faster**: ~15-20% speedup over LayerNorm
-**Numerically stable**: No subtraction (avoids catastrophic cancellation)
-**Same performance**: Empirically matches LayerNorm quality
**PyTorch implementation:**
```python
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize
x_norm = x / rms
# Scale (learnable)
return self.weight * x_norm
# Usage in Transformer
rms = RMSNorm(dim=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model
y = rms(x)
```
**Speed comparison (LLaMA-7B, A100 GPU):**
```
LayerNorm: 1000 tokens/sec
RMSNorm: 1180 tokens/sec # 18% faster!
# For large models, this adds up:
# 1 million tokens: 180 seconds saved
```
**Modern LLM adoption:**
```python
# LLaMA (Meta, 2023): RMSNorm
# Mistral (Mistral AI, 2023): RMSNorm
# Gemma (Google, 2024): RMSNorm
# PaLM (Google, 2022): RMSNorm
# Older models:
# GPT-2/3 (OpenAI): LayerNorm
# BERT (Google, 2018): LayerNorm
```
## Architecture-Specific Selection
### CNN (Convolutional Neural Networks)
**Default: BatchNorm**
```python
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels) # After conv
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # Normalize
x = self.relu(x)
return x
```
**Exception: Small batch sizes**
```python
# If batch_size < 8, use GroupNorm instead
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches
self.relu = nn.ReLU(inplace=True)
```
**Exception: Style transfer**
```python
# Use InstanceNorm for style transfer
class StyleConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization
self.relu = nn.ReLU(inplace=True)
```
### RNN / LSTM (Recurrent Neural Networks)
**Default: LayerNorm**
```python
import torch.nn as nn
class NormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states
def forward(self, x):
# x: (batch, seq_len, input_size)
output, (h_n, c_n) = self.lstm(x)
# output: (batch, seq_len, hidden_size)
# Normalize each timestep's output
output_norm = self.ln(output) # Applies independently per timestep
return output_norm, (h_n, c_n)
```
**Why NOT BatchNorm:**
```python
# BatchNorm in RNN mixes information across timesteps!
# Given (batch=32, seq_len=100, hidden=256)
# BatchNorm would compute:
# mean/std over (batch × seq_len) = 3200 values
# This mixes t=0 with t=99 (destroys temporal structure!)
# LayerNorm computes:
# mean/std over hidden_size = 256 values per (batch, timestep)
# Each timestep normalized independently (preserves temporal structure)
```
**Layer-wise normalization in stacked RNN:**
```python
class StackedNormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_size = input_size if i == 0 else hidden_size
self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True))
self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer
def forward(self, x):
for lstm, ln in zip(self.layers[::2], self.layers[1::2]):
x, _ = lstm(x)
x = ln(x) # Normalize between layers
return x
```
### Transformer
**Default: LayerNorm (or RMSNorm for modern/large models)**
**Two placement options: Pre-norm vs Post-norm**
**Post-norm (original Transformer, "Attention is All You Need"):**
```python
class TransformerLayerPostNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Post-norm: Apply normalization AFTER residual
x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding
x = self.ln2(x + self.ffn(x)) # Normalize after adding
return x
```
**Pre-norm (modern, more stable):**
```python
class TransformerLayerPreNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm: Apply normalization BEFORE sublayer
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention
x = x + self.ffn(self.ln2(x)) # Normalize before FFN
return x
```
**Pre-norm vs Post-norm comparison:**
```python
# Post-norm (original):
# - Less stable (requires careful initialization + warmup)
# - Slightly better performance IF training succeeds
# - Hard to train deep models (> 12 layers)
# Pre-norm (modern):
# - More stable (easier to train deep models)
# - Standard for large models (GPT-3: 96 layers!)
# - Recommended default
# Empirical: GPT-2, BERT (post-norm, ≤12 layers)
# GPT-3, T5, LLaMA (pre-norm, ≥24 layers)
```
**Using RMSNorm instead:**
```python
class TransformerLayerRMSNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm
self.rms2 = RMSNorm(d_model)
def forward(self, x):
# Pre-norm with RMSNorm (LLaMA style)
x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0]
x = x + self.ffn(self.rms2(x))
return x
```
### GAN (Generative Adversarial Network)
**Generator: InstanceNorm or no normalization**
```python
class Generator(nn.Module):
def __init__(self):
super().__init__()
# Use InstanceNorm for image-to-image translation
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # Per-image normalization
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Preserves per-image characteristics
return x
```
**Discriminator: No normalization or LayerNorm**
```python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
# Often no normalization (BatchNorm can hurt GAN training)
self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
# No normalization here
def forward(self, x):
x = self.conv1(x)
# Directly to activation (no norm)
return x
```
**Why avoid BatchNorm in GANs:**
```python
# BatchNorm in discriminator:
# - Mixes real and fake samples in batch
# - Leaks information (discriminator can detect batch composition)
# - Hurts training stability
# Recommendation:
# Generator: InstanceNorm (for image translation) or no norm
# Discriminator: No normalization or LayerNorm
```
## Decision Framework
### Step 1: Check batch size
```python
if batch_size >= 8:
consider_batchnorm = True
else:
use_groupnorm_or_layernorm = True # BatchNorm will be unstable
```
### Step 2: Check architecture
```python
if architecture == "CNN":
if batch_size >= 8:
use_batchnorm()
else:
use_groupnorm(num_groups=32)
# Exception: Style transfer
if task == "style_transfer":
use_instancenorm()
elif architecture in ["RNN", "LSTM", "GRU"]:
use_layernorm() # NEVER BatchNorm!
elif architecture == "Transformer":
if model_size == "large": # > 1B parameters
use_rmsnorm() # 15-20% faster
else:
use_layernorm()
# Placement: Pre-norm (more stable)
use_prenorm_placement()
elif architecture == "GAN":
if component == "generator":
if task == "image_translation":
use_instancenorm()
else:
use_no_norm() # Or InstanceNorm
elif component == "discriminator":
use_no_norm() # Or LayerNorm
```
### Step 3: Verify placement
```python
# CNNs: After convolution, before activation
x = conv(x)
x = norm(x) # Here!
x = relu(x)
# RNNs: After LSTM, normalize hidden states
output, (h, c) = lstm(x)
output = norm(output) # Here!
# Transformers: Pre-norm (modern) or post-norm (original)
# Pre-norm (recommended):
x = x + sublayer(norm(x)) # Normalize before sublayer
# Post-norm (original):
x = norm(x + sublayer(x)) # Normalize after residual
```
## Implementation Checklist
### Before adding normalization:
1.**Check batch size**: If < 8, avoid BatchNorm
2.**Check architecture**: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm
3.**Check task**: Style transfer→InstanceNorm
4.**Verify placement**: After conv/linear, before activation (CNNs)
5.**Test training stability**: Loss should decrease smoothly
### During training:
6.**Monitor running statistics** (BatchNorm): Check running_mean/running_var are updating
7.**Test inference mode**: Verify model.eval() uses running stats correctly
8.**Check gradient flow**: Normalization should help, not hurt gradients
### If training is unstable:
9.**Try different normalization**: BatchNorm→GroupNorm, LayerNorm→RMSNorm
10.**Try pre-norm** (Transformers): More stable than post-norm
11.**Reduce learning rate**: Normalization allows larger LR, but start conservatively
## Common Mistakes
### Mistake 1: BatchNorm with small batches
```python
# WRONG: BatchNorm with batch_size=2
model = ResNet50(norm_layer=nn.BatchNorm2d)
dataloader = DataLoader(dataset, batch_size=2) # Too small!
# RIGHT: GroupNorm for small batches
model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels))
dataloader = DataLoader(dataset, batch_size=2) # Works!
```
### Mistake 2: BatchNorm in RNN
```python
# WRONG: BatchNorm in LSTM
class BadLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps
def forward(self, x):
output, _ = self.lstm(x)
output = output.permute(0, 2, 1) # (B, H, T)
output = self.bn(output) # Mixes timesteps!
return output
# RIGHT: LayerNorm in LSTM
class GoodLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.ln = nn.LayerNorm(256) # Per-timestep normalization
def forward(self, x):
output, _ = self.lstm(x)
output = self.ln(output) # Independent per timestep
return output
```
### Mistake 3: Forgetting model.eval()
```python
# WRONG: Using training mode during inference
model.train() # BatchNorm uses batch statistics
predictions = model(test_data) # Batch statistics from test data (leakage!)
# RIGHT: Use eval mode during inference
model.eval() # BatchNorm uses running statistics
with torch.no_grad():
predictions = model(test_data) # Uses accumulated running stats
```
### Mistake 4: Post-norm for deep Transformers
```python
# WRONG: Post-norm for 24-layer Transformer (unstable!)
class DeepTransformerPostNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPostNorm(512, 8) for _ in range(24)
]) # Hard to train!
# RIGHT: Pre-norm for deep Transformers
class DeepTransformerPreNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPreNorm(512, 8) for _ in range(24)
]) # Stable training!
```
### Mistake 5: Wrong normalization for style transfer
```python
# WRONG: BatchNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch
# RIGHT: InstanceNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.InstanceNorm2d(64) # Per-image normalization
```
## Performance Impact
### Training speed:
```python
# Without normalization: 100 epochs to converge
# With normalization: 10 epochs to converge (10x faster!)
# Reason: Larger learning rates possible
lr_no_norm = 0.001 # Must be small (unstable otherwise)
lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes)
```
### Inference speed:
```python
# Normalization overhead (relative to no normalization):
BatchNorm: +2% (minimal, cached running stats)
LayerNorm: +3-5% (compute mean/std per forward pass)
RMSNorm: +2-3% (faster than LayerNorm)
GroupNorm: +5-8% (more computation than BatchNorm)
InstanceNorm: +3-5% (similar to LayerNorm)
# For most models: Overhead is negligible compared to conv/linear layers
```
### Memory usage:
```python
# Normalization memory (per layer):
BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β)
LayerNorm: 2 × normalized_shape (γ, β)
RMSNorm: 1 × normalized_shape (γ only, no β)
# Example: 512 channels
BatchNorm: 4 × 512 = 2048 parameters
LayerNorm: 2 × 512 = 1024 parameters
RMSNorm: 1 × 512 = 512 parameters # Most efficient!
```
## When NOT to Normalize
**Case 1: Final output layer**
```python
# Don't normalize final predictions
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = ResNet50() # Normalization inside
self.fc = nn.Linear(2048, 1000)
# NO normalization here! (final logits should be unnormalized)
def forward(self, x):
x = self.backbone(x)
x = self.fc(x) # Raw logits
return x # Don't normalize!
```
**Case 2: Very small networks**
```python
# Single-layer network: Normalization overkill
class TinyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10) # MNIST classifier
# No normalization needed (network too simple)
def forward(self, x):
return self.fc(x)
```
**Case 3: When debugging**
```python
# Remove normalization to isolate issues
# If training fails with normalization, try without to check if:
# - Initialization is correct
# - Loss function is correct
# - Data is correctly preprocessed
```
## Modern Recommendations (2025)
### CNNs:
- **Default**: BatchNorm (if batch_size ≥ 8)
- **Small batches**: GroupNorm (num_groups=32)
- **Style transfer**: InstanceNorm
### RNNs/LSTMs:
- **Default**: LayerNorm
- **Never**: BatchNorm (breaks temporal structure)
### Transformers:
- **Small models** (< 1B): LayerNorm + pre-norm
- **Large models** (≥ 1B): RMSNorm + pre-norm (15-20% faster)
- **Avoid**: Post-norm for deep models (> 12 layers)
### GANs:
- **Generator**: InstanceNorm (image translation) or no norm
- **Discriminator**: No normalization or LayerNorm
- **Avoid**: BatchNorm (leaks information)
### Emerging:
- **RMSNorm adoption increasing**: LLaMA, Mistral, Gemma all use RMSNorm
- **Pre-norm becoming standard**: More stable for deep networks
- **GroupNorm gaining traction**: Object detection, small-batch training
## Summary
**Normalization is mandatory for modern deep learning.** The question is which normalization, not whether to normalize.
**Quick decision tree:**
1. **Batch size ≥ 8?** → Consider BatchNorm (CNNs)
2. **Batch size < 8?** → Use GroupNorm (CNNs) or LayerNorm (all)
3. **RNN/LSTM?** → LayerNorm (never BatchNorm!)
4. **Transformer?** → LayerNorm or RMSNorm with pre-norm
5. **Style transfer?** → InstanceNorm
6. **GAN?** → InstanceNorm (generator) or no norm (discriminator)
**Modern defaults:**
- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8)
- RNNs: LayerNorm
- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models)
- GANs: InstanceNorm (generator), no norm (discriminator)
**Key insight**: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer.

View File

@@ -0,0 +1,714 @@
# Sequence Models Comparison: Choosing the Right Architecture for Sequential Data
<CRITICAL_CONTEXT>
Sequence modeling has evolved rapidly:
- 2014-2017: LSTM/GRU dominated
- 2017+: Transformers revolutionized the field
- 2018+: TCN emerged as efficient alternative
- 2021+: Sparse Transformers for very long sequences
- 2022+: State Space Models (S4) for extreme lengths
Don't default to LSTM (outdated) or blindly use Transformers (not always appropriate).
Match architecture to your sequence characteristics.
</CRITICAL_CONTEXT>
## When to Use This Skill
Use this skill when:
- ✅ Selecting model for sequential/temporal data
- ✅ Comparing RNN vs LSTM vs Transformer
- ✅ Deciding on sequence architecture for time series, text, audio
- ✅ Understanding modern alternatives to LSTM
- ✅ Optimizing for sequence length, speed, or accuracy
DO NOT use for:
- ❌ Vision tasks (use cnn-families-and-selection)
- ❌ Graph-structured data (use graph-neural-networks-basics)
- ❌ LLM-specific questions (use llm-specialist pack)
**When in doubt:** If data is sequential/temporal → this skill.
## Selection Framework
### Step 1: Identify Key Characteristics
**Before recommending, ask:**
| Characteristic | Question | Impact |
|----------------|----------|--------|
| **Sequence Length** | Typical length? | Short (< 100) → LSTM/CNN, Medium (100-1k) → Transformer, Long (> 1k) → Sparse Transformer/S4 |
| **Data Type** | Language, time series, audio? | Language → Transformer, Time series → TCN/Transformer, Audio → Specialized |
| **Data Volume** | Training examples? | Small (< 10k) → LSTM/TCN, Large (> 100k) → Transformer |
| **Latency** | Real-time needed? | Yes → TCN/LSTM, No → Transformer |
| **Deployment** | Cloud/edge/mobile? | Edge → TCN/LSTM, Cloud → Any |
### Step 2: Apply Decision Tree
```
START: What's your primary constraint?
┌─ SEQUENCE LENGTH
│ ├─ Short (< 100 steps)
│ │ ├─ Language → BiLSTM or small Transformer
│ │ └─ Time series → TCN or LSTM
│ │
│ ├─ Medium (100-1000 steps)
│ │ ├─ Language → Transformer (BERT-style)
│ │ └─ Time series → Transformer or TCN
│ │
│ ├─ Long (1000-10000 steps)
│ │ ├─ Sparse Transformer (Longformer, BigBird)
│ │ └─ Hierarchical models
│ │
│ └─ Very Long (> 10000 steps)
│ └─ State Space Models (S4)
├─ DATA TYPE
│ ├─ Natural Language
│ │ ├─ < 50k data → BiLSTM or DistilBERT
│ │ └─ > 50k data → Transformer (BERT, RoBERTa)
│ │
│ ├─ Time Series
│ │ ├─ Fast training → TCN
│ │ ├─ Long sequences → Transformer
│ │ └─ Multivariate → Transformer with cross-series attention
│ │
│ └─ Audio
│ ├─ Waveform → WaveNet (TCN-based)
│ └─ Spectrograms → CNN + Transformer
└─ COMPUTATIONAL CONSTRAINT
├─ Edge device → TCN or small LSTM
├─ Real-time latency → TCN (parallel inference)
└─ Cloud, no constraint → Transformer
```
## Architecture Catalog
### 1. RNN (Recurrent Neural Networks) - Legacy Foundation
**Architecture:** Basic recurrent cell with hidden state
**Status:** **OUTDATED** - don't use for new projects
**Why it existed:**
- First neural approach to sequences
- Hidden state captures temporal information
- Theoretically can model any sequence
**Why it failed:**
- Vanishing gradient (can't learn long dependencies)
- Very slow training (sequential processing)
- Replaced by LSTM in 2014
**When to mention:**
- Historical context only
- Teaching purposes
- Never recommend for production
**Key Insight:** Proved neural nets could handle sequences, but impractical due to vanishing gradients
### 2. LSTM (Long Short-Term Memory) - Legacy Standard
**Architecture:** Gated recurrent cell (forget, input, output gates)
**Complexity:** O(n) memory, sequential processing
**Strengths:**
- Solves vanishing gradient (gates maintain long-term info)
- Works well for short-medium sequences (< 500 steps)
- Small datasets (< 10k examples)
- Low memory footprint
**Weaknesses:**
- Sequential processing (slow training, can't parallelize)
- Still struggles with very long sequences (> 1000 steps)
- Slow inference (especially bidirectional)
- Superseded by Transformers for most language tasks
**When to Use:**
- ✅ Small datasets (< 10k examples)
- ✅ Short sequences (< 100 steps)
- ✅ Edge deployment (low memory)
- ✅ Baseline comparison
**When NOT to Use:**
- ❌ Large datasets (Transformer better)
- ❌ Long sequences (> 500 steps)
- ❌ Modern NLP (Transformer standard)
- ❌ Fast training needed (TCN better)
**Code Example:**
```python
class SeqLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len, features)
lstm_out, _ = self.lstm(x)
# Use last timestep
out = self.fc(lstm_out[:, -1, :])
return out
```
**Status:** Legacy but still useful for specific cases (small data, edge deployment)
### 3. GRU (Gated Recurrent Unit) - Simplified LSTM
**Architecture:** Simplified gating (2 gates instead of 3)
**Advantages over LSTM:**
- Fewer parameters (faster training)
- Similar performance in many tasks
- Lower memory
**Disadvantages:**
- Still sequential (same as LSTM)
- No major advantage over LSTM in practice
- Also superseded by Transformers
**When to Use:**
- Same as LSTM, but prefer LSTM for slightly better performance
- Use if computational savings matter
**Status:** Rarely recommended - if using recurrent, prefer LSTM or move to Transformer/TCN
### 4. Transformer - Modern Standard
**Architecture:** Self-attention mechanism, parallel processing
**Complexity:**
- Memory: O(n²) for sequence length n
- Compute: O(n²d) where d is embedding dimension
**Strengths:**
- ✅ Parallel processing (fast training)
- ✅ Captures long-range dependencies (better than LSTM)
- ✅ State-of-the-art for language (BERT, GPT)
- ✅ Pre-trained models available
- ✅ Scales with data (more data = better performance)
**Weaknesses:**
- ❌ Quadratic memory (struggles with sequences > 1000)
- ❌ Needs more data than LSTM (> 10k examples)
- ❌ Slower inference than TCN
- ❌ Harder to interpret than RNN
**When to Use:**
-**Natural language** (current standard)
- ✅ Medium sequences (100-1000 tokens)
- ✅ Large datasets (> 50k examples)
- ✅ Pre-training available (BERT, GPT)
- ✅ Accuracy priority
**When NOT to Use:**
- ❌ Short sequences (< 50 tokens) - LSTM/CNN competitive, simpler
- ❌ Very long sequences (> 2000) - quadratic memory explodes
- ❌ Small datasets (< 10k) - will overfit
- ❌ Edge deployment - large model size
**Memory Analysis:**
```python
# Standard Transformer attention
# For sequence length n=1000, batch_size=32, embedding_dim=512:
attention_weights = softmax(Q @ K^T / sqrt(d)) # Shape: (32, 1000, 1000)
# Memory: 32 * 1000 * 1000 * 4 bytes = 128 MB just for attention!
# For n=5000:
# Memory: 32 * 5000 * 5000 * 4 bytes = 3.2 GB per batch!
# → Impossible on most GPUs
```
**Code Example:**
```python
from transformers import BertModel, BertTokenizer
# Pre-trained BERT for text classification
class TransformerClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Use [CLS] token representation
pooled = outputs.pooler_output
return self.classifier(pooled)
# Fine-tuning
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TransformerClassifier(num_classes=2)
```
**Status:** **Current standard for NLP**, competitive for time series with large data
### 5. TCN (Temporal Convolutional Network) - Underrated Alternative
**Architecture:** 1D convolutions with dilated causal convolutions
**Complexity:** O(n) memory, fully parallel processing
**Strengths:**
-**Parallel training** (much faster than LSTM)
-**Parallel inference** (faster than LSTM/Transformer)
- ✅ Linear memory (no quadratic blow-up)
- ✅ Large receptive field (dilation)
- ✅ Works well for time series
- ✅ Simple architecture
**Weaknesses:**
- ❌ Less popular (fewer pre-trained models)
- ❌ Not standard for language (Transformer dominates)
- ❌ Fixed receptive field (vs adaptive attention)
**When to Use:**
-**Time series forecasting** (often BETTER than LSTM)
-**Fast training needed** (2-3x faster than LSTM)
-**Fast inference** (real-time applications)
- ✅ Long sequences (linear memory)
- ✅ Audio processing (WaveNet is TCN-based)
**When NOT to Use:**
- ❌ Natural language with pre-training available (use Transformer)
- ❌ Need very large receptive field (Transformer better)
**Performance Comparison:**
```
Time series forecasting (1000-step sequences):
Training speed:
- LSTM: 100% (baseline, sequential)
- TCN: 35% (2.8x faster, parallel)
- Transformer: 45% (2.2x faster)
Inference speed:
- LSTM: 100% (sequential)
- TCN: 20% (5x faster, parallel)
- Transformer: 60% (1.7x faster)
Accuracy (similar across all three):
- LSTM: Baseline
- TCN: Equal or slightly better
- Transformer: Equal or slightly better (needs more data)
Conclusion: TCN wins on speed, matches accuracy
```
**Code Example:**
```python
class TCN(nn.Module):
def __init__(self, input_channels, num_channels, kernel_size=3):
super().__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = input_channels if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
# Causal dilated convolution
layers.append(
nn.Conv1d(in_channels, out_channels, kernel_size,
padding=(kernel_size-1) * dilation_size,
dilation=dilation_size)
)
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.2))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Usage for time series
# Input: (batch, channels, sequence_length)
model = TCN(input_channels=1, num_channels=[64, 128, 256])
```
**Key Insight:** Dilated convolutions create exponentially large receptive field (2^k) while maintaining linear memory
**Status:** **Excellent for time series**, underrated, should be considered before LSTM
### 6. Sparse Transformers - Long Sequence Specialists
**Architecture:** Modified attention patterns to reduce complexity
**Variants:**
- **Longformer**: Local + global attention
- **BigBird**: Random + local + global attention
- **Linformer**: Low-rank projection of keys/values
- **Performer**: Kernel approximation of attention
**Complexity:** O(n log n) or O(n) depending on variant
**When to Use:**
-**Long sequences** (1000-10000 tokens)
- ✅ Document processing (multi-page documents)
- ✅ Long-context language modeling
- ✅ When standard Transformer runs out of memory
**Trade-offs:**
- Slightly lower accuracy than full attention (approximation)
- More complex implementation
- Fewer pre-trained models
**Example Use Cases:**
- Legal document analysis (10k+ tokens)
- Scientific paper understanding
- Long-form text generation
- Time series with thousands of steps
**Status:** Specialized for long sequences, active research area
### 7. State Space Models (S4) - Cutting Edge
**Architecture:** Structured state space with efficient recurrence
**Complexity:** O(n log n) training, O(n) inference
**Strengths:**
-**Very long sequences** (10k-100k steps)
- ✅ Linear inference complexity
- ✅ Strong theoretical foundations
- ✅ Handles continuous-time sequences
**Weaknesses:**
- ❌ Newer (less mature ecosystem)
- ❌ Complex mathematics
- ❌ Fewer pre-trained models
- ❌ Harder to implement
**When to Use:**
- ✅ Extremely long sequences (> 10k steps)
- ✅ Audio (raw waveforms, 16kHz sampling)
- ✅ Medical signals (ECG, EEG)
- ✅ Research applications
**Status:** **Cutting edge** (2022+), promising for very long sequences
## Practical Selection Guide
### Scenario 1: Natural Language Processing
**Short text (< 50 tokens, e.g., tweets, titles):**
```
Small dataset (< 10k):
→ BiLSTM or 1D CNN (simple, effective)
Large dataset (> 10k):
→ DistilBERT (smaller Transformer, 40M params)
→ Or BiLSTM if latency critical
```
**Medium text (50-512 tokens, e.g., reviews, articles):**
```
Standard approach:
→ BERT, RoBERTa, or similar (110M params)
→ Fine-tune on task-specific data
Small dataset:
→ DistilBERT (66M params, faster, similar accuracy)
```
**Long documents (> 512 tokens):**
```
→ Longformer (4096 tokens max)
→ BigBird (4096 tokens max)
→ Hierarchical: Process in chunks, aggregate
```
### Scenario 2: Time Series Forecasting
**Short sequences (< 100 steps):**
```
Fast training:
→ TCN (2-3x faster than LSTM)
Small dataset:
→ LSTM or simple models (ARIMA, Prophet)
Baseline:
→ LSTM (well-tested)
```
**Medium sequences (100-1000 steps):**
```
Best accuracy:
→ Transformer (if data > 50k examples)
Fast training/inference:
→ TCN (parallel processing)
Multivariate:
→ Transformer with cross-series attention
```
**Long sequences (> 1000 steps):**
```
→ Sparse Transformer (Informer for time series)
→ Hierarchical models (chunk + aggregate)
→ State Space Models (S4)
```
### Scenario 3: Audio Processing
**Waveform (raw audio, 16kHz):**
```
→ WaveNet (TCN-based)
→ State Space Models (S4)
```
**Spectrograms (mel-spectrograms):**
```
→ CNN + BiLSTM (traditional)
→ CNN + Transformer (modern)
```
**Speech recognition:**
```
→ Transformer (Wav2Vec 2.0, Whisper)
→ Pre-trained models available
```
## Trade-Off Analysis
### Speed Comparison
**Training speed (1000-step sequences):**
```
LSTM: 100% (baseline, sequential)
GRU: 75% (simpler gates)
TCN: 35% (2.8x faster, parallel)
Transformer: 45% (2.2x faster, parallel)
Conclusion: TCN fastest for training
```
**Inference speed:**
```
LSTM: 100% (sequential)
BiLSTM: 200% (2x passes)
TCN: 20% (5x faster, parallel)
Transformer: 60% (faster, but attention overhead)
Conclusion: TCN fastest for inference
```
### Memory Comparison
**Sequence length n=1000, batch=32:**
```
LSTM: ~500 MB (linear in n)
Transformer: ~2 GB (quadratic in n)
TCN: ~400 MB (linear in n)
Sparse Transformer: ~800 MB (n log n)
For n=5000:
LSTM: ~2 GB
Transformer: OUT OF MEMORY (50 GB needed!)
TCN: ~2 GB
Sparse Transformer: ~4 GB
```
### Accuracy vs Data Size
**Small dataset (< 10k examples):**
```
LSTM: ★★★★☆ (works well with little data)
Transformer: ★★☆☆☆ (overfits, needs more data)
TCN: ★★★★☆ (similar to LSTM)
Winner: LSTM or TCN
```
**Large dataset (> 100k examples):**
```
LSTM: ★★★☆☆ (good but plateaus)
Transformer: ★★★★★ (best, scales with data)
TCN: ★★★★☆ (competitive)
Winner: Transformer
```
## Common Pitfalls
### Pitfall 1: Using LSTM in 2025 Without Considering Modern Alternatives
**Symptom:** Defaulting to LSTM for all sequence tasks
**Why it's wrong:** Transformers (language) and TCN (time series) often better
**Fix:** Consider Transformer for language, TCN for time series, LSTM for small data/edge only
### Pitfall 2: Using Standard Transformer for Very Long Sequences
**Symptom:** Running out of memory on sequences > 1000 tokens
**Why it's wrong:** O(n²) memory explodes
**Fix:** Use Sparse Transformer (Longformer, BigBird) or hierarchical approach
### Pitfall 3: Not Trying TCN for Time Series
**Symptom:** Struggling with slow LSTM training
**Why it's wrong:** TCN is 2-3x faster, often more accurate
**Fix:** Try TCN before optimizing LSTM
### Pitfall 4: Using Transformer for Small Datasets
**Symptom:** Transformer overfits on < 10k examples
**Why it's wrong:** Transformers need large datasets to work well
**Fix:** Use LSTM or TCN for small datasets, or use pre-trained Transformer
### Pitfall 5: Ignoring Sequence Length Constraints
**Symptom:** Choosing architecture without considering typical sequence length
**Why it's wrong:** Architecture effectiveness varies dramatically with length
**Fix:** Match architecture to sequence length (short → LSTM/CNN, long → Sparse Transformer)
## Evolution Timeline
**Understanding why architectures evolved:**
```
2010-2013: Basic RNN
→ Vanishing gradient problem
→ Can't learn long dependencies
2014: LSTM (Hochreiter & Schmidhuber)
→ Gates solve vanishing gradient
→ Became standard for sequences
2014: GRU
→ Simplified LSTM
→ Similar performance, fewer parameters
2017: Transformer (Attention Is All You Need)
→ Self-attention replaces recurrence
→ Parallel processing (fast training)
→ Revolutionized NLP
2018: TCN (Temporal Convolutional Networks)
→ Dilated convolutions for sequences
→ Often better than LSTM for time series
→ Underrated alternative
2020: Sparse Transformers
→ Reduce quadratic complexity
→ Enable longer sequences
2021: State Space Models (S4)
→ Very long sequences (10k-100k)
→ Theoretical foundations
→ Cutting edge research
Current (2025):
- NLP: Transformer standard (BERT, GPT)
- Time Series: TCN or Transformer
- Audio: Specialized (WaveNet, Transformer)
- Edge: LSTM or TCN (low memory)
```
## Decision Checklist
Before choosing sequence model:
```
☐ Sequence length? (< 100 / 100-1k / > 1k)
☐ Data type? (language / time series / audio / other)
☐ Dataset size? (< 10k / 10k-100k / > 100k)
☐ Latency requirement? (real-time / batch / offline)
☐ Deployment target? (cloud / edge / mobile)
☐ Pre-trained models available? (yes / no)
☐ Training speed critical? (yes / no)
Based on answers:
→ Language + large data → Transformer
→ Language + small data → BiLSTM or DistilBERT
→ Time series + speed → TCN
→ Time series + accuracy + large data → Transformer
→ Very long sequences → Sparse Transformer or S4
→ Edge deployment → TCN or LSTM
→ Real-time latency → TCN
```
## Integration with Other Skills
**For language-specific questions:**
`yzmir/llm-specialist/using-llm-specialist`
- LLM-specific Transformers (GPT, BERT variants)
- Fine-tuning strategies
- Prompt engineering
**For Transformer internals:**
`yzmir/neural-architectures/transformer-architecture-deepdive`
- Attention mechanisms
- Positional encoding
- Transformer variants
**After selecting architecture:**
`yzmir/training-optimization/using-training-optimization`
- Optimizer selection
- Learning rate schedules
- Handling sequence-specific training issues
## Summary
**Quick Reference Table:**
| Use Case | Best Choice | Alternative | Avoid |
|----------|-------------|-------------|-------|
| Short text (< 50 tokens) | BiLSTM, DistilBERT | 1D CNN | Full BERT (overkill) |
| Long text (> 512 tokens) | Longformer, BigBird | Hierarchical | Standard BERT (memory) |
| Time series (< 1k steps) | TCN, Transformer | LSTM | Basic RNN |
| Time series (> 1k steps) | Sparse Transformer, S4 | Hierarchical | Standard Transformer |
| Small dataset (< 10k) | LSTM, TCN | Simple models | Transformer (overfits) |
| Large dataset (> 100k) | Transformer | TCN | LSTM (plateaus) |
| Edge deployment | TCN, LSTM | Quantized Transformer | Large Transformer |
| Real-time inference | TCN | Small LSTM | BiLSTM, Transformer |
**Key Principles:**
1. **Don't default to LSTM** (outdated for most tasks)
2. **Transformer for language** (current standard, if data sufficient)
3. **TCN for time series** (fast, effective, underrated)
4. **Match to sequence length** (short → LSTM/CNN, long → Sparse Transformer)
5. **Consider modern alternatives** (don't stop at LSTM vs Transformer)
**END OF SKILL**

View File

@@ -0,0 +1,937 @@
# Transformer Architecture Deep Dive
## When to Use This Skill
Use this skill when you need to:
- ✅ Implement a Transformer from scratch
- ✅ Understand HOW and WHY self-attention works
- ✅ Choose between encoder, decoder, or encoder-decoder architectures
- ✅ Decide if Vision Transformer (ViT) is appropriate for your vision task
- ✅ Understand modern variants (RoPE, ALiBi, GQA, MQA)
- ✅ Debug Transformer implementation issues
- ✅ Optimize Transformer performance
**Do NOT use this skill for:**
- ❌ High-level architecture selection (use `using-neural-architectures`)
- ❌ Attention mechanism comparison (use `attention-mechanisms-catalog`)
- ❌ LLM-specific topics like prompt engineering (use `llm-specialist` pack)
## Core Principle
**Transformers are NOT magic.** They are:
1. Self-attention mechanism (information retrieval)
2. + Position encoding (break permutation invariance)
3. + Residual connections + Layer norm (training stability)
4. + Feed-forward networks (non-linearity)
Understanding the mechanism beats cargo-culting implementations.
## Part 1: Self-Attention Mechanism Explained
### The Information Retrieval Analogy
**Self-attention = Querying a database:**
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "What information do I have?"
**Process:**
1. Compare your query with all keys (compute similarity)
2. Weight values by similarity
3. Return weighted sum of values
**Example:** Sentence: "The cat sat on the mat"
Token "sat" (verb):
- High attention to "cat" (subject) → Learns verb-subject relationship
- High attention to "mat" (object) → Learns verb-object relationship
- Low attention to "the", "on" (function words)
### Mathematical Breakdown
**Given input X:** (batch, seq_len, d_model)
**Step 1: Project to Q, K, V**
```python
Q = X @ W_Q # (batch, seq_len, d_k)
K = X @ W_K # (batch, seq_len, d_k)
V = X @ W_V # (batch, seq_len, d_v)
# Typically: d_k = d_v = d_model / num_heads
```
**Step 2: Compute attention scores** (similarity)
```python
scores = Q @ K.transpose(-2, -1) # (batch, seq_len, seq_len)
# scores[i, j] = similarity between query_i and key_j
```
**Geometric interpretation:**
- Dot product measures vector alignment
- q · k = ||q|| ||k|| cos(θ)
- Similar vectors → Large dot product → High attention
- Orthogonal vectors → Zero dot product → No attention
**Step 3: Scale by √d_k** (CRITICAL!)
```python
scores = scores / math.sqrt(d_k)
```
**WHY scaling?**
- Dot products grow with dimension: Var(q · k) = d_k
- Example: d_k=64 → Random dot products ~ ±64
- Large scores → Softmax saturates → Gradients vanish
- Scaling: Keep scores ~ O(1) regardless of dimension
**Without scaling:** Softmax([30, 25, 20]) ≈ [0.99, 0.01, 0.00] (saturated!)
**With scaling:** Softmax([3, 2.5, 2]) ≈ [0.50, 0.30, 0.20] (healthy gradients)
**Step 4: Softmax to get attention weights**
```python
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)
# Each row sums to 1 (probability distribution)
# attn_weights[i, j] = "how much token i attends to token j"
```
**Step 5: Weight values**
```python
output = attn_weights @ V # (batch, seq_len, d_v)
# Each token's output = weighted average of all values
```
**Complete formula:**
```python
Attention(Q, K, V) = softmax(Q K^T / d_k) V
```
### Why Three Matrices (Q, K, V)?
**Could we use just one?** Attention(X, X, X)
**Yes, but Q/K/V separation enables:**
1. **Asymmetry**: Query can differ from key (search ≠ database)
2. **Decoupling**: What you search for (Q@K) ≠ what you retrieve (V)
3. **Cross-attention**: Q from one source, K/V from another
- Example: Decoder queries encoder (translation)
**Modern optimization:** Multi-Query Attention (MQA), Grouped-Query Attention (GQA)
- Share K/V across heads (fewer parameters, faster inference)
### Implementation Example
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, d_model, d_k=None):
super().__init__()
self.d_k = d_k or d_model
self.W_q = nn.Linear(d_model, self.d_k)
self.W_k = nn.Linear(d_model, self.d_k)
self.W_v = nn.Linear(d_model, self.d_k)
def forward(self, x, mask=None):
# x: (batch, seq_len, d_model)
Q = self.W_q(x) # (batch, seq_len, d_k)
K = self.W_k(x) # (batch, seq_len, d_k)
V = self.W_v(x) # (batch, seq_len, d_k)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (batch, seq_len, seq_len)
# Apply mask if provided (for causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Attention weights
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)
# Weighted sum of values
output = torch.matmul(attn_weights, V) # (batch, seq_len, d_k)
return output, attn_weights
```
**Complexity:** O(n² · d) where n = seq_len, d = d_model
- **Quadratic in sequence length** (bottleneck for long sequences)
- For n=1000, d=512: 1000² × 512 = 512M operations
## Part 2: Multi-Head Attention
### Why Multiple Heads?
**Single-head attention** learns one attention pattern.
**Multi-head attention** learns multiple parallel patterns:
- Head 1: Syntactic relationships (subject-verb)
- Head 2: Semantic similarity
- Head 3: Positional proximity
- Head 4: Long-range dependencies
**Analogy:** Ensemble of attention functions, each specializing in different patterns.
### Head Dimension Calculation
**CRITICAL CONSTRAINT:** num_heads must divide d_model evenly!
```python
d_model = 512
num_heads = 8
d_k = d_model // num_heads # 512 / 8 = 64
# Each head operates in d_k dimensions
# Concatenate all heads → back to d_model dimensions
```
**Common configurations:**
- BERT-base: d_model=768, heads=12, d_k=64
- GPT-2: d_model=768, heads=12, d_k=64
- GPT-3 175B: d_model=12288, heads=96, d_k=128
- LLaMA-2 70B: d_model=8192, heads=64, d_k=128
**Rule of thumb:** d_k (head dimension) should be 64-128
- Too small (d_k < 32): Limited representational capacity
- Too large (d_k > 256): Redundant, wasteful
### Multi-Head Implementation
```python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Single linear layers for all heads (more efficient)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model) # Output projection
def split_heads(self, x):
# x: (batch, seq_len, d_model)
batch_size, seq_len, d_model = x.size()
# Reshape to (batch, seq_len, num_heads, d_k)
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
# Transpose to (batch, num_heads, seq_len, d_k)
return x.transpose(1, 2)
def forward(self, x, mask=None):
batch_size = x.size(0)
# Linear projections
Q = self.W_q(x) # (batch, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# Split into multiple heads
Q = self.split_heads(Q) # (batch, num_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_weights, V)
# attn_output: (batch, num_heads, seq_len, d_k)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous()
# (batch, seq_len, num_heads, d_k)
attn_output = attn_output.view(batch_size, -1, self.d_model)
# (batch, seq_len, d_model)
# Final linear projection
output = self.W_o(attn_output)
return output, attn_weights
```
### Modern Variants: GQA and MQA
**Problem:** K/V caching during inference is memory-intensive
- LLaMA-2 70B: 8192 × 64 heads × 2 (K + V) = 1M parameters per token cached!
**Solution 1: Multi-Query Attention (MQA)**
- **One** K/V head shared across **all** Q heads
- Benefit: Dramatically faster inference (smaller KV cache)
- Trade-off: ~1-2% accuracy loss
```python
# MQA: Single K/V projection
self.W_k = nn.Linear(d_model, d_k) # Not d_model!
self.W_v = nn.Linear(d_model, d_k)
self.W_q = nn.Linear(d_model, d_model) # Multiple Q heads
```
**Solution 2: Grouped-Query Attention (GQA)**
- Middle ground: Group multiple Q heads per K/V head
- Example: 32 Q heads → 8 K/V heads (4 Q per K/V)
- Benefit: 4x smaller KV cache, minimal accuracy loss
**Used in:** LLaMA-2, Mistral, Mixtral
## Part 3: Position Encoding
### Why Position Encoding?
**Problem:** Self-attention is **permutation-invariant**
- Attention("cat sat mat") = Attention("mat cat sat")
- No inherent notion of position or order!
**Solution:** Add position information to embeddings
### Strategy 1: Sinusoidal Position Encoding (Original)
**Formula:**
```python
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
```
**Implementation:**
```python
def sinusoidal_position_encoding(seq_len, d_model):
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
# Usage: Add to input embeddings
x = token_embeddings + positional_encoding
```
**Properties:**
- Deterministic (no learned parameters)
- Extrapolates to unseen lengths (geometric properties)
- Relative positions: PE(pos+k) is linear function of PE(pos)
**When to use:** Variable-length sequences in NLP
### Strategy 2: Learned Position Embeddings
```python
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
# Usage
positions = torch.arange(seq_len, device=x.device)
x = token_embeddings + self.pos_embedding(positions)
```
**Properties:**
- Learnable (adapts to data)
- Cannot extrapolate beyond max_seq_len
**When to use:**
- Fixed-length sequences
- Vision Transformers (image patches)
- When training data covers all positions
### Strategy 3: Rotary Position Embeddings (RoPE) ⭐
**Modern approach (2021+):** Rotate Q and K in complex plane
**Key advantages:**
- Encodes **relative** positions naturally
- Better long-range decay properties
- No addition to embeddings (applied in attention)
**Used in:** GPT-NeoX, PaLM, LLaMA, LLaMA-2, Mistral
```python
def apply_rotary_pos_emb(x, cos, sin):
# x: (batch, num_heads, seq_len, d_k)
# Split into even/odd
x1, x2 = x[..., ::2], x[..., 1::2]
# Rotate
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
```
### Strategy 4: ALiBi (Attention with Linear Biases) ⭐
**Simplest modern approach:** Add bias to attention scores (no embeddings!)
```python
# Bias matrix: -1 * distance
# [[0, -1, -2, -3],
# [0, 0, -1, -2],
# [0, 0, 0, -1],
# [0, 0, 0, 0]]
scores = Q @ K^T / d_k + alibi_bias
```
**Key advantages:**
- **Best extrapolation** to longer sequences
- No positional embeddings (simpler)
- Per-head slopes (different decay rates)
**Used in:** BLOOM
### Position Encoding Selection Guide
| Use Case | Recommended | Why |
|----------|-------------|-----|
| NLP (variable length) | RoPE or ALiBi | Better extrapolation |
| NLP (fixed length) | Learned embeddings | Adapts to data |
| Vision (ViT) | 2D learned embeddings | Spatial structure |
| Long sequences (>2k) | ALiBi | Best extrapolation |
| Legacy/compatibility | Sinusoidal | Original Transformer |
**Modern trend (2023+):** RoPE and ALiBi dominate over sinusoidal
## Part 4: Architecture Variants
### Variant 1: Encoder-Only (Bidirectional)
**Architecture:**
- Self-attention: Each token attends to **ALL** tokens (past + future)
- No masking (bidirectional context)
**Examples:** BERT, RoBERTa, ELECTRA, DeBERTa
**Use cases:**
- Text classification
- Named entity recognition
- Question answering (extract span from context)
- Sentence embeddings
**Key property:** Sees full context → Good for **understanding**
**Implementation:**
```python
class TransformerEncoder(nn.Module):
def __init__(self, d_model, num_heads, d_ff, num_layers):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask) # No causal mask!
return x
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Self-attention + residual + norm
attn_output, _ = self.self_attn(x, mask)
x = self.norm1(x + attn_output)
# Feed-forward + residual + norm
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output)
return x
```
### Variant 2: Decoder-Only (Autoregressive)
**Architecture:**
- Self-attention with **causal masking**
- Each token attends ONLY to past tokens (not future)
**Causal mask (lower triangular):**
```python
# mask[i, j] = 1 if j <= i else 0
[[1, 0, 0, 0], # Token 0 sees only itself
[1, 1, 0, 0], # Token 1 sees tokens 0-1
[1, 1, 1, 0], # Token 2 sees tokens 0-2
[1, 1, 1, 1]] # Token 3 sees all
```
**Examples:** GPT, GPT-2, GPT-3, GPT-4, LLaMA, Mistral
**Use cases:**
- Text generation
- Language modeling
- Code generation
- Autoregressive prediction
**Key property:** Generates sequentially → Good for **generation**
**Implementation:**
```python
def create_causal_mask(seq_len, device):
# Lower triangular matrix
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask
class TransformerDecoder(nn.Module):
def __init__(self, d_model, num_heads, d_ff, num_layers):
super().__init__()
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
def forward(self, x):
seq_len = x.size(1)
causal_mask = create_causal_mask(seq_len, x.device)
for layer in self.layers:
x = layer(x, causal_mask) # Apply causal mask!
return x
```
**Modern trend (2023+):** Decoder-only architectures dominate
- Can do both generation AND understanding (via prompting)
- Simpler than encoder-decoder (no cross-attention)
- Scales better to massive sizes
### Variant 3: Encoder-Decoder (Seq2Seq)
**Architecture:**
- **Encoder**: Bidirectional self-attention (understands input)
- **Decoder**: Causal self-attention (generates output)
- **Cross-attention**: Decoder queries encoder outputs
**Cross-attention mechanism:**
```python
# Q from decoder, K and V from encoder
Q = decoder_hidden @ W_q
K = encoder_output @ W_k
V = encoder_output @ W_v
cross_attn = softmax(Q K^T / d_k) V
```
**Examples:** T5, BART, mT5, original Transformer (2017)
**Use cases:**
- Translation (input ≠ output language)
- Summarization (long input → short output)
- Question answering (generate answer, not extract)
**When to use:** Input and output are fundamentally different
**Implementation:**
```python
class EncoderDecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads) # NEW!
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, decoder_input, encoder_output, causal_mask=None):
# 1. Self-attention (causal)
self_attn_out, _ = self.self_attn(decoder_input, causal_mask)
x = self.norm1(decoder_input + self_attn_out)
# 2. Cross-attention (Q from decoder, K/V from encoder)
cross_attn_out, _ = self.cross_attn.forward_cross(
query=x,
key=encoder_output,
value=encoder_output
)
x = self.norm2(x + cross_attn_out)
# 3. Feed-forward
ff_out = self.feed_forward(x)
x = self.norm3(x + ff_out)
return x
```
### Architecture Selection Guide
| Task | Architecture | Why |
|------|--------------|-----|
| Classification | Encoder-only | Need full bidirectional context |
| Text generation | Decoder-only | Autoregressive generation |
| Translation | Encoder-decoder or Decoder-only | Different languages, or use prompting |
| Summarization | Encoder-decoder or Decoder-only | Length mismatch, or use prompting |
| Q&A (extract) | Encoder-only | Find span in context |
| Q&A (generate) | Decoder-only | Generate freeform answer |
**2023+ trend:** Decoder-only can do everything via prompting (but less parameter-efficient for some tasks)
## Part 5: Vision Transformers (ViT)
### From Images to Sequences
**Key insight:** Treat image as sequence of patches
**Process:**
1. Split image into patches (e.g., 16×16 pixels)
2. Flatten each patch → 1D vector
3. Linear projection → token embeddings
4. Add 2D positional embeddings
5. Prepend [CLS] token (for classification)
6. Feed to Transformer encoder
**Example:** 224×224 image, 16×16 patches
- Number of patches: (224/16)² = 196
- Each patch: 16 × 16 × 3 = 768 dimensions
- Transformer input: 197 tokens (196 patches + 1 [CLS])
### ViT Implementation
```python
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
d_model=768, num_heads=12, num_layers=12, num_classes=1000):
super().__init__()
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
patch_dim = in_channels * patch_size ** 2
# Patch embedding (linear projection of flattened patches)
self.patch_embed = nn.Linear(patch_dim, d_model)
# [CLS] token (learnable)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
# Position embeddings (learnable)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
# Transformer encoder
self.encoder = TransformerEncoder(d_model, num_heads,
d_ff=4*d_model, num_layers=num_layers)
# Classification head
self.head = nn.Linear(d_model, num_classes)
def forward(self, x):
# x: (batch, channels, height, width)
batch_size = x.size(0)
# Divide into patches and flatten
x = x.unfold(2, self.patch_size, self.patch_size)
x = x.unfold(3, self.patch_size, self.patch_size)
# (batch, channels, num_patches_h, num_patches_w, patch_size, patch_size)
x = x.contiguous().view(batch_size, -1, self.patch_size ** 2 * 3)
# (batch, num_patches, patch_dim)
# Linear projection
x = self.patch_embed(x) # (batch, num_patches, d_model)
# Prepend [CLS] token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch, num_patches+1, d_model)
# Add positional embeddings
x = x + self.pos_embed
# Transformer encoder
x = self.encoder(x)
# Classification: Use [CLS] token
cls_output = x[:, 0] # (batch, d_model)
logits = self.head(cls_output)
return logits
```
### ViT vs CNN: Critical Differences
**1. Inductive Bias**
| Property | CNN | ViT |
|----------|-----|-----|
| Locality | Strong (conv kernel) | Weak (global attention) |
| Translation invariance | Strong (weight sharing) | Weak (position embeddings) |
| Hierarchy | Strong (pooling layers) | None (flat patches) |
**Implication:** CNN has strong priors, ViT learns from data
**2. Data Requirements**
| Dataset Size | CNN | ViT (from scratch) | ViT (pretrained) |
|--------------|-----|-------------------|------------------|
| Small (< 100k) | ✅ Good | ❌ Fails | ✅ Good |
| Medium (100k-1M) | ✅ Excellent | ⚠️ Poor | ✅ Good |
| Large (> 1M) | ✅ Excellent | ⚠️ OK | ✅ Excellent |
| Huge (> 100M) | ✅ Excellent | ✅ SOTA | N/A |
**Key finding:** ViT needs 100M+ images to train from scratch
- Original ViT: Trained on JFT-300M (300 million images)
- Without massive data, ViT underperforms CNNs significantly
**3. Computational Cost**
**Example: 224×224 images**
| Model | Parameters | GFLOPs | Inference (GPU) |
|-------|-----------|--------|-----------------|
| ResNet-50 | 25M | 4.1 | ~30ms |
| EfficientNet-B0 | 5M | 0.4 | ~10ms |
| ViT-B/16 | 86M | 17.6 | ~100ms |
**Implication:** ViT is 40x more expensive than EfficientNet!
### When to Use ViT
**Use ViT when:**
- Large dataset (> 1M images) OR using pretrained weights
- Computational cost acceptable (cloud, large GPU)
- Best possible accuracy needed
- Can fine-tune from ImageNet-21k checkpoint
**Use CNN when:**
- Small/medium dataset (< 1M images) and training from scratch
- Limited compute/memory
- Edge deployment (mobile, embedded)
- Need architectural inductive biases
### Hybrid Approaches (2022-2023)
**ConvNeXt:** CNN with ViT design choices
- Matches ViT accuracy with CNN efficiency
- Works better on small datasets
**Swin Transformer:** Hierarchical ViT with local windows
- Shifted windows for efficiency
- O(n) complexity instead of O(n²)
- Better for dense prediction (segmentation)
**CoAtNet:** Mix conv layers (early) + Transformer layers (late)
- Gets both inductive bias and global attention
## Part 6: Implementation Checklist
### Critical Details
**1. Layer Norm Placement**
**Post-norm (original):**
```python
x = x + self_attn(x)
x = layer_norm(x)
```
**Pre-norm (modern, recommended):**
```python
x = x + self_attn(layer_norm(x))
```
**Why pre-norm?** More stable training, less sensitive to learning rate
**2. Attention Dropout**
Apply dropout to **attention weights**, not Q/K/V!
```python
attn_weights = F.softmax(scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=0.1, training=self.training) # HERE!
output = torch.matmul(attn_weights, V)
```
**3. Feed-Forward Dimension**
Typically: d_ff = 4 × d_model
- BERT: d_model=768, d_ff=3072
- GPT-2: d_model=768, d_ff=3072
**4. Residual Connections**
ALWAYS use residual connections (essential for training)!
```python
x = x + self_attn(x) # Residual
x = x + feed_forward(x) # Residual
```
**5. Initialization**
Use Xavier/Glorot initialization for attention weights:
```python
nn.init.xavier_uniform_(self.W_q.weight)
nn.init.xavier_uniform_(self.W_k.weight)
nn.init.xavier_uniform_(self.W_v.weight)
```
## Part 7: When NOT to Use Transformers
### Limitation 1: Small Datasets
**Problem:** Transformers have weak inductive bias (learn from data)
**Impact:**
- ViT: Fails on < 100k images without pretraining
- NLP: BERT needs 100M+ tokens for pretraining
**Solution:** Use models with stronger priors (CNN for vision, smaller models for text)
### Limitation 2: Long Sequences
**Problem:** O(n²) memory complexity
**Impact:**
- Standard Transformer: n=10k → 100M attention scores
- GPU memory: 10k² × 4 bytes = 400MB per sample!
**Solution:**
- Sparse attention (Longformer, BigBird)
- Linear attention (Linformer, Performer)
- Flash Attention (memory-efficient kernel)
- State space models (S4, Mamba)
### Limitation 3: Edge Deployment
**Problem:** Large model size, high latency
**Impact:**
- ViT-B: 86M parameters, ~100ms inference
- Mobile/embedded: Need < 10M parameters, < 50ms
**Solution:** Efficient CNNs (MobileNet, EfficientNet) or distilled models
### Limitation 4: Real-Time Processing
**Problem:** Sequential generation in decoder (cannot parallelize at inference)
**Impact:** GPT-style models generate one token at a time
**Solution:** Non-autoregressive models, speculative decoding, or smaller models
## Part 8: Common Mistakes
### Mistake 1: Forgetting Causal Mask
**Symptom:** Decoder "cheats" by seeing future tokens
**Fix:** Always apply causal mask to decoder self-attention!
```python
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
```
### Mistake 2: Wrong Dimension for Multi-Head
**Symptom:** Runtime error or dimension mismatch
**Fix:** Ensure d_model % num_heads == 0
```python
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
```
### Mistake 3: Forgetting Position Encoding
**Symptom:** Model ignores word order
**Fix:** Always add position information!
```python
x = token_embeddings + positional_encoding
```
### Mistake 4: Wrong Softmax Dimension
**Symptom:** Attention weights don't sum to 1 per query
**Fix:** Softmax over last dimension (keys)
```python
attn_weights = F.softmax(scores, dim=-1) # Sum over keys for each query
```
### Mistake 5: No Residual Connections
**Symptom:** Training diverges or converges very slowly
**Fix:** Always add residual connections!
```python
x = x + self_attn(x)
x = x + feed_forward(x)
```
## Summary: Quick Reference
### Architecture Selection
```
Classification/Understanding → Encoder-only (BERT-style)
Generation/Autoregressive → Decoder-only (GPT-style)
Seq2Seq (input ≠ output) → Encoder-decoder (T5-style) or Decoder-only with prompting
```
### Position Encoding Selection
```
NLP (variable length) → RoPE or ALiBi
NLP (fixed length) → Learned embeddings
Vision (ViT) → 2D learned embeddings
Long sequences (> 2k) → ALiBi (best extrapolation)
```
### Multi-Head Configuration
```
Small models (d_model < 512): 4-8 heads
Medium models (d_model 512-1024): 8-12 heads
Large models (d_model > 1024): 12-32 heads
Rule: d_k (head dimension) should be 64-128
```
### ViT vs CNN
```
ViT: Large dataset (> 1M) OR pretrained weights
CNN: Small dataset (< 1M) OR edge deployment
```
### Implementation Essentials
```
✅ Pre-norm (more stable than post-norm)
✅ Residual connections (essential!)
✅ Causal mask for decoder
✅ Attention dropout (on weights, not Q/K/V)
✅ d_ff = 4 × d_model (feed-forward dimension)
✅ Check: d_model % num_heads == 0
```
## Next Steps
After mastering this skill:
- `attention-mechanisms-catalog`: Explore attention variants (sparse, linear, Flash)
- `llm-specialist/llm-finetuning-strategies`: Apply to language models
- `architecture-design-principles`: Understand design trade-offs
**Remember:** Transformers are NOT magic. Understanding the mechanism (information retrieval via Q/K/V) beats cargo-culting implementations.