Initial commit
This commit is contained in:
12
.claude-plugin/plugin.json
Normal file
12
.claude-plugin/plugin.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"name": "yzmir-neural-architectures",
|
||||||
|
"description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills",
|
||||||
|
"version": "1.0.1",
|
||||||
|
"author": {
|
||||||
|
"name": "tachyon-beep",
|
||||||
|
"url": "https://github.com/tachyon-beep"
|
||||||
|
},
|
||||||
|
"skills": [
|
||||||
|
"./skills"
|
||||||
|
]
|
||||||
|
}
|
||||||
3
README.md
Normal file
3
README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# yzmir-neural-architectures
|
||||||
|
|
||||||
|
Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills
|
||||||
77
plugin.lock.json
Normal file
77
plugin.lock.json
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
{
|
||||||
|
"$schema": "internal://schemas/plugin.lock.v1.json",
|
||||||
|
"pluginId": "gh:tachyon-beep/skillpacks:plugins/yzmir-neural-architectures",
|
||||||
|
"normalized": {
|
||||||
|
"repo": null,
|
||||||
|
"ref": "refs/tags/v20251128.0",
|
||||||
|
"commit": "52681919f25de737a14d958d80b685171797dfdb",
|
||||||
|
"treeHash": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc",
|
||||||
|
"generatedAt": "2025-11-28T10:28:34.235937Z",
|
||||||
|
"toolVersion": "publish_plugins.py@0.2.0"
|
||||||
|
},
|
||||||
|
"origin": {
|
||||||
|
"remote": "git@github.com:zhongweili/42plugin-data.git",
|
||||||
|
"branch": "master",
|
||||||
|
"commit": "aa1497ed0949fd50e99e70d6324a29c5b34f9390",
|
||||||
|
"repoRoot": "/Users/zhongweili/projects/openmind/42plugin-data"
|
||||||
|
},
|
||||||
|
"manifest": {
|
||||||
|
"name": "yzmir-neural-architectures",
|
||||||
|
"description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills",
|
||||||
|
"version": "1.0.1"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"path": "README.md",
|
||||||
|
"sha256": "cf12418cc5cc226683c463dca6c54e98f97f2ddf57482d593839f430acc36b86"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": ".claude-plugin/plugin.json",
|
||||||
|
"sha256": "9893529f120b1582ad62d899531401988bdf89d6b221e057239e09399d47e255"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/architecture-design-principles.md",
|
||||||
|
"sha256": "bd1cde3e2cf997b6e42dc52bc28638e644ee387c81b07464edc7fb4cf33d4329"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/sequence-models-comparison.md",
|
||||||
|
"sha256": "5c041c08261cf04e47edd3ec092db04a8b4c725fb6bb91620df54005881cbb66"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/normalization-techniques.md",
|
||||||
|
"sha256": "c87c82e926951a963d7909321c60a62b5cf6d09dbcbb576472eda46d72993766"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/graph-neural-networks-basics.md",
|
||||||
|
"sha256": "85a968972f0c1c226d6bcb155b231560af0eab52326f419d308b1a6a22ebd7c5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/generative-model-families.md",
|
||||||
|
"sha256": "7c527707a0d3ecab99e83079880b1166c4310de62a743823ae747a03545180b3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/cnn-families-and-selection.md",
|
||||||
|
"sha256": "acce2a55a84012259119e70d3eb035aee08beb124d0df1aab009ae7fdd01ff51"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/SKILL.md",
|
||||||
|
"sha256": "b8dc5fac750ca1740bd97591409701ca6162e0d5dc5faaf7a2cfcf7feef432b3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/attention-mechanisms-catalog.md",
|
||||||
|
"sha256": "3eab5267c7e5a1263d08be2dca4ef236ad66b554e08aec0e0cd85d5c1f8b5500"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "skills/using-neural-architectures/transformer-architecture-deepdive.md",
|
||||||
|
"sha256": "1913626a82ff02f638d10aa00644fcc59fc9836329dc152a7becd9aa030e0937"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"dirSha256": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc"
|
||||||
|
},
|
||||||
|
"security": {
|
||||||
|
"scannedAt": null,
|
||||||
|
"scannerVersion": null,
|
||||||
|
"flags": []
|
||||||
|
}
|
||||||
|
}
|
||||||
496
skills/using-neural-architectures/SKILL.md
Normal file
496
skills/using-neural-architectures/SKILL.md
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
---
|
||||||
|
name: using-neural-architectures
|
||||||
|
description: Architecture selection router: CNNs, Transformers, RNNs, GANs, GNNs by data modality and constraints
|
||||||
|
mode: true
|
||||||
|
pack: neural-architectures
|
||||||
|
faction: yzmir
|
||||||
|
---
|
||||||
|
|
||||||
|
# Using Neural Architectures: Architecture Selection Router
|
||||||
|
|
||||||
|
<CRITICAL_CONTEXT>
|
||||||
|
Architecture selection comes BEFORE training optimization. Wrong architecture = no amount of training will fix it.
|
||||||
|
|
||||||
|
This meta-skill routes you to the right architecture guidance based on:
|
||||||
|
- Data modality (images, sequences, graphs, etc.)
|
||||||
|
- Problem type (classification, generation, regression)
|
||||||
|
- Constraints (data size, compute, latency, interpretability)
|
||||||
|
|
||||||
|
Load this skill when architecture decisions are needed.
|
||||||
|
</CRITICAL_CONTEXT>
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when:
|
||||||
|
- ✅ Selecting an architecture for a new problem
|
||||||
|
- ✅ Comparing architecture families (CNN vs Transformer, RNN vs Transformer, etc.)
|
||||||
|
- ✅ Designing custom network topology
|
||||||
|
- ✅ Troubleshooting architectural instability (deep networks, gradient issues)
|
||||||
|
- ✅ Understanding when to use specialized architectures (GNNs, generative models)
|
||||||
|
|
||||||
|
DO NOT use for:
|
||||||
|
- ❌ Training/optimization issues (use training-optimization pack)
|
||||||
|
- ❌ PyTorch implementation details (use pytorch-engineering pack)
|
||||||
|
- ❌ Production deployment (use ml-production pack)
|
||||||
|
|
||||||
|
**When in doubt:** If choosing WHAT architecture → this skill. If training/deploying architecture → different pack.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Core Routing Logic
|
||||||
|
|
||||||
|
### Step 1: Identify Data Modality
|
||||||
|
|
||||||
|
**Question to ask:** "What type of data are you working with?"
|
||||||
|
|
||||||
|
| Data Type | Route To | Why |
|
||||||
|
|-----------|----------|-----|
|
||||||
|
| Images (photos, medical scans, etc.) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | CNNs excel at spatial hierarchies |
|
||||||
|
| Sequences (time series, text, audio) | [sequence-models-comparison.md](sequence-models-comparison.md) | Temporal dependencies need sequential models |
|
||||||
|
| Graphs (social networks, molecules) | [graph-neural-networks-basics.md](graph-neural-networks-basics.md) | Graph structure requires GNNs |
|
||||||
|
| Generation task (create images, text) | [generative-model-families.md](generative-model-families.md) | Generative models are specialized |
|
||||||
|
| Multiple modalities (text + images) | [architecture-design-principles.md](architecture-design-principles.md) | Need custom design |
|
||||||
|
| Unclear / Generic | [architecture-design-principles.md](architecture-design-principles.md) | Start with fundamentals |
|
||||||
|
|
||||||
|
### Step 2: Check for Special Requirements
|
||||||
|
|
||||||
|
**If any of these apply, address FIRST:**
|
||||||
|
|
||||||
|
| Requirement | Route To | Priority |
|
||||||
|
|-------------|----------|----------|
|
||||||
|
| Deep network (> 20 layers) unstable | [normalization-techniques.md](normalization-techniques.md) | CRITICAL - fix before continuing |
|
||||||
|
| Need attention mechanisms | [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) | Specialized component |
|
||||||
|
| Custom architecture design | [architecture-design-principles.md](architecture-design-principles.md) | Foundation before specifics |
|
||||||
|
| Transformer-specific question | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | Specialized architecture |
|
||||||
|
|
||||||
|
### Step 3: Consider Problem Characteristics
|
||||||
|
|
||||||
|
**Clarify BEFORE routing:**
|
||||||
|
|
||||||
|
Ask:
|
||||||
|
- "How large is your dataset?" (Small < 10k, Medium 10k-1M, Large > 1M)
|
||||||
|
- "What are your computational constraints?" (Edge device, cloud, GPU availability)
|
||||||
|
- "What are your latency requirements?" (Real-time, batch, offline)
|
||||||
|
- "Do you need interpretability?" (Clinical, research, production)
|
||||||
|
|
||||||
|
These answers determine architecture appropriateness.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Routing by Data Modality
|
||||||
|
|
||||||
|
### Images → CNN Families
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "classify images"
|
||||||
|
- "object detection"
|
||||||
|
- "semantic segmentation"
|
||||||
|
- "medical imaging"
|
||||||
|
- "computer vision"
|
||||||
|
|
||||||
|
**Route to:** See [cnn-families-and-selection.md](cnn-families-and-selection.md) for CNN architecture selection and comparison.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- ANY vision task (CNNs are default for spatial data)
|
||||||
|
- Even if considering Transformers, check CNN families first (often better with less data)
|
||||||
|
|
||||||
|
**Clarifying questions:**
|
||||||
|
- "Dataset size?" (< 10k → Start with proven CNNs, > 100k → Consider ViT)
|
||||||
|
- "Deployment target?" (Edge → EfficientNet, Cloud → Anything)
|
||||||
|
- "Task type?" (Classification → ResNet/EfficientNet, Detection → YOLO/Faster-RCNN)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Sequences → Sequence Models Comparison
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "time series"
|
||||||
|
- "forecasting"
|
||||||
|
- "natural language" (NLP)
|
||||||
|
- "sequential data"
|
||||||
|
- "temporal patterns"
|
||||||
|
- "RNN vs LSTM vs Transformer"
|
||||||
|
|
||||||
|
**Route to:** See [sequence-models-comparison.md](sequence-models-comparison.md) for sequential model selection (RNN, LSTM, Transformer, TCN).
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- ANY sequential data
|
||||||
|
- When user asks "RNN vs LSTM" (skill will present modern alternatives)
|
||||||
|
- Time-dependent patterns
|
||||||
|
|
||||||
|
**Clarifying questions:**
|
||||||
|
- "Sequence length?" (< 100 → RNN/LSTM/TCN, 100-1000 → Transformer, > 1000 → Sparse Transformers)
|
||||||
|
- "Latency requirements?" (Real-time → TCN/LSTM, Offline → Transformer)
|
||||||
|
- "Data volume?" (Small → Simpler models, Large → Transformers)
|
||||||
|
|
||||||
|
**CRITICAL:** Challenge "RNN vs LSTM" premise if they ask. Modern alternatives (Transformers, TCN) often better.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Graphs → Graph Neural Networks
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "social network"
|
||||||
|
- "molecular structure"
|
||||||
|
- "knowledge graph"
|
||||||
|
- "graph data"
|
||||||
|
- "node classification"
|
||||||
|
- "link prediction"
|
||||||
|
- "graph embeddings"
|
||||||
|
|
||||||
|
**Route to:** See [graph-neural-networks-basics.md](graph-neural-networks-basics.md) for GNN architectures and graph learning.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Data has explicit graph structure (nodes + edges)
|
||||||
|
- Relational information is important
|
||||||
|
- Network topology matters
|
||||||
|
|
||||||
|
**Red flag:** If treating graph as tabular data (extracting features and ignoring edges) → WRONG. Route to GNN skill.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Generation → Generative Model Families
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "generate images"
|
||||||
|
- "synthesize data"
|
||||||
|
- "GAN vs VAE vs Diffusion"
|
||||||
|
- "image-to-image translation"
|
||||||
|
- "style transfer"
|
||||||
|
- "generative modeling"
|
||||||
|
|
||||||
|
**Route to:** See [generative-model-families.md](generative-model-families.md) for GANs, VAEs, and Diffusion models.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Goal is to CREATE data, not classify/predict
|
||||||
|
- Need to sample from distribution
|
||||||
|
- Data augmentation through generation
|
||||||
|
|
||||||
|
**Clarifying questions:**
|
||||||
|
- "Use case?" (Real-time game → GAN, Art/research → Diffusion, Fast training → VAE)
|
||||||
|
- "Quality vs speed?" (Quality → Diffusion, Speed → GAN)
|
||||||
|
- "Controllability?" (Fine control → StyleGAN/Conditional models)
|
||||||
|
|
||||||
|
**CRITICAL:** Different generative models have VERY different trade-offs. Must clarify requirements.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Routing by Architecture Component
|
||||||
|
|
||||||
|
### Attention Mechanisms
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "when to use attention"
|
||||||
|
- "self-attention vs cross-attention"
|
||||||
|
- "attention in CNNs"
|
||||||
|
- "attention bottleneck"
|
||||||
|
- "multi-head attention"
|
||||||
|
|
||||||
|
**Route to:** See [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) for attention mechanism selection and design.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Designing custom architecture that might benefit from attention
|
||||||
|
- Understanding where attention helps vs hinders
|
||||||
|
- Comparing attention variants
|
||||||
|
|
||||||
|
**NOT for:** General Transformer questions → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) instead
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Transformer Deep Dive
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "how do transformers work"
|
||||||
|
- "Vision Transformer (ViT)"
|
||||||
|
- "BERT architecture"
|
||||||
|
- "positional encoding"
|
||||||
|
- "transformer blocks"
|
||||||
|
- "scaling transformers"
|
||||||
|
|
||||||
|
**Route to:** See [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) for Transformer internals and implementation.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Implementing/customizing transformers
|
||||||
|
- Understanding transformer internals
|
||||||
|
- Debugging transformer-specific issues
|
||||||
|
|
||||||
|
**Cross-reference:**
|
||||||
|
- For sequence models generally → [sequence-models-comparison.md](sequence-models-comparison.md) (includes transformers in context)
|
||||||
|
- For LLMs specifically → `yzmir/llm-specialist/transformer-for-llms` (LLM-specific transformers)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Normalization Techniques
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "gradient explosion"
|
||||||
|
- "training instability in deep network"
|
||||||
|
- "BatchNorm vs LayerNorm"
|
||||||
|
- "normalization layers"
|
||||||
|
- "50+ layer network won't train"
|
||||||
|
|
||||||
|
**Route to:** See [normalization-techniques.md](normalization-techniques.md) for deep network stability and normalization methods.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Deep networks (> 20 layers) with training instability
|
||||||
|
- Choosing between normalization methods
|
||||||
|
- Architectural stability issues
|
||||||
|
|
||||||
|
**CRITICAL:** This is often the ROOT CAUSE of "training won't work" - fix architecture before blaming hyperparameters.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Architecture Design Principles
|
||||||
|
|
||||||
|
**Symptoms triggering this route:**
|
||||||
|
- "how to design architecture"
|
||||||
|
- "architecture best practices"
|
||||||
|
- "when to use skip connections"
|
||||||
|
- "how deep should network be"
|
||||||
|
- "custom architecture for [novel task]"
|
||||||
|
- Unclear problem modality
|
||||||
|
|
||||||
|
**Route to:** See [architecture-design-principles.md](architecture-design-principles.md) for custom architecture design fundamentals.
|
||||||
|
|
||||||
|
**When to route here:**
|
||||||
|
- Designing custom architectures
|
||||||
|
- Novel problems without established architecture
|
||||||
|
- Understanding WHY architectures work
|
||||||
|
- User is unsure what modality/problem type they have
|
||||||
|
|
||||||
|
**This is the foundational skill** - route here if other specific skills don't match.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Multi-Modal / Cross-Pack Routing
|
||||||
|
|
||||||
|
### When Problem Spans Multiple Modalities
|
||||||
|
|
||||||
|
**Example:** "Text + image classification" (multimodal)
|
||||||
|
|
||||||
|
**Route to BOTH:**
|
||||||
|
1. [sequence-models-comparison.md](sequence-models-comparison.md) (for text)
|
||||||
|
2. [cnn-families-and-selection.md](cnn-families-and-selection.md) (for images)
|
||||||
|
3. [architecture-design-principles.md](architecture-design-principles.md) (for fusion strategy)
|
||||||
|
|
||||||
|
**Order matters:** Understand individual modalities BEFORE fusion.
|
||||||
|
|
||||||
|
### When Architecture + Other Concerns
|
||||||
|
|
||||||
|
**Example:** "Select architecture AND optimize training"
|
||||||
|
|
||||||
|
**Route order:**
|
||||||
|
1. Architecture skill FIRST (this pack)
|
||||||
|
2. Training-optimization SECOND (after architecture chosen)
|
||||||
|
|
||||||
|
**Why:** Wrong architecture can't be fixed by better training.
|
||||||
|
|
||||||
|
**Example:** "Select architecture AND deploy efficiently"
|
||||||
|
|
||||||
|
**Route order:**
|
||||||
|
1. Architecture skill FIRST
|
||||||
|
2. ML-production SECOND (quantization, serving)
|
||||||
|
|
||||||
|
**Deployment constraints might influence architecture choice** - if so, note constraints during architecture selection.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Routing Mistakes (DON'T DO THESE)
|
||||||
|
|
||||||
|
| Symptom | Wrong Route | Correct Route | Why |
|
||||||
|
|---------|-------------|---------------|-----|
|
||||||
|
| "My transformer won't train" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | training-optimization | Training issue, not architecture understanding |
|
||||||
|
| "Deploy image classifier" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | ml-production | Deployment, not selection |
|
||||||
|
| "ViT vs ResNet for medical imaging" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | Comparative selection, not single architecture detail |
|
||||||
|
| "Implement BatchNorm in PyTorch" | [normalization-techniques.md](normalization-techniques.md) | pytorch-engineering | Implementation, not architecture concept |
|
||||||
|
| "GAN won't converge" | [generative-model-families.md](generative-model-families.md) | training-optimization | Training stability, not architecture selection |
|
||||||
|
| "Which optimizer for CNN" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | training-optimization | Optimization, not architecture |
|
||||||
|
|
||||||
|
**Rule:** Architecture pack is for CHOOSING and DESIGNING architectures. Training/deployment/implementation are other packs.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Red Flags: Stop and Clarify
|
||||||
|
|
||||||
|
If query contains these patterns, ASK clarifying questions before routing:
|
||||||
|
|
||||||
|
| Pattern | Why Clarify | What to Ask |
|
||||||
|
|---------|-------------|--------------|
|
||||||
|
| "Best architecture for X" | "Best" depends on constraints | "What are your data size, compute, and latency constraints?" |
|
||||||
|
| Generic problem description | Can't route without modality | "What type of data? (images, sequences, graphs, etc.)" |
|
||||||
|
| Latest trend mentioned (ViT, Diffusion) | Recency bias risk | "Have you considered alternatives? What are your specific requirements?" |
|
||||||
|
| "Should I use X or Y" | May be wrong question | "What's the underlying problem? There might be option Z." |
|
||||||
|
| Very deep network (> 50 layers) | Likely needs normalization first | "Are you using normalization layers? Skip connections?" |
|
||||||
|
|
||||||
|
**Never guess modality or constraints. Always clarify.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Recency Bias: Resistance Table
|
||||||
|
|
||||||
|
| Trendy Architecture | When NOT to Use | Better Alternative |
|
||||||
|
|---------------------|------------------|-------------------|
|
||||||
|
| **Vision Transformers (ViT)** | Small datasets (< 10k images) | CNNs (ResNet, EfficientNet) |
|
||||||
|
| **Vision Transformers (ViT)** | Edge deployment (latency/power) | EfficientNets, MobileNets |
|
||||||
|
| **Transformers (general)** | Very small datasets | RNNs, CNNs (less capacity, less overfit) |
|
||||||
|
| **Diffusion Models** | Real-time generation needed | GANs (1 forward pass vs 50-1000 steps) |
|
||||||
|
| **Diffusion Models** | Limited compute for training | VAEs (faster training) |
|
||||||
|
| **Graph Transformers** | Small graphs (< 100 nodes) | Standard GNNs (GCN, GAT) simpler and effective |
|
||||||
|
| **LLMs (GPT-style)** | < 1M tokens of training data | Simpler language models or fine-tuning |
|
||||||
|
|
||||||
|
**Counter-narrative:** "New architecture ≠ better for your use case. Match architecture to constraints."
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Decision Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
Start here: What's your primary goal?
|
||||||
|
|
||||||
|
┌─ SELECT architecture for task
|
||||||
|
│ ├─ Data modality?
|
||||||
|
│ │ ├─ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md)
|
||||||
|
│ │ ├─ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md)
|
||||||
|
│ │ ├─ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md)
|
||||||
|
│ │ ├─ Generation → [generative-model-families.md](generative-model-families.md)
|
||||||
|
│ │ └─ Unknown/Multiple → [architecture-design-principles.md](architecture-design-principles.md)
|
||||||
|
│ └─ Special requirements?
|
||||||
|
│ ├─ Deep network (>20 layers) unstable → [normalization-techniques.md](normalization-techniques.md) (CRITICAL)
|
||||||
|
│ ├─ Need attention mechanism → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md)
|
||||||
|
│ └─ None → Proceed with modality-based route
|
||||||
|
│
|
||||||
|
├─ UNDERSTAND specific architecture
|
||||||
|
│ ├─ Transformers → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md)
|
||||||
|
│ ├─ Attention → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md)
|
||||||
|
│ ├─ Normalization → [normalization-techniques.md](normalization-techniques.md)
|
||||||
|
│ └─ General principles → [architecture-design-principles.md](architecture-design-principles.md)
|
||||||
|
│
|
||||||
|
├─ DESIGN custom architecture
|
||||||
|
│ └─ [architecture-design-principles.md](architecture-design-principles.md) (start here always)
|
||||||
|
│
|
||||||
|
└─ COMPARE architectures
|
||||||
|
├─ CNNs (ResNet vs EfficientNet) → [cnn-families-and-selection.md](cnn-families-and-selection.md)
|
||||||
|
├─ Sequence models (RNN vs Transformer) → [sequence-models-comparison.md](sequence-models-comparison.md)
|
||||||
|
├─ Generative (GAN vs Diffusion) → [generative-model-families.md](generative-model-families.md)
|
||||||
|
└─ General comparison → [architecture-design-principles.md](architecture-design-principles.md)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
**Standard Architecture Selection Workflow:**
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Clarify Problem
|
||||||
|
☐ What data modality? (images, sequences, graphs, etc.)
|
||||||
|
☐ What's the task? (classification, generation, regression, etc.)
|
||||||
|
☐ Dataset size?
|
||||||
|
☐ Computational constraints?
|
||||||
|
☐ Latency requirements?
|
||||||
|
☐ Interpretability needs?
|
||||||
|
|
||||||
|
2. Route Based on Modality
|
||||||
|
☐ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md)
|
||||||
|
☐ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md)
|
||||||
|
☐ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md)
|
||||||
|
☐ Generation → [generative-model-families.md](generative-model-families.md)
|
||||||
|
☐ Custom/Unclear → [architecture-design-principles.md](architecture-design-principles.md)
|
||||||
|
|
||||||
|
3. Check for Critical Issues
|
||||||
|
☐ Deep network unstable? → [normalization-techniques.md](normalization-techniques.md) FIRST
|
||||||
|
☐ Need specialized component? → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) or [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md)
|
||||||
|
|
||||||
|
4. Apply Architecture Skill
|
||||||
|
☐ Follow guidance from routed skill
|
||||||
|
☐ Consider trade-offs (accuracy vs speed vs data requirements)
|
||||||
|
|
||||||
|
5. Cross-Pack if Needed
|
||||||
|
☐ Architecture chosen → training-optimization (for training)
|
||||||
|
☐ Architecture chosen → ml-production (for deployment)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Rationalization Table
|
||||||
|
|
||||||
|
| Rationalization | Reality | Counter |
|
||||||
|
|-----------------|---------|---------|
|
||||||
|
| "Transformers are SOTA, recommend them" | SOTA on benchmark ≠ best for user's constraints | "Ask about dataset size and compute first" |
|
||||||
|
| "User said RNN vs LSTM, answer that" | Question premise might be outdated | "Challenge: Have you considered Transformers or TCN?" |
|
||||||
|
| "Just recommend latest architecture" | Latest ≠ appropriate | "Match architecture to requirements, not trends" |
|
||||||
|
| "Architecture doesn't matter, training matters" | Wrong architecture can't be fixed by training | "Architecture is foundation - get it right first" |
|
||||||
|
| "They seem rushed, skip clarification" | Wrong route wastes more time than clarification | "30 seconds to clarify saves hours of wasted effort" |
|
||||||
|
| "Generic architecture advice is safe" | Generic = useless for specific domains | "Route to domain-specific skill for actionable guidance" |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Integration with Other Packs
|
||||||
|
|
||||||
|
### After Architecture Selection
|
||||||
|
|
||||||
|
Once architecture is chosen, route to:
|
||||||
|
|
||||||
|
**Training the architecture:**
|
||||||
|
→ `yzmir/training-optimization/using-training-optimization`
|
||||||
|
- Optimizer selection
|
||||||
|
- Learning rate schedules
|
||||||
|
- Debugging training issues
|
||||||
|
|
||||||
|
**Implementing in PyTorch:**
|
||||||
|
→ `yzmir/pytorch-engineering/using-pytorch-engineering`
|
||||||
|
- Module design patterns
|
||||||
|
- Performance optimization
|
||||||
|
- Custom components
|
||||||
|
|
||||||
|
**Deploying to production:**
|
||||||
|
→ `yzmir/ml-production/using-ml-production`
|
||||||
|
- Model serving
|
||||||
|
- Quantization
|
||||||
|
- Inference optimization
|
||||||
|
|
||||||
|
### Before Architecture Selection
|
||||||
|
|
||||||
|
If problem involves:
|
||||||
|
|
||||||
|
**Reinforcement learning:**
|
||||||
|
→ `yzmir/deep-rl/using-deep-rl` FIRST
|
||||||
|
- RL algorithms dictate architecture requirements
|
||||||
|
- Value networks vs policy networks have different needs
|
||||||
|
|
||||||
|
**Large language models:**
|
||||||
|
→ `yzmir/llm-specialist/using-llm-specialist` FIRST
|
||||||
|
- LLM architectures are specialized transformers
|
||||||
|
- Different considerations than general sequence models
|
||||||
|
|
||||||
|
**Architecture is downstream of algorithm choice in RL and LLMs.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
**Use this meta-skill to:**
|
||||||
|
- ✅ Route architecture queries to appropriate specialized skill
|
||||||
|
- ✅ Identify data modality and problem type
|
||||||
|
- ✅ Clarify constraints before recommending
|
||||||
|
- ✅ Resist recency bias (latest ≠ best)
|
||||||
|
- ✅ Recognize when architecture is the problem (vs training/implementation)
|
||||||
|
|
||||||
|
## Neural Architecture Specialist Skills
|
||||||
|
|
||||||
|
After routing, load the appropriate specialist skill for detailed guidance:
|
||||||
|
|
||||||
|
1. [architecture-design-principles.md](architecture-design-principles.md) - Custom design, architectural best practices, skip connections, network depth fundamentals
|
||||||
|
2. [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) - Self-attention, cross-attention, multi-head attention, attention in CNNs, attention variants comparison
|
||||||
|
3. [cnn-families-and-selection.md](cnn-families-and-selection.md) - ResNet, EfficientNet, MobileNet, YOLO, computer vision architecture selection
|
||||||
|
4. [generative-model-families.md](generative-model-families.md) - GANs, VAEs, Diffusion models, image generation, style transfer, generative modeling trade-offs
|
||||||
|
5. [graph-neural-networks-basics.md](graph-neural-networks-basics.md) - GCN, GAT, node classification, link prediction, graph embeddings, molecular structures
|
||||||
|
6. [normalization-techniques.md](normalization-techniques.md) - BatchNorm, LayerNorm, GroupNorm, training stability for deep networks (>20 layers)
|
||||||
|
7. [sequence-models-comparison.md](sequence-models-comparison.md) - RNN, LSTM, Transformer, TCN comparison, time series, NLP, sequential data
|
||||||
|
8. [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) - Transformer internals, ViT, BERT, positional encoding, scaling transformers
|
||||||
|
|
||||||
|
**Critical principle:** Architecture comes BEFORE training. Get this right first.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**END OF SKILL**
|
||||||
@@ -0,0 +1,960 @@
|
|||||||
|
|
||||||
|
# Architecture Design Principles
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
You're designing a neural network architecture or debugging why your network isn't learning. Common mistakes:
|
||||||
|
- **Ignoring inductive biases**: Using MLP for images (should use CNN)
|
||||||
|
- **Over-engineering**: Using Transformer for 100 samples (should use linear regression)
|
||||||
|
- **No skip connections**: 50-layer plain network fails (should use ResNet)
|
||||||
|
- **Wrong depth-width balance**: 100 layers × 8 channels bottlenecks capacity
|
||||||
|
- **Ignoring constraints**: 1.5B parameter model doesn't fit 24GB GPU
|
||||||
|
|
||||||
|
**This skill provides principled architecture design: match structure to problem, respect constraints, avoid over-engineering.**
|
||||||
|
|
||||||
|
|
||||||
|
## Core Principle: Inductive Biases
|
||||||
|
|
||||||
|
**Inductive bias = assumptions baked into architecture about problem structure**
|
||||||
|
|
||||||
|
**Key insight**: The right inductive bias makes learning dramatically easier. Wrong bias makes learning impossible.
|
||||||
|
|
||||||
|
### What are Inductive Biases?
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example: Image classification
|
||||||
|
|
||||||
|
# MLP (no inductive bias):
|
||||||
|
# - Treats each pixel independently
|
||||||
|
# - No concept of "spatial locality" or "translation"
|
||||||
|
# - Must learn from scratch that nearby pixels are related
|
||||||
|
# - Learns "cat at position (10,10)" and "cat at (50,50)" separately
|
||||||
|
# Parameters: 150M, Accuracy: 75%
|
||||||
|
|
||||||
|
# CNN (strong inductive bias):
|
||||||
|
# - Assumes spatial locality (nearby pixels related)
|
||||||
|
# - Assumes translation invariance (cat is cat anywhere)
|
||||||
|
# - Shares filters across spatial positions
|
||||||
|
# - Hierarchical feature learning (edges → textures → objects)
|
||||||
|
# Parameters: 11M, Accuracy: 95%
|
||||||
|
|
||||||
|
# CNN's inductive bias: 14× fewer parameters, 20% better accuracy!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Principle**: Match your architecture's inductive biases to your problem's structure.
|
||||||
|
|
||||||
|
|
||||||
|
## Architecture Families and Their Inductive Biases
|
||||||
|
|
||||||
|
### 1. Fully Connected (MLP)
|
||||||
|
|
||||||
|
**Inductive bias:** None (general-purpose)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Tabular data (independent features)
|
||||||
|
- ✅ Small datasets (< 10,000 samples)
|
||||||
|
- ✅ Baseline / proof of concept
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Images (use CNN)
|
||||||
|
- ❌ Sequences (use RNN/Transformer)
|
||||||
|
- ❌ Graphs (use GNN)
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Simple and interpretable
|
||||||
|
- Fast training
|
||||||
|
- Works for any input type (flattened)
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- No structural assumptions (must learn everything from data)
|
||||||
|
- Parameter explosion (input_size × hidden_size can be huge)
|
||||||
|
- Doesn't leverage problem structure
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Convolutional Neural Networks (CNN)
|
||||||
|
|
||||||
|
**Inductive bias:** Spatial locality + Translation invariance
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
class SimpleCNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
self.fc = nn.Linear(128 * 7 * 7, 1000)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(F.relu(self.conv1(x))) # 112×112
|
||||||
|
x = self.pool(F.relu(self.conv2(x))) # 56×56
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inductive biases:**
|
||||||
|
1. **Local connectivity**: Neurons see only nearby pixels (spatial locality)
|
||||||
|
2. **Translation invariance**: Same filter slides across image (parameter sharing)
|
||||||
|
3. **Hierarchical features**: Stack layers to build complex features from simple ones
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Images (classification, detection, segmentation)
|
||||||
|
- ✅ Spatial data (maps, medical scans)
|
||||||
|
- ✅ Any grid-structured data
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Sequences with long-range dependencies (use Transformer)
|
||||||
|
- ❌ Graphs (irregular structure, use GNN)
|
||||||
|
- ❌ Tabular data (no spatial structure)
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Parameter efficient (filter sharing)
|
||||||
|
- Translation invariant (cat anywhere = cat)
|
||||||
|
- Hierarchical feature learning
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- Fixed receptive field (limited by kernel size)
|
||||||
|
- Not suitable for variable-length inputs
|
||||||
|
- Requires grid structure
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Recurrent Neural Networks (RNN/LSTM)
|
||||||
|
|
||||||
|
**Inductive bias:** Temporal dependencies
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
class LSTMModel(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||||
|
self.fc = nn.Linear(hidden_size, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, seq_len, input_size)
|
||||||
|
lstm_out, (h_n, c_n) = self.lstm(x)
|
||||||
|
# Use last hidden state
|
||||||
|
output = self.fc(h_n[-1])
|
||||||
|
return output
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inductive bias:** Sequential processing (earlier elements influence later elements)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Time series (stock prices, sensor data)
|
||||||
|
- ✅ Short sequences (< 100 timesteps)
|
||||||
|
- ✅ Online processing (process one timestep at a time)
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Long sequences (> 1000 timesteps, use Transformer)
|
||||||
|
- ❌ Non-sequential data (images, tabular)
|
||||||
|
- ❌ When parallel processing needed (use Transformer)
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Natural for sequential data
|
||||||
|
- Constant memory (doesn't grow with sequence length)
|
||||||
|
- Online processing capability
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- Slow (sequential, can't parallelize)
|
||||||
|
- Vanishing gradients (long sequences)
|
||||||
|
- Struggles with long-range dependencies
|
||||||
|
|
||||||
|
|
||||||
|
### 4. Transformers
|
||||||
|
|
||||||
|
**Inductive bias:** Minimal (self-attention is general-purpose)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
class SimpleTransformer(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
nn.TransformerEncoderLayer(d_model, num_heads),
|
||||||
|
num_layers
|
||||||
|
)
|
||||||
|
self.fc = nn.Linear(d_model, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, seq_len, d_model)
|
||||||
|
x = self.encoder(x)
|
||||||
|
# Global average pooling
|
||||||
|
x = x.mean(dim=1)
|
||||||
|
return self.fc(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inductive bias:** Minimal (learns relationships from data via attention)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Long sequences (> 100 tokens)
|
||||||
|
- ✅ Language (text, code)
|
||||||
|
- ✅ Large datasets (> 100k samples)
|
||||||
|
- ✅ When relationships are complex and data-dependent
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Small datasets (< 10k samples, use RNN or MLP)
|
||||||
|
- ❌ Strong structural priors available (images → CNN)
|
||||||
|
- ❌ Very long sequences (> 16k tokens, use sparse attention)
|
||||||
|
- ❌ Low-latency requirements (RNN faster)
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Parallel processing (fast training)
|
||||||
|
- Long-range dependencies (attention)
|
||||||
|
- State-of-the-art for language
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- Quadratic complexity O(n²) with sequence length
|
||||||
|
- Requires large datasets (weak inductive bias)
|
||||||
|
- High memory usage
|
||||||
|
|
||||||
|
|
||||||
|
### 5. Graph Neural Networks (GNN)
|
||||||
|
|
||||||
|
**Inductive bias:** Message passing over graph structure
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
class SimpleGNN(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GCNConv(input_dim, hidden_dim)
|
||||||
|
self.conv2 = GCNConv(hidden_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
# x: node features (num_nodes, input_dim)
|
||||||
|
# edge_index: graph structure (2, num_edges)
|
||||||
|
x = F.relu(self.conv1(x, edge_index))
|
||||||
|
x = self.conv2(x, edge_index)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inductive bias:** Nodes influenced by neighbors (message passing)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Graph data (social networks, molecules, knowledge graphs)
|
||||||
|
- ✅ Irregular connectivity (different # of neighbors per node)
|
||||||
|
- ✅ Relational reasoning
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Grid data (images → CNN)
|
||||||
|
- ❌ Sequences (text → Transformer)
|
||||||
|
- ❌ If graph structure doesn't help (test MLP baseline first!)
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Handles irregular structure
|
||||||
|
- Permutation invariant
|
||||||
|
- Natural for relational data
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- Requires meaningful graph structure
|
||||||
|
- Over-smoothing (too many layers)
|
||||||
|
- Scalability challenges (large graphs)
|
||||||
|
|
||||||
|
|
||||||
|
## Decision Tree: Architecture Selection
|
||||||
|
|
||||||
|
```
|
||||||
|
START
|
||||||
|
|
|
||||||
|
├─ Is data grid-structured (images)?
|
||||||
|
│ ├─ YES → Use CNN
|
||||||
|
│ │ └─ ResNet (general), EfficientNet (mobile), ViT (very large datasets)
|
||||||
|
│ └─ NO → Continue
|
||||||
|
│
|
||||||
|
├─ Is data sequential (text, time series)?
|
||||||
|
│ ├─ YES → Check sequence length
|
||||||
|
│ │ ├─ < 100 timesteps → LSTM/GRU
|
||||||
|
│ │ ├─ 100-4000 tokens → Transformer
|
||||||
|
│ │ └─ > 4000 tokens → Sparse Transformer (Longformer)
|
||||||
|
│ └─ NO → Continue
|
||||||
|
│
|
||||||
|
├─ Is data graph-structured (molecules, social networks)?
|
||||||
|
│ ├─ YES → Check if structure helps
|
||||||
|
│ │ ├─ Test MLP baseline first
|
||||||
|
│ │ └─ If structure helps → GNN (GCN, GraphSAGE, GAT)
|
||||||
|
│ └─ NO → Continue
|
||||||
|
│
|
||||||
|
└─ Is data tabular (independent features)?
|
||||||
|
└─ YES → Start simple
|
||||||
|
├─ < 1000 samples → Linear / Ridge regression
|
||||||
|
├─ 1000-100k samples → Small MLP (2-3 layers)
|
||||||
|
└─ > 100k samples → Larger MLP or Gradient Boosting (XGBoost)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Principle: Start Simple, Add Complexity Only When Needed
|
||||||
|
|
||||||
|
**Occam's Razor**: Simplest model that solves the problem is best.
|
||||||
|
|
||||||
|
### Progression:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Step 1: Linear baseline (ALWAYS start here!)
|
||||||
|
model = nn.Linear(input_size, num_classes)
|
||||||
|
# Train and evaluate
|
||||||
|
|
||||||
|
# Step 2: IF linear insufficient, add small MLP
|
||||||
|
if linear_accuracy < target:
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(input_size, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: IF small MLP insufficient, add depth/width
|
||||||
|
if mlp_accuracy < target:
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(input_size, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: IF simple models fail, use specialized architecture
|
||||||
|
if simple_models_fail:
|
||||||
|
# Images → CNN
|
||||||
|
# Sequences → RNN/Transformer
|
||||||
|
# Graphs → GNN
|
||||||
|
|
||||||
|
# NEVER skip to Step 4 without testing Step 1-3!
|
||||||
|
```
|
||||||
|
|
||||||
|
### Why Start Simple?
|
||||||
|
|
||||||
|
1. **Faster iteration**: Linear model trains in seconds, Transformer in hours
|
||||||
|
2. **Baseline**: Know if complexity helps (compare complex vs simple)
|
||||||
|
3. **Occam's Razor**: Simple model generalizes better (less overfitting)
|
||||||
|
4. **Debugging**: Easy to verify simple model works correctly
|
||||||
|
|
||||||
|
### Example: House Price Prediction
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Dataset: 1000 samples, 20 features
|
||||||
|
|
||||||
|
# WRONG: Start with Transformer
|
||||||
|
model = HugeTransformer(20, 512, 6, 1) # 10M parameters
|
||||||
|
# Result: Overfits (10M params / 1000 samples = 10,000:1 ratio!)
|
||||||
|
|
||||||
|
# RIGHT: Start simple
|
||||||
|
# Step 1: Linear
|
||||||
|
model = nn.Linear(20, 1) # 21 parameters
|
||||||
|
# Trains in 1 second, achieves R² = 0.85 (good!)
|
||||||
|
|
||||||
|
# Conclusion: Linear sufficient, stop here. No need for Transformer!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Add complexity only when simple models demonstrably fail.
|
||||||
|
|
||||||
|
|
||||||
|
## Principle: Deep Networks Need Skip Connections
|
||||||
|
|
||||||
|
**Problem**: Plain networks > 10 layers suffer from vanishing gradients and degradation.
|
||||||
|
|
||||||
|
### Vanishing Gradients:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Gradient flow in plain 50-layer network:
|
||||||
|
gradient_layer_1 = gradient_output × (∂L50/∂L49) × (∂L49/∂L48) × ... × (∂L2/∂L1)
|
||||||
|
|
||||||
|
# Each term < 1 (due to activations):
|
||||||
|
# If each ≈ 0.9, then: 0.9^50 = 0.0000051 (vanishes!)
|
||||||
|
|
||||||
|
# Result: Early layers don't learn (gradients too small)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Degradation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Empirical observation (ResNet paper):
|
||||||
|
20-layer plain network: 85% accuracy
|
||||||
|
56-layer plain network: 78% accuracy # WORSE with more layers!
|
||||||
|
|
||||||
|
# This is NOT overfitting (training accuracy also drops)
|
||||||
|
# This is optimization difficulty
|
||||||
|
```
|
||||||
|
|
||||||
|
### Solution: Skip Connections (Residual Networks)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x # Save input
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = F.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
out = out + identity # Skip connection!
|
||||||
|
out = F.relu(out)
|
||||||
|
return out
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why skip connections work:**
|
||||||
|
```python
|
||||||
|
# Gradient flow with skip connections:
|
||||||
|
∂loss/∂x = ∂loss/∂out × (1 + ∂F/∂x)
|
||||||
|
# ↑
|
||||||
|
# Always flows! ("+1" term)
|
||||||
|
|
||||||
|
# Even if ∂F/∂x ≈ 0, gradient flows through identity path
|
||||||
|
```
|
||||||
|
|
||||||
|
**Results:**
|
||||||
|
```python
|
||||||
|
# Without skip connections:
|
||||||
|
20-layer plain: 85% accuracy
|
||||||
|
50-layer plain: 78% accuracy (worse!)
|
||||||
|
|
||||||
|
# With skip connections (ResNet):
|
||||||
|
20-layer ResNet: 87% accuracy
|
||||||
|
50-layer ResNet: 92% accuracy (better!)
|
||||||
|
152-layer ResNet: 95% accuracy (even better!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: For networks > 10 layers, ALWAYS use skip connections.
|
||||||
|
|
||||||
|
### Skip Connection Variants:
|
||||||
|
|
||||||
|
**1. Residual (ResNet):**
|
||||||
|
```python
|
||||||
|
out = x + F(x) # Add input to output
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Dense (DenseNet):**
|
||||||
|
```python
|
||||||
|
out = torch.cat([x, F(x)], dim=1) # Concatenate input and output
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Highway:**
|
||||||
|
```python
|
||||||
|
gate = sigmoid(W_gate @ x)
|
||||||
|
out = gate * F(x) + (1 - gate) * x # Learned gating
|
||||||
|
```
|
||||||
|
|
||||||
|
**Most common**: Residual (simple, effective)
|
||||||
|
|
||||||
|
|
||||||
|
## Principle: Balance Depth and Width
|
||||||
|
|
||||||
|
**Depth = # of layers**
|
||||||
|
**Width = # of channels/neurons per layer**
|
||||||
|
|
||||||
|
### Capacity Formula:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Approximate capacity (for CNNs):
|
||||||
|
capacity ≈ depth × width²
|
||||||
|
|
||||||
|
# Why width²?
|
||||||
|
# Each layer: input_channels × output_channels × kernel_size²
|
||||||
|
# Doubling width → 4× parameters per layer
|
||||||
|
```
|
||||||
|
|
||||||
|
### Trade-offs:
|
||||||
|
|
||||||
|
**Too deep, too narrow:**
|
||||||
|
```python
|
||||||
|
# 100 layers × 8 channels
|
||||||
|
# Problems:
|
||||||
|
# - Information bottleneck (8 channels can't represent complex features)
|
||||||
|
# - Harder to optimize (more layers)
|
||||||
|
# - Slow inference (100 layers sequential)
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
model = VeryDeepNarrow(num_layers=100, channels=8)
|
||||||
|
# Result: 60% accuracy (bottleneck!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Too shallow, too wide:**
|
||||||
|
```python
|
||||||
|
# 2 layers × 1024 channels
|
||||||
|
# Problems:
|
||||||
|
# - Under-utilizes depth (no hierarchical features)
|
||||||
|
# - Memory explosion (1024 × 1024 = 1M parameters per layer!)
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
model = VeryWideShallow(num_layers=2, channels=1024)
|
||||||
|
# Result: 70% accuracy (doesn't leverage depth)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Balanced:**
|
||||||
|
```python
|
||||||
|
# 18 layers, gradually increasing width: 64 → 128 → 256 → 512
|
||||||
|
# Benefits:
|
||||||
|
# - Hierarchical features (depth)
|
||||||
|
# - Sufficient capacity (width)
|
||||||
|
# - Good optimization (not too deep)
|
||||||
|
|
||||||
|
# Example (ResNet-18):
|
||||||
|
model = ResNet18()
|
||||||
|
# Layers: 18, Channels: 64-512 (average ~200)
|
||||||
|
# Result: 95% accuracy (optimal balance!)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Standard Patterns:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CNNs: Gradually increase channels as spatial dims decrease
|
||||||
|
# Input: 224×224×3
|
||||||
|
# Layer 1: 224×224×64 (same spatial size, more channels)
|
||||||
|
# Layer 2: 112×112×128 (half spatial, double channels)
|
||||||
|
# Layer 3: 56×56×256 (half spatial, double channels)
|
||||||
|
# Layer 4: 28×28×512 (half spatial, double channels)
|
||||||
|
|
||||||
|
# Why? Compensate for spatial information loss with channel information
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Balance depth and width. Standard pattern: 12-50 layers, 64-512 channels.
|
||||||
|
|
||||||
|
|
||||||
|
## Principle: Match Capacity to Data Size
|
||||||
|
|
||||||
|
**Capacity = # of learnable parameters**
|
||||||
|
|
||||||
|
### Parameter Budget:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Rule of thumb: parameters should be 0.01-0.1× dataset size
|
||||||
|
|
||||||
|
# Example 1: MNIST (60,000 images)
|
||||||
|
# Budget: 600 - 6,000 parameters
|
||||||
|
# Simple CNN: 60,000 parameters (10×) → Works, but might overfit
|
||||||
|
# LeNet: 60,000 parameters → Classic, works well
|
||||||
|
|
||||||
|
# Example 2: ImageNet (1.2M images)
|
||||||
|
# Budget: 12,000 - 120,000 parameters
|
||||||
|
# ResNet-50: 25M parameters (200×) → Works (aggressive augmentation helps)
|
||||||
|
|
||||||
|
# Example 3: Tabular (100 samples, 20 features)
|
||||||
|
# Budget: 1 - 10 parameters
|
||||||
|
# Linear: 21 parameters → Perfect fit!
|
||||||
|
# MLP: 1,000 parameters → Overfits horribly
|
||||||
|
```
|
||||||
|
|
||||||
|
### Overfitting Detection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Training accuracy >> Validation accuracy (gap > 5%)
|
||||||
|
train_acc = 99%, val_acc = 70% # 29% gap → OVERFITTING!
|
||||||
|
|
||||||
|
# Solutions:
|
||||||
|
# 1. Reduce model capacity (fewer layers/channels)
|
||||||
|
# 2. Add regularization (dropout, weight decay)
|
||||||
|
# 3. Collect more data
|
||||||
|
# 4. Data augmentation
|
||||||
|
|
||||||
|
# Order: Try (1) first (simplest), then (2), then (3)/(4)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Underfitting Detection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Training accuracy < target (model too simple)
|
||||||
|
train_acc = 60%, val_acc = 58% # Both low → UNDERFITTING!
|
||||||
|
|
||||||
|
# Solutions:
|
||||||
|
# 1. Increase model capacity (more layers/channels)
|
||||||
|
# 2. Train longer
|
||||||
|
# 3. Reduce regularization
|
||||||
|
|
||||||
|
# Order: Try (2) first (cheapest), then (1), then (3)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Match parameters to data size. Start small, increase capacity only if underfitting.
|
||||||
|
|
||||||
|
|
||||||
|
## Principle: Design for Compute Constraints
|
||||||
|
|
||||||
|
**Constraints:**
|
||||||
|
1. **Memory**: Model + gradients + optimizer states < GPU VRAM
|
||||||
|
2. **Latency**: Inference time < requirement (e.g., < 100ms for real-time)
|
||||||
|
3. **Throughput**: Samples/second > requirement
|
||||||
|
|
||||||
|
### Memory Budget:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Memory calculation (training):
|
||||||
|
# 1. Model parameters (FP32): params × 4 bytes
|
||||||
|
# 2. Gradients: params × 4 bytes
|
||||||
|
# 3. Optimizer states (Adam): params × 8 bytes (2× weights)
|
||||||
|
# 4. Activations: batch_size × feature_maps × spatial_size × 4 bytes
|
||||||
|
|
||||||
|
# Example: ResNet-50
|
||||||
|
params = 25M
|
||||||
|
memory_params = 25M × 4 = 100 MB
|
||||||
|
memory_gradients = 100 MB
|
||||||
|
memory_optimizer = 200 MB
|
||||||
|
memory_activations = batch_size × 64 × 7×7 × 4 ≈ batch_size × 12 KB
|
||||||
|
|
||||||
|
# Total (batch=32): 100 + 100 + 200 + 0.4 = 400 MB
|
||||||
|
# Fits easily on 4GB GPU!
|
||||||
|
|
||||||
|
# Example: GPT-3 (175B parameters)
|
||||||
|
memory_params = 175B × 4 = 700 GB
|
||||||
|
memory_total = 700 + 700 + 1400 = 2800 GB = 2.8 TB!
|
||||||
|
# Requires 35×A100 (80GB each)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Calculate memory before training. Don't design models that don't fit.
|
||||||
|
|
||||||
|
### Latency Budget:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Inference latency = # operations / throughput
|
||||||
|
|
||||||
|
# Example: Mobile app (< 100ms latency requirement)
|
||||||
|
|
||||||
|
# ResNet-50:
|
||||||
|
# Operations: 4B FLOPs
|
||||||
|
# Mobile CPU: 10 GFLOPS
|
||||||
|
# Latency: 4B / 10G = 0.4 seconds (FAILS!)
|
||||||
|
|
||||||
|
# MobileNetV2:
|
||||||
|
# Operations: 300M FLOPs
|
||||||
|
# Mobile CPU: 10 GFLOPS
|
||||||
|
# Latency: 300M / 10G = 0.03 seconds = 30ms (PASSES!)
|
||||||
|
|
||||||
|
# Solution: Use efficient architectures (MobileNet, EfficientNet) for mobile
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Measure latency. Use efficient architectures if latency-constrained.
|
||||||
|
|
||||||
|
|
||||||
|
## Common Architectural Patterns
|
||||||
|
|
||||||
|
### 1. Bottleneck (ResNet)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
# Standard: 3×3 conv (256 channels) → 3×3 conv (256 channels)
|
||||||
|
# Parameters: 256 × 256 × 3 × 3 = 590K
|
||||||
|
|
||||||
|
# Bottleneck: 1×1 (256→64) → 3×3 (64→64) → 1×1 (64→256)
|
||||||
|
# Parameters: 256×64 + 64×64×3×3 + 64×256 = 16K + 37K + 16K = 69K
|
||||||
|
# Reduction: 590K → 69K (8.5× fewer!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Reduce parameters while maintaining capacity
|
||||||
|
|
||||||
|
**When to use**: Deep networks (> 50 layers) where parameters are a concern
|
||||||
|
|
||||||
|
### 2. Inverted Bottleneck (MobileNetV2)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
# Bottleneck (ResNet): Wide → Narrow → Wide (256 → 64 → 256)
|
||||||
|
# Inverted: Narrow → Wide → Narrow (64 → 256 → 64)
|
||||||
|
|
||||||
|
# Why? Efficient for mobile (depthwise separable convolutions)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Maximize efficiency (FLOPs per parameter)
|
||||||
|
|
||||||
|
**When to use**: Mobile/edge deployment
|
||||||
|
|
||||||
|
### 3. Multi-scale Features (Inception)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
# Parallel branches with different kernel sizes:
|
||||||
|
# Branch 1: 1×1 conv
|
||||||
|
# Branch 2: 3×3 conv
|
||||||
|
# Branch 3: 5×5 conv
|
||||||
|
# Branch 4: 3×3 max pool
|
||||||
|
# Concatenate all branches
|
||||||
|
|
||||||
|
# Captures features at multiple scales simultaneously
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Capture multi-scale patterns
|
||||||
|
|
||||||
|
**When to use**: When features exist at multiple scales (object detection)
|
||||||
|
|
||||||
|
### 4. Attention (Transformers, SE-Net)
|
||||||
|
|
||||||
|
**Structure:**
|
||||||
|
```python
|
||||||
|
# Squeeze-and-Excitation (SE) block:
|
||||||
|
# 1. Global average pooling (spatial → channel descriptor)
|
||||||
|
# 2. FC layer (bottleneck)
|
||||||
|
# 3. FC layer (restore channels)
|
||||||
|
# 4. Sigmoid (attention weights)
|
||||||
|
# 5. Multiply input channels by attention weights
|
||||||
|
|
||||||
|
# Result: Emphasize important channels, suppress irrelevant
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Learn importance of features (channels or positions)
|
||||||
|
|
||||||
|
**When to use**: When not all features equally important
|
||||||
|
|
||||||
|
|
||||||
|
## Debugging Architectures
|
||||||
|
|
||||||
|
### Problem 1: Network doesn't learn (loss stays constant)
|
||||||
|
|
||||||
|
**Diagnosis:**
|
||||||
|
```python
|
||||||
|
# Check gradient flow
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
print(f"{name}: grad_mean={param.grad.mean():.6f}, grad_std={param.grad.std():.6f}")
|
||||||
|
|
||||||
|
# Vanishing: grad_mean ≈ 0, grad_std ≈ 0 → Add skip connections
|
||||||
|
# Exploding: grad_mean > 1, grad_std > 1 → Gradient clipping or lower LR
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
- Add skip connections (ResNet)
|
||||||
|
- Check initialization (Xavier or He initialization)
|
||||||
|
- Lower learning rate
|
||||||
|
- Check data preprocessing (normalized inputs?)
|
||||||
|
|
||||||
|
### Problem 2: Overfitting (train >> val)
|
||||||
|
|
||||||
|
**Diagnosis:**
|
||||||
|
```python
|
||||||
|
train_acc = 99%, val_acc = 70% # 29% gap → Overfitting
|
||||||
|
|
||||||
|
# Check parameter/data ratio:
|
||||||
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
|
data_size = len(train_dataset)
|
||||||
|
ratio = num_params / data_size
|
||||||
|
|
||||||
|
# If ratio > 1: Model has more parameters than data points!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions (in order):**
|
||||||
|
1. Reduce capacity (fewer layers/channels)
|
||||||
|
2. Add dropout / weight decay
|
||||||
|
3. Data augmentation
|
||||||
|
4. Collect more data
|
||||||
|
|
||||||
|
### Problem 3: Underfitting (train and val both low)
|
||||||
|
|
||||||
|
**Diagnosis:**
|
||||||
|
```python
|
||||||
|
train_acc = 65%, val_acc = 63% # Both low → Underfitting
|
||||||
|
|
||||||
|
# Model too simple for task complexity
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions (in order):**
|
||||||
|
1. Train longer (more epochs)
|
||||||
|
2. Increase capacity (more layers/channels)
|
||||||
|
3. Reduce regularization (lower dropout/weight decay)
|
||||||
|
4. Check learning rate (too low?)
|
||||||
|
|
||||||
|
### Problem 4: Slow training
|
||||||
|
|
||||||
|
**Diagnosis:**
|
||||||
|
```python
|
||||||
|
# Profile forward/backward pass
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
loss = criterion(model(inputs), targets)
|
||||||
|
forward_time = time.time() - start
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
loss.backward()
|
||||||
|
backward_time = time.time() - start
|
||||||
|
|
||||||
|
# If backward_time >> forward_time: Gradient computation bottleneck
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
- Use mixed precision (FP16)
|
||||||
|
- Reduce batch size (if memory-bound)
|
||||||
|
- Use gradient accumulation (simulate large batch)
|
||||||
|
- Simplify architecture (fewer layers)
|
||||||
|
|
||||||
|
|
||||||
|
## Design Checklist
|
||||||
|
|
||||||
|
Before finalizing an architecture:
|
||||||
|
|
||||||
|
### ☐ Match inductive bias to problem
|
||||||
|
- Images → CNN
|
||||||
|
- Sequences → RNN/Transformer
|
||||||
|
- Graphs → GNN
|
||||||
|
- Tabular → MLP
|
||||||
|
|
||||||
|
### ☐ Start simple, add complexity only when needed
|
||||||
|
- Test linear baseline first
|
||||||
|
- Add complexity incrementally
|
||||||
|
- Compare performance at each step
|
||||||
|
|
||||||
|
### ☐ Use skip connections for deep networks (> 10 layers)
|
||||||
|
- ResNet for CNNs
|
||||||
|
- Pre-norm for Transformers
|
||||||
|
- Gradient flow is critical
|
||||||
|
|
||||||
|
### ☐ Balance depth and width
|
||||||
|
- Not too deep and narrow (bottleneck)
|
||||||
|
- Not too shallow and wide (under-utilizes depth)
|
||||||
|
- Standard: 12-50 layers, 64-512 channels
|
||||||
|
|
||||||
|
### ☐ Match capacity to data size
|
||||||
|
- Parameters ≈ 0.01-0.1× dataset size
|
||||||
|
- Monitor train/val gap (overfitting indicator)
|
||||||
|
|
||||||
|
### ☐ Respect compute constraints
|
||||||
|
- Memory: Model + gradients + optimizer + activations < VRAM
|
||||||
|
- Latency: Inference time < requirement
|
||||||
|
- Use efficient architectures if constrained (MobileNet, EfficientNet)
|
||||||
|
|
||||||
|
### ☐ Verify gradient flow
|
||||||
|
- Check gradients in early layers (should be non-zero)
|
||||||
|
- Use skip connections if vanishing
|
||||||
|
|
||||||
|
### ☐ Benchmark against baselines
|
||||||
|
- Compare to simple model (linear, small MLP)
|
||||||
|
- Ensure complexity adds value (% improvement > 5%)
|
||||||
|
|
||||||
|
|
||||||
|
## Anti-Patterns
|
||||||
|
|
||||||
|
### Anti-pattern 1: "Architecture X is state-of-the-art, so I'll use it"
|
||||||
|
|
||||||
|
**Wrong:**
|
||||||
|
```python
|
||||||
|
# Transformer is SOTA for NLP, so use for tabular data (100 samples)
|
||||||
|
model = HugeTransformer(...) # 10M parameters
|
||||||
|
# Result: Overfits horribly (100 samples / 10M params = 0.00001 ratio!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Right:**
|
||||||
|
```python
|
||||||
|
# Match architecture to problem AND data size
|
||||||
|
# Tabular + small data → Linear or small MLP
|
||||||
|
model = nn.Linear(20, 1) # 21 parameters (appropriate!)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Anti-pattern 2: "More layers = better"
|
||||||
|
|
||||||
|
**Wrong:**
|
||||||
|
```python
|
||||||
|
# 100-layer plain network (no skip connections)
|
||||||
|
for i in range(100):
|
||||||
|
layers.append(nn.Conv2d(64, 64, 3, padding=1))
|
||||||
|
# Result: Doesn't train (vanishing gradients)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Right:**
|
||||||
|
```python
|
||||||
|
# 50-layer ResNet (with skip connections)
|
||||||
|
# Each block: out = x + F(x) # Skip connection
|
||||||
|
# Result: Trains well, high accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Anti-pattern 3: "Deeper + narrower = efficient"
|
||||||
|
|
||||||
|
**Wrong:**
|
||||||
|
```python
|
||||||
|
# 100 layers × 8 channels = information bottleneck
|
||||||
|
model = VeryDeepNarrow(100, 8)
|
||||||
|
# Result: 60% accuracy (8 channels insufficient)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Right:**
|
||||||
|
```python
|
||||||
|
# 18 layers, 64-512 channels (balanced)
|
||||||
|
model = ResNet18() # Balanced depth and width
|
||||||
|
# Result: 95% accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Anti-pattern 4: "Ignore constraints, optimize later"
|
||||||
|
|
||||||
|
**Wrong:**
|
||||||
|
```python
|
||||||
|
# Design 1.5B parameter model for 24GB GPU
|
||||||
|
model = HugeModel(1.5e9)
|
||||||
|
# Result: OOM (out of memory), can't train
|
||||||
|
```
|
||||||
|
|
||||||
|
**Right:**
|
||||||
|
```python
|
||||||
|
# Calculate memory first:
|
||||||
|
# 1.5B params × 4 bytes = 6GB (weights)
|
||||||
|
# + 6GB (gradients) + 12GB (Adam) + 8GB (activations) = 32GB
|
||||||
|
# > 24GB → Doesn't fit!
|
||||||
|
|
||||||
|
# Design for hardware:
|
||||||
|
model = ReasonableSizeModel(200e6) # 200M parameters (fits!)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Anti-pattern 5: "Hyperparameters will fix architectural problems"
|
||||||
|
|
||||||
|
**Wrong:**
|
||||||
|
```python
|
||||||
|
# Architecture: MLP for images (wrong inductive bias)
|
||||||
|
# Response: "Just tune learning rate!"
|
||||||
|
for lr in [0.1, 0.01, 0.001, 0.0001]:
|
||||||
|
train(model, lr=lr)
|
||||||
|
# Result: All fail (architecture is wrong!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Right:**
|
||||||
|
```python
|
||||||
|
# Fix architecture first (use CNN for images)
|
||||||
|
model = ResNet18() # Correct inductive bias
|
||||||
|
# Then tune hyperparameters
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
**Core principles:**
|
||||||
|
|
||||||
|
1. **Inductive bias**: Match architecture to problem structure (CNN for images, RNN/Transformer for sequences, GNN for graphs)
|
||||||
|
|
||||||
|
2. **Occam's Razor**: Start simple (linear, small MLP), add complexity only when needed
|
||||||
|
|
||||||
|
3. **Skip connections**: Use for networks > 10 layers (ResNet, DenseNet)
|
||||||
|
|
||||||
|
4. **Depth-width balance**: Not too deep+narrow (bottleneck) or too shallow+wide (under-utilizes depth)
|
||||||
|
|
||||||
|
5. **Capacity**: Match parameters to data size (0.01-0.1× dataset size)
|
||||||
|
|
||||||
|
6. **Constraints**: Design for available memory, latency, throughput
|
||||||
|
|
||||||
|
**Decision framework:**
|
||||||
|
- Images → CNN (ResNet, EfficientNet)
|
||||||
|
- Short sequences → LSTM
|
||||||
|
- Long sequences → Transformer
|
||||||
|
- Graphs → GNN (test if structure helps first!)
|
||||||
|
- Tabular → Linear or small MLP
|
||||||
|
|
||||||
|
**Key insight**: Architecture design is about matching structural assumptions to problem structure, not about using the "best" or "most complex" model. Simple models often win.
|
||||||
|
|
||||||
|
**When in doubt**: Start with the simplest model that could plausibly work. Add complexity only when you have evidence it helps.
|
||||||
@@ -0,0 +1,824 @@
|
|||||||
|
|
||||||
|
# Attention Mechanisms Catalog
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when you need to:
|
||||||
|
- ✅ Select attention mechanism for long sequences (> 2k tokens)
|
||||||
|
- ✅ Optimize memory usage (GPU OOM errors)
|
||||||
|
- ✅ Speed up training or inference
|
||||||
|
- ✅ Understand exact vs approximate attention trade-offs
|
||||||
|
- ✅ Choose between Flash, sparse, or linear attention
|
||||||
|
- ✅ Implement cross-attention for multimodal models
|
||||||
|
|
||||||
|
**Do NOT use this skill for:**
|
||||||
|
- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`)
|
||||||
|
- ❌ High-level architecture selection (use `using-neural-architectures`)
|
||||||
|
- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`)
|
||||||
|
|
||||||
|
|
||||||
|
## Core Principle
|
||||||
|
|
||||||
|
**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve:
|
||||||
|
- **O(n²) with less memory**: Flash Attention (exact, 4x less memory)
|
||||||
|
- **O(n × w)**: Sparse attention (exact, sliding window)
|
||||||
|
- **O(n)**: Linear attention (approximate, 1-3% accuracy loss)
|
||||||
|
|
||||||
|
**Default recommendation:** Flash Attention (exact + fast + memory-efficient)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 1: Complexity Hierarchy
|
||||||
|
|
||||||
|
### Standard Self-Attention (Baseline)
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
|
||||||
|
```
|
||||||
|
|
||||||
|
**Complexity:**
|
||||||
|
- Time: O(n² · d) where n = seq_len, d = d_model
|
||||||
|
- Memory: O(n²) for attention matrix
|
||||||
|
- Exact: Yes (no approximation)
|
||||||
|
|
||||||
|
**Memory breakdown (4k tokens, d=768):**
|
||||||
|
```
|
||||||
|
Attention scores: 4096² × 4 bytes = 64MB per layer
|
||||||
|
Multi-head (12 heads): 64MB × 12 = 768MB per layer
|
||||||
|
16 layers: 768MB × 16 = 12GB just for attention!
|
||||||
|
Batch size 8: 12GB × 8 = 96GB (impossible on single GPU)
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Sequence length < 2k tokens
|
||||||
|
- Standard use case (most models)
|
||||||
|
- Pair with Flash Attention optimization
|
||||||
|
|
||||||
|
**Limitations:**
|
||||||
|
- Memory explosion for long sequences
|
||||||
|
- Quadratic scaling impractical beyond 4k tokens
|
||||||
|
|
||||||
|
|
||||||
|
## Part 2: Flash Attention ⭐ (Modern Default)
|
||||||
|
|
||||||
|
### What is Flash Attention?
|
||||||
|
|
||||||
|
**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster
|
||||||
|
|
||||||
|
**Key insight:**
|
||||||
|
- Standard attention is **memory-bound** (not compute-bound)
|
||||||
|
- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s)
|
||||||
|
- Bottleneck: Moving n² attention matrix to/from HBM
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- Tile attention computation
|
||||||
|
- Recompute instead of store intermediate values
|
||||||
|
- Fuse operations (reduce memory transfers)
|
||||||
|
- Result: Same O(n²) compute, O(n) memory
|
||||||
|
|
||||||
|
### Algorithm
|
||||||
|
|
||||||
|
```
|
||||||
|
Standard attention (3 memory operations):
|
||||||
|
1. Compute scores: S = Q K^T (store n² matrix)
|
||||||
|
2. Softmax: P = softmax(S) (store n² matrix)
|
||||||
|
3. Output: O = P V (store n×d matrix)
|
||||||
|
|
||||||
|
Flash Attention (tiled):
|
||||||
|
1. Divide Q, K, V into blocks
|
||||||
|
2. For each Q block:
|
||||||
|
- Load block to SRAM (fast memory)
|
||||||
|
- For each K, V block:
|
||||||
|
- Compute attention for this tile
|
||||||
|
- Update output incrementally
|
||||||
|
- Never materialize full n² matrix!
|
||||||
|
3. Result: Same output, O(n) memory
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
**Benchmarks (A100 GPU, 2k tokens):**
|
||||||
|
|
||||||
|
Standard attention:
|
||||||
|
- Memory: 4GB for batch_size=8
|
||||||
|
- Speed: 150ms/batch
|
||||||
|
- Max batch size: 16
|
||||||
|
|
||||||
|
Flash Attention:
|
||||||
|
- Memory: 1GB for batch_size=8 **(4x reduction)**
|
||||||
|
- Speed: 75ms/batch **(2x faster)**
|
||||||
|
- Max batch size: 64 **(4x larger)**
|
||||||
|
|
||||||
|
**Flash Attention 2 (2023 update):**
|
||||||
|
- Further optimized: 2-3x faster than Flash Attention 1
|
||||||
|
- Better parallelism
|
||||||
|
- Supports more head dimensions
|
||||||
|
|
||||||
|
### When to Use
|
||||||
|
|
||||||
|
✅ **ALWAYS use Flash Attention when:**
|
||||||
|
- Sequence length < 16k tokens
|
||||||
|
- Need exact attention (no approximation)
|
||||||
|
- Available in your framework
|
||||||
|
|
||||||
|
**It's a FREE LUNCH:**
|
||||||
|
- No accuracy loss (mathematically exact)
|
||||||
|
- Faster training AND inference
|
||||||
|
- Less memory usage
|
||||||
|
- Drop-in replacement
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
**PyTorch 2.0+ (built-in):**
|
||||||
|
```python
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Automatic Flash Attention (if available)
|
||||||
|
output = F.scaled_dot_product_attention(
|
||||||
|
query, key, value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
# PyTorch automatically uses Flash Attention if:
|
||||||
|
# - CUDA available
|
||||||
|
# - Sequence length suitable
|
||||||
|
# - No attention mask (or causal mask)
|
||||||
|
```
|
||||||
|
|
||||||
|
**HuggingFace Transformers:**
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
# Enable Flash Attention 2
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
"bert-base-uncased",
|
||||||
|
attn_implementation="flash_attention_2", # Requires flash-attn package
|
||||||
|
torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Manual installation:**
|
||||||
|
```bash
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
❌ **Flash Attention NOT suitable when:**
|
||||||
|
- Sequence length > 16k (memory still grows quadratically)
|
||||||
|
- Custom attention masks (complex patterns not supported)
|
||||||
|
- Inference on CPU (CUDA-only)
|
||||||
|
|
||||||
|
**For > 16k tokens:** Use sparse or linear attention
|
||||||
|
|
||||||
|
|
||||||
|
## Part 3: Sparse Attention (Exact for Long Sequences)
|
||||||
|
|
||||||
|
### Concept
|
||||||
|
|
||||||
|
**Idea:** Each token attends to subset of tokens (not all)
|
||||||
|
- Sliding window: Local context
|
||||||
|
- Global tokens: Long-range connections
|
||||||
|
- Result: O(n × window_size) instead of O(n²)
|
||||||
|
|
||||||
|
**Key property:** Still EXACT attention (not approximate)
|
||||||
|
- Just more structured attention pattern
|
||||||
|
- No accuracy loss if pattern matches task
|
||||||
|
|
||||||
|
### Variant 1: Longformer
|
||||||
|
|
||||||
|
**Pattern:** Sliding window + global attention
|
||||||
|
|
||||||
|
```
|
||||||
|
Attention pattern (window=2, global=[0]):
|
||||||
|
0 1 2 3 4 5
|
||||||
|
0 [ 1 1 1 1 1 1 ] ← Global token (attends to all)
|
||||||
|
1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2
|
||||||
|
2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3
|
||||||
|
3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4
|
||||||
|
4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5
|
||||||
|
5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5
|
||||||
|
|
||||||
|
Complexity: O(n × (window + num_global))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Components:**
|
||||||
|
1. **Sliding window**: Each token attends to w/2 tokens before and after
|
||||||
|
2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens
|
||||||
|
3. **Dilated windows**: Optional (stride > 1 for longer context)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
from transformers import LongformerModel
|
||||||
|
|
||||||
|
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
|
||||||
|
# Attention mask (shape: batch, seq_len)
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len)
|
||||||
|
attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token
|
||||||
|
|
||||||
|
output = model(input_ids, attention_mask=attention_mask)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Memory comparison (4k tokens, window=512):**
|
||||||
|
```
|
||||||
|
Standard: 4096² = 16M elements → 64MB
|
||||||
|
Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Documents: 4k-16k tokens (legal, scientific papers)
|
||||||
|
- Need full context but can't fit O(n²)
|
||||||
|
- Task has local + global structure
|
||||||
|
|
||||||
|
**Pretrained models:**
|
||||||
|
- `allenai/longformer-base-4096`: Max 4096 tokens
|
||||||
|
- `allenai/longformer-large-4096`: Larger version
|
||||||
|
|
||||||
|
### Variant 2: BigBird
|
||||||
|
|
||||||
|
**Pattern:** Random + window + global
|
||||||
|
|
||||||
|
```
|
||||||
|
Attention pattern:
|
||||||
|
- Sliding window: Like Longformer
|
||||||
|
- Random connections: Each token attends to r random tokens
|
||||||
|
- Global tokens: Special tokens attend to all
|
||||||
|
|
||||||
|
Complexity: O(n × (window + r + num_global))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key difference from Longformer:**
|
||||||
|
- Random connections help information flow
|
||||||
|
- Theoretically proven to approximate full attention
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Similar to Longformer
|
||||||
|
- Slightly better for tasks needing long-range
|
||||||
|
- Less widely adopted than Longformer
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
from transformers import BigBirdModel
|
||||||
|
|
||||||
|
model = BigBirdModel.from_pretrained(
|
||||||
|
"google/bigbird-roberta-base",
|
||||||
|
attention_type="block_sparse" # or "original_full"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sparse Attention Decision
|
||||||
|
|
||||||
|
```
|
||||||
|
Sequence length < 4k:
|
||||||
|
→ Flash Attention (exact, no pattern needed)
|
||||||
|
|
||||||
|
Sequence length 4k-16k:
|
||||||
|
→ Longformer (sliding window + global)
|
||||||
|
→ Best for: Documents, long-form text
|
||||||
|
|
||||||
|
Sequence length > 16k:
|
||||||
|
→ Longformer if possible
|
||||||
|
→ Linear attention if Longformer too slow
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 4: Linear Attention (Approximate for Very Long)
|
||||||
|
|
||||||
|
### Concept
|
||||||
|
|
||||||
|
**Idea:** Approximate softmax attention with linear operations
|
||||||
|
- Complexity: O(n × k) where k << n
|
||||||
|
- Trade-off: 1-3% accuracy loss
|
||||||
|
- Benefit: Can handle very long sequences (> 16k)
|
||||||
|
|
||||||
|
**Key property:** APPROXIMATE (not exact)
|
||||||
|
- Do NOT use if accuracy critical
|
||||||
|
- Good for extremely long sequences where exact is impossible
|
||||||
|
|
||||||
|
### Variant 1: Performer
|
||||||
|
|
||||||
|
**Method:** Random Fourier Features to approximate softmax(Q K^T)
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
# Standard attention
|
||||||
|
Attention(Q, K, V) = softmax(Q K^T) V
|
||||||
|
|
||||||
|
# Performer approximation
|
||||||
|
φ(Q) ≈ φ(K)^T ≈ softmax(Q K^T)
|
||||||
|
Attention(Q, K, V) ≈ φ(Q) (φ(K)^T V)
|
||||||
|
|
||||||
|
# Complexity: O(n × k) where k = feature dimension
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key trick:**
|
||||||
|
- Compute φ(K)^T V first: (k × d) matrix (small!)
|
||||||
|
- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d)
|
||||||
|
- Never materialize n² attention matrix
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
# From performer-pytorch library
|
||||||
|
from performer_pytorch import Performer
|
||||||
|
|
||||||
|
model = Performer(
|
||||||
|
dim=512,
|
||||||
|
depth=6,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
causal=False,
|
||||||
|
nb_features=256 # k = number of random features
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Accuracy:**
|
||||||
|
- Typical loss: 1-2% vs standard attention
|
||||||
|
- Depends on nb_features (more features = better approximation)
|
||||||
|
- k=256 usually sufficient
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Sequence length > 16k tokens
|
||||||
|
- Accuracy loss acceptable (not critical task)
|
||||||
|
- Need better than sparse attention (no structure assumptions)
|
||||||
|
|
||||||
|
### Variant 2: Linformer
|
||||||
|
|
||||||
|
**Method:** Project K and V to lower dimension
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
# Standard attention (n × n attention matrix)
|
||||||
|
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
|
||||||
|
|
||||||
|
# Linformer (project K, V to n × k where k << n)
|
||||||
|
K_proj = E K # E: (k × n) projection matrix
|
||||||
|
V_proj = F V # F: (k × n) projection matrix
|
||||||
|
|
||||||
|
Attention(Q, K, V) ≈ softmax(Q K_proj^T / √d_k) V_proj
|
||||||
|
# Attention matrix: (n × k) instead of (n × n)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Complexity:**
|
||||||
|
- Time: O(n × k × d) where k << n
|
||||||
|
- Memory: O(n × k) instead of O(n²)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
# From linformer library
|
||||||
|
from linformer import Linformer
|
||||||
|
|
||||||
|
model = Linformer(
|
||||||
|
dim=512,
|
||||||
|
seq_len=8192,
|
||||||
|
depth=12,
|
||||||
|
heads=8,
|
||||||
|
k=256 # Projected dimension
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Accuracy:**
|
||||||
|
- Typical loss: 1-3% vs standard attention
|
||||||
|
- More loss than Performer
|
||||||
|
- Fixed sequence length (k is tied to max_seq_len)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Fixed-length long sequences
|
||||||
|
- Memory more critical than speed
|
||||||
|
- Accuracy loss OK (2-3%)
|
||||||
|
|
||||||
|
### Linear Attention Decision
|
||||||
|
|
||||||
|
```
|
||||||
|
Need exact attention:
|
||||||
|
→ Flash Attention or Sparse Attention (NOT linear)
|
||||||
|
|
||||||
|
Sequence > 16k, accuracy critical:
|
||||||
|
→ Sparse Attention (Longformer)
|
||||||
|
|
||||||
|
Sequence > 16k, accuracy loss OK:
|
||||||
|
→ Performer (better) or Linformer
|
||||||
|
|
||||||
|
Sequence > 100k:
|
||||||
|
→ State space models (S4, Mamba, not attention)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 5: Cross-Attention (Multimodal)
|
||||||
|
|
||||||
|
### Concept
|
||||||
|
|
||||||
|
**Self-attention:** Q, K, V from same source
|
||||||
|
**Cross-attention:** Q from one source, K/V from another
|
||||||
|
|
||||||
|
**Use cases:**
|
||||||
|
- Multimodal: vision → language (image captioning)
|
||||||
|
- Seq2seq: source language → target language (translation)
|
||||||
|
- RAG: query → document retrieval
|
||||||
|
- Conditioning: generation conditioned on context
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
```python
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.mha = MultiHeadAttention(d_model, num_heads)
|
||||||
|
|
||||||
|
def forward(self, query_source, key_value_source, mask=None):
|
||||||
|
# query_source: (batch, n_q, d_model) - e.g., text tokens
|
||||||
|
# key_value_source: (batch, n_kv, d_model) - e.g., image patches
|
||||||
|
|
||||||
|
# Q from query source
|
||||||
|
Q = self.W_q(query_source)
|
||||||
|
|
||||||
|
# K, V from key-value source
|
||||||
|
K = self.W_k(key_value_source)
|
||||||
|
V = self.W_v(key_value_source)
|
||||||
|
|
||||||
|
# Attention: (batch, n_q, d_model)
|
||||||
|
output = attention(Q, K, V, mask)
|
||||||
|
return output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Image Captioning
|
||||||
|
|
||||||
|
**Task:** Generate caption from image
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
1. **Image Encoder:** ViT processes image → image features (n_patches × d)
|
||||||
|
2. **Text Decoder:** Autoregressive text generation
|
||||||
|
3. **Cross-Attention:** Text queries image features
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ImageCaptioningDecoder(nn.Module):
|
||||||
|
def forward(self, text_tokens, image_features):
|
||||||
|
# 1. Self-attention on text (causal)
|
||||||
|
text = self.text_self_attention(
|
||||||
|
query=text,
|
||||||
|
key=text,
|
||||||
|
value=text,
|
||||||
|
causal_mask=True # Don't see future words
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Cross-attention (text queries image)
|
||||||
|
text = self.cross_attention(
|
||||||
|
query=text, # From text decoder
|
||||||
|
key=image_features, # From image encoder
|
||||||
|
value=image_features # From image encoder
|
||||||
|
# No causal mask! Can attend to all image patches
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
text = self.feed_forward(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
```
|
||||||
|
|
||||||
|
**Attention flow:**
|
||||||
|
- Text token "cat" → High attention to cat region in image
|
||||||
|
- Text token "sitting" → High attention to posture in image
|
||||||
|
|
||||||
|
### Example: Retrieval-Augmented Generation (RAG)
|
||||||
|
|
||||||
|
**Task:** Generate answer using retrieved documents
|
||||||
|
|
||||||
|
```python
|
||||||
|
class RAGDecoder(nn.Module):
|
||||||
|
def forward(self, query_tokens, document_embeddings):
|
||||||
|
# 1. Self-attention on query
|
||||||
|
query = self.query_self_attention(query, query, query)
|
||||||
|
|
||||||
|
# 2. Cross-attention (query → documents)
|
||||||
|
query = self.cross_attention(
|
||||||
|
query=query, # What we're generating
|
||||||
|
key=document_embeddings, # Retrieved docs
|
||||||
|
value=document_embeddings # Retrieved docs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query learns to extract relevant info from docs
|
||||||
|
|
||||||
|
return query
|
||||||
|
```
|
||||||
|
|
||||||
|
### When to Use Cross-Attention
|
||||||
|
|
||||||
|
✅ **Use cross-attention when:**
|
||||||
|
- Two different modalities (vision + language)
|
||||||
|
- Conditioning generation on context (RAG)
|
||||||
|
- Seq2seq with different input/output (translation)
|
||||||
|
- Query-document matching
|
||||||
|
|
||||||
|
❌ **Don't use cross-attention when:**
|
||||||
|
- Same modality (use self-attention)
|
||||||
|
- No clear query vs key-value separation
|
||||||
|
|
||||||
|
|
||||||
|
## Part 6: Other Attention Variants
|
||||||
|
|
||||||
|
### Axial Attention (2D Images)
|
||||||
|
|
||||||
|
**Idea:** For 2D data (images), attend along each axis separately
|
||||||
|
|
||||||
|
```
|
||||||
|
Standard 2D attention: H×W tokens → (HW)² attention matrix
|
||||||
|
Axial attention:
|
||||||
|
- Row attention: Each row attends to itself (H × W²)
|
||||||
|
- Column attention: Each column attends to itself (W × H²)
|
||||||
|
- Total: O(HW × (H + W)) << O((HW)²)
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- High-resolution images
|
||||||
|
- 2D positional structure important
|
||||||
|
|
||||||
|
### Block-Sparse Attention
|
||||||
|
|
||||||
|
**Idea:** Divide attention into blocks, attend only within/across blocks
|
||||||
|
|
||||||
|
**Pattern:**
|
||||||
|
```
|
||||||
|
Block size = 64 tokens
|
||||||
|
- Local block: Attend within same block
|
||||||
|
- Vertical stripe: Attend to corresponding position in other blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**Used in:** Sparse Transformer (OpenAI), GPT-3
|
||||||
|
|
||||||
|
### Multi-Query Attention (MQA)
|
||||||
|
|
||||||
|
**Idea:** One K/V head shared across all Q heads
|
||||||
|
|
||||||
|
**Benefit:**
|
||||||
|
- Smaller KV cache during inference
|
||||||
|
- Much faster decoding (4-8x)
|
||||||
|
- Trade-off: ~1% accuracy loss
|
||||||
|
|
||||||
|
**Used in:** PaLM, Falcon
|
||||||
|
|
||||||
|
### Grouped-Query Attention (GQA)
|
||||||
|
|
||||||
|
**Idea:** Middle ground between multi-head and multi-query
|
||||||
|
- Group Q heads share K/V heads
|
||||||
|
- Example: 32 Q heads → 8 K/V heads (4:1 ratio)
|
||||||
|
|
||||||
|
**Benefit:**
|
||||||
|
- 4x smaller KV cache
|
||||||
|
- Minimal accuracy loss (< 0.5%)
|
||||||
|
|
||||||
|
**Used in:** LLaMA-2, Mistral
|
||||||
|
|
||||||
|
|
||||||
|
## Part 7: Decision Framework
|
||||||
|
|
||||||
|
### By Sequence Length
|
||||||
|
|
||||||
|
```
|
||||||
|
< 2k tokens:
|
||||||
|
→ Flash Attention
|
||||||
|
Exact, fast, standard
|
||||||
|
|
||||||
|
2k-4k tokens:
|
||||||
|
→ Flash Attention
|
||||||
|
Still manageable with modern GPUs
|
||||||
|
|
||||||
|
4k-16k tokens:
|
||||||
|
→ Sparse Attention (Longformer, BigBird)
|
||||||
|
Exact, designed for documents
|
||||||
|
→ OR Flash Attention if batch size = 1
|
||||||
|
|
||||||
|
> 16k tokens:
|
||||||
|
→ Sparse Attention
|
||||||
|
If task has local structure
|
||||||
|
→ Linear Attention (Performer)
|
||||||
|
If accuracy loss OK (1-2%)
|
||||||
|
→ State Space Models (S4, Mamba)
|
||||||
|
If sequence > 100k
|
||||||
|
```
|
||||||
|
|
||||||
|
### By Memory Constraints
|
||||||
|
|
||||||
|
```
|
||||||
|
GPU OOM with standard attention:
|
||||||
|
1. Try Flash Attention (4x less memory, free lunch)
|
||||||
|
2. If still OOM, reduce batch size
|
||||||
|
3. If batch size = 1 and still OOM, use sparse attention
|
||||||
|
4. Last resort: Linear attention (if accuracy loss OK)
|
||||||
|
|
||||||
|
DON'T:
|
||||||
|
- Gradient checkpointing (slower, use Flash Attention instead)
|
||||||
|
- Throwing more GPUs (algorithmic problem, not hardware)
|
||||||
|
```
|
||||||
|
|
||||||
|
### By Accuracy Requirements
|
||||||
|
|
||||||
|
```
|
||||||
|
Must be exact (no approximation):
|
||||||
|
→ Flash Attention or Sparse Attention
|
||||||
|
Never use linear attention!
|
||||||
|
|
||||||
|
Accuracy loss acceptable (1-3%):
|
||||||
|
→ Linear Attention (Performer, Linformer)
|
||||||
|
Only for very long sequences (> 16k)
|
||||||
|
|
||||||
|
Critical task (medical, legal):
|
||||||
|
→ Exact attention only
|
||||||
|
Flash Attention or Sparse Attention
|
||||||
|
```
|
||||||
|
|
||||||
|
### By Task Type
|
||||||
|
|
||||||
|
```
|
||||||
|
Classification / Understanding:
|
||||||
|
→ Standard + Flash Attention
|
||||||
|
Sequence usually < 2k
|
||||||
|
|
||||||
|
Document processing:
|
||||||
|
→ Longformer (4096 tokens)
|
||||||
|
Designed for documents
|
||||||
|
|
||||||
|
Generation (LLM):
|
||||||
|
→ Flash Attention for training
|
||||||
|
→ + GQA/MQA for inference (faster decoding)
|
||||||
|
|
||||||
|
Multimodal (vision + language):
|
||||||
|
→ Cross-attention for modality fusion
|
||||||
|
→ Self-attention within each modality
|
||||||
|
|
||||||
|
Retrieval-augmented:
|
||||||
|
→ Cross-attention (query → documents)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 8: Implementation Checklist
|
||||||
|
|
||||||
|
### Using Flash Attention
|
||||||
|
|
||||||
|
**PyTorch 2.0+:**
|
||||||
|
```python
|
||||||
|
# Automatic (recommended)
|
||||||
|
output = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
|
# Verify Flash Attention is used
|
||||||
|
import torch.backends.cuda
|
||||||
|
print(torch.backends.cuda.flash_sdp_enabled()) # Should be True
|
||||||
|
```
|
||||||
|
|
||||||
|
**HuggingFace:**
|
||||||
|
```python
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
"model-name",
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=torch.float16 # Flash Attention needs fp16/bf16
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Requirements:**
|
||||||
|
- CUDA GPU (not CPU)
|
||||||
|
- PyTorch >= 2.0 OR flash-attn package
|
||||||
|
- fp16 or bf16 dtype (not fp32)
|
||||||
|
|
||||||
|
### Using Sparse Attention
|
||||||
|
|
||||||
|
**Longformer:**
|
||||||
|
```python
|
||||||
|
from transformers import LongformerModel, LongformerTokenizer
|
||||||
|
|
||||||
|
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
|
||||||
|
# Attention mask
|
||||||
|
# 0 = no attention, 1 = local attention, 2 = global attention
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len)
|
||||||
|
attention_mask[:, 0] = 2 # [CLS] token gets global attention
|
||||||
|
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Custom sparse pattern:**
|
||||||
|
```python
|
||||||
|
# Create custom block-sparse mask
|
||||||
|
def create_block_sparse_mask(seq_len, block_size):
|
||||||
|
num_blocks = seq_len // block_size
|
||||||
|
mask = torch.zeros(seq_len, seq_len)
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
start = i * block_size
|
||||||
|
end = start + block_size
|
||||||
|
mask[start:end, start:end] = 1 # Local block
|
||||||
|
|
||||||
|
return mask
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Cross-Attention
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DecoderWithCrossAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
||||||
|
self.cross_attn = MultiHeadAttention(d_model, num_heads)
|
||||||
|
|
||||||
|
def forward(self, decoder_input, encoder_output, causal_mask=None):
|
||||||
|
# Self-attention (causal)
|
||||||
|
x = self.self_attn(
|
||||||
|
query=decoder_input,
|
||||||
|
key=decoder_input,
|
||||||
|
value=decoder_input,
|
||||||
|
mask=causal_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross-attention (Q from decoder, K/V from encoder)
|
||||||
|
x = self.cross_attn(
|
||||||
|
query=x, # From decoder
|
||||||
|
key=encoder_output, # From encoder
|
||||||
|
value=encoder_output, # From encoder
|
||||||
|
mask=None # No causal mask for cross-attention!
|
||||||
|
)
|
||||||
|
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 9: Common Mistakes
|
||||||
|
|
||||||
|
### Mistake 1: Ignoring Flash Attention
|
||||||
|
|
||||||
|
**Symptom:** Training slow, high memory usage
|
||||||
|
**Fix:** Always use Flash Attention for < 16k tokens
|
||||||
|
|
||||||
|
### Mistake 2: Using Linear Attention Unnecessarily
|
||||||
|
|
||||||
|
**Symptom:** 1-3% accuracy loss for no reason
|
||||||
|
**Fix:** Use Flash Attention (exact) unless sequence > 16k
|
||||||
|
|
||||||
|
### Mistake 3: Gradient Checkpointing Instead of Flash Attention
|
||||||
|
|
||||||
|
**Symptom:** Training 20% slower
|
||||||
|
**Fix:** Flash Attention gives memory savings AND speed
|
||||||
|
|
||||||
|
### Mistake 4: Cross-Attention with Causal Mask
|
||||||
|
|
||||||
|
**Symptom:** Decoder can't attend to encoder properly
|
||||||
|
**Fix:** Causal mask only for self-attention, NOT cross-attention
|
||||||
|
|
||||||
|
### Mistake 5: Accepting O(n²) Memory
|
||||||
|
|
||||||
|
**Symptom:** GPU OOM for > 4k tokens
|
||||||
|
**Fix:** Use sparse or Flash Attention, don't just add GPUs
|
||||||
|
|
||||||
|
|
||||||
|
## Summary: Quick Reference
|
||||||
|
|
||||||
|
### Attention Selection
|
||||||
|
|
||||||
|
```
|
||||||
|
Sequence length:
|
||||||
|
< 2k → Flash Attention (default)
|
||||||
|
2-4k → Flash Attention
|
||||||
|
4-16k → Longformer (documents) or Flash Attention (batch=1)
|
||||||
|
> 16k → Sparse or Linear Attention
|
||||||
|
|
||||||
|
Memory constrained:
|
||||||
|
First: Try Flash Attention (4x less memory)
|
||||||
|
Still OOM: Use sparse attention (Longformer)
|
||||||
|
Last resort: Linear attention (accuracy loss)
|
||||||
|
|
||||||
|
Speed critical:
|
||||||
|
Training: Flash Attention (2x faster)
|
||||||
|
Inference: Flash Attention + GQA/MQA
|
||||||
|
|
||||||
|
Accuracy critical:
|
||||||
|
Use exact attention only (Flash or Sparse)
|
||||||
|
NEVER linear attention
|
||||||
|
|
||||||
|
Multimodal:
|
||||||
|
Cross-attention for modality fusion
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
```
|
||||||
|
PyTorch 2.0+:
|
||||||
|
F.scaled_dot_product_attention() # Auto Flash Attention
|
||||||
|
|
||||||
|
HuggingFace:
|
||||||
|
attn_implementation="flash_attention_2"
|
||||||
|
|
||||||
|
Longformer:
|
||||||
|
LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
|
||||||
|
Custom:
|
||||||
|
Inherit from nn.Module, implement forward()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
After mastering this skill:
|
||||||
|
- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference
|
||||||
|
- `llm-specialist/context-window-management`: Manage long contexts in LLMs
|
||||||
|
- `architecture-design-principles`: Understand broader design trade-offs
|
||||||
|
|
||||||
|
**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns).
|
||||||
622
skills/using-neural-architectures/cnn-families-and-selection.md
Normal file
622
skills/using-neural-architectures/cnn-families-and-selection.md
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
|
||||||
|
# CNN Families and Selection: Choosing the Right Convolutional Network
|
||||||
|
|
||||||
|
<CRITICAL_CONTEXT>
|
||||||
|
CNNs are the foundation of computer vision. Different families have vastly different trade-offs:
|
||||||
|
- Accuracy vs Speed vs Size
|
||||||
|
- Dataset size requirements
|
||||||
|
- Deployment target (cloud vs edge vs mobile)
|
||||||
|
- Task type (classification vs detection vs segmentation)
|
||||||
|
|
||||||
|
This skill helps you choose the RIGHT CNN for YOUR constraints.
|
||||||
|
</CRITICAL_CONTEXT>
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when:
|
||||||
|
- ✅ Selecting CNN for vision task (classification, detection, segmentation)
|
||||||
|
- ✅ Comparing CNN families (ResNet vs EfficientNet vs MobileNet)
|
||||||
|
- ✅ Optimizing for specific constraints (latency, size, accuracy)
|
||||||
|
- ✅ Understanding CNN evolution (why newer architectures exist)
|
||||||
|
- ✅ Deployment-specific selection (cloud, edge, mobile)
|
||||||
|
|
||||||
|
DO NOT use for:
|
||||||
|
- ❌ Non-vision tasks (use sequence-models-comparison or other skills)
|
||||||
|
- ❌ Training optimization (use training-optimization pack)
|
||||||
|
- ❌ Implementation details (use pytorch-engineering pack)
|
||||||
|
|
||||||
|
**When in doubt:** If choosing WHICH CNN → this skill. If implementing/training CNN → other skills.
|
||||||
|
|
||||||
|
|
||||||
|
## Selection Framework
|
||||||
|
|
||||||
|
### Step 1: Identify Constraints
|
||||||
|
|
||||||
|
**Before recommending ANY architecture, ask:**
|
||||||
|
|
||||||
|
| Constraint | Question | Impact |
|
||||||
|
|------------|----------|--------|
|
||||||
|
| **Deployment** | Where will model run? | Cloud → Any, Edge → MobileNet/EfficientNet-Lite, Mobile → MobileNetV3 |
|
||||||
|
| **Latency** | Speed requirement? | Real-time (< 10ms) → MobileNet, Batch (> 100ms) → Any |
|
||||||
|
| **Model Size** | Parameter/memory budget? | < 10M params → MobileNet, < 50M → ResNet/EfficientNet, Any → Large models OK |
|
||||||
|
| **Dataset Size** | Training images? | < 10k → Small models, 10k-100k → Medium, > 100k → Large |
|
||||||
|
| **Accuracy** | Required accuracy? | Competitive → EfficientNet-B4+, Production → ResNet-50/EfficientNet-B2 |
|
||||||
|
| **Task Type** | Classification/detection/segmentation? | Detection → FPN-compatible, Segmentation → Multi-scale |
|
||||||
|
|
||||||
|
**Critical:** Get answers to these BEFORE recommending architecture.
|
||||||
|
|
||||||
|
### Step 2: Apply Decision Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
START: What's your primary constraint?
|
||||||
|
|
||||||
|
┌─ DEPLOYMENT TARGET
|
||||||
|
│ ├─ Cloud / Server
|
||||||
|
│ │ └─ Dataset size?
|
||||||
|
│ │ ├─ Small (< 10k) → ResNet-18, EfficientNet-B0
|
||||||
|
│ │ ├─ Medium (10k-100k) → ResNet-50, EfficientNet-B2
|
||||||
|
│ │ └─ Large (> 100k) → ResNet-101, EfficientNet-B4, ViT
|
||||||
|
│ │
|
||||||
|
│ ├─ Edge Device (Jetson, Coral)
|
||||||
|
│ │ └─ Latency requirement?
|
||||||
|
│ │ ├─ Real-time (< 10ms) → MobileNetV3-Small, EfficientNet-Lite0
|
||||||
|
│ │ ├─ Medium (10-50ms) → MobileNetV3-Large, EfficientNet-Lite2
|
||||||
|
│ │ └─ Relaxed (> 50ms) → EfficientNet-B0, ResNet-18
|
||||||
|
│ │
|
||||||
|
│ └─ Mobile (iOS/Android)
|
||||||
|
│ └─ MobileNetV3-Small (fastest), MobileNetV3-Large (balanced)
|
||||||
|
│ + INT8 quantization (route to ml-production)
|
||||||
|
│
|
||||||
|
├─ ACCURACY PRIORITY (cloud deployment assumed)
|
||||||
|
│ ├─ Maximum accuracy → EfficientNet-B7, ResNet-152, ViT-Large
|
||||||
|
│ ├─ Balanced → EfficientNet-B2/B3, ResNet-50
|
||||||
|
│ └─ Fast training → ResNet-18, EfficientNet-B0
|
||||||
|
│
|
||||||
|
├─ EFFICIENCY PRIORITY
|
||||||
|
│ └─ Best accuracy per FLOP → EfficientNet family (B0-B7)
|
||||||
|
│ (EfficientNet dominates ResNet on Pareto frontier)
|
||||||
|
│
|
||||||
|
└─ TASK TYPE
|
||||||
|
├─ Classification → Any CNN (use constraint-based selection above)
|
||||||
|
├─ Object Detection → ResNet + FPN, EfficientDet, YOLOv8 (CSPDarknet)
|
||||||
|
└─ Segmentation → ResNet + U-Net, EfficientNet + DeepLabV3
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## CNN Family Catalog
|
||||||
|
|
||||||
|
### 1. ResNet Family (2015) - The Standard Baseline
|
||||||
|
|
||||||
|
**Architecture:** Residual connections (skip connections) enable very deep networks
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- ResNet-18: 11M params, 1.8 GFLOPs, 69.8% ImageNet
|
||||||
|
- ResNet-34: 22M params, 3.7 GFLOPs, 73.3% ImageNet
|
||||||
|
- ResNet-50: 25M params, 4.1 GFLOPs, 76.1% ImageNet
|
||||||
|
- ResNet-101: 44M params, 7.8 GFLOPs, 77.4% ImageNet
|
||||||
|
- ResNet-152: 60M params, 11.6 GFLOPs, 78.3% ImageNet
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Baseline choice**: Well-tested, widely supported
|
||||||
|
- ✅ **Transfer learning**: Excellent pre-trained weights available
|
||||||
|
- ✅ **Object detection**: Standard backbone for Faster R-CNN, Mask R-CNN
|
||||||
|
- ✅ **Interpretability**: Simple architecture, easy to understand
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ **Edge/mobile deployment**: Too large and slow
|
||||||
|
- ❌ **Efficiency priority**: EfficientNet beats ResNet on accuracy/FLOP
|
||||||
|
- ❌ **Small datasets (< 10k)**: Use ResNet-18, not ResNet-50+
|
||||||
|
|
||||||
|
**Key Insight:** Skip connections solve vanishing gradient, enable depth
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
# For cloud/server (good dataset)
|
||||||
|
model = models.resnet50(pretrained=True)
|
||||||
|
|
||||||
|
# For small dataset or faster training
|
||||||
|
model = models.resnet18(pretrained=True)
|
||||||
|
|
||||||
|
# For maximum accuracy (cloud only)
|
||||||
|
model = models.resnet101(pretrained=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 2. EfficientNet Family (2019) - Best Efficiency
|
||||||
|
|
||||||
|
**Architecture:** Compound scaling (depth + width + resolution) optimized via neural architecture search
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- EfficientNet-B0: 5M params, 0.4 GFLOPs, 77.3% ImageNet
|
||||||
|
- EfficientNet-B1: 8M params, 0.7 GFLOPs, 79.2% ImageNet
|
||||||
|
- EfficientNet-B2: 9M params, 1.0 GFLOPs, 80.3% ImageNet
|
||||||
|
- EfficientNet-B3: 12M params, 1.8 GFLOPs, 81.7% ImageNet
|
||||||
|
- EfficientNet-B4: 19M params, 4.2 GFLOPs, 82.9% ImageNet
|
||||||
|
- EfficientNet-B7: 66M params, 37 GFLOPs, 84.4% ImageNet
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Efficiency matters**: Best accuracy per FLOP/parameter
|
||||||
|
- ✅ **Cloud deployment**: B2-B4 sweet spot for production
|
||||||
|
- ✅ **Limited compute**: B0 matches ResNet-50 accuracy at 10x fewer FLOPs
|
||||||
|
- ✅ **Scaling needs**: Want to scale model up/down systematically
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ **Real-time mobile**: Use MobileNet (EfficientNet has more layers)
|
||||||
|
- ❌ **Very small datasets**: Can overfit despite efficiency
|
||||||
|
- ❌ **Simplicity needed**: More complex than ResNet
|
||||||
|
|
||||||
|
**Key Insight:** Compound scaling balances depth, width, and resolution optimally
|
||||||
|
|
||||||
|
**Efficiency Comparison:**
|
||||||
|
```
|
||||||
|
Same accuracy as ResNet-50 (76%):
|
||||||
|
- ResNet-50: 25M params, 4.1 GFLOPs
|
||||||
|
- EfficientNet-B0: 5M params, 0.4 GFLOPs (10x more efficient!)
|
||||||
|
|
||||||
|
Better accuracy (82.9%):
|
||||||
|
- ResNet-152: 60M params, 11.6 GFLOPs → 78.3% ImageNet
|
||||||
|
- EfficientNet-B4: 19M params, 4.2 GFLOPs → 82.9% ImageNet
|
||||||
|
(Better accuracy with 3x fewer params and 3x less compute)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
import timm # PyTorch Image Models library
|
||||||
|
|
||||||
|
# Balanced choice (production)
|
||||||
|
model = timm.create_model('efficientnet_b2', pretrained=True)
|
||||||
|
|
||||||
|
# Efficiency priority (edge)
|
||||||
|
model = timm.create_model('efficientnet_b0', pretrained=True)
|
||||||
|
|
||||||
|
# Accuracy priority (research)
|
||||||
|
model = timm.create_model('efficientnet_b4', pretrained=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. MobileNet Family (2017-2019) - Mobile Optimized
|
||||||
|
|
||||||
|
**Architecture:** Depthwise separable convolutions (drastically reduce compute)
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- MobileNetV1: 4.2M params, 0.6 GFLOPs, 70.6% ImageNet
|
||||||
|
- MobileNetV2: 3.5M params, 0.3 GFLOPs, 72.0% ImageNet
|
||||||
|
- MobileNetV3-Small: 2.5M params, 0.06 GFLOPs, 67.4% ImageNet
|
||||||
|
- MobileNetV3-Large: 5.4M params, 0.2 GFLOPs, 75.2% ImageNet
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Mobile deployment**: iOS/Android apps
|
||||||
|
- ✅ **Edge devices**: Raspberry Pi, Jetson Nano
|
||||||
|
- ✅ **Real-time inference**: < 100ms latency
|
||||||
|
- ✅ **Extreme efficiency**: < 10M parameters budget
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ **Cloud deployment with no constraints**: EfficientNet or ResNet better accuracy
|
||||||
|
- ❌ **Accuracy priority**: Sacrifices accuracy for speed
|
||||||
|
- ❌ **Large datasets with compute**: Can afford better models
|
||||||
|
|
||||||
|
**Key Insight:** Depthwise separable convolutions = standard conv split into depthwise + pointwise (9x fewer operations)
|
||||||
|
|
||||||
|
**Deployment Performance:**
|
||||||
|
```
|
||||||
|
Raspberry Pi 4 inference (224×224 image):
|
||||||
|
- ResNet-50: ~2000ms (unusable)
|
||||||
|
- ResNet-18: ~600ms (slow)
|
||||||
|
- MobileNetV2: ~150ms (acceptable)
|
||||||
|
- MobileNetV3-Large: ~80ms (good)
|
||||||
|
- MobileNetV3-Small: ~40ms (fast)
|
||||||
|
|
||||||
|
With INT8 quantization:
|
||||||
|
- MobileNetV3-Large: ~30ms (production-ready)
|
||||||
|
- MobileNetV3-Small: ~15ms (real-time)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
# For mobile deployment
|
||||||
|
model = models.mobilenet_v3_large(pretrained=True)
|
||||||
|
|
||||||
|
# For ultra-low latency (sacrifice accuracy)
|
||||||
|
model = models.mobilenet_v3_small(pretrained=True)
|
||||||
|
|
||||||
|
# Quantization for mobile (route to ml-production skill for details)
|
||||||
|
# Achieves 2-4x speedup with minimal accuracy loss
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 4. Inception Family (2014-2016) - Multi-Scale Features
|
||||||
|
|
||||||
|
**Architecture:** Multi-scale convolutions in parallel (inception modules)
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- InceptionV3: 24M params, 5.7 GFLOPs, 77.5% ImageNet
|
||||||
|
- InceptionV4: 42M params, 12.3 GFLOPs, 80.0% ImageNet
|
||||||
|
- Inception-ResNet: Hybrid with residual connections
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Multi-scale features**: Objects at different sizes
|
||||||
|
- ✅ **Object detection**: Good backbone for detection
|
||||||
|
- ✅ **Historical interest**: Understanding multi-scale approaches
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ **Simplicity needed**: Complex architecture, hard to modify
|
||||||
|
- ❌ **Efficiency priority**: EfficientNet better
|
||||||
|
- ❌ **Modern projects**: Largely superseded by ResNet/EfficientNet
|
||||||
|
|
||||||
|
**Key Insight:** Parallel multi-scale convolutions (1×1, 3×3, 5×5) capture different receptive fields
|
||||||
|
|
||||||
|
**Status:** Mostly historical - ResNet and EfficientNet have replaced Inception in practice
|
||||||
|
|
||||||
|
|
||||||
|
### 5. DenseNet Family (2017) - Dense Connections
|
||||||
|
|
||||||
|
**Architecture:** Every layer connects to every other layer (dense connections)
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- DenseNet-121: 8M params, 2.9 GFLOPs, 74.4% ImageNet
|
||||||
|
- DenseNet-169: 14M params, 3.4 GFLOPs, 75.6% ImageNet
|
||||||
|
- DenseNet-201: 20M params, 4.3 GFLOPs, 76.9% ImageNet
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Parameter efficiency**: Good accuracy with few parameters
|
||||||
|
- ✅ **Feature reuse**: Dense connections enable feature reuse
|
||||||
|
- ✅ **Small datasets**: Better gradient flow helps with limited data
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ **Inference speed priority**: Dense connections slow (high memory bandwidth)
|
||||||
|
- ❌ **Training speed**: Slower to train than ResNet
|
||||||
|
- ❌ **Production deployment**: Less mature ecosystem than ResNet
|
||||||
|
|
||||||
|
**Key Insight:** Dense connections improve gradient flow, enable feature reuse, but slow inference
|
||||||
|
|
||||||
|
**Status:** Theoretically elegant, but ResNet/EfficientNet more practical
|
||||||
|
|
||||||
|
|
||||||
|
### 6. VGG Family (2014) - Historical Baseline
|
||||||
|
|
||||||
|
**Architecture:** Very deep (16-19 layers), small 3×3 convolutions, many parameters
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- VGG-16: 138M params, 15.5 GFLOPs, 71.5% ImageNet
|
||||||
|
- VGG-19: 144M params, 19.6 GFLOPs, 71.1% ImageNet
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ❌ **DON'T use VGG for new projects**
|
||||||
|
- Historical understanding only
|
||||||
|
|
||||||
|
**Why NOT to Use:**
|
||||||
|
- Massive parameter count (138M vs ResNet-50's 25M)
|
||||||
|
- Poor accuracy for size
|
||||||
|
- Superseded by ResNet (2015)
|
||||||
|
|
||||||
|
**Key Insight:** Proved that depth matters, but skip connections (ResNet) are better
|
||||||
|
|
||||||
|
**Status:** **Obsolete** - use ResNet or EfficientNet instead
|
||||||
|
|
||||||
|
|
||||||
|
## Practical Selection Guide
|
||||||
|
|
||||||
|
### Scenario 1: Cloud/Server Deployment
|
||||||
|
|
||||||
|
**Goal:** Best accuracy, no compute constraints
|
||||||
|
|
||||||
|
**Recommendation:**
|
||||||
|
```
|
||||||
|
Small dataset (< 10k images):
|
||||||
|
→ EfficientNet-B0 or ResNet-18
|
||||||
|
(Avoid overfitting with smaller model)
|
||||||
|
|
||||||
|
Medium dataset (10k-100k images):
|
||||||
|
→ EfficientNet-B2 or ResNet-50
|
||||||
|
(Balanced accuracy and efficiency)
|
||||||
|
|
||||||
|
Large dataset (> 100k images):
|
||||||
|
→ EfficientNet-B4 or ResNet-101
|
||||||
|
(Can afford larger model)
|
||||||
|
|
||||||
|
Maximum accuracy (research):
|
||||||
|
→ EfficientNet-B7 or Vision Transformer
|
||||||
|
(If dataset > 1M images and compute unlimited)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 2: Edge Deployment (Jetson, Coral TPU)
|
||||||
|
|
||||||
|
**Goal:** Optimize for edge hardware latency
|
||||||
|
|
||||||
|
**Recommendation:**
|
||||||
|
```
|
||||||
|
Real-time requirement (< 10ms):
|
||||||
|
→ MobileNetV3-Small or EfficientNet-Lite0
|
||||||
|
+ INT8 quantization
|
||||||
|
|
||||||
|
Medium latency (10-50ms):
|
||||||
|
→ MobileNetV3-Large or EfficientNet-Lite2
|
||||||
|
|
||||||
|
Relaxed latency (> 50ms):
|
||||||
|
→ EfficientNet-B0 or ResNet-18
|
||||||
|
```
|
||||||
|
|
||||||
|
**Critical:** Profile on actual edge hardware. Quantization is mandatory (route to ml-production).
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 3: Mobile Deployment (iOS/Android)
|
||||||
|
|
||||||
|
**Goal:** On-device inference, minimal battery drain
|
||||||
|
|
||||||
|
**Recommendation:**
|
||||||
|
```
|
||||||
|
All mobile deployments:
|
||||||
|
→ MobileNetV3-Large (balanced)
|
||||||
|
→ MobileNetV3-Small (fastest, less accurate)
|
||||||
|
|
||||||
|
Always use:
|
||||||
|
- INT8 quantization (2-4x speedup)
|
||||||
|
- CoreML (iOS) or TFLite (Android) optimization
|
||||||
|
- Benchmark on target device before deploying
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected latency (iPhone 12, INT8 quantized):**
|
||||||
|
- MobileNetV3-Small: 5-10ms
|
||||||
|
- MobileNetV3-Large: 15-25ms
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 4: Object Detection
|
||||||
|
|
||||||
|
**Goal:** Select backbone for detection framework
|
||||||
|
|
||||||
|
**Recommendation:**
|
||||||
|
```
|
||||||
|
Faster R-CNN:
|
||||||
|
→ ResNet-50 + FPN (standard)
|
||||||
|
→ ResNet-101 + FPN (more accuracy)
|
||||||
|
|
||||||
|
YOLOv8:
|
||||||
|
→ CSPDarknet (built-in, optimized)
|
||||||
|
|
||||||
|
EfficientDet:
|
||||||
|
→ EfficientNet + BiFPN (best efficiency)
|
||||||
|
|
||||||
|
Custom detection:
|
||||||
|
→ ResNet or EfficientNet as backbone
|
||||||
|
→ Add Feature Pyramid Network (FPN) for multi-scale
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Detection adds significant compute on top of backbone. Choose efficient backbone.
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 5: Semantic Segmentation
|
||||||
|
|
||||||
|
**Goal:** Dense pixel-wise prediction
|
||||||
|
|
||||||
|
**Recommendation:**
|
||||||
|
```
|
||||||
|
U-Net style:
|
||||||
|
→ ResNet-18/34 as encoder (fast)
|
||||||
|
→ EfficientNet-B0 as encoder (efficient)
|
||||||
|
|
||||||
|
DeepLabV3:
|
||||||
|
→ ResNet-50 (standard)
|
||||||
|
→ MobileNetV3 (mobile deployment)
|
||||||
|
|
||||||
|
Key: Segmentation requires multi-scale features
|
||||||
|
→ Ensure backbone has skip connections or FPN
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Trade-Off Analysis
|
||||||
|
|
||||||
|
### Accuracy vs Efficiency (Pareto Frontier)
|
||||||
|
|
||||||
|
**ImageNet Top-1 Accuracy vs FLOPs:**
|
||||||
|
|
||||||
|
```
|
||||||
|
Efficiency Winners (best accuracy per FLOP):
|
||||||
|
1. EfficientNet-B0: 77.3% @ 0.4 GFLOPs (best efficiency)
|
||||||
|
2. EfficientNet-B2: 80.3% @ 1.0 GFLOPs
|
||||||
|
3. EfficientNet-B4: 82.9% @ 4.2 GFLOPs
|
||||||
|
|
||||||
|
Accuracy Winners (best absolute accuracy):
|
||||||
|
1. EfficientNet-B7: 84.4% @ 37 GFLOPs
|
||||||
|
2. ViT-Large: 85.2% @ 190 GFLOPs (requires huge dataset)
|
||||||
|
3. ResNet-152: 78.3% @ 11.6 GFLOPs (dominated by EfficientNet)
|
||||||
|
|
||||||
|
Speed Winners (lowest latency):
|
||||||
|
1. MobileNetV3-Small: 67.4% @ 0.06 GFLOPs (50ms on mobile)
|
||||||
|
2. MobileNetV3-Large: 75.2% @ 0.2 GFLOPs (100ms on mobile)
|
||||||
|
3. EfficientNet-Lite0: 75.0% @ 0.4 GFLOPs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Takeaway:** EfficientNet dominates ResNet on Pareto frontier (better accuracy at same compute).
|
||||||
|
|
||||||
|
|
||||||
|
### Parameters vs Accuracy
|
||||||
|
|
||||||
|
**For same ~75% ImageNet accuracy:**
|
||||||
|
```
|
||||||
|
VGG-16: 138M params (❌ terrible efficiency)
|
||||||
|
ResNet-50: 25M params
|
||||||
|
EfficientNet-B0: 5M params (✅ 5x fewer parameters!)
|
||||||
|
MobileNetV3-Large: 5M params (fast inference)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Conclusion:** Modern architectures (EfficientNet, MobileNet) achieve same accuracy with far fewer parameters.
|
||||||
|
|
||||||
|
|
||||||
|
## Common Pitfalls
|
||||||
|
|
||||||
|
### Pitfall 1: Defaulting to ResNet-50
|
||||||
|
**Symptom:** Using ResNet-50 without considering alternatives
|
||||||
|
|
||||||
|
**Why it's wrong:** EfficientNet-B0 matches ResNet-50 accuracy with 10x less compute
|
||||||
|
|
||||||
|
**Fix:** Consider EfficientNet family first (better efficiency)
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 2: Choosing Large Model for Small Dataset
|
||||||
|
**Symptom:** Using ResNet-101 with < 10k images
|
||||||
|
|
||||||
|
**Why it's wrong:** Model will overfit (too many parameters for data)
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- < 10k images → ResNet-18 or EfficientNet-B0
|
||||||
|
- 10k-100k → ResNet-50 or EfficientNet-B2
|
||||||
|
- > 100k → Can use larger models
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 3: Using Desktop Model on Mobile
|
||||||
|
**Symptom:** Trying to run ResNet-50 on mobile device
|
||||||
|
|
||||||
|
**Why it's wrong:** 2000ms inference time is unusable
|
||||||
|
|
||||||
|
**Fix:** Use MobileNetV3 + quantization for mobile (15-30ms)
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 4: Ignoring Task Type
|
||||||
|
**Symptom:** Using standard CNN for object detection without FPN
|
||||||
|
|
||||||
|
**Why it's wrong:** Detection needs multi-scale features
|
||||||
|
|
||||||
|
**Fix:** Use detection-specific frameworks (YOLOv8, Faster R-CNN) with appropriate backbone
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 5: Believing "Bigger = Better"
|
||||||
|
**Symptom:** Choosing ResNet-152 over ResNet-50 without justification
|
||||||
|
|
||||||
|
**Why it's wrong:** Diminishing returns - 3x compute for 1.3% accuracy, will overfit on small data
|
||||||
|
|
||||||
|
**Fix:** Match model capacity to dataset size, consider efficiency
|
||||||
|
|
||||||
|
|
||||||
|
## Evolution and Historical Context
|
||||||
|
|
||||||
|
**Why CNNs evolved the way they did:**
|
||||||
|
|
||||||
|
```
|
||||||
|
2012: AlexNet
|
||||||
|
→ Proved deep learning works for vision
|
||||||
|
→ 8 layers, 60M params
|
||||||
|
|
||||||
|
2014: VGG
|
||||||
|
→ Deeper is better (16-19 layers)
|
||||||
|
→ But: 138M params (too many)
|
||||||
|
|
||||||
|
2014: Inception/GoogLeNet
|
||||||
|
→ Multi-scale convolutions
|
||||||
|
→ More efficient than VGG
|
||||||
|
|
||||||
|
2015: ResNet ★
|
||||||
|
→ Skip connections enable very deep networks (152 layers)
|
||||||
|
→ Solved vanishing gradient problem
|
||||||
|
→ Became standard baseline
|
||||||
|
|
||||||
|
2017: MobileNet
|
||||||
|
→ Mobile deployment needs
|
||||||
|
→ Depthwise separable convolutions (9x fewer ops)
|
||||||
|
|
||||||
|
2017: DenseNet
|
||||||
|
→ Dense connections for feature reuse
|
||||||
|
→ Parameter efficient but slow inference
|
||||||
|
|
||||||
|
2019: EfficientNet ★
|
||||||
|
→ Compound scaling (depth + width + resolution)
|
||||||
|
→ Neural architecture search
|
||||||
|
→ Dominates Pareto frontier (best accuracy per FLOP)
|
||||||
|
→ New standard for efficiency
|
||||||
|
|
||||||
|
2020: Vision Transformer
|
||||||
|
→ Attention-based (no convolutions)
|
||||||
|
→ Requires very large datasets (> 1M images)
|
||||||
|
→ For research/large-scale applications
|
||||||
|
```
|
||||||
|
|
||||||
|
**Current Recommendations (2025):**
|
||||||
|
- Cloud: **EfficientNet** (best efficiency) or ResNet (simplicity)
|
||||||
|
- Edge: **EfficientNet-Lite** or MobileNetV3
|
||||||
|
- Mobile: **MobileNetV3** + quantization
|
||||||
|
- Detection: **EfficientDet** or YOLOv8
|
||||||
|
- Baseline: **ResNet** (simple, well-tested)
|
||||||
|
|
||||||
|
|
||||||
|
## Decision Checklist
|
||||||
|
|
||||||
|
Before choosing CNN, answer these:
|
||||||
|
|
||||||
|
```
|
||||||
|
☐ Deployment target? (cloud/edge/mobile)
|
||||||
|
☐ Latency requirement? (< 10ms / 10-100ms / > 100ms)
|
||||||
|
☐ Model size budget? (< 10M / 10-50M / unlimited params)
|
||||||
|
☐ Dataset size? (< 10k / 10k-100k / > 100k images)
|
||||||
|
☐ Accuracy priority? (maximum / production / fast iteration)
|
||||||
|
☐ Task type? (classification / detection / segmentation)
|
||||||
|
☐ Efficiency matters? (yes → EfficientNet, no → flexibility)
|
||||||
|
|
||||||
|
Based on answers:
|
||||||
|
→ Mobile → MobileNetV3
|
||||||
|
→ Edge → EfficientNet-Lite or MobileNetV3
|
||||||
|
→ Cloud + efficiency → EfficientNet
|
||||||
|
→ Cloud + simplicity → ResNet
|
||||||
|
→ Maximum accuracy → EfficientNet-B7 or ViT
|
||||||
|
→ Small dataset → Small models (ResNet-18, EfficientNet-B0)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Integration with Other Skills
|
||||||
|
|
||||||
|
**After selecting CNN architecture:**
|
||||||
|
|
||||||
|
**Training the model:**
|
||||||
|
→ `yzmir/training-optimization/using-training-optimization`
|
||||||
|
- Optimizer selection (Adam, SGD, AdamW)
|
||||||
|
- Learning rate schedules
|
||||||
|
- Data augmentation
|
||||||
|
|
||||||
|
**Implementing in PyTorch:**
|
||||||
|
→ `yzmir/pytorch-engineering/using-pytorch-engineering`
|
||||||
|
- Custom modifications to pre-trained models
|
||||||
|
- Multi-GPU training
|
||||||
|
- Performance optimization
|
||||||
|
|
||||||
|
**Deploying to production:**
|
||||||
|
→ `yzmir/ml-production/using-ml-production`
|
||||||
|
- Quantization (INT8, FP16)
|
||||||
|
- Model serving (TorchServe, ONNX)
|
||||||
|
- Optimization for edge/mobile (TFLite, CoreML)
|
||||||
|
|
||||||
|
**If architecture is unstable (very deep):**
|
||||||
|
→ `yzmir/neural-architectures/normalization-techniques`
|
||||||
|
- Normalization layers (BatchNorm, LayerNorm)
|
||||||
|
- Skip connections
|
||||||
|
- Initialization strategies
|
||||||
|
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
**CNN Selection in One Table:**
|
||||||
|
|
||||||
|
| Scenario | Recommendation | Why |
|
||||||
|
|----------|----------------|-----|
|
||||||
|
| Cloud, balanced | EfficientNet-B2 | Best efficiency, 80% accuracy |
|
||||||
|
| Cloud, max accuracy | EfficientNet-B4 | 83% accuracy, reasonable compute |
|
||||||
|
| Cloud, simple baseline | ResNet-50 | Well-tested, widely used |
|
||||||
|
| Edge device | MobileNetV3-Large | Optimized for edge, 75% accuracy |
|
||||||
|
| Mobile app | MobileNetV3-Small + quantization | < 20ms inference |
|
||||||
|
| Small dataset (< 10k) | ResNet-18 or EfficientNet-B0 | Avoid overfitting |
|
||||||
|
| Object detection | ResNet-50 + FPN, EfficientDet | Multi-scale features |
|
||||||
|
| Segmentation | ResNet + U-Net, DeepLabV3 | Dense prediction |
|
||||||
|
|
||||||
|
**Key Principles:**
|
||||||
|
1. **Match model capacity to dataset size** (small data → small model)
|
||||||
|
2. **EfficientNet dominates ResNet on efficiency** (same accuracy, less compute)
|
||||||
|
3. **Mobile needs mobile-specific architectures** (MobileNet, quantization)
|
||||||
|
4. **Task type matters** (detection/segmentation need multi-scale features)
|
||||||
|
5. **Bigger ≠ always better** (diminishing returns, overfitting risk)
|
||||||
|
|
||||||
|
**When in doubt:** Start with **EfficientNet-B2** (cloud) or **MobileNetV3-Large** (edge/mobile).
|
||||||
|
|
||||||
|
|
||||||
|
**END OF SKILL**
|
||||||
811
skills/using-neural-architectures/generative-model-families.md
Normal file
811
skills/using-neural-architectures/generative-model-families.md
Normal file
@@ -0,0 +1,811 @@
|
|||||||
|
|
||||||
|
# Generative Model Families
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when you need to:
|
||||||
|
- ✅ Select generative model for image/audio/video generation
|
||||||
|
- ✅ Understand VAE vs GAN vs Diffusion trade-offs
|
||||||
|
- ✅ Decide between training from scratch vs fine-tuning
|
||||||
|
- ✅ Address mode collapse in GANs
|
||||||
|
- ✅ Choose between quality, speed, and training stability
|
||||||
|
- ✅ Understand modern landscape (Stable Diffusion, StyleGAN, etc.)
|
||||||
|
|
||||||
|
**Do NOT use this skill for:**
|
||||||
|
- ❌ Text generation (use `llm-specialist` pack)
|
||||||
|
- ❌ Architecture implementation details (use model-specific docs)
|
||||||
|
- ❌ High-level architecture selection (use `using-neural-architectures`)
|
||||||
|
|
||||||
|
|
||||||
|
## Core Principle
|
||||||
|
|
||||||
|
**Generative models have fundamental trade-offs:**
|
||||||
|
- **Quality vs Stability**: GANs sharp but unstable, VAEs blurry but stable
|
||||||
|
- **Quality vs Speed**: Diffusion high-quality but slow, GANs fast
|
||||||
|
- **Explicitness vs Flexibility**: Autoregressive/Flow have likelihood, GANs don't
|
||||||
|
|
||||||
|
**Modern default (2025):** Diffusion models (best quality + stability)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 1: Model Family Overview
|
||||||
|
|
||||||
|
### The Five Families
|
||||||
|
|
||||||
|
**1. VAE (Variational Autoencoder)**
|
||||||
|
- **Approach**: Learn latent space with encoder-decoder
|
||||||
|
- **Quality**: Blurry (6/10)
|
||||||
|
- **Training**: Very stable
|
||||||
|
- **Use**: Latent space exploration, NOT high-quality generation
|
||||||
|
|
||||||
|
**2. GAN (Generative Adversarial Network)**
|
||||||
|
- **Approach**: Adversarial game (generator vs discriminator)
|
||||||
|
- **Quality**: Sharp (9/10)
|
||||||
|
- **Training**: Unstable (adversarial dynamics)
|
||||||
|
- **Use**: High-quality generation, fast inference
|
||||||
|
|
||||||
|
**3. Diffusion Models**
|
||||||
|
- **Approach**: Iterative denoising
|
||||||
|
- **Quality**: Very sharp (9.5/10)
|
||||||
|
- **Training**: Stable
|
||||||
|
- **Use**: Modern default for high-quality generation
|
||||||
|
|
||||||
|
**4. Autoregressive Models**
|
||||||
|
- **Approach**: Sequential generation (pixel-by-pixel, token-by-token)
|
||||||
|
- **Quality**: Good (7-8/10)
|
||||||
|
- **Training**: Stable
|
||||||
|
- **Use**: Explicit likelihood, sequential data
|
||||||
|
|
||||||
|
**5. Flow Models**
|
||||||
|
- **Approach**: Invertible transformations
|
||||||
|
- **Quality**: Good (7-8/10)
|
||||||
|
- **Training**: Stable
|
||||||
|
- **Use**: Exact likelihood, invertibility needed
|
||||||
|
|
||||||
|
### Quick Comparison
|
||||||
|
|
||||||
|
| Model | Quality | Training Stability | Inference Speed | Mode Collapse | Likelihood |
|
||||||
|
|-------|---------|-------------------|----------------|---------------|------------|
|
||||||
|
| VAE | 6/10 (blurry) | 10/10 | Fast | No | Approximate |
|
||||||
|
| GAN | 9/10 | 3/10 | Fast | Yes | No |
|
||||||
|
| Diffusion | 9.5/10 | 9/10 | Slow | No | Approximate |
|
||||||
|
| Autoregressive | 7-8/10 | 9/10 | Very slow | No | Exact |
|
||||||
|
| Flow | 7-8/10 | 8/10 | Fast (both ways) | No | Exact |
|
||||||
|
|
||||||
|
|
||||||
|
## Part 2: VAE (Variational Autoencoder)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
**Components:**
|
||||||
|
1. **Encoder**: x → z (image to latent)
|
||||||
|
2. **Latent space**: z ~ N(μ, σ²)
|
||||||
|
3. **Decoder**: z → x' (latent to reconstruction)
|
||||||
|
|
||||||
|
**Loss function:**
|
||||||
|
```python
|
||||||
|
# ELBO (Evidence Lower Bound)
|
||||||
|
loss = reconstruction_loss + KL_divergence
|
||||||
|
|
||||||
|
# Reconstruction: How well decoder reconstructs input
|
||||||
|
reconstruction_loss = MSE(x, x_reconstructed)
|
||||||
|
|
||||||
|
# KL: How close latent is to standard normal
|
||||||
|
KL_divergence = KL(q(z|x) || p(z))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Why VAE is Blurry
|
||||||
|
|
||||||
|
**Problem**: MSE loss encourages pixel-wise averaging
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
- Dataset: Faces with both smiles and no smiles
|
||||||
|
- VAE learns: "Average face has half-smile blur"
|
||||||
|
- Result: Blurry, hedges between modes
|
||||||
|
|
||||||
|
**Mathematical reason:**
|
||||||
|
- MSE minimization = mean prediction
|
||||||
|
- Mean of sharp images = blurry image
|
||||||
|
|
||||||
|
### When to Use VAE
|
||||||
|
|
||||||
|
✅ **Use VAE for:**
|
||||||
|
- Latent space exploration (interpolation, arithmetic)
|
||||||
|
- Anomaly detection (reconstruction error)
|
||||||
|
- Disentangled representations (β-VAE)
|
||||||
|
- Compression (lossy, with latent codes)
|
||||||
|
|
||||||
|
❌ **DON'T use VAE for:**
|
||||||
|
- High-quality image generation (use GAN or Diffusion!)
|
||||||
|
- Sharp, realistic outputs
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
def __init__(self, latent_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
# Encoder
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 32, 4, 2, 1), # 64x64 -> 32x32
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)
|
||||||
|
self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
self.fc_decode = nn.Linear(latent_dim, 128 * 8 * 8)
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, 64, 4, 2, 1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(64, 32, 4, 2, 1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(32, 3, 4, 2, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def reparameterize(self, mu, logvar):
|
||||||
|
# Reparameterization trick: z = μ + σ * ε
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps * std
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Encode
|
||||||
|
h = self.encoder(x)
|
||||||
|
mu = self.fc_mu(h)
|
||||||
|
logvar = self.fc_logvar(h)
|
||||||
|
|
||||||
|
# Sample latent
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
h = self.fc_decode(z)
|
||||||
|
h = h.view(-1, 128, 8, 8)
|
||||||
|
x_recon = self.decoder(h)
|
||||||
|
|
||||||
|
return x_recon, mu, logvar
|
||||||
|
|
||||||
|
def loss_function(self, x, x_recon, mu, logvar):
|
||||||
|
# Reconstruction loss
|
||||||
|
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
|
||||||
|
|
||||||
|
# KL divergence
|
||||||
|
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
|
||||||
|
return recon_loss + kl_loss
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 3: GAN (Generative Adversarial Network)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
**Components:**
|
||||||
|
1. **Generator**: z → x (noise to image)
|
||||||
|
2. **Discriminator**: x → [0, 1] (image to real/fake probability)
|
||||||
|
|
||||||
|
**Adversarial Training:**
|
||||||
|
```python
|
||||||
|
# Discriminator loss: Classify real as real, fake as fake
|
||||||
|
D_loss = -log(D(x_real)) - log(1 - D(G(z)))
|
||||||
|
|
||||||
|
# Generator loss: Fool discriminator
|
||||||
|
G_loss = -log(D(G(z)))
|
||||||
|
|
||||||
|
# Minimax game:
|
||||||
|
min_G max_D V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Instability
|
||||||
|
|
||||||
|
**Problem**: Adversarial dynamics are unstable
|
||||||
|
|
||||||
|
**Common issues:**
|
||||||
|
1. **Mode collapse**: Generator produces limited variety
|
||||||
|
2. **Non-convergence**: Oscillation, never settles
|
||||||
|
3. **Vanishing gradients**: Discriminator too strong, generator can't learn
|
||||||
|
4. **Hyperparameter sensitivity**: Learning rates critical
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
- Spectral normalization (StyleGAN2)
|
||||||
|
- Progressive growing (start low-res, increase)
|
||||||
|
- Minibatch discrimination (penalize lack of diversity)
|
||||||
|
- Wasserstein loss (WGAN, more stable)
|
||||||
|
|
||||||
|
### Mode Collapse
|
||||||
|
|
||||||
|
**What is it?**
|
||||||
|
- Generator produces subset of distribution
|
||||||
|
- Example: Face GAN only generates 10 face types
|
||||||
|
|
||||||
|
**Why it happens:**
|
||||||
|
- Generator exploits discriminator weaknesses
|
||||||
|
- Finds "easy" samples that fool discriminator
|
||||||
|
- Forgets other modes
|
||||||
|
|
||||||
|
**Detection:**
|
||||||
|
```python
|
||||||
|
# Check diversity: Generate many samples
|
||||||
|
samples = generator.generate(n=1000)
|
||||||
|
diversity = compute_pairwise_distance(samples)
|
||||||
|
if diversity < threshold:
|
||||||
|
print("Mode collapse detected!")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
- Minibatch discrimination
|
||||||
|
- Unrolled GANs (slow but helps)
|
||||||
|
- Switch to diffusion (no mode collapse by design!)
|
||||||
|
|
||||||
|
### Modern GANs
|
||||||
|
|
||||||
|
**StyleGAN2 (2020):**
|
||||||
|
- State-of-the-art for faces
|
||||||
|
- Style-based generator
|
||||||
|
- Spectral normalization for stability
|
||||||
|
- Resolution: 1024×1024
|
||||||
|
|
||||||
|
**StyleGAN3 (2021):**
|
||||||
|
- Alias-free architecture
|
||||||
|
- Better animation/video
|
||||||
|
|
||||||
|
**When to use GAN:**
|
||||||
|
✅ Fast inference needed (50ms per image)
|
||||||
|
✅ Pretrained model available (StyleGAN2)
|
||||||
|
✅ Can tolerate training difficulty
|
||||||
|
|
||||||
|
❌ Training instability unacceptable
|
||||||
|
❌ Mode collapse problematic
|
||||||
|
❌ Starting from scratch (use diffusion instead)
|
||||||
|
|
||||||
|
### Implementation (Basic GAN)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, latent_dim=100, img_channels=3):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
nn.Linear(latent_dim, 128 * 8 * 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Unflatten(1, (128, 8, 8)),
|
||||||
|
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16x16 -> 32x32
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(32, img_channels, 4, 2, 1), # 32x32 -> 64x64
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
return self.model(z)
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
def __init__(self, img_channels=3):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
nn.Conv2d(img_channels, 32, 4, 2, 1), # 64x64 -> 32x32
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(128 * 8 * 8, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for real_images in dataloader:
|
||||||
|
# Train discriminator
|
||||||
|
fake_images = generator(noise)
|
||||||
|
D_real = discriminator(real_images)
|
||||||
|
D_fake = discriminator(fake_images.detach())
|
||||||
|
|
||||||
|
D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
|
||||||
|
D_loss.backward()
|
||||||
|
optimizer_D.step()
|
||||||
|
|
||||||
|
# Train generator
|
||||||
|
D_fake = discriminator(fake_images)
|
||||||
|
G_loss = -torch.mean(torch.log(D_fake))
|
||||||
|
G_loss.backward()
|
||||||
|
optimizer_G.step()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 4: Diffusion Models (Modern Default)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
**Concept**: Learn to reverse a diffusion (noising) process
|
||||||
|
|
||||||
|
**Forward process** (fixed):
|
||||||
|
```python
|
||||||
|
# Gradually add noise to image
|
||||||
|
x_0 (original) → x_1 → x_2 → ... → x_T (pure noise)
|
||||||
|
|
||||||
|
# At each step:
|
||||||
|
x_t = √(1 - β_t) * x_{t-1} + √β_t * ε
|
||||||
|
where ε ~ N(0, I), β_t = noise schedule
|
||||||
|
```
|
||||||
|
|
||||||
|
**Reverse process** (learned):
|
||||||
|
```python
|
||||||
|
# Model learns to denoise
|
||||||
|
x_T (noise) → x_{T-1} → ... → x_1 → x_0 (image)
|
||||||
|
|
||||||
|
# Model predicts: ε_θ(x_t, t)
|
||||||
|
# Then: x_{t-1} = (x_t - √β_t * ε_θ(x_t, t)) / √(1 - β_t)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Training:**
|
||||||
|
```python
|
||||||
|
# Simple loss: Predict the noise
|
||||||
|
loss = MSE(ε, ε_θ(x_t, t))
|
||||||
|
|
||||||
|
# x_t = noisy image at step t
|
||||||
|
# ε = actual noise added
|
||||||
|
# ε_θ(x_t, t) = model's noise prediction
|
||||||
|
```
|
||||||
|
|
||||||
|
### Why Diffusion is Excellent
|
||||||
|
|
||||||
|
**Advantages:**
|
||||||
|
1. **High quality**: State-of-the-art (better than GAN)
|
||||||
|
2. **Stable training**: Standard MSE loss (no adversarial dynamics)
|
||||||
|
3. **No mode collapse**: By design, covers full distribution
|
||||||
|
4. **Controllable**: Easy to add conditioning (text, class, etc.)
|
||||||
|
|
||||||
|
**Disadvantages:**
|
||||||
|
1. **Slow inference**: 50-1000 denoising steps (vs GAN's 1 step)
|
||||||
|
2. **Compute intensive**: T forward passes (T = 50-1000)
|
||||||
|
|
||||||
|
**Speed comparison:**
|
||||||
|
```
|
||||||
|
GAN: 1 forward pass = 50ms
|
||||||
|
Diffusion (T=50): 50 forward passes = 2.5 seconds
|
||||||
|
Diffusion (T=1000): 1000 forward passes = 50 seconds
|
||||||
|
```
|
||||||
|
|
||||||
|
**Speedup techniques:**
|
||||||
|
- DDIM (fewer steps, 10-50 instead of 1000)
|
||||||
|
- DPM-Solver (fast sampler)
|
||||||
|
- Latent diffusion (Stable Diffusion, denoise in latent space)
|
||||||
|
|
||||||
|
### Modern Diffusion Models
|
||||||
|
|
||||||
|
**Stable Diffusion (2022+):**
|
||||||
|
- Latent diffusion (denoise in VAE latent space)
|
||||||
|
- Text conditioning (CLIP text encoder)
|
||||||
|
- Pretrained on billions of images
|
||||||
|
- Fine-tunable
|
||||||
|
|
||||||
|
**DALL-E 2 (2022):**
|
||||||
|
- Prior network (text → image embedding)
|
||||||
|
- Diffusion decoder (embedding → image)
|
||||||
|
|
||||||
|
**Imagen (2022, Google):**
|
||||||
|
- Text conditioning with T5 encoder
|
||||||
|
- Cascaded diffusion (64×64 → 256×256 → 1024×1024)
|
||||||
|
|
||||||
|
**When to use Diffusion:**
|
||||||
|
✅ High-quality generation (best quality)
|
||||||
|
✅ Stable training (standard loss)
|
||||||
|
✅ Diversity needed (no mode collapse)
|
||||||
|
✅ Conditioning (text-to-image, class-conditional)
|
||||||
|
|
||||||
|
❌ Need fast inference (< 1 second)
|
||||||
|
❌ Real-time generation
|
||||||
|
|
||||||
|
### Implementation (DDPM)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DiffusionModel(nn.Module):
|
||||||
|
def __init__(self, img_channels=3):
|
||||||
|
super().__init__()
|
||||||
|
# U-Net architecture
|
||||||
|
self.model = UNet(
|
||||||
|
in_channels=img_channels,
|
||||||
|
out_channels=img_channels,
|
||||||
|
time_embedding_dim=256
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x_t, t):
|
||||||
|
# Predict noise ε at timestep t
|
||||||
|
return self.model(x_t, t)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
def train_step(model, x_0):
|
||||||
|
# Sample random timestep
|
||||||
|
t = torch.randint(0, T, (batch_size,))
|
||||||
|
|
||||||
|
# Sample noise
|
||||||
|
ε = torch.randn_like(x_0)
|
||||||
|
|
||||||
|
# Create noisy image x_t
|
||||||
|
α_t = alpha_schedule[t]
|
||||||
|
x_t = torch.sqrt(α_t) * x_0 + torch.sqrt(1 - α_t) * ε
|
||||||
|
|
||||||
|
# Predict noise
|
||||||
|
ε_pred = model(x_t, t)
|
||||||
|
|
||||||
|
# Loss: MSE between actual and predicted noise
|
||||||
|
loss = F.mse_loss(ε_pred, ε)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# Sampling (generation)
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(model, shape):
|
||||||
|
# Start from pure noise
|
||||||
|
x_t = torch.randn(shape)
|
||||||
|
|
||||||
|
# Iteratively denoise
|
||||||
|
for t in reversed(range(T)):
|
||||||
|
# Predict noise
|
||||||
|
ε_pred = model(x_t, t)
|
||||||
|
|
||||||
|
# Denoise one step
|
||||||
|
α_t = alpha_schedule[t]
|
||||||
|
x_t = (x_t - (1 - α_t) / torch.sqrt(1 - ᾱ_t) * ε_pred) / torch.sqrt(α_t)
|
||||||
|
|
||||||
|
# Add noise (except last step)
|
||||||
|
if t > 0:
|
||||||
|
x_t += torch.sqrt(β_t) * torch.randn_like(x_t)
|
||||||
|
|
||||||
|
return x_t # x_0 (generated image)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 5: Autoregressive Models
|
||||||
|
|
||||||
|
### Concept
|
||||||
|
|
||||||
|
**Idea**: Model probability as product of conditionals
|
||||||
|
```
|
||||||
|
p(x) = p(x_1) * p(x_2|x_1) * p(x_3|x_1,x_2) * ... * p(x_n|x_1,...,x_{n-1})
|
||||||
|
```
|
||||||
|
|
||||||
|
**For images**: Generate pixel-by-pixel (or patch-by-patch)
|
||||||
|
|
||||||
|
**Architectures:**
|
||||||
|
- **PixelCNN**: Convolutional with masked kernels
|
||||||
|
- **PixelCNN++**: Improved with mixture of logistics
|
||||||
|
- **VQ-VAE + PixelCNN**: Two-stage (learn discrete codes, model codes)
|
||||||
|
- **ImageGPT**: GPT-style Transformer for images
|
||||||
|
|
||||||
|
### Advantages
|
||||||
|
|
||||||
|
✅ **Explicit likelihood**: Can compute p(x) exactly
|
||||||
|
✅ **Stable training**: Standard cross-entropy loss
|
||||||
|
✅ **Theoretical guarantees**: Proper probability model
|
||||||
|
|
||||||
|
### Disadvantages
|
||||||
|
|
||||||
|
❌ **Very slow generation**: Sequential (can't parallelize)
|
||||||
|
❌ **Limited quality**: Worse than GAN/Diffusion for high-res
|
||||||
|
❌ **Resolution scaling**: Impractical for 1024×1024 (1M pixels!)
|
||||||
|
|
||||||
|
**Speed comparison:**
|
||||||
|
```
|
||||||
|
GAN: Generate 1024×1024 in 50ms (parallel)
|
||||||
|
PixelCNN: Generate 32×32 in 5 seconds (sequential!)
|
||||||
|
ImageGPT: Generate 256×256 in 30 seconds
|
||||||
|
|
||||||
|
For 1024×1024: 1M pixels × 5ms/pixel = 83 minutes!
|
||||||
|
```
|
||||||
|
|
||||||
|
### When to Use
|
||||||
|
|
||||||
|
✅ **Use autoregressive for:**
|
||||||
|
- Explicit likelihood needed (compression, evaluation)
|
||||||
|
- Small images (32×32, 64×64)
|
||||||
|
- Two-stage models (VQ-VAE + Transformer)
|
||||||
|
|
||||||
|
❌ **Don't use for:**
|
||||||
|
- High-resolution images (too slow)
|
||||||
|
- Real-time generation
|
||||||
|
- Quality-critical applications (use diffusion)
|
||||||
|
|
||||||
|
### Modern Usage
|
||||||
|
|
||||||
|
**Two-stage approach (DALL-E, VQ-GAN):**
|
||||||
|
1. **Stage 1**: VQ-VAE learns discrete codes
|
||||||
|
- Image → 32×32 grid of codes (instead of 1M pixels)
|
||||||
|
2. **Stage 2**: Autoregressive model (Transformer) on codes
|
||||||
|
- Much faster (32×32 = 1024 codes, not 1M pixels)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 6: Flow Models
|
||||||
|
|
||||||
|
### Concept
|
||||||
|
|
||||||
|
**Idea**: Invertible transformations
|
||||||
|
```
|
||||||
|
z ~ N(0, I) ←→ x ~ p_data
|
||||||
|
|
||||||
|
f: z → x (forward)
|
||||||
|
f⁻¹: x → z (inverse)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Requirement**: f must be invertible and differentiable
|
||||||
|
|
||||||
|
**Advantage**: Exact likelihood via change-of-variables
|
||||||
|
```
|
||||||
|
log p(x) = log p(z) + log |det(∂f⁻¹/∂x)|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architectures
|
||||||
|
|
||||||
|
**RealNVP (2017):**
|
||||||
|
- Coupling layers (affine transformations)
|
||||||
|
- Invertible by design
|
||||||
|
|
||||||
|
**Glow (2018, OpenAI):**
|
||||||
|
- Actnorm, invertible 1×1 convolutions
|
||||||
|
- Multi-scale architecture
|
||||||
|
|
||||||
|
**When to use Flow:**
|
||||||
|
✅ Exact likelihood needed (better than VAE)
|
||||||
|
✅ Invertibility needed (both z→x and x→z)
|
||||||
|
✅ Stable training (standard loss)
|
||||||
|
|
||||||
|
❌ Architecture constraints (must be invertible)
|
||||||
|
❌ Quality not as good as GAN/Diffusion
|
||||||
|
|
||||||
|
### Modern Status
|
||||||
|
|
||||||
|
**Mostly superseded by Diffusion:**
|
||||||
|
- Diffusion has better quality
|
||||||
|
- Diffusion more flexible (no invertibility constraint)
|
||||||
|
- Flow models still used in specialized applications
|
||||||
|
|
||||||
|
|
||||||
|
## Part 7: Decision Framework
|
||||||
|
|
||||||
|
### By Primary Goal
|
||||||
|
|
||||||
|
```
|
||||||
|
Goal: High-quality images
|
||||||
|
→ Diffusion (modern default)
|
||||||
|
→ OR GAN if pretrained available
|
||||||
|
|
||||||
|
Goal: Fast inference
|
||||||
|
→ GAN (50ms per image)
|
||||||
|
→ Avoid Diffusion (too slow for real-time)
|
||||||
|
|
||||||
|
Goal: Training stability
|
||||||
|
→ Diffusion or VAE (standard loss)
|
||||||
|
→ Avoid GAN (adversarial training hard)
|
||||||
|
|
||||||
|
Goal: Latent space exploration
|
||||||
|
→ VAE (smooth interpolation)
|
||||||
|
→ Avoid GAN (no encoder)
|
||||||
|
|
||||||
|
Goal: Explicit likelihood
|
||||||
|
→ Autoregressive or Flow
|
||||||
|
→ For evaluation, compression
|
||||||
|
|
||||||
|
Goal: Diversity (no mode collapse)
|
||||||
|
→ Diffusion (by design)
|
||||||
|
→ OR VAE (stable)
|
||||||
|
→ Avoid GAN (mode collapse common)
|
||||||
|
```
|
||||||
|
|
||||||
|
### By Data Type
|
||||||
|
|
||||||
|
```
|
||||||
|
Images (high-quality):
|
||||||
|
→ Diffusion (Stable Diffusion)
|
||||||
|
→ OR GAN (StyleGAN2)
|
||||||
|
|
||||||
|
Images (small, 32×32):
|
||||||
|
→ Any model works
|
||||||
|
→ Try VAE first (simplest)
|
||||||
|
|
||||||
|
Audio waveforms:
|
||||||
|
→ WaveGAN
|
||||||
|
→ OR Diffusion (WaveGrad)
|
||||||
|
|
||||||
|
Video:
|
||||||
|
→ Video Diffusion (limited)
|
||||||
|
→ OR GAN (StyleGAN-V)
|
||||||
|
|
||||||
|
Text:
|
||||||
|
→ Autoregressive (GPT)
|
||||||
|
→ NOT VAE/GAN/Diffusion (discrete tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
### By Training Budget
|
||||||
|
|
||||||
|
```
|
||||||
|
Large budget (millions $, pretrain from scratch):
|
||||||
|
→ Diffusion (Stable Diffusion scale)
|
||||||
|
→ Billions of images, weeks on cluster
|
||||||
|
|
||||||
|
Medium budget (thousands $, train from scratch):
|
||||||
|
→ GAN or Diffusion
|
||||||
|
→ 10k-1M images, days on GPU
|
||||||
|
|
||||||
|
Small budget (hundreds $, fine-tune):
|
||||||
|
→ Fine-tune Stable Diffusion (LoRA)
|
||||||
|
→ 1k-10k images, hours on consumer GPU
|
||||||
|
|
||||||
|
Tiny budget (research, small scale):
|
||||||
|
→ VAE (simplest, most stable)
|
||||||
|
→ Few thousand images, CPU possible
|
||||||
|
```
|
||||||
|
|
||||||
|
### Modern Recommendations (2025)
|
||||||
|
|
||||||
|
**For new projects:**
|
||||||
|
1. **Default: Diffusion**
|
||||||
|
- Fine-tune Stable Diffusion or train from scratch
|
||||||
|
- Best quality + stability
|
||||||
|
|
||||||
|
2. **If need speed: GAN**
|
||||||
|
- Use pretrained StyleGAN2 if available
|
||||||
|
- Or train GAN (if can tolerate instability)
|
||||||
|
|
||||||
|
3. **If need latent space: VAE**
|
||||||
|
- For interpolation, not generation quality
|
||||||
|
|
||||||
|
**AVOID:**
|
||||||
|
- Training GAN from scratch (unless necessary)
|
||||||
|
- Using VAE for high-quality generation
|
||||||
|
- Autoregressive for high-res images
|
||||||
|
|
||||||
|
|
||||||
|
## Part 8: Training from Scratch vs Fine-Tuning
|
||||||
|
|
||||||
|
### Stable Diffusion Example
|
||||||
|
|
||||||
|
**Pretraining (what Stability AI did):**
|
||||||
|
- Dataset: LAION-5B (5 billion images)
|
||||||
|
- Compute: 150,000 A100 GPU hours
|
||||||
|
- Cost: ~$600,000
|
||||||
|
- Time: Weeks on massive cluster
|
||||||
|
- **DON'T DO THIS!**
|
||||||
|
|
||||||
|
**Fine-tuning (what users do):**
|
||||||
|
- Dataset: 10k-100k domain images
|
||||||
|
- Compute: 100-1000 GPU hours
|
||||||
|
- Cost: $100-1,000
|
||||||
|
- Time: Days on single A100
|
||||||
|
- **DO THIS!**
|
||||||
|
|
||||||
|
**LoRA (Low-Rank Adaptation):**
|
||||||
|
- Efficient fine-tuning (fewer parameters)
|
||||||
|
- Dataset: 1k-5k images
|
||||||
|
- Compute: 10-100 GPU hours
|
||||||
|
- Cost: $10-100
|
||||||
|
- Time: Hours on consumer GPU (RTX 3090)
|
||||||
|
- **Best for small budgets!**
|
||||||
|
|
||||||
|
### Decision
|
||||||
|
|
||||||
|
```
|
||||||
|
Have pretrained model in your domain:
|
||||||
|
→ Fine-tune (don't retrain!)
|
||||||
|
|
||||||
|
No pretrained model:
|
||||||
|
→ Train from scratch (small model)
|
||||||
|
→ OR find closest pretrained and fine-tune
|
||||||
|
|
||||||
|
Budget < $1000:
|
||||||
|
→ LoRA fine-tuning
|
||||||
|
→ OR train small model (64×64)
|
||||||
|
|
||||||
|
Budget < $100:
|
||||||
|
→ LoRA with free Colab
|
||||||
|
→ OR VAE from scratch (cheap)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 9: Common Mistakes
|
||||||
|
|
||||||
|
### Mistake 1: VAE for High-Quality Generation
|
||||||
|
|
||||||
|
**Symptom:** Blurry outputs
|
||||||
|
**Fix:** Use GAN or Diffusion for quality
|
||||||
|
**VAE is for:** Latent space, not generation
|
||||||
|
|
||||||
|
### Mistake 2: Ignoring Mode Collapse
|
||||||
|
|
||||||
|
**Symptom:** GAN generates same images
|
||||||
|
**Fix:** Spectral norm, minibatch discrimination
|
||||||
|
**Better:** Switch to Diffusion (no mode collapse)
|
||||||
|
|
||||||
|
### Mistake 3: Training Stable Diffusion from Scratch
|
||||||
|
|
||||||
|
**Symptom:** Burning money, poor results
|
||||||
|
**Fix:** Fine-tune pretrained model
|
||||||
|
**Reality:** Pretraining costs $600k+
|
||||||
|
|
||||||
|
### Mistake 4: Slow Inference with Diffusion
|
||||||
|
|
||||||
|
**Symptom:** 50 seconds per image
|
||||||
|
**Fix:** Use DDIM (fewer steps, 10-50)
|
||||||
|
**OR:** Use GAN if speed critical
|
||||||
|
|
||||||
|
### Mistake 5: Wrong Loss for GAN
|
||||||
|
|
||||||
|
**Symptom:** Training diverges
|
||||||
|
**Fix:** Use Wasserstein loss (WGAN)
|
||||||
|
**OR:** Spectral normalization
|
||||||
|
**Better:** Switch to Diffusion (standard loss)
|
||||||
|
|
||||||
|
|
||||||
|
## Summary: Quick Reference
|
||||||
|
|
||||||
|
### Model Selection
|
||||||
|
|
||||||
|
```
|
||||||
|
High quality + stable training:
|
||||||
|
→ Diffusion (modern default)
|
||||||
|
|
||||||
|
Fast inference required:
|
||||||
|
→ GAN (if pretrained) or trained GAN
|
||||||
|
|
||||||
|
Latent space exploration:
|
||||||
|
→ VAE
|
||||||
|
|
||||||
|
Explicit likelihood:
|
||||||
|
→ Autoregressive or Flow
|
||||||
|
|
||||||
|
Small images (< 64×64):
|
||||||
|
→ Any model (start with VAE)
|
||||||
|
|
||||||
|
Large images (> 256×256):
|
||||||
|
→ Diffusion or GAN (avoid autoregressive)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quality Ranking
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Diffusion (9.5/10)
|
||||||
|
2. GAN (9/10)
|
||||||
|
3. Autoregressive (7-8/10)
|
||||||
|
4. Flow (7-8/10)
|
||||||
|
5. VAE (6/10 - blurry)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Stability Ranking
|
||||||
|
|
||||||
|
```
|
||||||
|
1. VAE (10/10)
|
||||||
|
2. Diffusion (9/10)
|
||||||
|
3. Autoregressive (9/10)
|
||||||
|
4. Flow (8/10)
|
||||||
|
5. GAN (3/10 - very unstable)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Modern Stack (2025)
|
||||||
|
|
||||||
|
```
|
||||||
|
Image generation: Stable Diffusion (fine-tuned)
|
||||||
|
Fast inference: StyleGAN2 (if available)
|
||||||
|
Latent space: VAE
|
||||||
|
Research: Diffusion (easiest to train)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
After mastering this skill:
|
||||||
|
- `llm-specialist/llm-finetuning-strategies`: Apply to text generation
|
||||||
|
- `architecture-design-principles`: Understand design trade-offs
|
||||||
|
- `training-optimization`: Optimize training for your chosen model
|
||||||
|
|
||||||
|
**Remember:** Diffusion models dominate in 2025. Use them unless you have specific reason not to (speed, latent space, likelihood).
|
||||||
@@ -0,0 +1,625 @@
|
|||||||
|
|
||||||
|
# Graph Neural Networks Basics
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when you need to:
|
||||||
|
- ✅ Work with graph-structured data (molecules, social networks, citations)
|
||||||
|
- ✅ Understand why CNN/RNN don't work on graphs
|
||||||
|
- ✅ Learn message passing framework
|
||||||
|
- ✅ Choose between GCN, GraphSAGE, GAT
|
||||||
|
- ✅ Decide if GNN is appropriate (vs simple model)
|
||||||
|
- ✅ Implement permutation-invariant aggregations
|
||||||
|
|
||||||
|
**Do NOT use this skill for:**
|
||||||
|
- ❌ Sequential data (use RNN/Transformer)
|
||||||
|
- ❌ Grid data (use CNN)
|
||||||
|
- ❌ High-level architecture selection (use `using-neural-architectures`)
|
||||||
|
|
||||||
|
|
||||||
|
## Core Principle
|
||||||
|
|
||||||
|
**Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work.
|
||||||
|
|
||||||
|
**GNN solution:** Message passing
|
||||||
|
- Nodes aggregate information from neighbors
|
||||||
|
- Multiple layers = multi-hop neighborhoods
|
||||||
|
- Permutation invariant (order doesn't matter)
|
||||||
|
|
||||||
|
**Critical question:** Does graph structure actually help? (Test: Compare with/without edges)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 1: Why GNN (Not CNN/RNN)
|
||||||
|
|
||||||
|
### Problem: Graph Structure
|
||||||
|
|
||||||
|
**Graph components:**
|
||||||
|
- **Nodes**: Entities (atoms, users, papers)
|
||||||
|
- **Edges**: Relationships (bonds, friendships, citations)
|
||||||
|
- **Features**: Node/edge attributes
|
||||||
|
|
||||||
|
**Key property:** Irregular structure
|
||||||
|
- Each node has variable number of neighbors
|
||||||
|
- No fixed spatial arrangement
|
||||||
|
- Permutation invariant (node order doesn't matter)
|
||||||
|
|
||||||
|
### Why CNN Doesn't Work
|
||||||
|
|
||||||
|
**CNN assumption:** Regular grid structure
|
||||||
|
|
||||||
|
**Example:** Image (2D grid)
|
||||||
|
```
|
||||||
|
Every pixel has exactly 8 neighbors:
|
||||||
|
[■][■][■]
|
||||||
|
[■][X][■] ← Center pixel has 8 neighbors (fixed!)
|
||||||
|
[■][■][■]
|
||||||
|
|
||||||
|
CNN kernel: 3×3 (fixed size, fixed positions)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Graph reality:** Irregular neighborhoods
|
||||||
|
```
|
||||||
|
Node A: 2 neighbors (H, C)
|
||||||
|
Node B: 4 neighbors (C, C, C, H)
|
||||||
|
Node C: 1 neighbor (H)
|
||||||
|
|
||||||
|
No fixed kernel size or position!
|
||||||
|
```
|
||||||
|
|
||||||
|
**CNN limitations:**
|
||||||
|
- Requires fixed-size neighborhoods → Graphs have variable-size
|
||||||
|
- Assumes spatial locality → Graphs have arbitrary connectivity
|
||||||
|
- Depends on node ordering → Should be permutation invariant
|
||||||
|
|
||||||
|
### Why RNN Doesn't Work
|
||||||
|
|
||||||
|
**RNN assumption:** Sequential structure
|
||||||
|
|
||||||
|
**Example:** Text (1D sequence)
|
||||||
|
```
|
||||||
|
"The cat sat" → [The] → [cat] → [sat]
|
||||||
|
Clear sequential order, temporal dependencies
|
||||||
|
```
|
||||||
|
|
||||||
|
**Graph reality:** No inherent sequence
|
||||||
|
```
|
||||||
|
Social network:
|
||||||
|
A — B — C
|
||||||
|
| |
|
||||||
|
D ——————E
|
||||||
|
|
||||||
|
What's the "sequence"? A→B→C? A→D→E? No natural ordering!
|
||||||
|
```
|
||||||
|
|
||||||
|
**RNN limitations:**
|
||||||
|
- Requires sequential order → Graphs have no natural order
|
||||||
|
- Processes one element at a time → Graphs have parallel connections
|
||||||
|
- Order-dependent → Should be permutation invariant
|
||||||
|
|
||||||
|
### GNN Solution
|
||||||
|
|
||||||
|
**Key innovation:** Message passing on graph structure
|
||||||
|
- Operate directly on nodes and edges
|
||||||
|
- Variable-size neighborhoods (handled naturally)
|
||||||
|
- Permutation invariant aggregations
|
||||||
|
|
||||||
|
|
||||||
|
## Part 2: Message Passing Framework
|
||||||
|
|
||||||
|
### Core Mechanism
|
||||||
|
|
||||||
|
**Message passing in 3 steps:**
|
||||||
|
|
||||||
|
**1. Aggregate neighbor messages**
|
||||||
|
```python
|
||||||
|
# Node i aggregates from neighbors N(i)
|
||||||
|
messages = [h_j for j in neighbors(i)]
|
||||||
|
aggregated = aggregate(messages) # e.g., mean, sum, max
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Update node representation**
|
||||||
|
```python
|
||||||
|
# Combine own features with aggregated messages
|
||||||
|
h_i_new = update(h_i_old, aggregated) # e.g., neural network
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Repeat for L layers**
|
||||||
|
- Layer 1: Node sees 1-hop neighbors
|
||||||
|
- Layer 2: Node sees 2-hop neighbors
|
||||||
|
- Layer L: Node sees L-hop neighborhood
|
||||||
|
|
||||||
|
### Concrete Example: Social Network
|
||||||
|
|
||||||
|
**Task:** Predict user interests
|
||||||
|
|
||||||
|
**Graph:**
|
||||||
|
```
|
||||||
|
B (sports)
|
||||||
|
|
|
||||||
|
A ---+--- C (cooking)
|
||||||
|
|
|
||||||
|
D (music)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Layer 1: 1-hop neighbors**
|
||||||
|
```python
|
||||||
|
# Node A aggregates from direct friends
|
||||||
|
h_A_layer1 = update(
|
||||||
|
h_A,
|
||||||
|
aggregate([h_B, h_C, h_D])
|
||||||
|
)
|
||||||
|
# Now h_A includes friend interests
|
||||||
|
```
|
||||||
|
|
||||||
|
**Layer 2: 2-hop neighbors (friends of friends)**
|
||||||
|
```python
|
||||||
|
# B's friends: E, F
|
||||||
|
# C's friends: G, H
|
||||||
|
# D's friends: I
|
||||||
|
|
||||||
|
h_A_layer2 = update(
|
||||||
|
h_A_layer1,
|
||||||
|
aggregate([h_B', h_C', h_D']) # h_B' includes E, F
|
||||||
|
)
|
||||||
|
# Now h_A includes friends-of-friends!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key insight:** More layers = larger receptive field (L-hop neighborhood)
|
||||||
|
|
||||||
|
### Permutation Invariance
|
||||||
|
|
||||||
|
**Critical property:** Same graph → same output (regardless of node ordering)
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
Graph: A-B, B-C
|
||||||
|
|
||||||
|
Node list 1: [A, B, C]
|
||||||
|
Node list 2: [C, B, A]
|
||||||
|
|
||||||
|
Output MUST be identical! (Same graph, different ordering)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Invariant aggregations:**
|
||||||
|
- ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])`
|
||||||
|
- ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])`
|
||||||
|
- ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])`
|
||||||
|
|
||||||
|
**NOT invariant:**
|
||||||
|
- ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])`
|
||||||
|
- ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]`
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
# CORRECT: Permutation invariant
|
||||||
|
def aggregate(neighbor_features):
|
||||||
|
return torch.mean(neighbor_features, dim=0)
|
||||||
|
|
||||||
|
# WRONG: Order-dependent!
|
||||||
|
def aggregate(neighbor_features):
|
||||||
|
return LSTM(neighbor_features) # Output depends on order
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 3: GNN Architectures
|
||||||
|
|
||||||
|
### Architecture 1: GCN (Graph Convolutional Network)
|
||||||
|
|
||||||
|
**Key idea:** Spectral convolution on graphs (simplified)
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
h_i^(l+1) = σ(∑_{j∈N(i)} W^(l) h_j^(l) / √(|N(i)| |N(j)|))
|
||||||
|
|
||||||
|
# Normalize by degree (√(deg(i) * deg(j)))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Aggregation:** Weighted mean (degree-normalized)
|
||||||
|
|
||||||
|
**Properties:**
|
||||||
|
- Transductive (needs full graph at training)
|
||||||
|
- Computationally efficient
|
||||||
|
- Good baseline
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Full graph available at training time
|
||||||
|
- Starting point (simplest GNN)
|
||||||
|
- Small to medium graphs
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
from torch_geometric.nn import GCNConv
|
||||||
|
|
||||||
|
class GCN(nn.Module):
|
||||||
|
def __init__(self, in_channels, hidden_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GCNConv(in_channels, hidden_channels)
|
||||||
|
self.conv2 = GCNConv(hidden_channels, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
# x: Node features (N, in_channels)
|
||||||
|
# edge_index: Graph connectivity (2, E)
|
||||||
|
|
||||||
|
# Layer 1
|
||||||
|
x = self.conv1(x, edge_index)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = F.dropout(x, p=0.5, training=self.training)
|
||||||
|
|
||||||
|
# Layer 2
|
||||||
|
x = self.conv2(x, edge_index)
|
||||||
|
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture 2: GraphSAGE
|
||||||
|
|
||||||
|
**Key idea:** Sample and aggregate (inductive learning)
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
# Sample fixed-size neighborhood
|
||||||
|
neighbors_sampled = sample(neighbors(i), k=10)
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
h_N = aggregate({h_j for j in neighbors_sampled})
|
||||||
|
|
||||||
|
# Concatenate and transform
|
||||||
|
h_i^(l+1) = σ(W^(l) [h_i^(l); h_N])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance)
|
||||||
|
|
||||||
|
**Key innovation:** Sampling
|
||||||
|
- Sample fixed number of neighbors (e.g., 10)
|
||||||
|
- Makes computation tractable for large graphs
|
||||||
|
- Enables inductive learning (generalizes to unseen nodes)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Large graphs (millions of nodes)
|
||||||
|
- Need inductive capability (new nodes appear)
|
||||||
|
- Training on subset, testing on full graph
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
from torch_geometric.nn import SAGEConv
|
||||||
|
|
||||||
|
class GraphSAGE(nn.Module):
|
||||||
|
def __init__(self, in_channels, hidden_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = SAGEConv(in_channels, hidden_channels)
|
||||||
|
self.conv2 = SAGEConv(hidden_channels, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
x = self.conv1(x, edge_index)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = F.dropout(x, p=0.5, training=self.training)
|
||||||
|
x = self.conv2(x, edge_index)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture 3: GAT (Graph Attention Network)
|
||||||
|
|
||||||
|
**Key idea:** Learn attention weights for neighbors
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
# Attention scores
|
||||||
|
α_ij = attention(h_i, h_j) # How important is neighbor j to node i?
|
||||||
|
|
||||||
|
# Normalize (softmax)
|
||||||
|
α_ij = softmax_j(α_ij)
|
||||||
|
|
||||||
|
# Weighted aggregation
|
||||||
|
h_i^(l+1) = σ(∑_{j∈N(i)} α_ij W h_j^(l))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key innovation:** Learned neighbor importance
|
||||||
|
- Not all neighbors equally important
|
||||||
|
- Attention mechanism decides weights
|
||||||
|
- Multi-head attention (like Transformer)
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Neighbors have varying importance
|
||||||
|
- Need interpretability (attention weights)
|
||||||
|
- Have sufficient data (attention needs more data to learn)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
from torch_geometric.nn import GATConv
|
||||||
|
|
||||||
|
class GAT(nn.Module):
|
||||||
|
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
||||||
|
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
x = self.conv1(x, edge_index)
|
||||||
|
x = F.elu(x)
|
||||||
|
x = F.dropout(x, p=0.6, training=self.training)
|
||||||
|
x = self.conv2(x, edge_index)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Comparison
|
||||||
|
|
||||||
|
| Feature | GCN | GraphSAGE | GAT |
|
||||||
|
|---------|-----|-----------|-----|
|
||||||
|
| Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted |
|
||||||
|
| Neighbor weighting | Fixed (by degree) | Equal | Learned |
|
||||||
|
| Inductive | No | Yes | Yes |
|
||||||
|
| Scalability | Medium | High (sampling) | Medium |
|
||||||
|
| Interpretability | Low | Low | High (attention) |
|
||||||
|
| Complexity | Low | Medium | High |
|
||||||
|
|
||||||
|
### Decision Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
Starting out / Small graph:
|
||||||
|
→ GCN (simplest baseline)
|
||||||
|
|
||||||
|
Large graph (millions of nodes):
|
||||||
|
→ GraphSAGE (sampling enables scalability)
|
||||||
|
|
||||||
|
Need inductive learning (new nodes):
|
||||||
|
→ GraphSAGE or GAT
|
||||||
|
|
||||||
|
Neighbors have different importance:
|
||||||
|
→ GAT (attention learns importance)
|
||||||
|
|
||||||
|
Need interpretability:
|
||||||
|
→ GAT (attention weights explain predictions)
|
||||||
|
|
||||||
|
Production deployment:
|
||||||
|
→ GraphSAGE (most robust and scalable)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 4: When NOT to Use GNN
|
||||||
|
|
||||||
|
### Critical Question
|
||||||
|
|
||||||
|
**Does graph structure actually help?**
|
||||||
|
|
||||||
|
**Test:** Compare model with and without edges
|
||||||
|
```python
|
||||||
|
# Baseline: MLP on node features only
|
||||||
|
mlp_accuracy = train_mlp(node_features, labels)
|
||||||
|
|
||||||
|
# GNN: Use node features + graph structure
|
||||||
|
gnn_accuracy = train_gnn(node_features, edges, labels)
|
||||||
|
|
||||||
|
# Decision:
|
||||||
|
if gnn_accuracy - mlp_accuracy < 2%:
|
||||||
|
print("Graph structure doesn't help much")
|
||||||
|
print("Use simpler model (MLP or XGBoost)")
|
||||||
|
else:
|
||||||
|
print("Graph structure adds value")
|
||||||
|
print("Use GNN")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scenarios Where GNN Doesn't Help
|
||||||
|
|
||||||
|
**1. Node features dominate**
|
||||||
|
```
|
||||||
|
User churn prediction:
|
||||||
|
- Node features: Usage hours, demographics, subscription → Highly predictive
|
||||||
|
- Graph edges: Sparse user interactions → Weak signal
|
||||||
|
- Result: MLP 85%, GNN 86% (not worth complexity!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Sparse graphs**
|
||||||
|
```
|
||||||
|
Graph with 1000 nodes, 100 edges (0.01% density):
|
||||||
|
- Most nodes have 0-1 neighbors
|
||||||
|
- No information to aggregate
|
||||||
|
- GNN reduces to MLP
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Random graph structure**
|
||||||
|
```
|
||||||
|
If edges are random (no homophily):
|
||||||
|
- Neighbor labels uncorrelated
|
||||||
|
- Aggregation adds noise
|
||||||
|
- Simple model better
|
||||||
|
```
|
||||||
|
|
||||||
|
### When GNN DOES Help
|
||||||
|
|
||||||
|
✅ **Molecular property prediction**
|
||||||
|
- Structure is PRIMARY signal
|
||||||
|
- Atom types + bonds determine properties
|
||||||
|
- GNN: Huge improvement over fingerprints
|
||||||
|
|
||||||
|
✅ **Citation networks**
|
||||||
|
- Paper quality correlated with neighbors
|
||||||
|
- "You are what you cite"
|
||||||
|
- Clear homophily
|
||||||
|
|
||||||
|
✅ **Social recommendation**
|
||||||
|
- Friends have similar preferences
|
||||||
|
- Graph structure informative
|
||||||
|
- GNN: Moderate to large improvement
|
||||||
|
|
||||||
|
✅ **Knowledge graphs**
|
||||||
|
- Entities connected by relations
|
||||||
|
- Multi-hop reasoning valuable
|
||||||
|
- GNN captures complex patterns
|
||||||
|
|
||||||
|
### Decision Framework
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Start simple:
|
||||||
|
- Try MLP or XGBoost on node features
|
||||||
|
- Establish baseline performance
|
||||||
|
|
||||||
|
2. Check graph structure value:
|
||||||
|
- Does edge information correlate with target?
|
||||||
|
- Is there homophily (similar nodes connected)?
|
||||||
|
- Test: Remove edges, compare performance
|
||||||
|
|
||||||
|
3. Use GNN if:
|
||||||
|
- Graph structure adds >2-5% accuracy
|
||||||
|
- Structure is interpretable (not random)
|
||||||
|
- Have enough nodes for GNN to learn
|
||||||
|
|
||||||
|
4. Stick with simple if:
|
||||||
|
- Node features alone sufficient
|
||||||
|
- Graph structure weak/random
|
||||||
|
- Small dataset (< 1000 nodes)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 5: Practical Implementation
|
||||||
|
|
||||||
|
### Using PyTorch Geometric
|
||||||
|
|
||||||
|
**Installation:**
|
||||||
|
```bash
|
||||||
|
pip install torch-geometric
|
||||||
|
```
|
||||||
|
|
||||||
|
**Basic workflow:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.nn import GCNConv
|
||||||
|
|
||||||
|
# 1. Create graph data
|
||||||
|
x = torch.tensor([[feature1], [feature2], ...]) # Node features
|
||||||
|
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format)
|
||||||
|
y = torch.tensor([label1, label2, ...]) # Node labels
|
||||||
|
|
||||||
|
data = Data(x=x, edge_index=edge_index, y=y)
|
||||||
|
|
||||||
|
# 2. Define model
|
||||||
|
class GNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GCNConv(in_features, 64)
|
||||||
|
self.conv2 = GCNConv(64, num_classes)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, edge_index = data.x, data.edge_index
|
||||||
|
x = F.relu(self.conv1(x, edge_index))
|
||||||
|
x = F.dropout(x, training=self.training)
|
||||||
|
x = self.conv2(x, edge_index)
|
||||||
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
|
# 3. Train
|
||||||
|
model = GNN()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
for epoch in range(200):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
out = model(data)
|
||||||
|
loss = F.nll_loss(out[train_mask], data.y[train_mask])
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Edge Index Format
|
||||||
|
|
||||||
|
**COO (Coordinate) format:**
|
||||||
|
```python
|
||||||
|
# Edge list: (0→1), (1→2), (2→0)
|
||||||
|
edge_index = torch.tensor([
|
||||||
|
[0, 1, 2], # Source nodes
|
||||||
|
[1, 2, 0] # Target nodes
|
||||||
|
])
|
||||||
|
|
||||||
|
# For undirected graph, include both directions:
|
||||||
|
edge_index = torch.tensor([
|
||||||
|
[0, 1, 1, 2, 2, 0], # Source
|
||||||
|
[1, 0, 2, 1, 0, 2] # Target
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mini-batching Graphs
|
||||||
|
|
||||||
|
**Problem:** Graphs have different sizes
|
||||||
|
|
||||||
|
**Solution:** Batch graphs as one large disconnected graph
|
||||||
|
```python
|
||||||
|
from torch_geometric.data import DataLoader
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = [Data(...), Data(...), ...]
|
||||||
|
|
||||||
|
# DataLoader handles batching
|
||||||
|
loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
# batch contains multiple graphs as one large graph
|
||||||
|
# batch.batch: Indicator which nodes belong to which graph
|
||||||
|
out = model(batch.x, batch.edge_index)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 6: Common Mistakes
|
||||||
|
|
||||||
|
### Mistake 1: LSTM Aggregation
|
||||||
|
|
||||||
|
**Symptom:** Different outputs for same graph with reordered nodes
|
||||||
|
**Fix:** Use mean/sum/max aggregation (permutation invariant)
|
||||||
|
|
||||||
|
### Mistake 2: Forgetting Edge Direction
|
||||||
|
|
||||||
|
**Symptom:** Information flows wrong way
|
||||||
|
**Fix:** For undirected graphs, add edges in both directions
|
||||||
|
|
||||||
|
### Mistake 3: Too Many Layers
|
||||||
|
|
||||||
|
**Symptom:** Performance degrades, over-smoothing
|
||||||
|
**Fix:** Use 2-3 layers (most graphs have small diameter)
|
||||||
|
**Explanation:** Too many layers → all nodes converge to same representation
|
||||||
|
|
||||||
|
### Mistake 4: Not Testing Simple Baseline
|
||||||
|
|
||||||
|
**Symptom:** Complex GNN with minimal improvement
|
||||||
|
**Fix:** Always test MLP on node features first
|
||||||
|
|
||||||
|
### Mistake 5: Using GNN on Euclidean Data
|
||||||
|
|
||||||
|
**Symptom:** CNN/RNN would work better
|
||||||
|
**Fix:** Use GNN only for irregular graph structure (not grids/sequences)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 7: Summary
|
||||||
|
|
||||||
|
### Quick Reference
|
||||||
|
|
||||||
|
**When to use GNN:**
|
||||||
|
- Graph-structured data (molecules, social networks, citations)
|
||||||
|
- Irregular neighborhoods (not grid/sequence)
|
||||||
|
- Graph structure informative (test this!)
|
||||||
|
|
||||||
|
**Architecture selection:**
|
||||||
|
```
|
||||||
|
Start: GCN (simplest)
|
||||||
|
Large graph: GraphSAGE (scalable)
|
||||||
|
Inductive learning: GraphSAGE or GAT
|
||||||
|
Neighbor importance: GAT (attention)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key principles:**
|
||||||
|
- Message passing: Aggregate neighbors + Update node
|
||||||
|
- Permutation invariance: Use mean/sum/max (not LSTM)
|
||||||
|
- Test baseline: MLP first, GNN if structure helps
|
||||||
|
- Layers: 2-3 sufficient (more = over-smoothing)
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- PyTorch Geometric: Standard library
|
||||||
|
- COO format: Edge index as 2×E tensor
|
||||||
|
- Batching: Merge graphs into one large graph
|
||||||
|
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
After mastering this skill:
|
||||||
|
- `transformer-architecture-deepdive`: Understand attention (used in GAT)
|
||||||
|
- `architecture-design-principles`: Design principles for graph architectures
|
||||||
|
- Advanced GNNs: Graph Transformers, Equivariant GNNs
|
||||||
|
|
||||||
|
**Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline)
|
||||||
915
skills/using-neural-architectures/normalization-techniques.md
Normal file
915
skills/using-neural-architectures/normalization-techniques.md
Normal file
@@ -0,0 +1,915 @@
|
|||||||
|
|
||||||
|
# Normalization Techniques
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering:
|
||||||
|
- **Batch size dependency**: BatchNorm fails with small batches (< 8)
|
||||||
|
- **Architecture mismatch**: BatchNorm breaks RNNs/Transformers (use LayerNorm)
|
||||||
|
- **Task-specific needs**: Style transfer needs InstanceNorm, not BatchNorm
|
||||||
|
- **Modern alternatives**: RMSNorm simpler and faster than LayerNorm for LLMs
|
||||||
|
|
||||||
|
**This skill prevents normalization cargo-culting and provides architecture-specific selection.**
|
||||||
|
|
||||||
|
## Why Normalization Matters
|
||||||
|
|
||||||
|
**Problem: Internal Covariate Shift**
|
||||||
|
During training, layer input distributions shift as previous layers update. This causes:
|
||||||
|
- Vanishing/exploding gradients (deep networks)
|
||||||
|
- Slow convergence (small learning rates required)
|
||||||
|
- Training instability (loss spikes)
|
||||||
|
|
||||||
|
**Solution: Normalization**
|
||||||
|
Normalize activations to have stable statistics (mean=0, std=1). Benefits:
|
||||||
|
- **10x faster convergence**: Can use larger learning rates
|
||||||
|
- **Better generalization**: Regularization effect
|
||||||
|
- **Enables deep networks**: 50+ layers without gradient issues
|
||||||
|
- **Less sensitive to initialization**: Weights can start further from optimal
|
||||||
|
|
||||||
|
**Key insight**: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize.
|
||||||
|
|
||||||
|
|
||||||
|
## Normalization Families
|
||||||
|
|
||||||
|
### 1. Batch Normalization (BatchNorm)
|
||||||
|
|
||||||
|
**What it does:**
|
||||||
|
Normalizes across the batch dimension for each channel/feature.
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```
|
||||||
|
Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width
|
||||||
|
|
||||||
|
For each channel c:
|
||||||
|
μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims
|
||||||
|
σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims
|
||||||
|
|
||||||
|
x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε)
|
||||||
|
|
||||||
|
# Learnable scale and shift
|
||||||
|
y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ CNNs for classification (ResNet, EfficientNet)
|
||||||
|
- ✅ Large batch sizes (≥ 16)
|
||||||
|
- ✅ IID data (image classification, object detection)
|
||||||
|
|
||||||
|
**When NOT to use:**
|
||||||
|
- ❌ Small batch sizes (< 8): Noisy statistics cause training failure
|
||||||
|
- ❌ RNNs/LSTMs: Breaks temporal dependencies
|
||||||
|
- ❌ Transformers: Batch dependency problematic for variable-length sequences
|
||||||
|
- ❌ Style transfer: Batch statistics erase style information
|
||||||
|
|
||||||
|
**Batch size dependency:**
|
||||||
|
```python
|
||||||
|
batch_size = 32: # ✓ Works well (stable statistics)
|
||||||
|
batch_size = 16: # ✓ Acceptable
|
||||||
|
batch_size = 8: # ✓ Marginal (consider GroupNorm)
|
||||||
|
batch_size = 4: # ✗ Unstable (use GroupNorm)
|
||||||
|
batch_size = 2: # ✗ FAILS! (noisy statistics)
|
||||||
|
batch_size = 1: # ✗ Undefined (no batch to normalize over!)
|
||||||
|
```
|
||||||
|
|
||||||
|
**PyTorch example:**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# For Conv2d
|
||||||
|
bn = nn.BatchNorm2d(num_features=64) # 64 channels
|
||||||
|
x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64
|
||||||
|
y = bn(x)
|
||||||
|
|
||||||
|
# For Linear
|
||||||
|
bn = nn.BatchNorm1d(num_features=128) # 128 features
|
||||||
|
x = torch.randn(32, 128) # Batch=32, Features=128
|
||||||
|
y = bn(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inference mode:**
|
||||||
|
```python
|
||||||
|
# Training: Uses batch statistics
|
||||||
|
model.train()
|
||||||
|
y = bn(x) # Normalizes using current batch mean/std
|
||||||
|
|
||||||
|
# Inference: Uses running statistics (accumulated during training)
|
||||||
|
model.eval()
|
||||||
|
y = bn(x) # Normalizes using running_mean/running_std
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Layer Normalization (LayerNorm)
|
||||||
|
|
||||||
|
**What it does:**
|
||||||
|
Normalizes across the feature dimension for each sample independently.
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```
|
||||||
|
Given input x with shape (B, C): # Batch, Features
|
||||||
|
|
||||||
|
For each sample b:
|
||||||
|
μ_b = mean(x[b, :]) # Mean over features
|
||||||
|
σ_b = std(x[b, :]) # Std over features
|
||||||
|
|
||||||
|
x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε)
|
||||||
|
|
||||||
|
# Learnable scale and shift
|
||||||
|
y[b, :] = γ * x_norm[b, :] + β
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Transformers (BERT, GPT, T5)
|
||||||
|
- ✅ RNNs/LSTMs (maintains temporal independence)
|
||||||
|
- ✅ Small batch sizes (batch-independent!)
|
||||||
|
- ✅ Variable-length sequences
|
||||||
|
- ✅ Reinforcement learning (batch_size=1 common)
|
||||||
|
|
||||||
|
**Advantages over BatchNorm:**
|
||||||
|
- ✅ **Batch-independent**: Works with batch_size=1
|
||||||
|
- ✅ **No running statistics**: Inference = training (no mode switching)
|
||||||
|
- ✅ **Sequence-friendly**: Doesn't mix information across timesteps
|
||||||
|
|
||||||
|
**PyTorch example:**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# For Transformer
|
||||||
|
ln = nn.LayerNorm(normalized_shape=512) # d_model=512
|
||||||
|
x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512
|
||||||
|
y = ln(x) # Normalizes last dimension independently per (batch, position)
|
||||||
|
|
||||||
|
# For RNN hidden states
|
||||||
|
ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256
|
||||||
|
h = torch.randn(32, 256) # Batch=32, Hidden=256
|
||||||
|
h_norm = ln(h)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key difference from BatchNorm:**
|
||||||
|
```python
|
||||||
|
# BatchNorm: Normalizes across batch dimension
|
||||||
|
# Given (B=32, C=64, H=28, W=28)
|
||||||
|
# Computes 64 means/stds (one per channel, across batch + spatial)
|
||||||
|
|
||||||
|
# LayerNorm: Normalizes across feature dimension
|
||||||
|
# Given (B=32, L=128, D=512)
|
||||||
|
# Computes 32×128 means/stds (one per (batch, position), across features)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Group Normalization (GroupNorm)
|
||||||
|
|
||||||
|
**What it does:**
|
||||||
|
Normalizes channels in groups, batch-independent.
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```
|
||||||
|
Given input x with shape (B, C, H, W):
|
||||||
|
Divide C channels into G groups (C must be divisible by G)
|
||||||
|
|
||||||
|
For each sample b and group g:
|
||||||
|
channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g
|
||||||
|
μ_{b,g} = mean(channels) # Mean over channels in group + spatial
|
||||||
|
σ_{b,g} = std(channels) # Std over channels in group + spatial
|
||||||
|
|
||||||
|
x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε)
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Small batch sizes (< 8)
|
||||||
|
- ✅ CNNs with batch_size=1 (style transfer, RL)
|
||||||
|
- ✅ Object detection/segmentation (often use small batches)
|
||||||
|
- ✅ When BatchNorm unstable but want spatial normalization
|
||||||
|
|
||||||
|
**Group size selection:**
|
||||||
|
```python
|
||||||
|
# num_groups trade-off:
|
||||||
|
num_groups = 1: # = LayerNorm (all channels together)
|
||||||
|
num_groups = C: # = InstanceNorm (each channel separate)
|
||||||
|
num_groups = 32: # Standard choice (good balance)
|
||||||
|
|
||||||
|
# Rule: C must be divisible by num_groups
|
||||||
|
channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group
|
||||||
|
channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group
|
||||||
|
channels = 64, num_groups = 30: # ✗ 64/30 not integer!
|
||||||
|
```
|
||||||
|
|
||||||
|
**PyTorch example:**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# For small batch CNN
|
||||||
|
gn = nn.GroupNorm(num_groups=32, num_channels=64)
|
||||||
|
x = torch.randn(2, 64, 28, 28) # Batch=2 (small!)
|
||||||
|
y = gn(x) # Works well even with batch=2
|
||||||
|
|
||||||
|
# Compare performance:
|
||||||
|
batch_sizes = [1, 2, 4, 8, 16, 32]
|
||||||
|
|
||||||
|
bn = nn.BatchNorm2d(64)
|
||||||
|
gn = nn.GroupNorm(32, 64)
|
||||||
|
|
||||||
|
for bs in batch_sizes:
|
||||||
|
x = torch.randn(bs, 64, 28, 28)
|
||||||
|
|
||||||
|
# BatchNorm gets more stable with larger batch
|
||||||
|
# GroupNorm consistent across all batch sizes
|
||||||
|
```
|
||||||
|
|
||||||
|
**Empirical results (He et al. 2018):**
|
||||||
|
```
|
||||||
|
ImageNet classification with ResNet-50:
|
||||||
|
batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie)
|
||||||
|
batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!)
|
||||||
|
batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 4. Instance Normalization (InstanceNorm)
|
||||||
|
|
||||||
|
**What it does:**
|
||||||
|
Normalizes each sample and channel independently (no batch mixing).
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```
|
||||||
|
Given input x with shape (B, C, H, W):
|
||||||
|
|
||||||
|
For each sample b and channel c:
|
||||||
|
μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only
|
||||||
|
σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only
|
||||||
|
|
||||||
|
x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε)
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Style transfer (neural style, CycleGAN, pix2pix)
|
||||||
|
- ✅ Image-to-image translation
|
||||||
|
- ✅ When batch/channel mixing destroys information
|
||||||
|
|
||||||
|
**Why for style transfer:**
|
||||||
|
```python
|
||||||
|
# Style transfer goal: Transfer style while preserving content
|
||||||
|
# BatchNorm: Mixes statistics across batch (erases individual style!)
|
||||||
|
# InstanceNorm: Per-image statistics (preserves each image's style)
|
||||||
|
|
||||||
|
# Example: Neural style transfer
|
||||||
|
content_image = load_image("photo.jpg")
|
||||||
|
style_image = load_image("starry_night.jpg")
|
||||||
|
|
||||||
|
# With BatchNorm: Output loses content image's unique characteristics
|
||||||
|
# With InstanceNorm: Content characteristics preserved, style applied
|
||||||
|
```
|
||||||
|
|
||||||
|
**PyTorch example:**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# For style transfer generator
|
||||||
|
in_norm = nn.InstanceNorm2d(num_features=64)
|
||||||
|
x = torch.randn(1, 64, 256, 256) # Single image
|
||||||
|
y = in_norm(x) # Normalizes each channel independently
|
||||||
|
|
||||||
|
# CycleGAN generator architecture
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
|
||||||
|
self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm!
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.in1(x) # Per-image normalization
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Relation to GroupNorm:**
|
||||||
|
```python
|
||||||
|
# InstanceNorm is GroupNorm with num_groups = num_channels
|
||||||
|
InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 5. RMS Normalization (RMSNorm)
|
||||||
|
|
||||||
|
**What it does:**
|
||||||
|
Simplified LayerNorm that only rescales (no recentering), faster and simpler.
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```
|
||||||
|
Given input x:
|
||||||
|
|
||||||
|
# LayerNorm (2 steps):
|
||||||
|
x_centered = x - mean(x) # 1. Center
|
||||||
|
x_norm = x_centered / std(x) # 2. Scale
|
||||||
|
|
||||||
|
# RMSNorm (1 step):
|
||||||
|
rms = sqrt(mean(x²)) # Root Mean Square
|
||||||
|
x_norm = x / rms # Only scale, no centering
|
||||||
|
```
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- ✅ Modern LLMs (LLaMA, Mistral, Gemma)
|
||||||
|
- ✅ When speed matters (15-20% faster than LayerNorm)
|
||||||
|
- ✅ Large Transformer models (billions of parameters)
|
||||||
|
|
||||||
|
**Advantages:**
|
||||||
|
- ✅ **Simpler**: One operation instead of two
|
||||||
|
- ✅ **Faster**: ~15-20% speedup over LayerNorm
|
||||||
|
- ✅ **Numerically stable**: No subtraction (avoids catastrophic cancellation)
|
||||||
|
- ✅ **Same performance**: Empirically matches LayerNorm quality
|
||||||
|
|
||||||
|
**PyTorch implementation:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Compute RMS
|
||||||
|
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
x_norm = x / rms
|
||||||
|
|
||||||
|
# Scale (learnable)
|
||||||
|
return self.weight * x_norm
|
||||||
|
|
||||||
|
# Usage in Transformer
|
||||||
|
rms = RMSNorm(dim=512) # d_model=512
|
||||||
|
x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model
|
||||||
|
y = rms(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Speed comparison (LLaMA-7B, A100 GPU):**
|
||||||
|
```
|
||||||
|
LayerNorm: 1000 tokens/sec
|
||||||
|
RMSNorm: 1180 tokens/sec # 18% faster!
|
||||||
|
|
||||||
|
# For large models, this adds up:
|
||||||
|
# 1 million tokens: 180 seconds saved
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modern LLM adoption:**
|
||||||
|
```python
|
||||||
|
# LLaMA (Meta, 2023): RMSNorm
|
||||||
|
# Mistral (Mistral AI, 2023): RMSNorm
|
||||||
|
# Gemma (Google, 2024): RMSNorm
|
||||||
|
# PaLM (Google, 2022): RMSNorm
|
||||||
|
|
||||||
|
# Older models:
|
||||||
|
# GPT-2/3 (OpenAI): LayerNorm
|
||||||
|
# BERT (Google, 2018): LayerNorm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Architecture-Specific Selection
|
||||||
|
|
||||||
|
### CNN (Convolutional Neural Networks)
|
||||||
|
|
||||||
|
**Default: BatchNorm**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels) # After conv
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x) # Normalize
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Exception: Small batch sizes**
|
||||||
|
```python
|
||||||
|
# If batch_size < 8, use GroupNorm instead
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||||
|
self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Exception: Style transfer**
|
||||||
|
```python
|
||||||
|
# Use InstanceNorm for style transfer
|
||||||
|
class StyleConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||||
|
self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### RNN / LSTM (Recurrent Neural Networks)
|
||||||
|
|
||||||
|
**Default: LayerNorm**
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class NormalizedLSTM(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
||||||
|
self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, seq_len, input_size)
|
||||||
|
output, (h_n, c_n) = self.lstm(x)
|
||||||
|
# output: (batch, seq_len, hidden_size)
|
||||||
|
|
||||||
|
# Normalize each timestep's output
|
||||||
|
output_norm = self.ln(output) # Applies independently per timestep
|
||||||
|
return output_norm, (h_n, c_n)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why NOT BatchNorm:**
|
||||||
|
```python
|
||||||
|
# BatchNorm in RNN mixes information across timesteps!
|
||||||
|
# Given (batch=32, seq_len=100, hidden=256)
|
||||||
|
|
||||||
|
# BatchNorm would compute:
|
||||||
|
# mean/std over (batch × seq_len) = 3200 values
|
||||||
|
# This mixes t=0 with t=99 (destroys temporal structure!)
|
||||||
|
|
||||||
|
# LayerNorm computes:
|
||||||
|
# mean/std over hidden_size = 256 values per (batch, timestep)
|
||||||
|
# Each timestep normalized independently (preserves temporal structure)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Layer-wise normalization in stacked RNN:**
|
||||||
|
```python
|
||||||
|
class StackedNormalizedLSTM(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
in_size = input_size if i == 0 else hidden_size
|
||||||
|
self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True))
|
||||||
|
self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for lstm, ln in zip(self.layers[::2], self.layers[1::2]):
|
||||||
|
x, _ = lstm(x)
|
||||||
|
x = ln(x) # Normalize between layers
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Transformer
|
||||||
|
|
||||||
|
**Default: LayerNorm (or RMSNorm for modern/large models)**
|
||||||
|
|
||||||
|
**Two placement options: Pre-norm vs Post-norm**
|
||||||
|
|
||||||
|
**Post-norm (original Transformer, "Attention is All You Need"):**
|
||||||
|
```python
|
||||||
|
class TransformerLayerPostNorm(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(d_model, 4 * d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(4 * d_model, d_model)
|
||||||
|
)
|
||||||
|
self.ln1 = nn.LayerNorm(d_model)
|
||||||
|
self.ln2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Post-norm: Apply normalization AFTER residual
|
||||||
|
x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding
|
||||||
|
x = self.ln2(x + self.ffn(x)) # Normalize after adding
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pre-norm (modern, more stable):**
|
||||||
|
```python
|
||||||
|
class TransformerLayerPreNorm(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(d_model, 4 * d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(4 * d_model, d_model)
|
||||||
|
)
|
||||||
|
self.ln1 = nn.LayerNorm(d_model)
|
||||||
|
self.ln2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Pre-norm: Apply normalization BEFORE sublayer
|
||||||
|
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention
|
||||||
|
x = x + self.ffn(self.ln2(x)) # Normalize before FFN
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pre-norm vs Post-norm comparison:**
|
||||||
|
```python
|
||||||
|
# Post-norm (original):
|
||||||
|
# - Less stable (requires careful initialization + warmup)
|
||||||
|
# - Slightly better performance IF training succeeds
|
||||||
|
# - Hard to train deep models (> 12 layers)
|
||||||
|
|
||||||
|
# Pre-norm (modern):
|
||||||
|
# - More stable (easier to train deep models)
|
||||||
|
# - Standard for large models (GPT-3: 96 layers!)
|
||||||
|
# - Recommended default
|
||||||
|
|
||||||
|
# Empirical: GPT-2, BERT (post-norm, ≤12 layers)
|
||||||
|
# GPT-3, T5, LLaMA (pre-norm, ≥24 layers)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using RMSNorm instead:**
|
||||||
|
```python
|
||||||
|
class TransformerLayerRMSNorm(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(d_model, 4 * d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(4 * d_model, d_model)
|
||||||
|
)
|
||||||
|
self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm
|
||||||
|
self.rms2 = RMSNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Pre-norm with RMSNorm (LLaMA style)
|
||||||
|
x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0]
|
||||||
|
x = x + self.ffn(self.rms2(x))
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### GAN (Generative Adversarial Network)
|
||||||
|
|
||||||
|
**Generator: InstanceNorm or no normalization**
|
||||||
|
```python
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Use InstanceNorm for image-to-image translation
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
|
||||||
|
self.in1 = nn.InstanceNorm2d(64) # Per-image normalization
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.in1(x) # Preserves per-image characteristics
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Discriminator: No normalization or LayerNorm**
|
||||||
|
```python
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Often no normalization (BatchNorm can hurt GAN training)
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
|
||||||
|
# No normalization here
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
# Directly to activation (no norm)
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why avoid BatchNorm in GANs:**
|
||||||
|
```python
|
||||||
|
# BatchNorm in discriminator:
|
||||||
|
# - Mixes real and fake samples in batch
|
||||||
|
# - Leaks information (discriminator can detect batch composition)
|
||||||
|
# - Hurts training stability
|
||||||
|
|
||||||
|
# Recommendation:
|
||||||
|
# Generator: InstanceNorm (for image translation) or no norm
|
||||||
|
# Discriminator: No normalization or LayerNorm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Decision Framework
|
||||||
|
|
||||||
|
### Step 1: Check batch size
|
||||||
|
|
||||||
|
```python
|
||||||
|
if batch_size >= 8:
|
||||||
|
consider_batchnorm = True
|
||||||
|
else:
|
||||||
|
use_groupnorm_or_layernorm = True # BatchNorm will be unstable
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Check architecture
|
||||||
|
|
||||||
|
```python
|
||||||
|
if architecture == "CNN":
|
||||||
|
if batch_size >= 8:
|
||||||
|
use_batchnorm()
|
||||||
|
else:
|
||||||
|
use_groupnorm(num_groups=32)
|
||||||
|
|
||||||
|
# Exception: Style transfer
|
||||||
|
if task == "style_transfer":
|
||||||
|
use_instancenorm()
|
||||||
|
|
||||||
|
elif architecture in ["RNN", "LSTM", "GRU"]:
|
||||||
|
use_layernorm() # NEVER BatchNorm!
|
||||||
|
|
||||||
|
elif architecture == "Transformer":
|
||||||
|
if model_size == "large": # > 1B parameters
|
||||||
|
use_rmsnorm() # 15-20% faster
|
||||||
|
else:
|
||||||
|
use_layernorm()
|
||||||
|
|
||||||
|
# Placement: Pre-norm (more stable)
|
||||||
|
use_prenorm_placement()
|
||||||
|
|
||||||
|
elif architecture == "GAN":
|
||||||
|
if component == "generator":
|
||||||
|
if task == "image_translation":
|
||||||
|
use_instancenorm()
|
||||||
|
else:
|
||||||
|
use_no_norm() # Or InstanceNorm
|
||||||
|
elif component == "discriminator":
|
||||||
|
use_no_norm() # Or LayerNorm
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Verify placement
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CNNs: After convolution, before activation
|
||||||
|
x = conv(x)
|
||||||
|
x = norm(x) # Here!
|
||||||
|
x = relu(x)
|
||||||
|
|
||||||
|
# RNNs: After LSTM, normalize hidden states
|
||||||
|
output, (h, c) = lstm(x)
|
||||||
|
output = norm(output) # Here!
|
||||||
|
|
||||||
|
# Transformers: Pre-norm (modern) or post-norm (original)
|
||||||
|
# Pre-norm (recommended):
|
||||||
|
x = x + sublayer(norm(x)) # Normalize before sublayer
|
||||||
|
|
||||||
|
# Post-norm (original):
|
||||||
|
x = norm(x + sublayer(x)) # Normalize after residual
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Implementation Checklist
|
||||||
|
|
||||||
|
### Before adding normalization:
|
||||||
|
|
||||||
|
1. ☐ **Check batch size**: If < 8, avoid BatchNorm
|
||||||
|
2. ☐ **Check architecture**: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm
|
||||||
|
3. ☐ **Check task**: Style transfer→InstanceNorm
|
||||||
|
4. ☐ **Verify placement**: After conv/linear, before activation (CNNs)
|
||||||
|
5. ☐ **Test training stability**: Loss should decrease smoothly
|
||||||
|
|
||||||
|
### During training:
|
||||||
|
|
||||||
|
6. ☐ **Monitor running statistics** (BatchNorm): Check running_mean/running_var are updating
|
||||||
|
7. ☐ **Test inference mode**: Verify model.eval() uses running stats correctly
|
||||||
|
8. ☐ **Check gradient flow**: Normalization should help, not hurt gradients
|
||||||
|
|
||||||
|
### If training is unstable:
|
||||||
|
|
||||||
|
9. ☐ **Try different normalization**: BatchNorm→GroupNorm, LayerNorm→RMSNorm
|
||||||
|
10. ☐ **Try pre-norm** (Transformers): More stable than post-norm
|
||||||
|
11. ☐ **Reduce learning rate**: Normalization allows larger LR, but start conservatively
|
||||||
|
|
||||||
|
|
||||||
|
## Common Mistakes
|
||||||
|
|
||||||
|
### Mistake 1: BatchNorm with small batches
|
||||||
|
|
||||||
|
```python
|
||||||
|
# WRONG: BatchNorm with batch_size=2
|
||||||
|
model = ResNet50(norm_layer=nn.BatchNorm2d)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=2) # Too small!
|
||||||
|
|
||||||
|
# RIGHT: GroupNorm for small batches
|
||||||
|
model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels))
|
||||||
|
dataloader = DataLoader(dataset, batch_size=2) # Works!
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 2: BatchNorm in RNN
|
||||||
|
|
||||||
|
```python
|
||||||
|
# WRONG: BatchNorm in LSTM
|
||||||
|
class BadLSTM(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(100, 256)
|
||||||
|
self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output, _ = self.lstm(x)
|
||||||
|
output = output.permute(0, 2, 1) # (B, H, T)
|
||||||
|
output = self.bn(output) # Mixes timesteps!
|
||||||
|
return output
|
||||||
|
|
||||||
|
# RIGHT: LayerNorm in LSTM
|
||||||
|
class GoodLSTM(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(100, 256)
|
||||||
|
self.ln = nn.LayerNorm(256) # Per-timestep normalization
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output, _ = self.lstm(x)
|
||||||
|
output = self.ln(output) # Independent per timestep
|
||||||
|
return output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 3: Forgetting model.eval()
|
||||||
|
|
||||||
|
```python
|
||||||
|
# WRONG: Using training mode during inference
|
||||||
|
model.train() # BatchNorm uses batch statistics
|
||||||
|
predictions = model(test_data) # Batch statistics from test data (leakage!)
|
||||||
|
|
||||||
|
# RIGHT: Use eval mode during inference
|
||||||
|
model.eval() # BatchNorm uses running statistics
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = model(test_data) # Uses accumulated running stats
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 4: Post-norm for deep Transformers
|
||||||
|
|
||||||
|
```python
|
||||||
|
# WRONG: Post-norm for 24-layer Transformer (unstable!)
|
||||||
|
class DeepTransformerPostNorm(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerLayerPostNorm(512, 8) for _ in range(24)
|
||||||
|
]) # Hard to train!
|
||||||
|
|
||||||
|
# RIGHT: Pre-norm for deep Transformers
|
||||||
|
class DeepTransformerPreNorm(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerLayerPreNorm(512, 8) for _ in range(24)
|
||||||
|
]) # Stable training!
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 5: Wrong normalization for style transfer
|
||||||
|
|
||||||
|
```python
|
||||||
|
# WRONG: BatchNorm for style transfer
|
||||||
|
class StyleGenerator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 64, 7, padding=3)
|
||||||
|
self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch
|
||||||
|
|
||||||
|
# RIGHT: InstanceNorm for style transfer
|
||||||
|
class StyleGenerator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 64, 7, padding=3)
|
||||||
|
self.norm = nn.InstanceNorm2d(64) # Per-image normalization
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Performance Impact
|
||||||
|
|
||||||
|
### Training speed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Without normalization: 100 epochs to converge
|
||||||
|
# With normalization: 10 epochs to converge (10x faster!)
|
||||||
|
|
||||||
|
# Reason: Larger learning rates possible
|
||||||
|
lr_no_norm = 0.001 # Must be small (unstable otherwise)
|
||||||
|
lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inference speed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Normalization overhead (relative to no normalization):
|
||||||
|
BatchNorm: +2% (minimal, cached running stats)
|
||||||
|
LayerNorm: +3-5% (compute mean/std per forward pass)
|
||||||
|
RMSNorm: +2-3% (faster than LayerNorm)
|
||||||
|
GroupNorm: +5-8% (more computation than BatchNorm)
|
||||||
|
InstanceNorm: +3-5% (similar to LayerNorm)
|
||||||
|
|
||||||
|
# For most models: Overhead is negligible compared to conv/linear layers
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Normalization memory (per layer):
|
||||||
|
BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β)
|
||||||
|
LayerNorm: 2 × normalized_shape (γ, β)
|
||||||
|
RMSNorm: 1 × normalized_shape (γ only, no β)
|
||||||
|
|
||||||
|
# Example: 512 channels
|
||||||
|
BatchNorm: 4 × 512 = 2048 parameters
|
||||||
|
LayerNorm: 2 × 512 = 1024 parameters
|
||||||
|
RMSNorm: 1 × 512 = 512 parameters # Most efficient!
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## When NOT to Normalize
|
||||||
|
|
||||||
|
**Case 1: Final output layer**
|
||||||
|
```python
|
||||||
|
# Don't normalize final predictions
|
||||||
|
class Classifier(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.backbone = ResNet50() # Normalization inside
|
||||||
|
self.fc = nn.Linear(2048, 1000)
|
||||||
|
# NO normalization here! (final logits should be unnormalized)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.backbone(x)
|
||||||
|
x = self.fc(x) # Raw logits
|
||||||
|
return x # Don't normalize!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Case 2: Very small networks**
|
||||||
|
```python
|
||||||
|
# Single-layer network: Normalization overkill
|
||||||
|
class TinyNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn.Linear(784, 10) # MNIST classifier
|
||||||
|
# No normalization needed (network too simple)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Case 3: When debugging**
|
||||||
|
```python
|
||||||
|
# Remove normalization to isolate issues
|
||||||
|
# If training fails with normalization, try without to check if:
|
||||||
|
# - Initialization is correct
|
||||||
|
# - Loss function is correct
|
||||||
|
# - Data is correctly preprocessed
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Modern Recommendations (2025)
|
||||||
|
|
||||||
|
### CNNs:
|
||||||
|
- **Default**: BatchNorm (if batch_size ≥ 8)
|
||||||
|
- **Small batches**: GroupNorm (num_groups=32)
|
||||||
|
- **Style transfer**: InstanceNorm
|
||||||
|
|
||||||
|
### RNNs/LSTMs:
|
||||||
|
- **Default**: LayerNorm
|
||||||
|
- **Never**: BatchNorm (breaks temporal structure)
|
||||||
|
|
||||||
|
### Transformers:
|
||||||
|
- **Small models** (< 1B): LayerNorm + pre-norm
|
||||||
|
- **Large models** (≥ 1B): RMSNorm + pre-norm (15-20% faster)
|
||||||
|
- **Avoid**: Post-norm for deep models (> 12 layers)
|
||||||
|
|
||||||
|
### GANs:
|
||||||
|
- **Generator**: InstanceNorm (image translation) or no norm
|
||||||
|
- **Discriminator**: No normalization or LayerNorm
|
||||||
|
- **Avoid**: BatchNorm (leaks information)
|
||||||
|
|
||||||
|
### Emerging:
|
||||||
|
- **RMSNorm adoption increasing**: LLaMA, Mistral, Gemma all use RMSNorm
|
||||||
|
- **Pre-norm becoming standard**: More stable for deep networks
|
||||||
|
- **GroupNorm gaining traction**: Object detection, small-batch training
|
||||||
|
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
**Normalization is mandatory for modern deep learning.** The question is which normalization, not whether to normalize.
|
||||||
|
|
||||||
|
**Quick decision tree:**
|
||||||
|
1. **Batch size ≥ 8?** → Consider BatchNorm (CNNs)
|
||||||
|
2. **Batch size < 8?** → Use GroupNorm (CNNs) or LayerNorm (all)
|
||||||
|
3. **RNN/LSTM?** → LayerNorm (never BatchNorm!)
|
||||||
|
4. **Transformer?** → LayerNorm or RMSNorm with pre-norm
|
||||||
|
5. **Style transfer?** → InstanceNorm
|
||||||
|
6. **GAN?** → InstanceNorm (generator) or no norm (discriminator)
|
||||||
|
|
||||||
|
**Modern defaults:**
|
||||||
|
- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8)
|
||||||
|
- RNNs: LayerNorm
|
||||||
|
- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models)
|
||||||
|
- GANs: InstanceNorm (generator), no norm (discriminator)
|
||||||
|
|
||||||
|
**Key insight**: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer.
|
||||||
714
skills/using-neural-architectures/sequence-models-comparison.md
Normal file
714
skills/using-neural-architectures/sequence-models-comparison.md
Normal file
@@ -0,0 +1,714 @@
|
|||||||
|
|
||||||
|
# Sequence Models Comparison: Choosing the Right Architecture for Sequential Data
|
||||||
|
|
||||||
|
<CRITICAL_CONTEXT>
|
||||||
|
Sequence modeling has evolved rapidly:
|
||||||
|
- 2014-2017: LSTM/GRU dominated
|
||||||
|
- 2017+: Transformers revolutionized the field
|
||||||
|
- 2018+: TCN emerged as efficient alternative
|
||||||
|
- 2021+: Sparse Transformers for very long sequences
|
||||||
|
- 2022+: State Space Models (S4) for extreme lengths
|
||||||
|
|
||||||
|
Don't default to LSTM (outdated) or blindly use Transformers (not always appropriate).
|
||||||
|
Match architecture to your sequence characteristics.
|
||||||
|
</CRITICAL_CONTEXT>
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when:
|
||||||
|
- ✅ Selecting model for sequential/temporal data
|
||||||
|
- ✅ Comparing RNN vs LSTM vs Transformer
|
||||||
|
- ✅ Deciding on sequence architecture for time series, text, audio
|
||||||
|
- ✅ Understanding modern alternatives to LSTM
|
||||||
|
- ✅ Optimizing for sequence length, speed, or accuracy
|
||||||
|
|
||||||
|
DO NOT use for:
|
||||||
|
- ❌ Vision tasks (use cnn-families-and-selection)
|
||||||
|
- ❌ Graph-structured data (use graph-neural-networks-basics)
|
||||||
|
- ❌ LLM-specific questions (use llm-specialist pack)
|
||||||
|
|
||||||
|
**When in doubt:** If data is sequential/temporal → this skill.
|
||||||
|
|
||||||
|
|
||||||
|
## Selection Framework
|
||||||
|
|
||||||
|
### Step 1: Identify Key Characteristics
|
||||||
|
|
||||||
|
**Before recommending, ask:**
|
||||||
|
|
||||||
|
| Characteristic | Question | Impact |
|
||||||
|
|----------------|----------|--------|
|
||||||
|
| **Sequence Length** | Typical length? | Short (< 100) → LSTM/CNN, Medium (100-1k) → Transformer, Long (> 1k) → Sparse Transformer/S4 |
|
||||||
|
| **Data Type** | Language, time series, audio? | Language → Transformer, Time series → TCN/Transformer, Audio → Specialized |
|
||||||
|
| **Data Volume** | Training examples? | Small (< 10k) → LSTM/TCN, Large (> 100k) → Transformer |
|
||||||
|
| **Latency** | Real-time needed? | Yes → TCN/LSTM, No → Transformer |
|
||||||
|
| **Deployment** | Cloud/edge/mobile? | Edge → TCN/LSTM, Cloud → Any |
|
||||||
|
|
||||||
|
### Step 2: Apply Decision Tree
|
||||||
|
|
||||||
|
```
|
||||||
|
START: What's your primary constraint?
|
||||||
|
|
||||||
|
┌─ SEQUENCE LENGTH
|
||||||
|
│ ├─ Short (< 100 steps)
|
||||||
|
│ │ ├─ Language → BiLSTM or small Transformer
|
||||||
|
│ │ └─ Time series → TCN or LSTM
|
||||||
|
│ │
|
||||||
|
│ ├─ Medium (100-1000 steps)
|
||||||
|
│ │ ├─ Language → Transformer (BERT-style)
|
||||||
|
│ │ └─ Time series → Transformer or TCN
|
||||||
|
│ │
|
||||||
|
│ ├─ Long (1000-10000 steps)
|
||||||
|
│ │ ├─ Sparse Transformer (Longformer, BigBird)
|
||||||
|
│ │ └─ Hierarchical models
|
||||||
|
│ │
|
||||||
|
│ └─ Very Long (> 10000 steps)
|
||||||
|
│ └─ State Space Models (S4)
|
||||||
|
│
|
||||||
|
├─ DATA TYPE
|
||||||
|
│ ├─ Natural Language
|
||||||
|
│ │ ├─ < 50k data → BiLSTM or DistilBERT
|
||||||
|
│ │ └─ > 50k data → Transformer (BERT, RoBERTa)
|
||||||
|
│ │
|
||||||
|
│ ├─ Time Series
|
||||||
|
│ │ ├─ Fast training → TCN
|
||||||
|
│ │ ├─ Long sequences → Transformer
|
||||||
|
│ │ └─ Multivariate → Transformer with cross-series attention
|
||||||
|
│ │
|
||||||
|
│ └─ Audio
|
||||||
|
│ ├─ Waveform → WaveNet (TCN-based)
|
||||||
|
│ └─ Spectrograms → CNN + Transformer
|
||||||
|
│
|
||||||
|
└─ COMPUTATIONAL CONSTRAINT
|
||||||
|
├─ Edge device → TCN or small LSTM
|
||||||
|
├─ Real-time latency → TCN (parallel inference)
|
||||||
|
└─ Cloud, no constraint → Transformer
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Architecture Catalog
|
||||||
|
|
||||||
|
### 1. RNN (Recurrent Neural Networks) - Legacy Foundation
|
||||||
|
|
||||||
|
**Architecture:** Basic recurrent cell with hidden state
|
||||||
|
|
||||||
|
**Status:** **OUTDATED** - don't use for new projects
|
||||||
|
|
||||||
|
**Why it existed:**
|
||||||
|
- First neural approach to sequences
|
||||||
|
- Hidden state captures temporal information
|
||||||
|
- Theoretically can model any sequence
|
||||||
|
|
||||||
|
**Why it failed:**
|
||||||
|
- Vanishing gradient (can't learn long dependencies)
|
||||||
|
- Very slow training (sequential processing)
|
||||||
|
- Replaced by LSTM in 2014
|
||||||
|
|
||||||
|
**When to mention:**
|
||||||
|
- Historical context only
|
||||||
|
- Teaching purposes
|
||||||
|
- Never recommend for production
|
||||||
|
|
||||||
|
**Key Insight:** Proved neural nets could handle sequences, but impractical due to vanishing gradients
|
||||||
|
|
||||||
|
|
||||||
|
### 2. LSTM (Long Short-Term Memory) - Legacy Standard
|
||||||
|
|
||||||
|
**Architecture:** Gated recurrent cell (forget, input, output gates)
|
||||||
|
|
||||||
|
**Complexity:** O(n) memory, sequential processing
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- Solves vanishing gradient (gates maintain long-term info)
|
||||||
|
- Works well for short-medium sequences (< 500 steps)
|
||||||
|
- Small datasets (< 10k examples)
|
||||||
|
- Low memory footprint
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- Sequential processing (slow training, can't parallelize)
|
||||||
|
- Still struggles with very long sequences (> 1000 steps)
|
||||||
|
- Slow inference (especially bidirectional)
|
||||||
|
- Superseded by Transformers for most language tasks
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ Small datasets (< 10k examples)
|
||||||
|
- ✅ Short sequences (< 100 steps)
|
||||||
|
- ✅ Edge deployment (low memory)
|
||||||
|
- ✅ Baseline comparison
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ Large datasets (Transformer better)
|
||||||
|
- ❌ Long sequences (> 500 steps)
|
||||||
|
- ❌ Modern NLP (Transformer standard)
|
||||||
|
- ❌ Fast training needed (TCN better)
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
class SeqLSTM(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(input_size, hidden_size,
|
||||||
|
num_layers=2,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True)
|
||||||
|
self.fc = nn.Linear(hidden_size * 2, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, seq_len, features)
|
||||||
|
lstm_out, _ = self.lstm(x)
|
||||||
|
# Use last timestep
|
||||||
|
out = self.fc(lstm_out[:, -1, :])
|
||||||
|
return out
|
||||||
|
```
|
||||||
|
|
||||||
|
**Status:** Legacy but still useful for specific cases (small data, edge deployment)
|
||||||
|
|
||||||
|
|
||||||
|
### 3. GRU (Gated Recurrent Unit) - Simplified LSTM
|
||||||
|
|
||||||
|
**Architecture:** Simplified gating (2 gates instead of 3)
|
||||||
|
|
||||||
|
**Advantages over LSTM:**
|
||||||
|
- Fewer parameters (faster training)
|
||||||
|
- Similar performance in many tasks
|
||||||
|
- Lower memory
|
||||||
|
|
||||||
|
**Disadvantages:**
|
||||||
|
- Still sequential (same as LSTM)
|
||||||
|
- No major advantage over LSTM in practice
|
||||||
|
- Also superseded by Transformers
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- Same as LSTM, but prefer LSTM for slightly better performance
|
||||||
|
- Use if computational savings matter
|
||||||
|
|
||||||
|
**Status:** Rarely recommended - if using recurrent, prefer LSTM or move to Transformer/TCN
|
||||||
|
|
||||||
|
|
||||||
|
### 4. Transformer - Modern Standard
|
||||||
|
|
||||||
|
**Architecture:** Self-attention mechanism, parallel processing
|
||||||
|
|
||||||
|
**Complexity:**
|
||||||
|
- Memory: O(n²) for sequence length n
|
||||||
|
- Compute: O(n²d) where d is embedding dimension
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- ✅ Parallel processing (fast training)
|
||||||
|
- ✅ Captures long-range dependencies (better than LSTM)
|
||||||
|
- ✅ State-of-the-art for language (BERT, GPT)
|
||||||
|
- ✅ Pre-trained models available
|
||||||
|
- ✅ Scales with data (more data = better performance)
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- ❌ Quadratic memory (struggles with sequences > 1000)
|
||||||
|
- ❌ Needs more data than LSTM (> 10k examples)
|
||||||
|
- ❌ Slower inference than TCN
|
||||||
|
- ❌ Harder to interpret than RNN
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Natural language** (current standard)
|
||||||
|
- ✅ Medium sequences (100-1000 tokens)
|
||||||
|
- ✅ Large datasets (> 50k examples)
|
||||||
|
- ✅ Pre-training available (BERT, GPT)
|
||||||
|
- ✅ Accuracy priority
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ Short sequences (< 50 tokens) - LSTM/CNN competitive, simpler
|
||||||
|
- ❌ Very long sequences (> 2000) - quadratic memory explodes
|
||||||
|
- ❌ Small datasets (< 10k) - will overfit
|
||||||
|
- ❌ Edge deployment - large model size
|
||||||
|
|
||||||
|
**Memory Analysis:**
|
||||||
|
```python
|
||||||
|
# Standard Transformer attention
|
||||||
|
|
||||||
|
# For sequence length n=1000, batch_size=32, embedding_dim=512:
|
||||||
|
attention_weights = softmax(Q @ K^T / sqrt(d)) # Shape: (32, 1000, 1000)
|
||||||
|
# Memory: 32 * 1000 * 1000 * 4 bytes = 128 MB just for attention!
|
||||||
|
|
||||||
|
# For n=5000:
|
||||||
|
# Memory: 32 * 5000 * 5000 * 4 bytes = 3.2 GB per batch!
|
||||||
|
# → Impossible on most GPUs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
from transformers import BertModel, BertTokenizer
|
||||||
|
|
||||||
|
# Pre-trained BERT for text classification
|
||||||
|
class TransformerClassifier(nn.Module):
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
||||||
|
self.classifier = nn.Linear(768, num_classes)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
outputs = self.bert(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask)
|
||||||
|
# Use [CLS] token representation
|
||||||
|
pooled = outputs.pooler_output
|
||||||
|
return self.classifier(pooled)
|
||||||
|
|
||||||
|
# Fine-tuning
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
model = TransformerClassifier(num_classes=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Status:** **Current standard for NLP**, competitive for time series with large data
|
||||||
|
|
||||||
|
|
||||||
|
### 5. TCN (Temporal Convolutional Network) - Underrated Alternative
|
||||||
|
|
||||||
|
**Architecture:** 1D convolutions with dilated causal convolutions
|
||||||
|
|
||||||
|
**Complexity:** O(n) memory, fully parallel processing
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- ✅ **Parallel training** (much faster than LSTM)
|
||||||
|
- ✅ **Parallel inference** (faster than LSTM/Transformer)
|
||||||
|
- ✅ Linear memory (no quadratic blow-up)
|
||||||
|
- ✅ Large receptive field (dilation)
|
||||||
|
- ✅ Works well for time series
|
||||||
|
- ✅ Simple architecture
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- ❌ Less popular (fewer pre-trained models)
|
||||||
|
- ❌ Not standard for language (Transformer dominates)
|
||||||
|
- ❌ Fixed receptive field (vs adaptive attention)
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Time series forecasting** (often BETTER than LSTM)
|
||||||
|
- ✅ **Fast training needed** (2-3x faster than LSTM)
|
||||||
|
- ✅ **Fast inference** (real-time applications)
|
||||||
|
- ✅ Long sequences (linear memory)
|
||||||
|
- ✅ Audio processing (WaveNet is TCN-based)
|
||||||
|
|
||||||
|
**When NOT to Use:**
|
||||||
|
- ❌ Natural language with pre-training available (use Transformer)
|
||||||
|
- ❌ Need very large receptive field (Transformer better)
|
||||||
|
|
||||||
|
**Performance Comparison:**
|
||||||
|
```
|
||||||
|
Time series forecasting (1000-step sequences):
|
||||||
|
|
||||||
|
Training speed:
|
||||||
|
- LSTM: 100% (baseline, sequential)
|
||||||
|
- TCN: 35% (2.8x faster, parallel)
|
||||||
|
- Transformer: 45% (2.2x faster)
|
||||||
|
|
||||||
|
Inference speed:
|
||||||
|
- LSTM: 100% (sequential)
|
||||||
|
- TCN: 20% (5x faster, parallel)
|
||||||
|
- Transformer: 60% (1.7x faster)
|
||||||
|
|
||||||
|
Accuracy (similar across all three):
|
||||||
|
- LSTM: Baseline
|
||||||
|
- TCN: Equal or slightly better
|
||||||
|
- Transformer: Equal or slightly better (needs more data)
|
||||||
|
|
||||||
|
Conclusion: TCN wins on speed, matches accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
**Code Example:**
|
||||||
|
```python
|
||||||
|
class TCN(nn.Module):
|
||||||
|
def __init__(self, input_channels, num_channels, kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
num_levels = len(num_channels)
|
||||||
|
|
||||||
|
for i in range(num_levels):
|
||||||
|
dilation_size = 2 ** i
|
||||||
|
in_channels = input_channels if i == 0 else num_channels[i-1]
|
||||||
|
out_channels = num_channels[i]
|
||||||
|
|
||||||
|
# Causal dilated convolution
|
||||||
|
layers.append(
|
||||||
|
nn.Conv1d(in_channels, out_channels, kernel_size,
|
||||||
|
padding=(kernel_size-1) * dilation_size,
|
||||||
|
dilation=dilation_size)
|
||||||
|
)
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
layers.append(nn.Dropout(0.2))
|
||||||
|
|
||||||
|
self.network = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.network(x)
|
||||||
|
|
||||||
|
# Usage for time series
|
||||||
|
# Input: (batch, channels, sequence_length)
|
||||||
|
model = TCN(input_channels=1, num_channels=[64, 128, 256])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Insight:** Dilated convolutions create exponentially large receptive field (2^k) while maintaining linear memory
|
||||||
|
|
||||||
|
**Status:** **Excellent for time series**, underrated, should be considered before LSTM
|
||||||
|
|
||||||
|
|
||||||
|
### 6. Sparse Transformers - Long Sequence Specialists
|
||||||
|
|
||||||
|
**Architecture:** Modified attention patterns to reduce complexity
|
||||||
|
|
||||||
|
**Variants:**
|
||||||
|
- **Longformer**: Local + global attention
|
||||||
|
- **BigBird**: Random + local + global attention
|
||||||
|
- **Linformer**: Low-rank projection of keys/values
|
||||||
|
- **Performer**: Kernel approximation of attention
|
||||||
|
|
||||||
|
**Complexity:** O(n log n) or O(n) depending on variant
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ **Long sequences** (1000-10000 tokens)
|
||||||
|
- ✅ Document processing (multi-page documents)
|
||||||
|
- ✅ Long-context language modeling
|
||||||
|
- ✅ When standard Transformer runs out of memory
|
||||||
|
|
||||||
|
**Trade-offs:**
|
||||||
|
- Slightly lower accuracy than full attention (approximation)
|
||||||
|
- More complex implementation
|
||||||
|
- Fewer pre-trained models
|
||||||
|
|
||||||
|
**Example Use Cases:**
|
||||||
|
- Legal document analysis (10k+ tokens)
|
||||||
|
- Scientific paper understanding
|
||||||
|
- Long-form text generation
|
||||||
|
- Time series with thousands of steps
|
||||||
|
|
||||||
|
**Status:** Specialized for long sequences, active research area
|
||||||
|
|
||||||
|
|
||||||
|
### 7. State Space Models (S4) - Cutting Edge
|
||||||
|
|
||||||
|
**Architecture:** Structured state space with efficient recurrence
|
||||||
|
|
||||||
|
**Complexity:** O(n log n) training, O(n) inference
|
||||||
|
|
||||||
|
**Strengths:**
|
||||||
|
- ✅ **Very long sequences** (10k-100k steps)
|
||||||
|
- ✅ Linear inference complexity
|
||||||
|
- ✅ Strong theoretical foundations
|
||||||
|
- ✅ Handles continuous-time sequences
|
||||||
|
|
||||||
|
**Weaknesses:**
|
||||||
|
- ❌ Newer (less mature ecosystem)
|
||||||
|
- ❌ Complex mathematics
|
||||||
|
- ❌ Fewer pre-trained models
|
||||||
|
- ❌ Harder to implement
|
||||||
|
|
||||||
|
**When to Use:**
|
||||||
|
- ✅ Extremely long sequences (> 10k steps)
|
||||||
|
- ✅ Audio (raw waveforms, 16kHz sampling)
|
||||||
|
- ✅ Medical signals (ECG, EEG)
|
||||||
|
- ✅ Research applications
|
||||||
|
|
||||||
|
**Status:** **Cutting edge** (2022+), promising for very long sequences
|
||||||
|
|
||||||
|
|
||||||
|
## Practical Selection Guide
|
||||||
|
|
||||||
|
### Scenario 1: Natural Language Processing
|
||||||
|
|
||||||
|
**Short text (< 50 tokens, e.g., tweets, titles):**
|
||||||
|
```
|
||||||
|
Small dataset (< 10k):
|
||||||
|
→ BiLSTM or 1D CNN (simple, effective)
|
||||||
|
|
||||||
|
Large dataset (> 10k):
|
||||||
|
→ DistilBERT (smaller Transformer, 40M params)
|
||||||
|
→ Or BiLSTM if latency critical
|
||||||
|
```
|
||||||
|
|
||||||
|
**Medium text (50-512 tokens, e.g., reviews, articles):**
|
||||||
|
```
|
||||||
|
Standard approach:
|
||||||
|
→ BERT, RoBERTa, or similar (110M params)
|
||||||
|
→ Fine-tune on task-specific data
|
||||||
|
|
||||||
|
Small dataset:
|
||||||
|
→ DistilBERT (66M params, faster, similar accuracy)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Long documents (> 512 tokens):**
|
||||||
|
```
|
||||||
|
→ Longformer (4096 tokens max)
|
||||||
|
→ BigBird (4096 tokens max)
|
||||||
|
→ Hierarchical: Process in chunks, aggregate
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 2: Time Series Forecasting
|
||||||
|
|
||||||
|
**Short sequences (< 100 steps):**
|
||||||
|
```
|
||||||
|
Fast training:
|
||||||
|
→ TCN (2-3x faster than LSTM)
|
||||||
|
|
||||||
|
Small dataset:
|
||||||
|
→ LSTM or simple models (ARIMA, Prophet)
|
||||||
|
|
||||||
|
Baseline:
|
||||||
|
→ LSTM (well-tested)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Medium sequences (100-1000 steps):**
|
||||||
|
```
|
||||||
|
Best accuracy:
|
||||||
|
→ Transformer (if data > 50k examples)
|
||||||
|
|
||||||
|
Fast training/inference:
|
||||||
|
→ TCN (parallel processing)
|
||||||
|
|
||||||
|
Multivariate:
|
||||||
|
→ Transformer with cross-series attention
|
||||||
|
```
|
||||||
|
|
||||||
|
**Long sequences (> 1000 steps):**
|
||||||
|
```
|
||||||
|
→ Sparse Transformer (Informer for time series)
|
||||||
|
→ Hierarchical models (chunk + aggregate)
|
||||||
|
→ State Space Models (S4)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Scenario 3: Audio Processing
|
||||||
|
|
||||||
|
**Waveform (raw audio, 16kHz):**
|
||||||
|
```
|
||||||
|
→ WaveNet (TCN-based)
|
||||||
|
→ State Space Models (S4)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Spectrograms (mel-spectrograms):**
|
||||||
|
```
|
||||||
|
→ CNN + BiLSTM (traditional)
|
||||||
|
→ CNN + Transformer (modern)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Speech recognition:**
|
||||||
|
```
|
||||||
|
→ Transformer (Wav2Vec 2.0, Whisper)
|
||||||
|
→ Pre-trained models available
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Trade-Off Analysis
|
||||||
|
|
||||||
|
### Speed Comparison
|
||||||
|
|
||||||
|
**Training speed (1000-step sequences):**
|
||||||
|
```
|
||||||
|
LSTM: 100% (baseline, sequential)
|
||||||
|
GRU: 75% (simpler gates)
|
||||||
|
TCN: 35% (2.8x faster, parallel)
|
||||||
|
Transformer: 45% (2.2x faster, parallel)
|
||||||
|
|
||||||
|
Conclusion: TCN fastest for training
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inference speed:**
|
||||||
|
```
|
||||||
|
LSTM: 100% (sequential)
|
||||||
|
BiLSTM: 200% (2x passes)
|
||||||
|
TCN: 20% (5x faster, parallel)
|
||||||
|
Transformer: 60% (faster, but attention overhead)
|
||||||
|
|
||||||
|
Conclusion: TCN fastest for inference
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Memory Comparison
|
||||||
|
|
||||||
|
**Sequence length n=1000, batch=32:**
|
||||||
|
```
|
||||||
|
LSTM: ~500 MB (linear in n)
|
||||||
|
Transformer: ~2 GB (quadratic in n)
|
||||||
|
TCN: ~400 MB (linear in n)
|
||||||
|
Sparse Transformer: ~800 MB (n log n)
|
||||||
|
|
||||||
|
For n=5000:
|
||||||
|
LSTM: ~2 GB
|
||||||
|
Transformer: OUT OF MEMORY (50 GB needed!)
|
||||||
|
TCN: ~2 GB
|
||||||
|
Sparse Transformer: ~4 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Accuracy vs Data Size
|
||||||
|
|
||||||
|
**Small dataset (< 10k examples):**
|
||||||
|
```
|
||||||
|
LSTM: ★★★★☆ (works well with little data)
|
||||||
|
Transformer: ★★☆☆☆ (overfits, needs more data)
|
||||||
|
TCN: ★★★★☆ (similar to LSTM)
|
||||||
|
|
||||||
|
Winner: LSTM or TCN
|
||||||
|
```
|
||||||
|
|
||||||
|
**Large dataset (> 100k examples):**
|
||||||
|
```
|
||||||
|
LSTM: ★★★☆☆ (good but plateaus)
|
||||||
|
Transformer: ★★★★★ (best, scales with data)
|
||||||
|
TCN: ★★★★☆ (competitive)
|
||||||
|
|
||||||
|
Winner: Transformer
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Common Pitfalls
|
||||||
|
|
||||||
|
### Pitfall 1: Using LSTM in 2025 Without Considering Modern Alternatives
|
||||||
|
**Symptom:** Defaulting to LSTM for all sequence tasks
|
||||||
|
|
||||||
|
**Why it's wrong:** Transformers (language) and TCN (time series) often better
|
||||||
|
|
||||||
|
**Fix:** Consider Transformer for language, TCN for time series, LSTM for small data/edge only
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 2: Using Standard Transformer for Very Long Sequences
|
||||||
|
**Symptom:** Running out of memory on sequences > 1000 tokens
|
||||||
|
|
||||||
|
**Why it's wrong:** O(n²) memory explodes
|
||||||
|
|
||||||
|
**Fix:** Use Sparse Transformer (Longformer, BigBird) or hierarchical approach
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 3: Not Trying TCN for Time Series
|
||||||
|
**Symptom:** Struggling with slow LSTM training
|
||||||
|
|
||||||
|
**Why it's wrong:** TCN is 2-3x faster, often more accurate
|
||||||
|
|
||||||
|
**Fix:** Try TCN before optimizing LSTM
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 4: Using Transformer for Small Datasets
|
||||||
|
**Symptom:** Transformer overfits on < 10k examples
|
||||||
|
|
||||||
|
**Why it's wrong:** Transformers need large datasets to work well
|
||||||
|
|
||||||
|
**Fix:** Use LSTM or TCN for small datasets, or use pre-trained Transformer
|
||||||
|
|
||||||
|
|
||||||
|
### Pitfall 5: Ignoring Sequence Length Constraints
|
||||||
|
**Symptom:** Choosing architecture without considering typical sequence length
|
||||||
|
|
||||||
|
**Why it's wrong:** Architecture effectiveness varies dramatically with length
|
||||||
|
|
||||||
|
**Fix:** Match architecture to sequence length (short → LSTM/CNN, long → Sparse Transformer)
|
||||||
|
|
||||||
|
|
||||||
|
## Evolution Timeline
|
||||||
|
|
||||||
|
**Understanding why architectures evolved:**
|
||||||
|
|
||||||
|
```
|
||||||
|
2010-2013: Basic RNN
|
||||||
|
→ Vanishing gradient problem
|
||||||
|
→ Can't learn long dependencies
|
||||||
|
|
||||||
|
2014: LSTM (Hochreiter & Schmidhuber)
|
||||||
|
→ Gates solve vanishing gradient
|
||||||
|
→ Became standard for sequences
|
||||||
|
|
||||||
|
2014: GRU
|
||||||
|
→ Simplified LSTM
|
||||||
|
→ Similar performance, fewer parameters
|
||||||
|
|
||||||
|
2017: Transformer (Attention Is All You Need)
|
||||||
|
→ Self-attention replaces recurrence
|
||||||
|
→ Parallel processing (fast training)
|
||||||
|
→ Revolutionized NLP
|
||||||
|
|
||||||
|
2018: TCN (Temporal Convolutional Networks)
|
||||||
|
→ Dilated convolutions for sequences
|
||||||
|
→ Often better than LSTM for time series
|
||||||
|
→ Underrated alternative
|
||||||
|
|
||||||
|
2020: Sparse Transformers
|
||||||
|
→ Reduce quadratic complexity
|
||||||
|
→ Enable longer sequences
|
||||||
|
|
||||||
|
2021: State Space Models (S4)
|
||||||
|
→ Very long sequences (10k-100k)
|
||||||
|
→ Theoretical foundations
|
||||||
|
→ Cutting edge research
|
||||||
|
|
||||||
|
Current (2025):
|
||||||
|
- NLP: Transformer standard (BERT, GPT)
|
||||||
|
- Time Series: TCN or Transformer
|
||||||
|
- Audio: Specialized (WaveNet, Transformer)
|
||||||
|
- Edge: LSTM or TCN (low memory)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Decision Checklist
|
||||||
|
|
||||||
|
Before choosing sequence model:
|
||||||
|
|
||||||
|
```
|
||||||
|
☐ Sequence length? (< 100 / 100-1k / > 1k)
|
||||||
|
☐ Data type? (language / time series / audio / other)
|
||||||
|
☐ Dataset size? (< 10k / 10k-100k / > 100k)
|
||||||
|
☐ Latency requirement? (real-time / batch / offline)
|
||||||
|
☐ Deployment target? (cloud / edge / mobile)
|
||||||
|
☐ Pre-trained models available? (yes / no)
|
||||||
|
☐ Training speed critical? (yes / no)
|
||||||
|
|
||||||
|
Based on answers:
|
||||||
|
→ Language + large data → Transformer
|
||||||
|
→ Language + small data → BiLSTM or DistilBERT
|
||||||
|
→ Time series + speed → TCN
|
||||||
|
→ Time series + accuracy + large data → Transformer
|
||||||
|
→ Very long sequences → Sparse Transformer or S4
|
||||||
|
→ Edge deployment → TCN or LSTM
|
||||||
|
→ Real-time latency → TCN
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Integration with Other Skills
|
||||||
|
|
||||||
|
**For language-specific questions:**
|
||||||
|
→ `yzmir/llm-specialist/using-llm-specialist`
|
||||||
|
- LLM-specific Transformers (GPT, BERT variants)
|
||||||
|
- Fine-tuning strategies
|
||||||
|
- Prompt engineering
|
||||||
|
|
||||||
|
**For Transformer internals:**
|
||||||
|
→ `yzmir/neural-architectures/transformer-architecture-deepdive`
|
||||||
|
- Attention mechanisms
|
||||||
|
- Positional encoding
|
||||||
|
- Transformer variants
|
||||||
|
|
||||||
|
**After selecting architecture:**
|
||||||
|
→ `yzmir/training-optimization/using-training-optimization`
|
||||||
|
- Optimizer selection
|
||||||
|
- Learning rate schedules
|
||||||
|
- Handling sequence-specific training issues
|
||||||
|
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
**Quick Reference Table:**
|
||||||
|
|
||||||
|
| Use Case | Best Choice | Alternative | Avoid |
|
||||||
|
|----------|-------------|-------------|-------|
|
||||||
|
| Short text (< 50 tokens) | BiLSTM, DistilBERT | 1D CNN | Full BERT (overkill) |
|
||||||
|
| Long text (> 512 tokens) | Longformer, BigBird | Hierarchical | Standard BERT (memory) |
|
||||||
|
| Time series (< 1k steps) | TCN, Transformer | LSTM | Basic RNN |
|
||||||
|
| Time series (> 1k steps) | Sparse Transformer, S4 | Hierarchical | Standard Transformer |
|
||||||
|
| Small dataset (< 10k) | LSTM, TCN | Simple models | Transformer (overfits) |
|
||||||
|
| Large dataset (> 100k) | Transformer | TCN | LSTM (plateaus) |
|
||||||
|
| Edge deployment | TCN, LSTM | Quantized Transformer | Large Transformer |
|
||||||
|
| Real-time inference | TCN | Small LSTM | BiLSTM, Transformer |
|
||||||
|
|
||||||
|
**Key Principles:**
|
||||||
|
1. **Don't default to LSTM** (outdated for most tasks)
|
||||||
|
2. **Transformer for language** (current standard, if data sufficient)
|
||||||
|
3. **TCN for time series** (fast, effective, underrated)
|
||||||
|
4. **Match to sequence length** (short → LSTM/CNN, long → Sparse Transformer)
|
||||||
|
5. **Consider modern alternatives** (don't stop at LSTM vs Transformer)
|
||||||
|
|
||||||
|
|
||||||
|
**END OF SKILL**
|
||||||
@@ -0,0 +1,937 @@
|
|||||||
|
|
||||||
|
# Transformer Architecture Deep Dive
|
||||||
|
|
||||||
|
## When to Use This Skill
|
||||||
|
|
||||||
|
Use this skill when you need to:
|
||||||
|
- ✅ Implement a Transformer from scratch
|
||||||
|
- ✅ Understand HOW and WHY self-attention works
|
||||||
|
- ✅ Choose between encoder, decoder, or encoder-decoder architectures
|
||||||
|
- ✅ Decide if Vision Transformer (ViT) is appropriate for your vision task
|
||||||
|
- ✅ Understand modern variants (RoPE, ALiBi, GQA, MQA)
|
||||||
|
- ✅ Debug Transformer implementation issues
|
||||||
|
- ✅ Optimize Transformer performance
|
||||||
|
|
||||||
|
**Do NOT use this skill for:**
|
||||||
|
- ❌ High-level architecture selection (use `using-neural-architectures`)
|
||||||
|
- ❌ Attention mechanism comparison (use `attention-mechanisms-catalog`)
|
||||||
|
- ❌ LLM-specific topics like prompt engineering (use `llm-specialist` pack)
|
||||||
|
|
||||||
|
|
||||||
|
## Core Principle
|
||||||
|
|
||||||
|
**Transformers are NOT magic.** They are:
|
||||||
|
1. Self-attention mechanism (information retrieval)
|
||||||
|
2. + Position encoding (break permutation invariance)
|
||||||
|
3. + Residual connections + Layer norm (training stability)
|
||||||
|
4. + Feed-forward networks (non-linearity)
|
||||||
|
|
||||||
|
Understanding the mechanism beats cargo-culting implementations.
|
||||||
|
|
||||||
|
|
||||||
|
## Part 1: Self-Attention Mechanism Explained
|
||||||
|
|
||||||
|
### The Information Retrieval Analogy
|
||||||
|
|
||||||
|
**Self-attention = Querying a database:**
|
||||||
|
- **Query (Q)**: "What am I looking for?"
|
||||||
|
- **Key (K)**: "What do I contain?"
|
||||||
|
- **Value (V)**: "What information do I have?"
|
||||||
|
|
||||||
|
**Process:**
|
||||||
|
1. Compare your query with all keys (compute similarity)
|
||||||
|
2. Weight values by similarity
|
||||||
|
3. Return weighted sum of values
|
||||||
|
|
||||||
|
**Example:** Sentence: "The cat sat on the mat"
|
||||||
|
Token "sat" (verb):
|
||||||
|
- High attention to "cat" (subject) → Learns verb-subject relationship
|
||||||
|
- High attention to "mat" (object) → Learns verb-object relationship
|
||||||
|
- Low attention to "the", "on" (function words)
|
||||||
|
|
||||||
|
### Mathematical Breakdown
|
||||||
|
|
||||||
|
**Given input X:** (batch, seq_len, d_model)
|
||||||
|
|
||||||
|
**Step 1: Project to Q, K, V**
|
||||||
|
```python
|
||||||
|
Q = X @ W_Q # (batch, seq_len, d_k)
|
||||||
|
K = X @ W_K # (batch, seq_len, d_k)
|
||||||
|
V = X @ W_V # (batch, seq_len, d_v)
|
||||||
|
|
||||||
|
# Typically: d_k = d_v = d_model / num_heads
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Compute attention scores** (similarity)
|
||||||
|
```python
|
||||||
|
scores = Q @ K.transpose(-2, -1) # (batch, seq_len, seq_len)
|
||||||
|
# scores[i, j] = similarity between query_i and key_j
|
||||||
|
```
|
||||||
|
|
||||||
|
**Geometric interpretation:**
|
||||||
|
- Dot product measures vector alignment
|
||||||
|
- q · k = ||q|| ||k|| cos(θ)
|
||||||
|
- Similar vectors → Large dot product → High attention
|
||||||
|
- Orthogonal vectors → Zero dot product → No attention
|
||||||
|
|
||||||
|
**Step 3: Scale by √d_k** (CRITICAL!)
|
||||||
|
```python
|
||||||
|
scores = scores / math.sqrt(d_k)
|
||||||
|
```
|
||||||
|
|
||||||
|
**WHY scaling?**
|
||||||
|
- Dot products grow with dimension: Var(q · k) = d_k
|
||||||
|
- Example: d_k=64 → Random dot products ~ ±64
|
||||||
|
- Large scores → Softmax saturates → Gradients vanish
|
||||||
|
- Scaling: Keep scores ~ O(1) regardless of dimension
|
||||||
|
|
||||||
|
**Without scaling:** Softmax([30, 25, 20]) ≈ [0.99, 0.01, 0.00] (saturated!)
|
||||||
|
**With scaling:** Softmax([3, 2.5, 2]) ≈ [0.50, 0.30, 0.20] (healthy gradients)
|
||||||
|
|
||||||
|
**Step 4: Softmax to get attention weights**
|
||||||
|
```python
|
||||||
|
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)
|
||||||
|
# Each row sums to 1 (probability distribution)
|
||||||
|
# attn_weights[i, j] = "how much token i attends to token j"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Weight values**
|
||||||
|
```python
|
||||||
|
output = attn_weights @ V # (batch, seq_len, d_v)
|
||||||
|
# Each token's output = weighted average of all values
|
||||||
|
```
|
||||||
|
|
||||||
|
**Complete formula:**
|
||||||
|
```python
|
||||||
|
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
|
||||||
|
```
|
||||||
|
|
||||||
|
### Why Three Matrices (Q, K, V)?
|
||||||
|
|
||||||
|
**Could we use just one?** Attention(X, X, X)
|
||||||
|
**Yes, but Q/K/V separation enables:**
|
||||||
|
|
||||||
|
1. **Asymmetry**: Query can differ from key (search ≠ database)
|
||||||
|
2. **Decoupling**: What you search for (Q@K) ≠ what you retrieve (V)
|
||||||
|
3. **Cross-attention**: Q from one source, K/V from another
|
||||||
|
- Example: Decoder queries encoder (translation)
|
||||||
|
|
||||||
|
**Modern optimization:** Multi-Query Attention (MQA), Grouped-Query Attention (GQA)
|
||||||
|
- Share K/V across heads (fewer parameters, faster inference)
|
||||||
|
|
||||||
|
### Implementation Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, d_k=None):
|
||||||
|
super().__init__()
|
||||||
|
self.d_k = d_k or d_model
|
||||||
|
|
||||||
|
self.W_q = nn.Linear(d_model, self.d_k)
|
||||||
|
self.W_k = nn.Linear(d_model, self.d_k)
|
||||||
|
self.W_v = nn.Linear(d_model, self.d_k)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
# x: (batch, seq_len, d_model)
|
||||||
|
Q = self.W_q(x) # (batch, seq_len, d_k)
|
||||||
|
K = self.W_k(x) # (batch, seq_len, d_k)
|
||||||
|
V = self.W_v(x) # (batch, seq_len, d_k)
|
||||||
|
|
||||||
|
# Attention scores
|
||||||
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||||
|
# scores: (batch, seq_len, seq_len)
|
||||||
|
|
||||||
|
# Apply mask if provided (for causal attention)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
|
||||||
|
# Attention weights
|
||||||
|
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)
|
||||||
|
|
||||||
|
# Weighted sum of values
|
||||||
|
output = torch.matmul(attn_weights, V) # (batch, seq_len, d_k)
|
||||||
|
|
||||||
|
return output, attn_weights
|
||||||
|
```
|
||||||
|
|
||||||
|
**Complexity:** O(n² · d) where n = seq_len, d = d_model
|
||||||
|
- **Quadratic in sequence length** (bottleneck for long sequences)
|
||||||
|
- For n=1000, d=512: 1000² × 512 = 512M operations
|
||||||
|
|
||||||
|
|
||||||
|
## Part 2: Multi-Head Attention
|
||||||
|
|
||||||
|
### Why Multiple Heads?
|
||||||
|
|
||||||
|
**Single-head attention** learns one attention pattern.
|
||||||
|
**Multi-head attention** learns multiple parallel patterns:
|
||||||
|
- Head 1: Syntactic relationships (subject-verb)
|
||||||
|
- Head 2: Semantic similarity
|
||||||
|
- Head 3: Positional proximity
|
||||||
|
- Head 4: Long-range dependencies
|
||||||
|
|
||||||
|
**Analogy:** Ensemble of attention functions, each specializing in different patterns.
|
||||||
|
|
||||||
|
### Head Dimension Calculation
|
||||||
|
|
||||||
|
**CRITICAL CONSTRAINT:** num_heads must divide d_model evenly!
|
||||||
|
|
||||||
|
```python
|
||||||
|
d_model = 512
|
||||||
|
num_heads = 8
|
||||||
|
d_k = d_model // num_heads # 512 / 8 = 64
|
||||||
|
|
||||||
|
# Each head operates in d_k dimensions
|
||||||
|
# Concatenate all heads → back to d_model dimensions
|
||||||
|
```
|
||||||
|
|
||||||
|
**Common configurations:**
|
||||||
|
- BERT-base: d_model=768, heads=12, d_k=64
|
||||||
|
- GPT-2: d_model=768, heads=12, d_k=64
|
||||||
|
- GPT-3 175B: d_model=12288, heads=96, d_k=128
|
||||||
|
- LLaMA-2 70B: d_model=8192, heads=64, d_k=128
|
||||||
|
|
||||||
|
**Rule of thumb:** d_k (head dimension) should be 64-128
|
||||||
|
- Too small (d_k < 32): Limited representational capacity
|
||||||
|
- Too large (d_k > 256): Redundant, wasteful
|
||||||
|
|
||||||
|
### Multi-Head Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.d_k = d_model // num_heads
|
||||||
|
|
||||||
|
# Single linear layers for all heads (more efficient)
|
||||||
|
self.W_q = nn.Linear(d_model, d_model)
|
||||||
|
self.W_k = nn.Linear(d_model, d_model)
|
||||||
|
self.W_v = nn.Linear(d_model, d_model)
|
||||||
|
self.W_o = nn.Linear(d_model, d_model) # Output projection
|
||||||
|
|
||||||
|
def split_heads(self, x):
|
||||||
|
# x: (batch, seq_len, d_model)
|
||||||
|
batch_size, seq_len, d_model = x.size()
|
||||||
|
# Reshape to (batch, seq_len, num_heads, d_k)
|
||||||
|
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
|
||||||
|
# Transpose to (batch, num_heads, seq_len, d_k)
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# Linear projections
|
||||||
|
Q = self.W_q(x) # (batch, seq_len, d_model)
|
||||||
|
K = self.W_k(x)
|
||||||
|
V = self.W_v(x)
|
||||||
|
|
||||||
|
# Split into multiple heads
|
||||||
|
Q = self.split_heads(Q) # (batch, num_heads, seq_len, d_k)
|
||||||
|
K = self.split_heads(K)
|
||||||
|
V = self.split_heads(V)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
attn_weights = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Weighted sum
|
||||||
|
attn_output = torch.matmul(attn_weights, V)
|
||||||
|
# attn_output: (batch, num_heads, seq_len, d_k)
|
||||||
|
|
||||||
|
# Concatenate heads
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
# (batch, seq_len, num_heads, d_k)
|
||||||
|
attn_output = attn_output.view(batch_size, -1, self.d_model)
|
||||||
|
# (batch, seq_len, d_model)
|
||||||
|
|
||||||
|
# Final linear projection
|
||||||
|
output = self.W_o(attn_output)
|
||||||
|
|
||||||
|
return output, attn_weights
|
||||||
|
```
|
||||||
|
|
||||||
|
### Modern Variants: GQA and MQA
|
||||||
|
|
||||||
|
**Problem:** K/V caching during inference is memory-intensive
|
||||||
|
- LLaMA-2 70B: 8192 × 64 heads × 2 (K + V) = 1M parameters per token cached!
|
||||||
|
|
||||||
|
**Solution 1: Multi-Query Attention (MQA)**
|
||||||
|
- **One** K/V head shared across **all** Q heads
|
||||||
|
- Benefit: Dramatically faster inference (smaller KV cache)
|
||||||
|
- Trade-off: ~1-2% accuracy loss
|
||||||
|
|
||||||
|
```python
|
||||||
|
# MQA: Single K/V projection
|
||||||
|
self.W_k = nn.Linear(d_model, d_k) # Not d_model!
|
||||||
|
self.W_v = nn.Linear(d_model, d_k)
|
||||||
|
self.W_q = nn.Linear(d_model, d_model) # Multiple Q heads
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution 2: Grouped-Query Attention (GQA)**
|
||||||
|
- Middle ground: Group multiple Q heads per K/V head
|
||||||
|
- Example: 32 Q heads → 8 K/V heads (4 Q per K/V)
|
||||||
|
- Benefit: 4x smaller KV cache, minimal accuracy loss
|
||||||
|
|
||||||
|
**Used in:** LLaMA-2, Mistral, Mixtral
|
||||||
|
|
||||||
|
|
||||||
|
## Part 3: Position Encoding
|
||||||
|
|
||||||
|
### Why Position Encoding?
|
||||||
|
|
||||||
|
**Problem:** Self-attention is **permutation-invariant**
|
||||||
|
- Attention("cat sat mat") = Attention("mat cat sat")
|
||||||
|
- No inherent notion of position or order!
|
||||||
|
|
||||||
|
**Solution:** Add position information to embeddings
|
||||||
|
|
||||||
|
### Strategy 1: Sinusoidal Position Encoding (Original)
|
||||||
|
|
||||||
|
**Formula:**
|
||||||
|
```python
|
||||||
|
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
|
||||||
|
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
def sinusoidal_position_encoding(seq_len, d_model):
|
||||||
|
pe = torch.zeros(seq_len, d_model)
|
||||||
|
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
||||||
|
(-math.log(10000.0) / d_model))
|
||||||
|
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
# Usage: Add to input embeddings
|
||||||
|
x = token_embeddings + positional_encoding
|
||||||
|
```
|
||||||
|
|
||||||
|
**Properties:**
|
||||||
|
- Deterministic (no learned parameters)
|
||||||
|
- Extrapolates to unseen lengths (geometric properties)
|
||||||
|
- Relative positions: PE(pos+k) is linear function of PE(pos)
|
||||||
|
|
||||||
|
**When to use:** Variable-length sequences in NLP
|
||||||
|
|
||||||
|
### Strategy 2: Learned Position Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
positions = torch.arange(seq_len, device=x.device)
|
||||||
|
x = token_embeddings + self.pos_embedding(positions)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Properties:**
|
||||||
|
- Learnable (adapts to data)
|
||||||
|
- Cannot extrapolate beyond max_seq_len
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Fixed-length sequences
|
||||||
|
- Vision Transformers (image patches)
|
||||||
|
- When training data covers all positions
|
||||||
|
|
||||||
|
### Strategy 3: Rotary Position Embeddings (RoPE) ⭐
|
||||||
|
|
||||||
|
**Modern approach (2021+):** Rotate Q and K in complex plane
|
||||||
|
|
||||||
|
**Key advantages:**
|
||||||
|
- Encodes **relative** positions naturally
|
||||||
|
- Better long-range decay properties
|
||||||
|
- No addition to embeddings (applied in attention)
|
||||||
|
|
||||||
|
**Used in:** GPT-NeoX, PaLM, LLaMA, LLaMA-2, Mistral
|
||||||
|
|
||||||
|
```python
|
||||||
|
def apply_rotary_pos_emb(x, cos, sin):
|
||||||
|
# x: (batch, num_heads, seq_len, d_k)
|
||||||
|
# Split into even/odd
|
||||||
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||||
|
# Rotate
|
||||||
|
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strategy 4: ALiBi (Attention with Linear Biases) ⭐
|
||||||
|
|
||||||
|
**Simplest modern approach:** Add bias to attention scores (no embeddings!)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Bias matrix: -1 * distance
|
||||||
|
# [[0, -1, -2, -3],
|
||||||
|
# [0, 0, -1, -2],
|
||||||
|
# [0, 0, 0, -1],
|
||||||
|
# [0, 0, 0, 0]]
|
||||||
|
|
||||||
|
scores = Q @ K^T / √d_k + alibi_bias
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key advantages:**
|
||||||
|
- **Best extrapolation** to longer sequences
|
||||||
|
- No positional embeddings (simpler)
|
||||||
|
- Per-head slopes (different decay rates)
|
||||||
|
|
||||||
|
**Used in:** BLOOM
|
||||||
|
|
||||||
|
### Position Encoding Selection Guide
|
||||||
|
|
||||||
|
| Use Case | Recommended | Why |
|
||||||
|
|----------|-------------|-----|
|
||||||
|
| NLP (variable length) | RoPE or ALiBi | Better extrapolation |
|
||||||
|
| NLP (fixed length) | Learned embeddings | Adapts to data |
|
||||||
|
| Vision (ViT) | 2D learned embeddings | Spatial structure |
|
||||||
|
| Long sequences (>2k) | ALiBi | Best extrapolation |
|
||||||
|
| Legacy/compatibility | Sinusoidal | Original Transformer |
|
||||||
|
|
||||||
|
**Modern trend (2023+):** RoPE and ALiBi dominate over sinusoidal
|
||||||
|
|
||||||
|
|
||||||
|
## Part 4: Architecture Variants
|
||||||
|
|
||||||
|
### Variant 1: Encoder-Only (Bidirectional)
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
- Self-attention: Each token attends to **ALL** tokens (past + future)
|
||||||
|
- No masking (bidirectional context)
|
||||||
|
|
||||||
|
**Examples:** BERT, RoBERTa, ELECTRA, DeBERTa
|
||||||
|
|
||||||
|
**Use cases:**
|
||||||
|
- Text classification
|
||||||
|
- Named entity recognition
|
||||||
|
- Question answering (extract span from context)
|
||||||
|
- Sentence embeddings
|
||||||
|
|
||||||
|
**Key property:** Sees full context → Good for **understanding**
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, d_ff, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
EncoderLayer(d_model, num_heads, d_ff)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask) # No causal mask!
|
||||||
|
return x
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, d_ff):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_ff),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d_ff, d_model)
|
||||||
|
)
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
# Self-attention + residual + norm
|
||||||
|
attn_output, _ = self.self_attn(x, mask)
|
||||||
|
x = self.norm1(x + attn_output)
|
||||||
|
|
||||||
|
# Feed-forward + residual + norm
|
||||||
|
ff_output = self.feed_forward(x)
|
||||||
|
x = self.norm2(x + ff_output)
|
||||||
|
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
### Variant 2: Decoder-Only (Autoregressive)
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
- Self-attention with **causal masking**
|
||||||
|
- Each token attends ONLY to past tokens (not future)
|
||||||
|
|
||||||
|
**Causal mask (lower triangular):**
|
||||||
|
```python
|
||||||
|
# mask[i, j] = 1 if j <= i else 0
|
||||||
|
[[1, 0, 0, 0], # Token 0 sees only itself
|
||||||
|
[1, 1, 0, 0], # Token 1 sees tokens 0-1
|
||||||
|
[1, 1, 1, 0], # Token 2 sees tokens 0-2
|
||||||
|
[1, 1, 1, 1]] # Token 3 sees all
|
||||||
|
```
|
||||||
|
|
||||||
|
**Examples:** GPT, GPT-2, GPT-3, GPT-4, LLaMA, Mistral
|
||||||
|
|
||||||
|
**Use cases:**
|
||||||
|
- Text generation
|
||||||
|
- Language modeling
|
||||||
|
- Code generation
|
||||||
|
- Autoregressive prediction
|
||||||
|
|
||||||
|
**Key property:** Generates sequentially → Good for **generation**
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
def create_causal_mask(seq_len, device):
|
||||||
|
# Lower triangular matrix
|
||||||
|
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, d_ff, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
DecoderLayer(d_model, num_heads, d_ff)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
seq_len = x.size(1)
|
||||||
|
causal_mask = create_causal_mask(seq_len, x.device)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, causal_mask) # Apply causal mask!
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modern trend (2023+):** Decoder-only architectures dominate
|
||||||
|
- Can do both generation AND understanding (via prompting)
|
||||||
|
- Simpler than encoder-decoder (no cross-attention)
|
||||||
|
- Scales better to massive sizes
|
||||||
|
|
||||||
|
### Variant 3: Encoder-Decoder (Seq2Seq)
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
- **Encoder**: Bidirectional self-attention (understands input)
|
||||||
|
- **Decoder**: Causal self-attention (generates output)
|
||||||
|
- **Cross-attention**: Decoder queries encoder outputs
|
||||||
|
|
||||||
|
**Cross-attention mechanism:**
|
||||||
|
```python
|
||||||
|
# Q from decoder, K and V from encoder
|
||||||
|
Q = decoder_hidden @ W_q
|
||||||
|
K = encoder_output @ W_k
|
||||||
|
V = encoder_output @ W_v
|
||||||
|
|
||||||
|
cross_attn = softmax(Q K^T / √d_k) V
|
||||||
|
```
|
||||||
|
|
||||||
|
**Examples:** T5, BART, mT5, original Transformer (2017)
|
||||||
|
|
||||||
|
**Use cases:**
|
||||||
|
- Translation (input ≠ output language)
|
||||||
|
- Summarization (long input → short output)
|
||||||
|
- Question answering (generate answer, not extract)
|
||||||
|
|
||||||
|
**When to use:** Input and output are fundamentally different
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
```python
|
||||||
|
class EncoderDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, d_ff):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
||||||
|
self.cross_attn = MultiHeadAttention(d_model, num_heads) # NEW!
|
||||||
|
self.feed_forward = FeedForward(d_model, d_ff)
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.norm3 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, decoder_input, encoder_output, causal_mask=None):
|
||||||
|
# 1. Self-attention (causal)
|
||||||
|
self_attn_out, _ = self.self_attn(decoder_input, causal_mask)
|
||||||
|
x = self.norm1(decoder_input + self_attn_out)
|
||||||
|
|
||||||
|
# 2. Cross-attention (Q from decoder, K/V from encoder)
|
||||||
|
cross_attn_out, _ = self.cross_attn.forward_cross(
|
||||||
|
query=x,
|
||||||
|
key=encoder_output,
|
||||||
|
value=encoder_output
|
||||||
|
)
|
||||||
|
x = self.norm2(x + cross_attn_out)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
ff_out = self.feed_forward(x)
|
||||||
|
x = self.norm3(x + ff_out)
|
||||||
|
|
||||||
|
return x
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Selection Guide
|
||||||
|
|
||||||
|
| Task | Architecture | Why |
|
||||||
|
|------|--------------|-----|
|
||||||
|
| Classification | Encoder-only | Need full bidirectional context |
|
||||||
|
| Text generation | Decoder-only | Autoregressive generation |
|
||||||
|
| Translation | Encoder-decoder or Decoder-only | Different languages, or use prompting |
|
||||||
|
| Summarization | Encoder-decoder or Decoder-only | Length mismatch, or use prompting |
|
||||||
|
| Q&A (extract) | Encoder-only | Find span in context |
|
||||||
|
| Q&A (generate) | Decoder-only | Generate freeform answer |
|
||||||
|
|
||||||
|
**2023+ trend:** Decoder-only can do everything via prompting (but less parameter-efficient for some tasks)
|
||||||
|
|
||||||
|
|
||||||
|
## Part 5: Vision Transformers (ViT)
|
||||||
|
|
||||||
|
### From Images to Sequences
|
||||||
|
|
||||||
|
**Key insight:** Treat image as sequence of patches
|
||||||
|
|
||||||
|
**Process:**
|
||||||
|
1. Split image into patches (e.g., 16×16 pixels)
|
||||||
|
2. Flatten each patch → 1D vector
|
||||||
|
3. Linear projection → token embeddings
|
||||||
|
4. Add 2D positional embeddings
|
||||||
|
5. Prepend [CLS] token (for classification)
|
||||||
|
6. Feed to Transformer encoder
|
||||||
|
|
||||||
|
**Example:** 224×224 image, 16×16 patches
|
||||||
|
- Number of patches: (224/16)² = 196
|
||||||
|
- Each patch: 16 × 16 × 3 = 768 dimensions
|
||||||
|
- Transformer input: 197 tokens (196 patches + 1 [CLS])
|
||||||
|
|
||||||
|
### ViT Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||||||
|
d_model=768, num_heads=12, num_layers=12, num_classes=1000):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
num_patches = (img_size // patch_size) ** 2
|
||||||
|
patch_dim = in_channels * patch_size ** 2
|
||||||
|
|
||||||
|
# Patch embedding (linear projection of flattened patches)
|
||||||
|
self.patch_embed = nn.Linear(patch_dim, d_model)
|
||||||
|
|
||||||
|
# [CLS] token (learnable)
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
|
||||||
|
|
||||||
|
# Position embeddings (learnable)
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
|
||||||
|
|
||||||
|
# Transformer encoder
|
||||||
|
self.encoder = TransformerEncoder(d_model, num_heads,
|
||||||
|
d_ff=4*d_model, num_layers=num_layers)
|
||||||
|
|
||||||
|
# Classification head
|
||||||
|
self.head = nn.Linear(d_model, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, channels, height, width)
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# Divide into patches and flatten
|
||||||
|
x = x.unfold(2, self.patch_size, self.patch_size)
|
||||||
|
x = x.unfold(3, self.patch_size, self.patch_size)
|
||||||
|
# (batch, channels, num_patches_h, num_patches_w, patch_size, patch_size)
|
||||||
|
|
||||||
|
x = x.contiguous().view(batch_size, -1, self.patch_size ** 2 * 3)
|
||||||
|
# (batch, num_patches, patch_dim)
|
||||||
|
|
||||||
|
# Linear projection
|
||||||
|
x = self.patch_embed(x) # (batch, num_patches, d_model)
|
||||||
|
|
||||||
|
# Prepend [CLS] token
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
|
x = torch.cat([cls_tokens, x], dim=1) # (batch, num_patches+1, d_model)
|
||||||
|
|
||||||
|
# Add positional embeddings
|
||||||
|
x = x + self.pos_embed
|
||||||
|
|
||||||
|
# Transformer encoder
|
||||||
|
x = self.encoder(x)
|
||||||
|
|
||||||
|
# Classification: Use [CLS] token
|
||||||
|
cls_output = x[:, 0] # (batch, d_model)
|
||||||
|
logits = self.head(cls_output)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
```
|
||||||
|
|
||||||
|
### ViT vs CNN: Critical Differences
|
||||||
|
|
||||||
|
**1. Inductive Bias**
|
||||||
|
|
||||||
|
| Property | CNN | ViT |
|
||||||
|
|----------|-----|-----|
|
||||||
|
| Locality | Strong (conv kernel) | Weak (global attention) |
|
||||||
|
| Translation invariance | Strong (weight sharing) | Weak (position embeddings) |
|
||||||
|
| Hierarchy | Strong (pooling layers) | None (flat patches) |
|
||||||
|
|
||||||
|
**Implication:** CNN has strong priors, ViT learns from data
|
||||||
|
|
||||||
|
**2. Data Requirements**
|
||||||
|
|
||||||
|
| Dataset Size | CNN | ViT (from scratch) | ViT (pretrained) |
|
||||||
|
|--------------|-----|-------------------|------------------|
|
||||||
|
| Small (< 100k) | ✅ Good | ❌ Fails | ✅ Good |
|
||||||
|
| Medium (100k-1M) | ✅ Excellent | ⚠️ Poor | ✅ Good |
|
||||||
|
| Large (> 1M) | ✅ Excellent | ⚠️ OK | ✅ Excellent |
|
||||||
|
| Huge (> 100M) | ✅ Excellent | ✅ SOTA | N/A |
|
||||||
|
|
||||||
|
**Key finding:** ViT needs 100M+ images to train from scratch
|
||||||
|
- Original ViT: Trained on JFT-300M (300 million images)
|
||||||
|
- Without massive data, ViT underperforms CNNs significantly
|
||||||
|
|
||||||
|
**3. Computational Cost**
|
||||||
|
|
||||||
|
**Example: 224×224 images**
|
||||||
|
|
||||||
|
| Model | Parameters | GFLOPs | Inference (GPU) |
|
||||||
|
|-------|-----------|--------|-----------------|
|
||||||
|
| ResNet-50 | 25M | 4.1 | ~30ms |
|
||||||
|
| EfficientNet-B0 | 5M | 0.4 | ~10ms |
|
||||||
|
| ViT-B/16 | 86M | 17.6 | ~100ms |
|
||||||
|
|
||||||
|
**Implication:** ViT is 40x more expensive than EfficientNet!
|
||||||
|
|
||||||
|
### When to Use ViT
|
||||||
|
|
||||||
|
✅ **Use ViT when:**
|
||||||
|
- Large dataset (> 1M images) OR using pretrained weights
|
||||||
|
- Computational cost acceptable (cloud, large GPU)
|
||||||
|
- Best possible accuracy needed
|
||||||
|
- Can fine-tune from ImageNet-21k checkpoint
|
||||||
|
|
||||||
|
❌ **Use CNN when:**
|
||||||
|
- Small/medium dataset (< 1M images) and training from scratch
|
||||||
|
- Limited compute/memory
|
||||||
|
- Edge deployment (mobile, embedded)
|
||||||
|
- Need architectural inductive biases
|
||||||
|
|
||||||
|
### Hybrid Approaches (2022-2023)
|
||||||
|
|
||||||
|
**ConvNeXt:** CNN with ViT design choices
|
||||||
|
- Matches ViT accuracy with CNN efficiency
|
||||||
|
- Works better on small datasets
|
||||||
|
|
||||||
|
**Swin Transformer:** Hierarchical ViT with local windows
|
||||||
|
- Shifted windows for efficiency
|
||||||
|
- O(n) complexity instead of O(n²)
|
||||||
|
- Better for dense prediction (segmentation)
|
||||||
|
|
||||||
|
**CoAtNet:** Mix conv layers (early) + Transformer layers (late)
|
||||||
|
- Gets both inductive bias and global attention
|
||||||
|
|
||||||
|
|
||||||
|
## Part 6: Implementation Checklist
|
||||||
|
|
||||||
|
### Critical Details
|
||||||
|
|
||||||
|
**1. Layer Norm Placement**
|
||||||
|
|
||||||
|
**Post-norm (original):**
|
||||||
|
```python
|
||||||
|
x = x + self_attn(x)
|
||||||
|
x = layer_norm(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pre-norm (modern, recommended):**
|
||||||
|
```python
|
||||||
|
x = x + self_attn(layer_norm(x))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why pre-norm?** More stable training, less sensitive to learning rate
|
||||||
|
|
||||||
|
**2. Attention Dropout**
|
||||||
|
|
||||||
|
Apply dropout to **attention weights**, not Q/K/V!
|
||||||
|
|
||||||
|
```python
|
||||||
|
attn_weights = F.softmax(scores, dim=-1)
|
||||||
|
attn_weights = F.dropout(attn_weights, p=0.1, training=self.training) # HERE!
|
||||||
|
output = torch.matmul(attn_weights, V)
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Feed-Forward Dimension**
|
||||||
|
|
||||||
|
Typically: d_ff = 4 × d_model
|
||||||
|
- BERT: d_model=768, d_ff=3072
|
||||||
|
- GPT-2: d_model=768, d_ff=3072
|
||||||
|
|
||||||
|
**4. Residual Connections**
|
||||||
|
|
||||||
|
ALWAYS use residual connections (essential for training)!
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = x + self_attn(x) # Residual
|
||||||
|
x = x + feed_forward(x) # Residual
|
||||||
|
```
|
||||||
|
|
||||||
|
**5. Initialization**
|
||||||
|
|
||||||
|
Use Xavier/Glorot initialization for attention weights:
|
||||||
|
```python
|
||||||
|
nn.init.xavier_uniform_(self.W_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.W_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.W_v.weight)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Part 7: When NOT to Use Transformers
|
||||||
|
|
||||||
|
### Limitation 1: Small Datasets
|
||||||
|
|
||||||
|
**Problem:** Transformers have weak inductive bias (learn from data)
|
||||||
|
|
||||||
|
**Impact:**
|
||||||
|
- ViT: Fails on < 100k images without pretraining
|
||||||
|
- NLP: BERT needs 100M+ tokens for pretraining
|
||||||
|
|
||||||
|
**Solution:** Use models with stronger priors (CNN for vision, smaller models for text)
|
||||||
|
|
||||||
|
### Limitation 2: Long Sequences
|
||||||
|
|
||||||
|
**Problem:** O(n²) memory complexity
|
||||||
|
|
||||||
|
**Impact:**
|
||||||
|
- Standard Transformer: n=10k → 100M attention scores
|
||||||
|
- GPU memory: 10k² × 4 bytes = 400MB per sample!
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- Sparse attention (Longformer, BigBird)
|
||||||
|
- Linear attention (Linformer, Performer)
|
||||||
|
- Flash Attention (memory-efficient kernel)
|
||||||
|
- State space models (S4, Mamba)
|
||||||
|
|
||||||
|
### Limitation 3: Edge Deployment
|
||||||
|
|
||||||
|
**Problem:** Large model size, high latency
|
||||||
|
|
||||||
|
**Impact:**
|
||||||
|
- ViT-B: 86M parameters, ~100ms inference
|
||||||
|
- Mobile/embedded: Need < 10M parameters, < 50ms
|
||||||
|
|
||||||
|
**Solution:** Efficient CNNs (MobileNet, EfficientNet) or distilled models
|
||||||
|
|
||||||
|
### Limitation 4: Real-Time Processing
|
||||||
|
|
||||||
|
**Problem:** Sequential generation in decoder (cannot parallelize at inference)
|
||||||
|
|
||||||
|
**Impact:** GPT-style models generate one token at a time
|
||||||
|
|
||||||
|
**Solution:** Non-autoregressive models, speculative decoding, or smaller models
|
||||||
|
|
||||||
|
|
||||||
|
## Part 8: Common Mistakes
|
||||||
|
|
||||||
|
### Mistake 1: Forgetting Causal Mask
|
||||||
|
|
||||||
|
**Symptom:** Decoder "cheats" by seeing future tokens
|
||||||
|
|
||||||
|
**Fix:** Always apply causal mask to decoder self-attention!
|
||||||
|
|
||||||
|
```python
|
||||||
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 2: Wrong Dimension for Multi-Head
|
||||||
|
|
||||||
|
**Symptom:** Runtime error or dimension mismatch
|
||||||
|
|
||||||
|
**Fix:** Ensure d_model % num_heads == 0
|
||||||
|
|
||||||
|
```python
|
||||||
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 3: Forgetting Position Encoding
|
||||||
|
|
||||||
|
**Symptom:** Model ignores word order
|
||||||
|
|
||||||
|
**Fix:** Always add position information!
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = token_embeddings + positional_encoding
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 4: Wrong Softmax Dimension
|
||||||
|
|
||||||
|
**Symptom:** Attention weights don't sum to 1 per query
|
||||||
|
|
||||||
|
**Fix:** Softmax over last dimension (keys)
|
||||||
|
|
||||||
|
```python
|
||||||
|
attn_weights = F.softmax(scores, dim=-1) # Sum over keys for each query
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistake 5: No Residual Connections
|
||||||
|
|
||||||
|
**Symptom:** Training diverges or converges very slowly
|
||||||
|
|
||||||
|
**Fix:** Always add residual connections!
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = x + self_attn(x)
|
||||||
|
x = x + feed_forward(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Summary: Quick Reference
|
||||||
|
|
||||||
|
### Architecture Selection
|
||||||
|
|
||||||
|
```
|
||||||
|
Classification/Understanding → Encoder-only (BERT-style)
|
||||||
|
Generation/Autoregressive → Decoder-only (GPT-style)
|
||||||
|
Seq2Seq (input ≠ output) → Encoder-decoder (T5-style) or Decoder-only with prompting
|
||||||
|
```
|
||||||
|
|
||||||
|
### Position Encoding Selection
|
||||||
|
|
||||||
|
```
|
||||||
|
NLP (variable length) → RoPE or ALiBi
|
||||||
|
NLP (fixed length) → Learned embeddings
|
||||||
|
Vision (ViT) → 2D learned embeddings
|
||||||
|
Long sequences (> 2k) → ALiBi (best extrapolation)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Head Configuration
|
||||||
|
|
||||||
|
```
|
||||||
|
Small models (d_model < 512): 4-8 heads
|
||||||
|
Medium models (d_model 512-1024): 8-12 heads
|
||||||
|
Large models (d_model > 1024): 12-32 heads
|
||||||
|
Rule: d_k (head dimension) should be 64-128
|
||||||
|
```
|
||||||
|
|
||||||
|
### ViT vs CNN
|
||||||
|
|
||||||
|
```
|
||||||
|
ViT: Large dataset (> 1M) OR pretrained weights
|
||||||
|
CNN: Small dataset (< 1M) OR edge deployment
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementation Essentials
|
||||||
|
|
||||||
|
```
|
||||||
|
✅ Pre-norm (more stable than post-norm)
|
||||||
|
✅ Residual connections (essential!)
|
||||||
|
✅ Causal mask for decoder
|
||||||
|
✅ Attention dropout (on weights, not Q/K/V)
|
||||||
|
✅ d_ff = 4 × d_model (feed-forward dimension)
|
||||||
|
✅ Check: d_model % num_heads == 0
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
After mastering this skill:
|
||||||
|
- `attention-mechanisms-catalog`: Explore attention variants (sparse, linear, Flash)
|
||||||
|
- `llm-specialist/llm-finetuning-strategies`: Apply to language models
|
||||||
|
- `architecture-design-principles`: Understand design trade-offs
|
||||||
|
|
||||||
|
**Remember:** Transformers are NOT magic. Understanding the mechanism (information retrieval via Q/K/V) beats cargo-culting implementations.
|
||||||
Reference in New Issue
Block a user