From 955d5c6743c203021be86386ca4ea63230e60550 Mon Sep 17 00:00:00 2001 From: Zhongwei Li Date: Sun, 30 Nov 2025 09:00:00 +0800 Subject: [PATCH] Initial commit --- .claude-plugin/plugin.json | 12 + README.md | 3 + plugin.lock.json | 77 ++ skills/using-neural-architectures/SKILL.md | 496 +++++++++ .../architecture-design-principles.md | 960 ++++++++++++++++++ .../attention-mechanisms-catalog.md | 824 +++++++++++++++ .../cnn-families-and-selection.md | 622 ++++++++++++ .../generative-model-families.md | 811 +++++++++++++++ .../graph-neural-networks-basics.md | 625 ++++++++++++ .../normalization-techniques.md | 915 +++++++++++++++++ .../sequence-models-comparison.md | 714 +++++++++++++ .../transformer-architecture-deepdive.md | 937 +++++++++++++++++ 12 files changed, 6996 insertions(+) create mode 100644 .claude-plugin/plugin.json create mode 100644 README.md create mode 100644 plugin.lock.json create mode 100644 skills/using-neural-architectures/SKILL.md create mode 100644 skills/using-neural-architectures/architecture-design-principles.md create mode 100644 skills/using-neural-architectures/attention-mechanisms-catalog.md create mode 100644 skills/using-neural-architectures/cnn-families-and-selection.md create mode 100644 skills/using-neural-architectures/generative-model-families.md create mode 100644 skills/using-neural-architectures/graph-neural-networks-basics.md create mode 100644 skills/using-neural-architectures/normalization-techniques.md create mode 100644 skills/using-neural-architectures/sequence-models-comparison.md create mode 100644 skills/using-neural-architectures/transformer-architecture-deepdive.md diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json new file mode 100644 index 0000000..4d6349b --- /dev/null +++ b/.claude-plugin/plugin.json @@ -0,0 +1,12 @@ +{ + "name": "yzmir-neural-architectures", + "description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills", + "version": "1.0.1", + "author": { + "name": "tachyon-beep", + "url": "https://github.com/tachyon-beep" + }, + "skills": [ + "./skills" + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..fb420ee --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# yzmir-neural-architectures + +Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills diff --git a/plugin.lock.json b/plugin.lock.json new file mode 100644 index 0000000..c798630 --- /dev/null +++ b/plugin.lock.json @@ -0,0 +1,77 @@ +{ + "$schema": "internal://schemas/plugin.lock.v1.json", + "pluginId": "gh:tachyon-beep/skillpacks:plugins/yzmir-neural-architectures", + "normalized": { + "repo": null, + "ref": "refs/tags/v20251128.0", + "commit": "52681919f25de737a14d958d80b685171797dfdb", + "treeHash": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc", + "generatedAt": "2025-11-28T10:28:34.235937Z", + "toolVersion": "publish_plugins.py@0.2.0" + }, + "origin": { + "remote": "git@github.com:zhongweili/42plugin-data.git", + "branch": "master", + "commit": "aa1497ed0949fd50e99e70d6324a29c5b34f9390", + "repoRoot": "/Users/zhongweili/projects/openmind/42plugin-data" + }, + "manifest": { + "name": "yzmir-neural-architectures", + "description": "Neural architectures - CNNs, Transformers, RNNs, selection guidance - 9 skills", + "version": "1.0.1" + }, + "content": { + "files": [ + { + "path": "README.md", + "sha256": "cf12418cc5cc226683c463dca6c54e98f97f2ddf57482d593839f430acc36b86" + }, + { + "path": ".claude-plugin/plugin.json", + "sha256": "9893529f120b1582ad62d899531401988bdf89d6b221e057239e09399d47e255" + }, + { + "path": "skills/using-neural-architectures/architecture-design-principles.md", + "sha256": "bd1cde3e2cf997b6e42dc52bc28638e644ee387c81b07464edc7fb4cf33d4329" + }, + { + "path": "skills/using-neural-architectures/sequence-models-comparison.md", + "sha256": "5c041c08261cf04e47edd3ec092db04a8b4c725fb6bb91620df54005881cbb66" + }, + { + "path": "skills/using-neural-architectures/normalization-techniques.md", + "sha256": "c87c82e926951a963d7909321c60a62b5cf6d09dbcbb576472eda46d72993766" + }, + { + "path": "skills/using-neural-architectures/graph-neural-networks-basics.md", + "sha256": "85a968972f0c1c226d6bcb155b231560af0eab52326f419d308b1a6a22ebd7c5" + }, + { + "path": "skills/using-neural-architectures/generative-model-families.md", + "sha256": "7c527707a0d3ecab99e83079880b1166c4310de62a743823ae747a03545180b3" + }, + { + "path": "skills/using-neural-architectures/cnn-families-and-selection.md", + "sha256": "acce2a55a84012259119e70d3eb035aee08beb124d0df1aab009ae7fdd01ff51" + }, + { + "path": "skills/using-neural-architectures/SKILL.md", + "sha256": "b8dc5fac750ca1740bd97591409701ca6162e0d5dc5faaf7a2cfcf7feef432b3" + }, + { + "path": "skills/using-neural-architectures/attention-mechanisms-catalog.md", + "sha256": "3eab5267c7e5a1263d08be2dca4ef236ad66b554e08aec0e0cd85d5c1f8b5500" + }, + { + "path": "skills/using-neural-architectures/transformer-architecture-deepdive.md", + "sha256": "1913626a82ff02f638d10aa00644fcc59fc9836329dc152a7becd9aa030e0937" + } + ], + "dirSha256": "fb6fcb8194fb2cba8912c22cc1b73554c78da695cf2417a0331443a4d8e719bc" + }, + "security": { + "scannedAt": null, + "scannerVersion": null, + "flags": [] + } +} \ No newline at end of file diff --git a/skills/using-neural-architectures/SKILL.md b/skills/using-neural-architectures/SKILL.md new file mode 100644 index 0000000..189f7ee --- /dev/null +++ b/skills/using-neural-architectures/SKILL.md @@ -0,0 +1,496 @@ +--- +name: using-neural-architectures +description: Architecture selection router: CNNs, Transformers, RNNs, GANs, GNNs by data modality and constraints +mode: true +pack: neural-architectures +faction: yzmir +--- + +# Using Neural Architectures: Architecture Selection Router + + +Architecture selection comes BEFORE training optimization. Wrong architecture = no amount of training will fix it. + +This meta-skill routes you to the right architecture guidance based on: +- Data modality (images, sequences, graphs, etc.) +- Problem type (classification, generation, regression) +- Constraints (data size, compute, latency, interpretability) + +Load this skill when architecture decisions are needed. + + +## When to Use This Skill + +Use this skill when: +- ✅ Selecting an architecture for a new problem +- ✅ Comparing architecture families (CNN vs Transformer, RNN vs Transformer, etc.) +- ✅ Designing custom network topology +- ✅ Troubleshooting architectural instability (deep networks, gradient issues) +- ✅ Understanding when to use specialized architectures (GNNs, generative models) + +DO NOT use for: +- ❌ Training/optimization issues (use training-optimization pack) +- ❌ PyTorch implementation details (use pytorch-engineering pack) +- ❌ Production deployment (use ml-production pack) + +**When in doubt:** If choosing WHAT architecture → this skill. If training/deploying architecture → different pack. + +--- + +## Core Routing Logic + +### Step 1: Identify Data Modality + +**Question to ask:** "What type of data are you working with?" + +| Data Type | Route To | Why | +|-----------|----------|-----| +| Images (photos, medical scans, etc.) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | CNNs excel at spatial hierarchies | +| Sequences (time series, text, audio) | [sequence-models-comparison.md](sequence-models-comparison.md) | Temporal dependencies need sequential models | +| Graphs (social networks, molecules) | [graph-neural-networks-basics.md](graph-neural-networks-basics.md) | Graph structure requires GNNs | +| Generation task (create images, text) | [generative-model-families.md](generative-model-families.md) | Generative models are specialized | +| Multiple modalities (text + images) | [architecture-design-principles.md](architecture-design-principles.md) | Need custom design | +| Unclear / Generic | [architecture-design-principles.md](architecture-design-principles.md) | Start with fundamentals | + +### Step 2: Check for Special Requirements + +**If any of these apply, address FIRST:** + +| Requirement | Route To | Priority | +|-------------|----------|----------| +| Deep network (> 20 layers) unstable | [normalization-techniques.md](normalization-techniques.md) | CRITICAL - fix before continuing | +| Need attention mechanisms | [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) | Specialized component | +| Custom architecture design | [architecture-design-principles.md](architecture-design-principles.md) | Foundation before specifics | +| Transformer-specific question | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | Specialized architecture | + +### Step 3: Consider Problem Characteristics + +**Clarify BEFORE routing:** + +Ask: +- "How large is your dataset?" (Small < 10k, Medium 10k-1M, Large > 1M) +- "What are your computational constraints?" (Edge device, cloud, GPU availability) +- "What are your latency requirements?" (Real-time, batch, offline) +- "Do you need interpretability?" (Clinical, research, production) + +These answers determine architecture appropriateness. + +--- + +## Routing by Data Modality + +### Images → CNN Families + +**Symptoms triggering this route:** +- "classify images" +- "object detection" +- "semantic segmentation" +- "medical imaging" +- "computer vision" + +**Route to:** See [cnn-families-and-selection.md](cnn-families-and-selection.md) for CNN architecture selection and comparison. + +**When to route here:** +- ANY vision task (CNNs are default for spatial data) +- Even if considering Transformers, check CNN families first (often better with less data) + +**Clarifying questions:** +- "Dataset size?" (< 10k → Start with proven CNNs, > 100k → Consider ViT) +- "Deployment target?" (Edge → EfficientNet, Cloud → Anything) +- "Task type?" (Classification → ResNet/EfficientNet, Detection → YOLO/Faster-RCNN) + +--- + +### Sequences → Sequence Models Comparison + +**Symptoms triggering this route:** +- "time series" +- "forecasting" +- "natural language" (NLP) +- "sequential data" +- "temporal patterns" +- "RNN vs LSTM vs Transformer" + +**Route to:** See [sequence-models-comparison.md](sequence-models-comparison.md) for sequential model selection (RNN, LSTM, Transformer, TCN). + +**When to route here:** +- ANY sequential data +- When user asks "RNN vs LSTM" (skill will present modern alternatives) +- Time-dependent patterns + +**Clarifying questions:** +- "Sequence length?" (< 100 → RNN/LSTM/TCN, 100-1000 → Transformer, > 1000 → Sparse Transformers) +- "Latency requirements?" (Real-time → TCN/LSTM, Offline → Transformer) +- "Data volume?" (Small → Simpler models, Large → Transformers) + +**CRITICAL:** Challenge "RNN vs LSTM" premise if they ask. Modern alternatives (Transformers, TCN) often better. + +--- + +### Graphs → Graph Neural Networks + +**Symptoms triggering this route:** +- "social network" +- "molecular structure" +- "knowledge graph" +- "graph data" +- "node classification" +- "link prediction" +- "graph embeddings" + +**Route to:** See [graph-neural-networks-basics.md](graph-neural-networks-basics.md) for GNN architectures and graph learning. + +**When to route here:** +- Data has explicit graph structure (nodes + edges) +- Relational information is important +- Network topology matters + +**Red flag:** If treating graph as tabular data (extracting features and ignoring edges) → WRONG. Route to GNN skill. + +--- + +### Generation → Generative Model Families + +**Symptoms triggering this route:** +- "generate images" +- "synthesize data" +- "GAN vs VAE vs Diffusion" +- "image-to-image translation" +- "style transfer" +- "generative modeling" + +**Route to:** See [generative-model-families.md](generative-model-families.md) for GANs, VAEs, and Diffusion models. + +**When to route here:** +- Goal is to CREATE data, not classify/predict +- Need to sample from distribution +- Data augmentation through generation + +**Clarifying questions:** +- "Use case?" (Real-time game → GAN, Art/research → Diffusion, Fast training → VAE) +- "Quality vs speed?" (Quality → Diffusion, Speed → GAN) +- "Controllability?" (Fine control → StyleGAN/Conditional models) + +**CRITICAL:** Different generative models have VERY different trade-offs. Must clarify requirements. + +--- + +## Routing by Architecture Component + +### Attention Mechanisms + +**Symptoms triggering this route:** +- "when to use attention" +- "self-attention vs cross-attention" +- "attention in CNNs" +- "attention bottleneck" +- "multi-head attention" + +**Route to:** See [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) for attention mechanism selection and design. + +**When to route here:** +- Designing custom architecture that might benefit from attention +- Understanding where attention helps vs hinders +- Comparing attention variants + +**NOT for:** General Transformer questions → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) instead + +--- + +### Transformer Deep Dive + +**Symptoms triggering this route:** +- "how do transformers work" +- "Vision Transformer (ViT)" +- "BERT architecture" +- "positional encoding" +- "transformer blocks" +- "scaling transformers" + +**Route to:** See [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) for Transformer internals and implementation. + +**When to route here:** +- Implementing/customizing transformers +- Understanding transformer internals +- Debugging transformer-specific issues + +**Cross-reference:** +- For sequence models generally → [sequence-models-comparison.md](sequence-models-comparison.md) (includes transformers in context) +- For LLMs specifically → `yzmir/llm-specialist/transformer-for-llms` (LLM-specific transformers) + +--- + +### Normalization Techniques + +**Symptoms triggering this route:** +- "gradient explosion" +- "training instability in deep network" +- "BatchNorm vs LayerNorm" +- "normalization layers" +- "50+ layer network won't train" + +**Route to:** See [normalization-techniques.md](normalization-techniques.md) for deep network stability and normalization methods. + +**When to route here:** +- Deep networks (> 20 layers) with training instability +- Choosing between normalization methods +- Architectural stability issues + +**CRITICAL:** This is often the ROOT CAUSE of "training won't work" - fix architecture before blaming hyperparameters. + +--- + +### Architecture Design Principles + +**Symptoms triggering this route:** +- "how to design architecture" +- "architecture best practices" +- "when to use skip connections" +- "how deep should network be" +- "custom architecture for [novel task]" +- Unclear problem modality + +**Route to:** See [architecture-design-principles.md](architecture-design-principles.md) for custom architecture design fundamentals. + +**When to route here:** +- Designing custom architectures +- Novel problems without established architecture +- Understanding WHY architectures work +- User is unsure what modality/problem type they have + +**This is the foundational skill** - route here if other specific skills don't match. + +--- + +## Multi-Modal / Cross-Pack Routing + +### When Problem Spans Multiple Modalities + +**Example:** "Text + image classification" (multimodal) + +**Route to BOTH:** +1. [sequence-models-comparison.md](sequence-models-comparison.md) (for text) +2. [cnn-families-and-selection.md](cnn-families-and-selection.md) (for images) +3. [architecture-design-principles.md](architecture-design-principles.md) (for fusion strategy) + +**Order matters:** Understand individual modalities BEFORE fusion. + +### When Architecture + Other Concerns + +**Example:** "Select architecture AND optimize training" + +**Route order:** +1. Architecture skill FIRST (this pack) +2. Training-optimization SECOND (after architecture chosen) + +**Why:** Wrong architecture can't be fixed by better training. + +**Example:** "Select architecture AND deploy efficiently" + +**Route order:** +1. Architecture skill FIRST +2. ML-production SECOND (quantization, serving) + +**Deployment constraints might influence architecture choice** - if so, note constraints during architecture selection. + +--- + +## Common Routing Mistakes (DON'T DO THESE) + +| Symptom | Wrong Route | Correct Route | Why | +|---------|-------------|---------------|-----| +| "My transformer won't train" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | training-optimization | Training issue, not architecture understanding | +| "Deploy image classifier" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | ml-production | Deployment, not selection | +| "ViT vs ResNet for medical imaging" | [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) | [cnn-families-and-selection.md](cnn-families-and-selection.md) | Comparative selection, not single architecture detail | +| "Implement BatchNorm in PyTorch" | [normalization-techniques.md](normalization-techniques.md) | pytorch-engineering | Implementation, not architecture concept | +| "GAN won't converge" | [generative-model-families.md](generative-model-families.md) | training-optimization | Training stability, not architecture selection | +| "Which optimizer for CNN" | [cnn-families-and-selection.md](cnn-families-and-selection.md) | training-optimization | Optimization, not architecture | + +**Rule:** Architecture pack is for CHOOSING and DESIGNING architectures. Training/deployment/implementation are other packs. + +--- + +## Red Flags: Stop and Clarify + +If query contains these patterns, ASK clarifying questions before routing: + +| Pattern | Why Clarify | What to Ask | +|---------|-------------|--------------| +| "Best architecture for X" | "Best" depends on constraints | "What are your data size, compute, and latency constraints?" | +| Generic problem description | Can't route without modality | "What type of data? (images, sequences, graphs, etc.)" | +| Latest trend mentioned (ViT, Diffusion) | Recency bias risk | "Have you considered alternatives? What are your specific requirements?" | +| "Should I use X or Y" | May be wrong question | "What's the underlying problem? There might be option Z." | +| Very deep network (> 50 layers) | Likely needs normalization first | "Are you using normalization layers? Skip connections?" | + +**Never guess modality or constraints. Always clarify.** + +--- + +## Recency Bias: Resistance Table + +| Trendy Architecture | When NOT to Use | Better Alternative | +|---------------------|------------------|-------------------| +| **Vision Transformers (ViT)** | Small datasets (< 10k images) | CNNs (ResNet, EfficientNet) | +| **Vision Transformers (ViT)** | Edge deployment (latency/power) | EfficientNets, MobileNets | +| **Transformers (general)** | Very small datasets | RNNs, CNNs (less capacity, less overfit) | +| **Diffusion Models** | Real-time generation needed | GANs (1 forward pass vs 50-1000 steps) | +| **Diffusion Models** | Limited compute for training | VAEs (faster training) | +| **Graph Transformers** | Small graphs (< 100 nodes) | Standard GNNs (GCN, GAT) simpler and effective | +| **LLMs (GPT-style)** | < 1M tokens of training data | Simpler language models or fine-tuning | + +**Counter-narrative:** "New architecture ≠ better for your use case. Match architecture to constraints." + +--- + +## Decision Tree + +``` +Start here: What's your primary goal? + +┌─ SELECT architecture for task +│ ├─ Data modality? +│ │ ├─ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md) +│ │ ├─ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md) +│ │ ├─ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md) +│ │ ├─ Generation → [generative-model-families.md](generative-model-families.md) +│ │ └─ Unknown/Multiple → [architecture-design-principles.md](architecture-design-principles.md) +│ └─ Special requirements? +│ ├─ Deep network (>20 layers) unstable → [normalization-techniques.md](normalization-techniques.md) (CRITICAL) +│ ├─ Need attention mechanism → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) +│ └─ None → Proceed with modality-based route +│ +├─ UNDERSTAND specific architecture +│ ├─ Transformers → [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) +│ ├─ Attention → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) +│ ├─ Normalization → [normalization-techniques.md](normalization-techniques.md) +│ └─ General principles → [architecture-design-principles.md](architecture-design-principles.md) +│ +├─ DESIGN custom architecture +│ └─ [architecture-design-principles.md](architecture-design-principles.md) (start here always) +│ +└─ COMPARE architectures + ├─ CNNs (ResNet vs EfficientNet) → [cnn-families-and-selection.md](cnn-families-and-selection.md) + ├─ Sequence models (RNN vs Transformer) → [sequence-models-comparison.md](sequence-models-comparison.md) + ├─ Generative (GAN vs Diffusion) → [generative-model-families.md](generative-model-families.md) + └─ General comparison → [architecture-design-principles.md](architecture-design-principles.md) +``` + +--- + +## Workflow + +**Standard Architecture Selection Workflow:** + +``` +1. Clarify Problem + ☐ What data modality? (images, sequences, graphs, etc.) + ☐ What's the task? (classification, generation, regression, etc.) + ☐ Dataset size? + ☐ Computational constraints? + ☐ Latency requirements? + ☐ Interpretability needs? + +2. Route Based on Modality + ☐ Images → [cnn-families-and-selection.md](cnn-families-and-selection.md) + ☐ Sequences → [sequence-models-comparison.md](sequence-models-comparison.md) + ☐ Graphs → [graph-neural-networks-basics.md](graph-neural-networks-basics.md) + ☐ Generation → [generative-model-families.md](generative-model-families.md) + ☐ Custom/Unclear → [architecture-design-principles.md](architecture-design-principles.md) + +3. Check for Critical Issues + ☐ Deep network unstable? → [normalization-techniques.md](normalization-techniques.md) FIRST + ☐ Need specialized component? → [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) or [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) + +4. Apply Architecture Skill + ☐ Follow guidance from routed skill + ☐ Consider trade-offs (accuracy vs speed vs data requirements) + +5. Cross-Pack if Needed + ☐ Architecture chosen → training-optimization (for training) + ☐ Architecture chosen → ml-production (for deployment) +``` + +--- + +## Rationalization Table + +| Rationalization | Reality | Counter | +|-----------------|---------|---------| +| "Transformers are SOTA, recommend them" | SOTA on benchmark ≠ best for user's constraints | "Ask about dataset size and compute first" | +| "User said RNN vs LSTM, answer that" | Question premise might be outdated | "Challenge: Have you considered Transformers or TCN?" | +| "Just recommend latest architecture" | Latest ≠ appropriate | "Match architecture to requirements, not trends" | +| "Architecture doesn't matter, training matters" | Wrong architecture can't be fixed by training | "Architecture is foundation - get it right first" | +| "They seem rushed, skip clarification" | Wrong route wastes more time than clarification | "30 seconds to clarify saves hours of wasted effort" | +| "Generic architecture advice is safe" | Generic = useless for specific domains | "Route to domain-specific skill for actionable guidance" | + +--- + +## Integration with Other Packs + +### After Architecture Selection + +Once architecture is chosen, route to: + +**Training the architecture:** +→ `yzmir/training-optimization/using-training-optimization` +- Optimizer selection +- Learning rate schedules +- Debugging training issues + +**Implementing in PyTorch:** +→ `yzmir/pytorch-engineering/using-pytorch-engineering` +- Module design patterns +- Performance optimization +- Custom components + +**Deploying to production:** +→ `yzmir/ml-production/using-ml-production` +- Model serving +- Quantization +- Inference optimization + +### Before Architecture Selection + +If problem involves: + +**Reinforcement learning:** +→ `yzmir/deep-rl/using-deep-rl` FIRST +- RL algorithms dictate architecture requirements +- Value networks vs policy networks have different needs + +**Large language models:** +→ `yzmir/llm-specialist/using-llm-specialist` FIRST +- LLM architectures are specialized transformers +- Different considerations than general sequence models + +**Architecture is downstream of algorithm choice in RL and LLMs.** + +--- + +## Summary + +**Use this meta-skill to:** +- ✅ Route architecture queries to appropriate specialized skill +- ✅ Identify data modality and problem type +- ✅ Clarify constraints before recommending +- ✅ Resist recency bias (latest ≠ best) +- ✅ Recognize when architecture is the problem (vs training/implementation) + +## Neural Architecture Specialist Skills + +After routing, load the appropriate specialist skill for detailed guidance: + +1. [architecture-design-principles.md](architecture-design-principles.md) - Custom design, architectural best practices, skip connections, network depth fundamentals +2. [attention-mechanisms-catalog.md](attention-mechanisms-catalog.md) - Self-attention, cross-attention, multi-head attention, attention in CNNs, attention variants comparison +3. [cnn-families-and-selection.md](cnn-families-and-selection.md) - ResNet, EfficientNet, MobileNet, YOLO, computer vision architecture selection +4. [generative-model-families.md](generative-model-families.md) - GANs, VAEs, Diffusion models, image generation, style transfer, generative modeling trade-offs +5. [graph-neural-networks-basics.md](graph-neural-networks-basics.md) - GCN, GAT, node classification, link prediction, graph embeddings, molecular structures +6. [normalization-techniques.md](normalization-techniques.md) - BatchNorm, LayerNorm, GroupNorm, training stability for deep networks (>20 layers) +7. [sequence-models-comparison.md](sequence-models-comparison.md) - RNN, LSTM, Transformer, TCN comparison, time series, NLP, sequential data +8. [transformer-architecture-deepdive.md](transformer-architecture-deepdive.md) - Transformer internals, ViT, BERT, positional encoding, scaling transformers + +**Critical principle:** Architecture comes BEFORE training. Get this right first. + +--- + +**END OF SKILL** diff --git a/skills/using-neural-architectures/architecture-design-principles.md b/skills/using-neural-architectures/architecture-design-principles.md new file mode 100644 index 0000000..0510e3a --- /dev/null +++ b/skills/using-neural-architectures/architecture-design-principles.md @@ -0,0 +1,960 @@ + +# Architecture Design Principles + +## Context + +You're designing a neural network architecture or debugging why your network isn't learning. Common mistakes: +- **Ignoring inductive biases**: Using MLP for images (should use CNN) +- **Over-engineering**: Using Transformer for 100 samples (should use linear regression) +- **No skip connections**: 50-layer plain network fails (should use ResNet) +- **Wrong depth-width balance**: 100 layers × 8 channels bottlenecks capacity +- **Ignoring constraints**: 1.5B parameter model doesn't fit 24GB GPU + +**This skill provides principled architecture design: match structure to problem, respect constraints, avoid over-engineering.** + + +## Core Principle: Inductive Biases + +**Inductive bias = assumptions baked into architecture about problem structure** + +**Key insight**: The right inductive bias makes learning dramatically easier. Wrong bias makes learning impossible. + +### What are Inductive Biases? + +```python +# Example: Image classification + +# MLP (no inductive bias): +# - Treats each pixel independently +# - No concept of "spatial locality" or "translation" +# - Must learn from scratch that nearby pixels are related +# - Learns "cat at position (10,10)" and "cat at (50,50)" separately +# Parameters: 150M, Accuracy: 75% + +# CNN (strong inductive bias): +# - Assumes spatial locality (nearby pixels related) +# - Assumes translation invariance (cat is cat anywhere) +# - Shares filters across spatial positions +# - Hierarchical feature learning (edges → textures → objects) +# Parameters: 11M, Accuracy: 95% + +# CNN's inductive bias: 14× fewer parameters, 20% better accuracy! +``` + +**Principle**: Match your architecture's inductive biases to your problem's structure. + + +## Architecture Families and Their Inductive Biases + +### 1. Fully Connected (MLP) + +**Inductive bias:** None (general-purpose) + +**Structure:** +```python +class MLP(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x +``` + +**When to use:** +- ✅ Tabular data (independent features) +- ✅ Small datasets (< 10,000 samples) +- ✅ Baseline / proof of concept + +**When NOT to use:** +- ❌ Images (use CNN) +- ❌ Sequences (use RNN/Transformer) +- ❌ Graphs (use GNN) + +**Strengths:** +- Simple and interpretable +- Fast training +- Works for any input type (flattened) + +**Weaknesses:** +- No structural assumptions (must learn everything from data) +- Parameter explosion (input_size × hidden_size can be huge) +- Doesn't leverage problem structure + + +### 2. Convolutional Neural Networks (CNN) + +**Inductive bias:** Spatial locality + Translation invariance + +**Structure:** +```python +class SimpleCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.pool = nn.MaxPool2d(2, 2) + self.fc = nn.Linear(128 * 7 * 7, 1000) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) # 112×112 + x = self.pool(F.relu(self.conv2(x))) # 56×56 + x = x.view(x.size(0), -1) + x = self.fc(x) + return x +``` + +**Inductive biases:** +1. **Local connectivity**: Neurons see only nearby pixels (spatial locality) +2. **Translation invariance**: Same filter slides across image (parameter sharing) +3. **Hierarchical features**: Stack layers to build complex features from simple ones + +**When to use:** +- ✅ Images (classification, detection, segmentation) +- ✅ Spatial data (maps, medical scans) +- ✅ Any grid-structured data + +**When NOT to use:** +- ❌ Sequences with long-range dependencies (use Transformer) +- ❌ Graphs (irregular structure, use GNN) +- ❌ Tabular data (no spatial structure) + +**Strengths:** +- Parameter efficient (filter sharing) +- Translation invariant (cat anywhere = cat) +- Hierarchical feature learning + +**Weaknesses:** +- Fixed receptive field (limited by kernel size) +- Not suitable for variable-length inputs +- Requires grid structure + + +### 3. Recurrent Neural Networks (RNN/LSTM) + +**Inductive bias:** Temporal dependencies + +**Structure:** +```python +class LSTMModel(nn.Module): + def __init__(self, input_size, hidden_size, num_layers): + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + self.fc = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + # x: (batch, seq_len, input_size) + lstm_out, (h_n, c_n) = self.lstm(x) + # Use last hidden state + output = self.fc(h_n[-1]) + return output +``` + +**Inductive bias:** Sequential processing (earlier elements influence later elements) + +**When to use:** +- ✅ Time series (stock prices, sensor data) +- ✅ Short sequences (< 100 timesteps) +- ✅ Online processing (process one timestep at a time) + +**When NOT to use:** +- ❌ Long sequences (> 1000 timesteps, use Transformer) +- ❌ Non-sequential data (images, tabular) +- ❌ When parallel processing needed (use Transformer) + +**Strengths:** +- Natural for sequential data +- Constant memory (doesn't grow with sequence length) +- Online processing capability + +**Weaknesses:** +- Slow (sequential, can't parallelize) +- Vanishing gradients (long sequences) +- Struggles with long-range dependencies + + +### 4. Transformers + +**Inductive bias:** Minimal (self-attention is general-purpose) + +**Structure:** +```python +class SimpleTransformer(nn.Module): + def __init__(self, d_model, num_heads, num_layers): + super().__init__() + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model, num_heads), + num_layers + ) + self.fc = nn.Linear(d_model, num_classes) + + def forward(self, x): + # x: (batch, seq_len, d_model) + x = self.encoder(x) + # Global average pooling + x = x.mean(dim=1) + return self.fc(x) +``` + +**Inductive bias:** Minimal (learns relationships from data via attention) + +**When to use:** +- ✅ Long sequences (> 100 tokens) +- ✅ Language (text, code) +- ✅ Large datasets (> 100k samples) +- ✅ When relationships are complex and data-dependent + +**When NOT to use:** +- ❌ Small datasets (< 10k samples, use RNN or MLP) +- ❌ Strong structural priors available (images → CNN) +- ❌ Very long sequences (> 16k tokens, use sparse attention) +- ❌ Low-latency requirements (RNN faster) + +**Strengths:** +- Parallel processing (fast training) +- Long-range dependencies (attention) +- State-of-the-art for language + +**Weaknesses:** +- Quadratic complexity O(n²) with sequence length +- Requires large datasets (weak inductive bias) +- High memory usage + + +### 5. Graph Neural Networks (GNN) + +**Inductive bias:** Message passing over graph structure + +**Structure:** +```python +class SimpleGNN(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.conv1 = GCNConv(input_dim, hidden_dim) + self.conv2 = GCNConv(hidden_dim, output_dim) + + def forward(self, x, edge_index): + # x: node features (num_nodes, input_dim) + # edge_index: graph structure (2, num_edges) + x = F.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + return x +``` + +**Inductive bias:** Nodes influenced by neighbors (message passing) + +**When to use:** +- ✅ Graph data (social networks, molecules, knowledge graphs) +- ✅ Irregular connectivity (different # of neighbors per node) +- ✅ Relational reasoning + +**When NOT to use:** +- ❌ Grid data (images → CNN) +- ❌ Sequences (text → Transformer) +- ❌ If graph structure doesn't help (test MLP baseline first!) + +**Strengths:** +- Handles irregular structure +- Permutation invariant +- Natural for relational data + +**Weaknesses:** +- Requires meaningful graph structure +- Over-smoothing (too many layers) +- Scalability challenges (large graphs) + + +## Decision Tree: Architecture Selection + +``` +START +| +├─ Is data grid-structured (images)? +│ ├─ YES → Use CNN +│ │ └─ ResNet (general), EfficientNet (mobile), ViT (very large datasets) +│ └─ NO → Continue +│ +├─ Is data sequential (text, time series)? +│ ├─ YES → Check sequence length +│ │ ├─ < 100 timesteps → LSTM/GRU +│ │ ├─ 100-4000 tokens → Transformer +│ │ └─ > 4000 tokens → Sparse Transformer (Longformer) +│ └─ NO → Continue +│ +├─ Is data graph-structured (molecules, social networks)? +│ ├─ YES → Check if structure helps +│ │ ├─ Test MLP baseline first +│ │ └─ If structure helps → GNN (GCN, GraphSAGE, GAT) +│ └─ NO → Continue +│ +└─ Is data tabular (independent features)? + └─ YES → Start simple + ├─ < 1000 samples → Linear / Ridge regression + ├─ 1000-100k samples → Small MLP (2-3 layers) + └─ > 100k samples → Larger MLP or Gradient Boosting (XGBoost) +``` + + +## Principle: Start Simple, Add Complexity Only When Needed + +**Occam's Razor**: Simplest model that solves the problem is best. + +### Progression: + +```python +# Step 1: Linear baseline (ALWAYS start here!) +model = nn.Linear(input_size, num_classes) +# Train and evaluate + +# Step 2: IF linear insufficient, add small MLP +if linear_accuracy < target: + model = nn.Sequential( + nn.Linear(input_size, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + +# Step 3: IF small MLP insufficient, add depth/width +if mlp_accuracy < target: + model = nn.Sequential( + nn.Linear(input_size, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, num_classes) + ) + +# Step 4: IF simple models fail, use specialized architecture +if simple_models_fail: + # Images → CNN + # Sequences → RNN/Transformer + # Graphs → GNN + +# NEVER skip to Step 4 without testing Step 1-3! +``` + +### Why Start Simple? + +1. **Faster iteration**: Linear model trains in seconds, Transformer in hours +2. **Baseline**: Know if complexity helps (compare complex vs simple) +3. **Occam's Razor**: Simple model generalizes better (less overfitting) +4. **Debugging**: Easy to verify simple model works correctly + +### Example: House Price Prediction + +```python +# Dataset: 1000 samples, 20 features + +# WRONG: Start with Transformer +model = HugeTransformer(20, 512, 6, 1) # 10M parameters +# Result: Overfits (10M params / 1000 samples = 10,000:1 ratio!) + +# RIGHT: Start simple +# Step 1: Linear +model = nn.Linear(20, 1) # 21 parameters +# Trains in 1 second, achieves R² = 0.85 (good!) + +# Conclusion: Linear sufficient, stop here. No need for Transformer! +``` + +**Rule**: Add complexity only when simple models demonstrably fail. + + +## Principle: Deep Networks Need Skip Connections + +**Problem**: Plain networks > 10 layers suffer from vanishing gradients and degradation. + +### Vanishing Gradients: + +```python +# Gradient flow in plain 50-layer network: +gradient_layer_1 = gradient_output × (∂L50/∂L49) × (∂L49/∂L48) × ... × (∂L2/∂L1) + +# Each term < 1 (due to activations): +# If each ≈ 0.9, then: 0.9^50 = 0.0000051 (vanishes!) + +# Result: Early layers don't learn (gradients too small) +``` + +### Degradation: + +```python +# Empirical observation (ResNet paper): +20-layer plain network: 85% accuracy +56-layer plain network: 78% accuracy # WORSE with more layers! + +# This is NOT overfitting (training accuracy also drops) +# This is optimization difficulty +``` + +### Solution: Skip Connections (Residual Networks) + +```python +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) + self.bn1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) + self.bn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + identity = x # Save input + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out = out + identity # Skip connection! + out = F.relu(out) + return out +``` + +**Why skip connections work:** +```python +# Gradient flow with skip connections: +∂loss/∂x = ∂loss/∂out × (1 + ∂F/∂x) +# ↑ +# Always flows! ("+1" term) + +# Even if ∂F/∂x ≈ 0, gradient flows through identity path +``` + +**Results:** +```python +# Without skip connections: +20-layer plain: 85% accuracy +50-layer plain: 78% accuracy (worse!) + +# With skip connections (ResNet): +20-layer ResNet: 87% accuracy +50-layer ResNet: 92% accuracy (better!) +152-layer ResNet: 95% accuracy (even better!) +``` + +**Rule**: For networks > 10 layers, ALWAYS use skip connections. + +### Skip Connection Variants: + +**1. Residual (ResNet):** +```python +out = x + F(x) # Add input to output +``` + +**2. Dense (DenseNet):** +```python +out = torch.cat([x, F(x)], dim=1) # Concatenate input and output +``` + +**3. Highway:** +```python +gate = sigmoid(W_gate @ x) +out = gate * F(x) + (1 - gate) * x # Learned gating +``` + +**Most common**: Residual (simple, effective) + + +## Principle: Balance Depth and Width + +**Depth = # of layers** +**Width = # of channels/neurons per layer** + +### Capacity Formula: + +```python +# Approximate capacity (for CNNs): +capacity ≈ depth × width² + +# Why width²? +# Each layer: input_channels × output_channels × kernel_size² +# Doubling width → 4× parameters per layer +``` + +### Trade-offs: + +**Too deep, too narrow:** +```python +# 100 layers × 8 channels +# Problems: +# - Information bottleneck (8 channels can't represent complex features) +# - Harder to optimize (more layers) +# - Slow inference (100 layers sequential) + +# Example: +model = VeryDeepNarrow(num_layers=100, channels=8) +# Result: 60% accuracy (bottleneck!) +``` + +**Too shallow, too wide:** +```python +# 2 layers × 1024 channels +# Problems: +# - Under-utilizes depth (no hierarchical features) +# - Memory explosion (1024 × 1024 = 1M parameters per layer!) + +# Example: +model = VeryWideShallow(num_layers=2, channels=1024) +# Result: 70% accuracy (doesn't leverage depth) +``` + +**Balanced:** +```python +# 18 layers, gradually increasing width: 64 → 128 → 256 → 512 +# Benefits: +# - Hierarchical features (depth) +# - Sufficient capacity (width) +# - Good optimization (not too deep) + +# Example (ResNet-18): +model = ResNet18() +# Layers: 18, Channels: 64-512 (average ~200) +# Result: 95% accuracy (optimal balance!) +``` + +### Standard Patterns: + +```python +# CNNs: Gradually increase channels as spatial dims decrease +# Input: 224×224×3 +# Layer 1: 224×224×64 (same spatial size, more channels) +# Layer 2: 112×112×128 (half spatial, double channels) +# Layer 3: 56×56×256 (half spatial, double channels) +# Layer 4: 28×28×512 (half spatial, double channels) + +# Why? Compensate for spatial information loss with channel information +``` + +**Rule**: Balance depth and width. Standard pattern: 12-50 layers, 64-512 channels. + + +## Principle: Match Capacity to Data Size + +**Capacity = # of learnable parameters** + +### Parameter Budget: + +```python +# Rule of thumb: parameters should be 0.01-0.1× dataset size + +# Example 1: MNIST (60,000 images) +# Budget: 600 - 6,000 parameters +# Simple CNN: 60,000 parameters (10×) → Works, but might overfit +# LeNet: 60,000 parameters → Classic, works well + +# Example 2: ImageNet (1.2M images) +# Budget: 12,000 - 120,000 parameters +# ResNet-50: 25M parameters (200×) → Works (aggressive augmentation helps) + +# Example 3: Tabular (100 samples, 20 features) +# Budget: 1 - 10 parameters +# Linear: 21 parameters → Perfect fit! +# MLP: 1,000 parameters → Overfits horribly +``` + +### Overfitting Detection: + +```python +# Training accuracy >> Validation accuracy (gap > 5%) +train_acc = 99%, val_acc = 70% # 29% gap → OVERFITTING! + +# Solutions: +# 1. Reduce model capacity (fewer layers/channels) +# 2. Add regularization (dropout, weight decay) +# 3. Collect more data +# 4. Data augmentation + +# Order: Try (1) first (simplest), then (2), then (3)/(4) +``` + +### Underfitting Detection: + +```python +# Training accuracy < target (model too simple) +train_acc = 60%, val_acc = 58% # Both low → UNDERFITTING! + +# Solutions: +# 1. Increase model capacity (more layers/channels) +# 2. Train longer +# 3. Reduce regularization + +# Order: Try (2) first (cheapest), then (1), then (3) +``` + +**Rule**: Match parameters to data size. Start small, increase capacity only if underfitting. + + +## Principle: Design for Compute Constraints + +**Constraints:** +1. **Memory**: Model + gradients + optimizer states < GPU VRAM +2. **Latency**: Inference time < requirement (e.g., < 100ms for real-time) +3. **Throughput**: Samples/second > requirement + +### Memory Budget: + +```python +# Memory calculation (training): +# 1. Model parameters (FP32): params × 4 bytes +# 2. Gradients: params × 4 bytes +# 3. Optimizer states (Adam): params × 8 bytes (2× weights) +# 4. Activations: batch_size × feature_maps × spatial_size × 4 bytes + +# Example: ResNet-50 +params = 25M +memory_params = 25M × 4 = 100 MB +memory_gradients = 100 MB +memory_optimizer = 200 MB +memory_activations = batch_size × 64 × 7×7 × 4 ≈ batch_size × 12 KB + +# Total (batch=32): 100 + 100 + 200 + 0.4 = 400 MB +# Fits easily on 4GB GPU! + +# Example: GPT-3 (175B parameters) +memory_params = 175B × 4 = 700 GB +memory_total = 700 + 700 + 1400 = 2800 GB = 2.8 TB! +# Requires 35×A100 (80GB each) +``` + +**Rule**: Calculate memory before training. Don't design models that don't fit. + +### Latency Budget: + +```python +# Inference latency = # operations / throughput + +# Example: Mobile app (< 100ms latency requirement) + +# ResNet-50: +# Operations: 4B FLOPs +# Mobile CPU: 10 GFLOPS +# Latency: 4B / 10G = 0.4 seconds (FAILS!) + +# MobileNetV2: +# Operations: 300M FLOPs +# Mobile CPU: 10 GFLOPS +# Latency: 300M / 10G = 0.03 seconds = 30ms (PASSES!) + +# Solution: Use efficient architectures (MobileNet, EfficientNet) for mobile +``` + +**Rule**: Measure latency. Use efficient architectures if latency-constrained. + + +## Common Architectural Patterns + +### 1. Bottleneck (ResNet) + +**Structure:** +```python +# Standard: 3×3 conv (256 channels) → 3×3 conv (256 channels) +# Parameters: 256 × 256 × 3 × 3 = 590K + +# Bottleneck: 1×1 (256→64) → 3×3 (64→64) → 1×1 (64→256) +# Parameters: 256×64 + 64×64×3×3 + 64×256 = 16K + 37K + 16K = 69K +# Reduction: 590K → 69K (8.5× fewer!) +``` + +**Purpose**: Reduce parameters while maintaining capacity + +**When to use**: Deep networks (> 50 layers) where parameters are a concern + +### 2. Inverted Bottleneck (MobileNetV2) + +**Structure:** +```python +# Bottleneck (ResNet): Wide → Narrow → Wide (256 → 64 → 256) +# Inverted: Narrow → Wide → Narrow (64 → 256 → 64) + +# Why? Efficient for mobile (depthwise separable convolutions) +``` + +**Purpose**: Maximize efficiency (FLOPs per parameter) + +**When to use**: Mobile/edge deployment + +### 3. Multi-scale Features (Inception) + +**Structure:** +```python +# Parallel branches with different kernel sizes: +# Branch 1: 1×1 conv +# Branch 2: 3×3 conv +# Branch 3: 5×5 conv +# Branch 4: 3×3 max pool +# Concatenate all branches + +# Captures features at multiple scales simultaneously +``` + +**Purpose**: Capture multi-scale patterns + +**When to use**: When features exist at multiple scales (object detection) + +### 4. Attention (Transformers, SE-Net) + +**Structure:** +```python +# Squeeze-and-Excitation (SE) block: +# 1. Global average pooling (spatial → channel descriptor) +# 2. FC layer (bottleneck) +# 3. FC layer (restore channels) +# 4. Sigmoid (attention weights) +# 5. Multiply input channels by attention weights + +# Result: Emphasize important channels, suppress irrelevant +``` + +**Purpose**: Learn importance of features (channels or positions) + +**When to use**: When not all features equally important + + +## Debugging Architectures + +### Problem 1: Network doesn't learn (loss stays constant) + +**Diagnosis:** +```python +# Check gradient flow +for name, param in model.named_parameters(): + if param.grad is not None: + print(f"{name}: grad_mean={param.grad.mean():.6f}, grad_std={param.grad.std():.6f}") + +# Vanishing: grad_mean ≈ 0, grad_std ≈ 0 → Add skip connections +# Exploding: grad_mean > 1, grad_std > 1 → Gradient clipping or lower LR +``` + +**Solutions:** +- Add skip connections (ResNet) +- Check initialization (Xavier or He initialization) +- Lower learning rate +- Check data preprocessing (normalized inputs?) + +### Problem 2: Overfitting (train >> val) + +**Diagnosis:** +```python +train_acc = 99%, val_acc = 70% # 29% gap → Overfitting + +# Check parameter/data ratio: +num_params = sum(p.numel() for p in model.parameters()) +data_size = len(train_dataset) +ratio = num_params / data_size + +# If ratio > 1: Model has more parameters than data points! +``` + +**Solutions (in order):** +1. Reduce capacity (fewer layers/channels) +2. Add dropout / weight decay +3. Data augmentation +4. Collect more data + +### Problem 3: Underfitting (train and val both low) + +**Diagnosis:** +```python +train_acc = 65%, val_acc = 63% # Both low → Underfitting + +# Model too simple for task complexity +``` + +**Solutions (in order):** +1. Train longer (more epochs) +2. Increase capacity (more layers/channels) +3. Reduce regularization (lower dropout/weight decay) +4. Check learning rate (too low?) + +### Problem 4: Slow training + +**Diagnosis:** +```python +# Profile forward/backward pass +import time + +start = time.time() +loss = criterion(model(inputs), targets) +forward_time = time.time() - start + +start = time.time() +loss.backward() +backward_time = time.time() - start + +# If backward_time >> forward_time: Gradient computation bottleneck +``` + +**Solutions:** +- Use mixed precision (FP16) +- Reduce batch size (if memory-bound) +- Use gradient accumulation (simulate large batch) +- Simplify architecture (fewer layers) + + +## Design Checklist + +Before finalizing an architecture: + +### ☐ Match inductive bias to problem +- Images → CNN +- Sequences → RNN/Transformer +- Graphs → GNN +- Tabular → MLP + +### ☐ Start simple, add complexity only when needed +- Test linear baseline first +- Add complexity incrementally +- Compare performance at each step + +### ☐ Use skip connections for deep networks (> 10 layers) +- ResNet for CNNs +- Pre-norm for Transformers +- Gradient flow is critical + +### ☐ Balance depth and width +- Not too deep and narrow (bottleneck) +- Not too shallow and wide (under-utilizes depth) +- Standard: 12-50 layers, 64-512 channels + +### ☐ Match capacity to data size +- Parameters ≈ 0.01-0.1× dataset size +- Monitor train/val gap (overfitting indicator) + +### ☐ Respect compute constraints +- Memory: Model + gradients + optimizer + activations < VRAM +- Latency: Inference time < requirement +- Use efficient architectures if constrained (MobileNet, EfficientNet) + +### ☐ Verify gradient flow +- Check gradients in early layers (should be non-zero) +- Use skip connections if vanishing + +### ☐ Benchmark against baselines +- Compare to simple model (linear, small MLP) +- Ensure complexity adds value (% improvement > 5%) + + +## Anti-Patterns + +### Anti-pattern 1: "Architecture X is state-of-the-art, so I'll use it" + +**Wrong:** +```python +# Transformer is SOTA for NLP, so use for tabular data (100 samples) +model = HugeTransformer(...) # 10M parameters +# Result: Overfits horribly (100 samples / 10M params = 0.00001 ratio!) +``` + +**Right:** +```python +# Match architecture to problem AND data size +# Tabular + small data → Linear or small MLP +model = nn.Linear(20, 1) # 21 parameters (appropriate!) +``` + +### Anti-pattern 2: "More layers = better" + +**Wrong:** +```python +# 100-layer plain network (no skip connections) +for i in range(100): + layers.append(nn.Conv2d(64, 64, 3, padding=1)) +# Result: Doesn't train (vanishing gradients) +``` + +**Right:** +```python +# 50-layer ResNet (with skip connections) +# Each block: out = x + F(x) # Skip connection +# Result: Trains well, high accuracy +``` + +### Anti-pattern 3: "Deeper + narrower = efficient" + +**Wrong:** +```python +# 100 layers × 8 channels = information bottleneck +model = VeryDeepNarrow(100, 8) +# Result: 60% accuracy (8 channels insufficient) +``` + +**Right:** +```python +# 18 layers, 64-512 channels (balanced) +model = ResNet18() # Balanced depth and width +# Result: 95% accuracy +``` + +### Anti-pattern 4: "Ignore constraints, optimize later" + +**Wrong:** +```python +# Design 1.5B parameter model for 24GB GPU +model = HugeModel(1.5e9) +# Result: OOM (out of memory), can't train +``` + +**Right:** +```python +# Calculate memory first: +# 1.5B params × 4 bytes = 6GB (weights) +# + 6GB (gradients) + 12GB (Adam) + 8GB (activations) = 32GB +# > 24GB → Doesn't fit! + +# Design for hardware: +model = ReasonableSizeModel(200e6) # 200M parameters (fits!) +``` + +### Anti-pattern 5: "Hyperparameters will fix architectural problems" + +**Wrong:** +```python +# Architecture: MLP for images (wrong inductive bias) +# Response: "Just tune learning rate!" +for lr in [0.1, 0.01, 0.001, 0.0001]: + train(model, lr=lr) +# Result: All fail (architecture is wrong!) +``` + +**Right:** +```python +# Fix architecture first (use CNN for images) +model = ResNet18() # Correct inductive bias +# Then tune hyperparameters +``` + + +## Summary + +**Core principles:** + +1. **Inductive bias**: Match architecture to problem structure (CNN for images, RNN/Transformer for sequences, GNN for graphs) + +2. **Occam's Razor**: Start simple (linear, small MLP), add complexity only when needed + +3. **Skip connections**: Use for networks > 10 layers (ResNet, DenseNet) + +4. **Depth-width balance**: Not too deep+narrow (bottleneck) or too shallow+wide (under-utilizes depth) + +5. **Capacity**: Match parameters to data size (0.01-0.1× dataset size) + +6. **Constraints**: Design for available memory, latency, throughput + +**Decision framework:** +- Images → CNN (ResNet, EfficientNet) +- Short sequences → LSTM +- Long sequences → Transformer +- Graphs → GNN (test if structure helps first!) +- Tabular → Linear or small MLP + +**Key insight**: Architecture design is about matching structural assumptions to problem structure, not about using the "best" or "most complex" model. Simple models often win. + +**When in doubt**: Start with the simplest model that could plausibly work. Add complexity only when you have evidence it helps. diff --git a/skills/using-neural-architectures/attention-mechanisms-catalog.md b/skills/using-neural-architectures/attention-mechanisms-catalog.md new file mode 100644 index 0000000..2a339b5 --- /dev/null +++ b/skills/using-neural-architectures/attention-mechanisms-catalog.md @@ -0,0 +1,824 @@ + +# Attention Mechanisms Catalog + +## When to Use This Skill + +Use this skill when you need to: +- ✅ Select attention mechanism for long sequences (> 2k tokens) +- ✅ Optimize memory usage (GPU OOM errors) +- ✅ Speed up training or inference +- ✅ Understand exact vs approximate attention trade-offs +- ✅ Choose between Flash, sparse, or linear attention +- ✅ Implement cross-attention for multimodal models + +**Do NOT use this skill for:** +- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`) +- ❌ High-level architecture selection (use `using-neural-architectures`) +- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`) + + +## Core Principle + +**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve: +- **O(n²) with less memory**: Flash Attention (exact, 4x less memory) +- **O(n × w)**: Sparse attention (exact, sliding window) +- **O(n)**: Linear attention (approximate, 1-3% accuracy loss) + +**Default recommendation:** Flash Attention (exact + fast + memory-efficient) + + +## Part 1: Complexity Hierarchy + +### Standard Self-Attention (Baseline) + +**Formula:** +```python +Attention(Q, K, V) = softmax(Q K^T / √d_k) V +``` + +**Complexity:** +- Time: O(n² · d) where n = seq_len, d = d_model +- Memory: O(n²) for attention matrix +- Exact: Yes (no approximation) + +**Memory breakdown (4k tokens, d=768):** +``` +Attention scores: 4096² × 4 bytes = 64MB per layer +Multi-head (12 heads): 64MB × 12 = 768MB per layer +16 layers: 768MB × 16 = 12GB just for attention! +Batch size 8: 12GB × 8 = 96GB (impossible on single GPU) +``` + +**When to use:** +- Sequence length < 2k tokens +- Standard use case (most models) +- Pair with Flash Attention optimization + +**Limitations:** +- Memory explosion for long sequences +- Quadratic scaling impractical beyond 4k tokens + + +## Part 2: Flash Attention ⭐ (Modern Default) + +### What is Flash Attention? + +**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster + +**Key insight:** +- Standard attention is **memory-bound** (not compute-bound) +- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s) +- Bottleneck: Moving n² attention matrix to/from HBM + +**Solution:** +- Tile attention computation +- Recompute instead of store intermediate values +- Fuse operations (reduce memory transfers) +- Result: Same O(n²) compute, O(n) memory + +### Algorithm + +``` +Standard attention (3 memory operations): +1. Compute scores: S = Q K^T (store n² matrix) +2. Softmax: P = softmax(S) (store n² matrix) +3. Output: O = P V (store n×d matrix) + +Flash Attention (tiled): +1. Divide Q, K, V into blocks +2. For each Q block: + - Load block to SRAM (fast memory) + - For each K, V block: + - Compute attention for this tile + - Update output incrementally + - Never materialize full n² matrix! +3. Result: Same output, O(n) memory +``` + +### Performance + +**Benchmarks (A100 GPU, 2k tokens):** + +Standard attention: +- Memory: 4GB for batch_size=8 +- Speed: 150ms/batch +- Max batch size: 16 + +Flash Attention: +- Memory: 1GB for batch_size=8 **(4x reduction)** +- Speed: 75ms/batch **(2x faster)** +- Max batch size: 64 **(4x larger)** + +**Flash Attention 2 (2023 update):** +- Further optimized: 2-3x faster than Flash Attention 1 +- Better parallelism +- Supports more head dimensions + +### When to Use + +✅ **ALWAYS use Flash Attention when:** +- Sequence length < 16k tokens +- Need exact attention (no approximation) +- Available in your framework + +**It's a FREE LUNCH:** +- No accuracy loss (mathematically exact) +- Faster training AND inference +- Less memory usage +- Drop-in replacement + +### Implementation + +**PyTorch 2.0+ (built-in):** +```python +import torch.nn.functional as F + +# Automatic Flash Attention (if available) +output = F.scaled_dot_product_attention( + query, key, value, + attn_mask=None, + dropout_p=0.0, + is_causal=False +) +# PyTorch automatically uses Flash Attention if: +# - CUDA available +# - Sequence length suitable +# - No attention mask (or causal mask) +``` + +**HuggingFace Transformers:** +```python +from transformers import AutoModel + +# Enable Flash Attention 2 +model = AutoModel.from_pretrained( + "bert-base-uncased", + attn_implementation="flash_attention_2", # Requires flash-attn package + torch_dtype=torch.float16 +) +``` + +**Manual installation:** +```bash +pip install flash-attn --no-build-isolation +``` + +### Limitations + +❌ **Flash Attention NOT suitable when:** +- Sequence length > 16k (memory still grows quadratically) +- Custom attention masks (complex patterns not supported) +- Inference on CPU (CUDA-only) + +**For > 16k tokens:** Use sparse or linear attention + + +## Part 3: Sparse Attention (Exact for Long Sequences) + +### Concept + +**Idea:** Each token attends to subset of tokens (not all) +- Sliding window: Local context +- Global tokens: Long-range connections +- Result: O(n × window_size) instead of O(n²) + +**Key property:** Still EXACT attention (not approximate) +- Just more structured attention pattern +- No accuracy loss if pattern matches task + +### Variant 1: Longformer + +**Pattern:** Sliding window + global attention + +``` +Attention pattern (window=2, global=[0]): + 0 1 2 3 4 5 +0 [ 1 1 1 1 1 1 ] ← Global token (attends to all) +1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2 +2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3 +3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4 +4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5 +5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5 + +Complexity: O(n × (window + num_global)) +``` + +**Components:** +1. **Sliding window**: Each token attends to w/2 tokens before and after +2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens +3. **Dilated windows**: Optional (stride > 1 for longer context) + +**Implementation:** +```python +from transformers import LongformerModel + +model = LongformerModel.from_pretrained("allenai/longformer-base-4096") + +# Attention mask (shape: batch, seq_len) +attention_mask = torch.ones(batch_size, seq_len) +attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token + +output = model(input_ids, attention_mask=attention_mask) +``` + +**Memory comparison (4k tokens, window=512):** +``` +Standard: 4096² = 16M elements → 64MB +Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!) +``` + +**When to use:** +- Documents: 4k-16k tokens (legal, scientific papers) +- Need full context but can't fit O(n²) +- Task has local + global structure + +**Pretrained models:** +- `allenai/longformer-base-4096`: Max 4096 tokens +- `allenai/longformer-large-4096`: Larger version + +### Variant 2: BigBird + +**Pattern:** Random + window + global + +``` +Attention pattern: +- Sliding window: Like Longformer +- Random connections: Each token attends to r random tokens +- Global tokens: Special tokens attend to all + +Complexity: O(n × (window + r + num_global)) +``` + +**Key difference from Longformer:** +- Random connections help information flow +- Theoretically proven to approximate full attention + +**When to use:** +- Similar to Longformer +- Slightly better for tasks needing long-range +- Less widely adopted than Longformer + +**Implementation:** +```python +from transformers import BigBirdModel + +model = BigBirdModel.from_pretrained( + "google/bigbird-roberta-base", + attention_type="block_sparse" # or "original_full" +) +``` + +### Sparse Attention Decision + +``` +Sequence length < 4k: +→ Flash Attention (exact, no pattern needed) + +Sequence length 4k-16k: +→ Longformer (sliding window + global) +→ Best for: Documents, long-form text + +Sequence length > 16k: +→ Longformer if possible +→ Linear attention if Longformer too slow +``` + + +## Part 4: Linear Attention (Approximate for Very Long) + +### Concept + +**Idea:** Approximate softmax attention with linear operations +- Complexity: O(n × k) where k << n +- Trade-off: 1-3% accuracy loss +- Benefit: Can handle very long sequences (> 16k) + +**Key property:** APPROXIMATE (not exact) +- Do NOT use if accuracy critical +- Good for extremely long sequences where exact is impossible + +### Variant 1: Performer + +**Method:** Random Fourier Features to approximate softmax(Q K^T) + +**Formula:** +```python +# Standard attention +Attention(Q, K, V) = softmax(Q K^T) V + +# Performer approximation +φ(Q) ≈ φ(K)^T ≈ softmax(Q K^T) +Attention(Q, K, V) ≈ φ(Q) (φ(K)^T V) + +# Complexity: O(n × k) where k = feature dimension +``` + +**Key trick:** +- Compute φ(K)^T V first: (k × d) matrix (small!) +- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d) +- Never materialize n² attention matrix + +**Implementation:** +```python +# From performer-pytorch library +from performer_pytorch import Performer + +model = Performer( + dim=512, + depth=6, + heads=8, + dim_head=64, + causal=False, + nb_features=256 # k = number of random features +) +``` + +**Accuracy:** +- Typical loss: 1-2% vs standard attention +- Depends on nb_features (more features = better approximation) +- k=256 usually sufficient + +**When to use:** +- Sequence length > 16k tokens +- Accuracy loss acceptable (not critical task) +- Need better than sparse attention (no structure assumptions) + +### Variant 2: Linformer + +**Method:** Project K and V to lower dimension + +**Formula:** +```python +# Standard attention (n × n attention matrix) +Attention(Q, K, V) = softmax(Q K^T / √d_k) V + +# Linformer (project K, V to n × k where k << n) +K_proj = E K # E: (k × n) projection matrix +V_proj = F V # F: (k × n) projection matrix + +Attention(Q, K, V) ≈ softmax(Q K_proj^T / √d_k) V_proj +# Attention matrix: (n × k) instead of (n × n) +``` + +**Complexity:** +- Time: O(n × k × d) where k << n +- Memory: O(n × k) instead of O(n²) + +**Implementation:** +```python +# From linformer library +from linformer import Linformer + +model = Linformer( + dim=512, + seq_len=8192, + depth=12, + heads=8, + k=256 # Projected dimension +) +``` + +**Accuracy:** +- Typical loss: 1-3% vs standard attention +- More loss than Performer +- Fixed sequence length (k is tied to max_seq_len) + +**When to use:** +- Fixed-length long sequences +- Memory more critical than speed +- Accuracy loss OK (2-3%) + +### Linear Attention Decision + +``` +Need exact attention: +→ Flash Attention or Sparse Attention (NOT linear) + +Sequence > 16k, accuracy critical: +→ Sparse Attention (Longformer) + +Sequence > 16k, accuracy loss OK: +→ Performer (better) or Linformer + +Sequence > 100k: +→ State space models (S4, Mamba, not attention) +``` + + +## Part 5: Cross-Attention (Multimodal) + +### Concept + +**Self-attention:** Q, K, V from same source +**Cross-attention:** Q from one source, K/V from another + +**Use cases:** +- Multimodal: vision → language (image captioning) +- Seq2seq: source language → target language (translation) +- RAG: query → document retrieval +- Conditioning: generation conditioned on context + +### Architecture + +```python +class CrossAttention(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + self.mha = MultiHeadAttention(d_model, num_heads) + + def forward(self, query_source, key_value_source, mask=None): + # query_source: (batch, n_q, d_model) - e.g., text tokens + # key_value_source: (batch, n_kv, d_model) - e.g., image patches + + # Q from query source + Q = self.W_q(query_source) + + # K, V from key-value source + K = self.W_k(key_value_source) + V = self.W_v(key_value_source) + + # Attention: (batch, n_q, d_model) + output = attention(Q, K, V, mask) + return output +``` + +### Example: Image Captioning + +**Task:** Generate caption from image + +**Architecture:** +1. **Image Encoder:** ViT processes image → image features (n_patches × d) +2. **Text Decoder:** Autoregressive text generation +3. **Cross-Attention:** Text queries image features + +```python +class ImageCaptioningDecoder(nn.Module): + def forward(self, text_tokens, image_features): + # 1. Self-attention on text (causal) + text = self.text_self_attention( + query=text, + key=text, + value=text, + causal_mask=True # Don't see future words + ) + + # 2. Cross-attention (text queries image) + text = self.cross_attention( + query=text, # From text decoder + key=image_features, # From image encoder + value=image_features # From image encoder + # No causal mask! Can attend to all image patches + ) + + # 3. Feed-forward + text = self.feed_forward(text) + + return text +``` + +**Attention flow:** +- Text token "cat" → High attention to cat region in image +- Text token "sitting" → High attention to posture in image + +### Example: Retrieval-Augmented Generation (RAG) + +**Task:** Generate answer using retrieved documents + +```python +class RAGDecoder(nn.Module): + def forward(self, query_tokens, document_embeddings): + # 1. Self-attention on query + query = self.query_self_attention(query, query, query) + + # 2. Cross-attention (query → documents) + query = self.cross_attention( + query=query, # What we're generating + key=document_embeddings, # Retrieved docs + value=document_embeddings # Retrieved docs + ) + + # Query learns to extract relevant info from docs + + return query +``` + +### When to Use Cross-Attention + +✅ **Use cross-attention when:** +- Two different modalities (vision + language) +- Conditioning generation on context (RAG) +- Seq2seq with different input/output (translation) +- Query-document matching + +❌ **Don't use cross-attention when:** +- Same modality (use self-attention) +- No clear query vs key-value separation + + +## Part 6: Other Attention Variants + +### Axial Attention (2D Images) + +**Idea:** For 2D data (images), attend along each axis separately + +``` +Standard 2D attention: H×W tokens → (HW)² attention matrix +Axial attention: + - Row attention: Each row attends to itself (H × W²) + - Column attention: Each column attends to itself (W × H²) + - Total: O(HW × (H + W)) << O((HW)²) +``` + +**When to use:** +- High-resolution images +- 2D positional structure important + +### Block-Sparse Attention + +**Idea:** Divide attention into blocks, attend only within/across blocks + +**Pattern:** +``` +Block size = 64 tokens +- Local block: Attend within same block +- Vertical stripe: Attend to corresponding position in other blocks +``` + +**Used in:** Sparse Transformer (OpenAI), GPT-3 + +### Multi-Query Attention (MQA) + +**Idea:** One K/V head shared across all Q heads + +**Benefit:** +- Smaller KV cache during inference +- Much faster decoding (4-8x) +- Trade-off: ~1% accuracy loss + +**Used in:** PaLM, Falcon + +### Grouped-Query Attention (GQA) + +**Idea:** Middle ground between multi-head and multi-query +- Group Q heads share K/V heads +- Example: 32 Q heads → 8 K/V heads (4:1 ratio) + +**Benefit:** +- 4x smaller KV cache +- Minimal accuracy loss (< 0.5%) + +**Used in:** LLaMA-2, Mistral + + +## Part 7: Decision Framework + +### By Sequence Length + +``` +< 2k tokens: +→ Flash Attention + Exact, fast, standard + +2k-4k tokens: +→ Flash Attention + Still manageable with modern GPUs + +4k-16k tokens: +→ Sparse Attention (Longformer, BigBird) + Exact, designed for documents +→ OR Flash Attention if batch size = 1 + +> 16k tokens: +→ Sparse Attention + If task has local structure +→ Linear Attention (Performer) + If accuracy loss OK (1-2%) +→ State Space Models (S4, Mamba) + If sequence > 100k +``` + +### By Memory Constraints + +``` +GPU OOM with standard attention: +1. Try Flash Attention (4x less memory, free lunch) +2. If still OOM, reduce batch size +3. If batch size = 1 and still OOM, use sparse attention +4. Last resort: Linear attention (if accuracy loss OK) + +DON'T: +- Gradient checkpointing (slower, use Flash Attention instead) +- Throwing more GPUs (algorithmic problem, not hardware) +``` + +### By Accuracy Requirements + +``` +Must be exact (no approximation): +→ Flash Attention or Sparse Attention + Never use linear attention! + +Accuracy loss acceptable (1-3%): +→ Linear Attention (Performer, Linformer) + Only for very long sequences (> 16k) + +Critical task (medical, legal): +→ Exact attention only + Flash Attention or Sparse Attention +``` + +### By Task Type + +``` +Classification / Understanding: +→ Standard + Flash Attention + Sequence usually < 2k + +Document processing: +→ Longformer (4096 tokens) + Designed for documents + +Generation (LLM): +→ Flash Attention for training +→ + GQA/MQA for inference (faster decoding) + +Multimodal (vision + language): +→ Cross-attention for modality fusion +→ Self-attention within each modality + +Retrieval-augmented: +→ Cross-attention (query → documents) +``` + + +## Part 8: Implementation Checklist + +### Using Flash Attention + +**PyTorch 2.0+:** +```python +# Automatic (recommended) +output = F.scaled_dot_product_attention(query, key, value) + +# Verify Flash Attention is used +import torch.backends.cuda +print(torch.backends.cuda.flash_sdp_enabled()) # Should be True +``` + +**HuggingFace:** +```python +model = AutoModel.from_pretrained( + "model-name", + attn_implementation="flash_attention_2", + torch_dtype=torch.float16 # Flash Attention needs fp16/bf16 +) +``` + +**Requirements:** +- CUDA GPU (not CPU) +- PyTorch >= 2.0 OR flash-attn package +- fp16 or bf16 dtype (not fp32) + +### Using Sparse Attention + +**Longformer:** +```python +from transformers import LongformerModel, LongformerTokenizer + +tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096") +model = LongformerModel.from_pretrained("allenai/longformer-base-4096") + +# Attention mask +# 0 = no attention, 1 = local attention, 2 = global attention +attention_mask = torch.ones(batch_size, seq_len) +attention_mask[:, 0] = 2 # [CLS] token gets global attention + +outputs = model(input_ids, attention_mask=attention_mask) +``` + +**Custom sparse pattern:** +```python +# Create custom block-sparse mask +def create_block_sparse_mask(seq_len, block_size): + num_blocks = seq_len // block_size + mask = torch.zeros(seq_len, seq_len) + + for i in range(num_blocks): + start = i * block_size + end = start + block_size + mask[start:end, start:end] = 1 # Local block + + return mask +``` + +### Using Cross-Attention + +```python +class DecoderWithCrossAttention(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + self.self_attn = MultiHeadAttention(d_model, num_heads) + self.cross_attn = MultiHeadAttention(d_model, num_heads) + + def forward(self, decoder_input, encoder_output, causal_mask=None): + # Self-attention (causal) + x = self.self_attn( + query=decoder_input, + key=decoder_input, + value=decoder_input, + mask=causal_mask + ) + + # Cross-attention (Q from decoder, K/V from encoder) + x = self.cross_attn( + query=x, # From decoder + key=encoder_output, # From encoder + value=encoder_output, # From encoder + mask=None # No causal mask for cross-attention! + ) + + return x +``` + + +## Part 9: Common Mistakes + +### Mistake 1: Ignoring Flash Attention + +**Symptom:** Training slow, high memory usage +**Fix:** Always use Flash Attention for < 16k tokens + +### Mistake 2: Using Linear Attention Unnecessarily + +**Symptom:** 1-3% accuracy loss for no reason +**Fix:** Use Flash Attention (exact) unless sequence > 16k + +### Mistake 3: Gradient Checkpointing Instead of Flash Attention + +**Symptom:** Training 20% slower +**Fix:** Flash Attention gives memory savings AND speed + +### Mistake 4: Cross-Attention with Causal Mask + +**Symptom:** Decoder can't attend to encoder properly +**Fix:** Causal mask only for self-attention, NOT cross-attention + +### Mistake 5: Accepting O(n²) Memory + +**Symptom:** GPU OOM for > 4k tokens +**Fix:** Use sparse or Flash Attention, don't just add GPUs + + +## Summary: Quick Reference + +### Attention Selection + +``` +Sequence length: + < 2k → Flash Attention (default) + 2-4k → Flash Attention + 4-16k → Longformer (documents) or Flash Attention (batch=1) + > 16k → Sparse or Linear Attention + +Memory constrained: + First: Try Flash Attention (4x less memory) + Still OOM: Use sparse attention (Longformer) + Last resort: Linear attention (accuracy loss) + +Speed critical: + Training: Flash Attention (2x faster) + Inference: Flash Attention + GQA/MQA + +Accuracy critical: + Use exact attention only (Flash or Sparse) + NEVER linear attention + +Multimodal: + Cross-attention for modality fusion +``` + +### Implementation + +``` +PyTorch 2.0+: + F.scaled_dot_product_attention() # Auto Flash Attention + +HuggingFace: + attn_implementation="flash_attention_2" + +Longformer: + LongformerModel.from_pretrained("allenai/longformer-base-4096") + +Custom: + Inherit from nn.Module, implement forward() +``` + + +## Next Steps + +After mastering this skill: +- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference +- `llm-specialist/context-window-management`: Manage long contexts in LLMs +- `architecture-design-principles`: Understand broader design trade-offs + +**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns). diff --git a/skills/using-neural-architectures/cnn-families-and-selection.md b/skills/using-neural-architectures/cnn-families-and-selection.md new file mode 100644 index 0000000..834b40c --- /dev/null +++ b/skills/using-neural-architectures/cnn-families-and-selection.md @@ -0,0 +1,622 @@ + +# CNN Families and Selection: Choosing the Right Convolutional Network + + +CNNs are the foundation of computer vision. Different families have vastly different trade-offs: +- Accuracy vs Speed vs Size +- Dataset size requirements +- Deployment target (cloud vs edge vs mobile) +- Task type (classification vs detection vs segmentation) + +This skill helps you choose the RIGHT CNN for YOUR constraints. + + +## When to Use This Skill + +Use this skill when: +- ✅ Selecting CNN for vision task (classification, detection, segmentation) +- ✅ Comparing CNN families (ResNet vs EfficientNet vs MobileNet) +- ✅ Optimizing for specific constraints (latency, size, accuracy) +- ✅ Understanding CNN evolution (why newer architectures exist) +- ✅ Deployment-specific selection (cloud, edge, mobile) + +DO NOT use for: +- ❌ Non-vision tasks (use sequence-models-comparison or other skills) +- ❌ Training optimization (use training-optimization pack) +- ❌ Implementation details (use pytorch-engineering pack) + +**When in doubt:** If choosing WHICH CNN → this skill. If implementing/training CNN → other skills. + + +## Selection Framework + +### Step 1: Identify Constraints + +**Before recommending ANY architecture, ask:** + +| Constraint | Question | Impact | +|------------|----------|--------| +| **Deployment** | Where will model run? | Cloud → Any, Edge → MobileNet/EfficientNet-Lite, Mobile → MobileNetV3 | +| **Latency** | Speed requirement? | Real-time (< 10ms) → MobileNet, Batch (> 100ms) → Any | +| **Model Size** | Parameter/memory budget? | < 10M params → MobileNet, < 50M → ResNet/EfficientNet, Any → Large models OK | +| **Dataset Size** | Training images? | < 10k → Small models, 10k-100k → Medium, > 100k → Large | +| **Accuracy** | Required accuracy? | Competitive → EfficientNet-B4+, Production → ResNet-50/EfficientNet-B2 | +| **Task Type** | Classification/detection/segmentation? | Detection → FPN-compatible, Segmentation → Multi-scale | + +**Critical:** Get answers to these BEFORE recommending architecture. + +### Step 2: Apply Decision Tree + +``` +START: What's your primary constraint? + +┌─ DEPLOYMENT TARGET +│ ├─ Cloud / Server +│ │ └─ Dataset size? +│ │ ├─ Small (< 10k) → ResNet-18, EfficientNet-B0 +│ │ ├─ Medium (10k-100k) → ResNet-50, EfficientNet-B2 +│ │ └─ Large (> 100k) → ResNet-101, EfficientNet-B4, ViT +│ │ +│ ├─ Edge Device (Jetson, Coral) +│ │ └─ Latency requirement? +│ │ ├─ Real-time (< 10ms) → MobileNetV3-Small, EfficientNet-Lite0 +│ │ ├─ Medium (10-50ms) → MobileNetV3-Large, EfficientNet-Lite2 +│ │ └─ Relaxed (> 50ms) → EfficientNet-B0, ResNet-18 +│ │ +│ └─ Mobile (iOS/Android) +│ └─ MobileNetV3-Small (fastest), MobileNetV3-Large (balanced) +│ + INT8 quantization (route to ml-production) +│ +├─ ACCURACY PRIORITY (cloud deployment assumed) +│ ├─ Maximum accuracy → EfficientNet-B7, ResNet-152, ViT-Large +│ ├─ Balanced → EfficientNet-B2/B3, ResNet-50 +│ └─ Fast training → ResNet-18, EfficientNet-B0 +│ +├─ EFFICIENCY PRIORITY +│ └─ Best accuracy per FLOP → EfficientNet family (B0-B7) +│ (EfficientNet dominates ResNet on Pareto frontier) +│ +└─ TASK TYPE + ├─ Classification → Any CNN (use constraint-based selection above) + ├─ Object Detection → ResNet + FPN, EfficientDet, YOLOv8 (CSPDarknet) + └─ Segmentation → ResNet + U-Net, EfficientNet + DeepLabV3 +``` + + +## CNN Family Catalog + +### 1. ResNet Family (2015) - The Standard Baseline + +**Architecture:** Residual connections (skip connections) enable very deep networks + +**Variants:** +- ResNet-18: 11M params, 1.8 GFLOPs, 69.8% ImageNet +- ResNet-34: 22M params, 3.7 GFLOPs, 73.3% ImageNet +- ResNet-50: 25M params, 4.1 GFLOPs, 76.1% ImageNet +- ResNet-101: 44M params, 7.8 GFLOPs, 77.4% ImageNet +- ResNet-152: 60M params, 11.6 GFLOPs, 78.3% ImageNet + +**When to Use:** +- ✅ **Baseline choice**: Well-tested, widely supported +- ✅ **Transfer learning**: Excellent pre-trained weights available +- ✅ **Object detection**: Standard backbone for Faster R-CNN, Mask R-CNN +- ✅ **Interpretability**: Simple architecture, easy to understand + +**When NOT to Use:** +- ❌ **Edge/mobile deployment**: Too large and slow +- ❌ **Efficiency priority**: EfficientNet beats ResNet on accuracy/FLOP +- ❌ **Small datasets (< 10k)**: Use ResNet-18, not ResNet-50+ + +**Key Insight:** Skip connections solve vanishing gradient, enable depth + +**Code Example:** +```python +import torchvision.models as models + +# For cloud/server (good dataset) +model = models.resnet50(pretrained=True) + +# For small dataset or faster training +model = models.resnet18(pretrained=True) + +# For maximum accuracy (cloud only) +model = models.resnet101(pretrained=True) +``` + + +### 2. EfficientNet Family (2019) - Best Efficiency + +**Architecture:** Compound scaling (depth + width + resolution) optimized via neural architecture search + +**Variants:** +- EfficientNet-B0: 5M params, 0.4 GFLOPs, 77.3% ImageNet +- EfficientNet-B1: 8M params, 0.7 GFLOPs, 79.2% ImageNet +- EfficientNet-B2: 9M params, 1.0 GFLOPs, 80.3% ImageNet +- EfficientNet-B3: 12M params, 1.8 GFLOPs, 81.7% ImageNet +- EfficientNet-B4: 19M params, 4.2 GFLOPs, 82.9% ImageNet +- EfficientNet-B7: 66M params, 37 GFLOPs, 84.4% ImageNet + +**When to Use:** +- ✅ **Efficiency matters**: Best accuracy per FLOP/parameter +- ✅ **Cloud deployment**: B2-B4 sweet spot for production +- ✅ **Limited compute**: B0 matches ResNet-50 accuracy at 10x fewer FLOPs +- ✅ **Scaling needs**: Want to scale model up/down systematically + +**When NOT to Use:** +- ❌ **Real-time mobile**: Use MobileNet (EfficientNet has more layers) +- ❌ **Very small datasets**: Can overfit despite efficiency +- ❌ **Simplicity needed**: More complex than ResNet + +**Key Insight:** Compound scaling balances depth, width, and resolution optimally + +**Efficiency Comparison:** +``` +Same accuracy as ResNet-50 (76%): +- ResNet-50: 25M params, 4.1 GFLOPs +- EfficientNet-B0: 5M params, 0.4 GFLOPs (10x more efficient!) + +Better accuracy (82.9%): +- ResNet-152: 60M params, 11.6 GFLOPs → 78.3% ImageNet +- EfficientNet-B4: 19M params, 4.2 GFLOPs → 82.9% ImageNet + (Better accuracy with 3x fewer params and 3x less compute) +``` + +**Code Example:** +```python +import timm # PyTorch Image Models library + +# Balanced choice (production) +model = timm.create_model('efficientnet_b2', pretrained=True) + +# Efficiency priority (edge) +model = timm.create_model('efficientnet_b0', pretrained=True) + +# Accuracy priority (research) +model = timm.create_model('efficientnet_b4', pretrained=True) +``` + + +### 3. MobileNet Family (2017-2019) - Mobile Optimized + +**Architecture:** Depthwise separable convolutions (drastically reduce compute) + +**Variants:** +- MobileNetV1: 4.2M params, 0.6 GFLOPs, 70.6% ImageNet +- MobileNetV2: 3.5M params, 0.3 GFLOPs, 72.0% ImageNet +- MobileNetV3-Small: 2.5M params, 0.06 GFLOPs, 67.4% ImageNet +- MobileNetV3-Large: 5.4M params, 0.2 GFLOPs, 75.2% ImageNet + +**When to Use:** +- ✅ **Mobile deployment**: iOS/Android apps +- ✅ **Edge devices**: Raspberry Pi, Jetson Nano +- ✅ **Real-time inference**: < 100ms latency +- ✅ **Extreme efficiency**: < 10M parameters budget + +**When NOT to Use:** +- ❌ **Cloud deployment with no constraints**: EfficientNet or ResNet better accuracy +- ❌ **Accuracy priority**: Sacrifices accuracy for speed +- ❌ **Large datasets with compute**: Can afford better models + +**Key Insight:** Depthwise separable convolutions = standard conv split into depthwise + pointwise (9x fewer operations) + +**Deployment Performance:** +``` +Raspberry Pi 4 inference (224×224 image): +- ResNet-50: ~2000ms (unusable) +- ResNet-18: ~600ms (slow) +- MobileNetV2: ~150ms (acceptable) +- MobileNetV3-Large: ~80ms (good) +- MobileNetV3-Small: ~40ms (fast) + +With INT8 quantization: +- MobileNetV3-Large: ~30ms (production-ready) +- MobileNetV3-Small: ~15ms (real-time) +``` + +**Code Example:** +```python +import torchvision.models as models + +# For mobile deployment +model = models.mobilenet_v3_large(pretrained=True) + +# For ultra-low latency (sacrifice accuracy) +model = models.mobilenet_v3_small(pretrained=True) + +# Quantization for mobile (route to ml-production skill for details) +# Achieves 2-4x speedup with minimal accuracy loss +``` + + +### 4. Inception Family (2014-2016) - Multi-Scale Features + +**Architecture:** Multi-scale convolutions in parallel (inception modules) + +**Variants:** +- InceptionV3: 24M params, 5.7 GFLOPs, 77.5% ImageNet +- InceptionV4: 42M params, 12.3 GFLOPs, 80.0% ImageNet +- Inception-ResNet: Hybrid with residual connections + +**When to Use:** +- ✅ **Multi-scale features**: Objects at different sizes +- ✅ **Object detection**: Good backbone for detection +- ✅ **Historical interest**: Understanding multi-scale approaches + +**When NOT to Use:** +- ❌ **Simplicity needed**: Complex architecture, hard to modify +- ❌ **Efficiency priority**: EfficientNet better +- ❌ **Modern projects**: Largely superseded by ResNet/EfficientNet + +**Key Insight:** Parallel multi-scale convolutions (1×1, 3×3, 5×5) capture different receptive fields + +**Status:** Mostly historical - ResNet and EfficientNet have replaced Inception in practice + + +### 5. DenseNet Family (2017) - Dense Connections + +**Architecture:** Every layer connects to every other layer (dense connections) + +**Variants:** +- DenseNet-121: 8M params, 2.9 GFLOPs, 74.4% ImageNet +- DenseNet-169: 14M params, 3.4 GFLOPs, 75.6% ImageNet +- DenseNet-201: 20M params, 4.3 GFLOPs, 76.9% ImageNet + +**When to Use:** +- ✅ **Parameter efficiency**: Good accuracy with few parameters +- ✅ **Feature reuse**: Dense connections enable feature reuse +- ✅ **Small datasets**: Better gradient flow helps with limited data + +**When NOT to Use:** +- ❌ **Inference speed priority**: Dense connections slow (high memory bandwidth) +- ❌ **Training speed**: Slower to train than ResNet +- ❌ **Production deployment**: Less mature ecosystem than ResNet + +**Key Insight:** Dense connections improve gradient flow, enable feature reuse, but slow inference + +**Status:** Theoretically elegant, but ResNet/EfficientNet more practical + + +### 6. VGG Family (2014) - Historical Baseline + +**Architecture:** Very deep (16-19 layers), small 3×3 convolutions, many parameters + +**Variants:** +- VGG-16: 138M params, 15.5 GFLOPs, 71.5% ImageNet +- VGG-19: 144M params, 19.6 GFLOPs, 71.1% ImageNet + +**When to Use:** +- ❌ **DON'T use VGG for new projects** +- Historical understanding only + +**Why NOT to Use:** +- Massive parameter count (138M vs ResNet-50's 25M) +- Poor accuracy for size +- Superseded by ResNet (2015) + +**Key Insight:** Proved that depth matters, but skip connections (ResNet) are better + +**Status:** **Obsolete** - use ResNet or EfficientNet instead + + +## Practical Selection Guide + +### Scenario 1: Cloud/Server Deployment + +**Goal:** Best accuracy, no compute constraints + +**Recommendation:** +``` +Small dataset (< 10k images): +→ EfficientNet-B0 or ResNet-18 + (Avoid overfitting with smaller model) + +Medium dataset (10k-100k images): +→ EfficientNet-B2 or ResNet-50 + (Balanced accuracy and efficiency) + +Large dataset (> 100k images): +→ EfficientNet-B4 or ResNet-101 + (Can afford larger model) + +Maximum accuracy (research): +→ EfficientNet-B7 or Vision Transformer + (If dataset > 1M images and compute unlimited) +``` + + +### Scenario 2: Edge Deployment (Jetson, Coral TPU) + +**Goal:** Optimize for edge hardware latency + +**Recommendation:** +``` +Real-time requirement (< 10ms): +→ MobileNetV3-Small or EfficientNet-Lite0 + + INT8 quantization + +Medium latency (10-50ms): +→ MobileNetV3-Large or EfficientNet-Lite2 + +Relaxed latency (> 50ms): +→ EfficientNet-B0 or ResNet-18 +``` + +**Critical:** Profile on actual edge hardware. Quantization is mandatory (route to ml-production). + + +### Scenario 3: Mobile Deployment (iOS/Android) + +**Goal:** On-device inference, minimal battery drain + +**Recommendation:** +``` +All mobile deployments: +→ MobileNetV3-Large (balanced) +→ MobileNetV3-Small (fastest, less accurate) + +Always use: +- INT8 quantization (2-4x speedup) +- CoreML (iOS) or TFLite (Android) optimization +- Benchmark on target device before deploying +``` + +**Expected latency (iPhone 12, INT8 quantized):** +- MobileNetV3-Small: 5-10ms +- MobileNetV3-Large: 15-25ms + + +### Scenario 4: Object Detection + +**Goal:** Select backbone for detection framework + +**Recommendation:** +``` +Faster R-CNN: +→ ResNet-50 + FPN (standard) +→ ResNet-101 + FPN (more accuracy) + +YOLOv8: +→ CSPDarknet (built-in, optimized) + +EfficientDet: +→ EfficientNet + BiFPN (best efficiency) + +Custom detection: +→ ResNet or EfficientNet as backbone +→ Add Feature Pyramid Network (FPN) for multi-scale +``` + +**Note:** Detection adds significant compute on top of backbone. Choose efficient backbone. + + +### Scenario 5: Semantic Segmentation + +**Goal:** Dense pixel-wise prediction + +**Recommendation:** +``` +U-Net style: +→ ResNet-18/34 as encoder (fast) +→ EfficientNet-B0 as encoder (efficient) + +DeepLabV3: +→ ResNet-50 (standard) +→ MobileNetV3 (mobile deployment) + +Key: Segmentation requires multi-scale features +→ Ensure backbone has skip connections or FPN +``` + + +## Trade-Off Analysis + +### Accuracy vs Efficiency (Pareto Frontier) + +**ImageNet Top-1 Accuracy vs FLOPs:** + +``` +Efficiency Winners (best accuracy per FLOP): +1. EfficientNet-B0: 77.3% @ 0.4 GFLOPs (best efficiency) +2. EfficientNet-B2: 80.3% @ 1.0 GFLOPs +3. EfficientNet-B4: 82.9% @ 4.2 GFLOPs + +Accuracy Winners (best absolute accuracy): +1. EfficientNet-B7: 84.4% @ 37 GFLOPs +2. ViT-Large: 85.2% @ 190 GFLOPs (requires huge dataset) +3. ResNet-152: 78.3% @ 11.6 GFLOPs (dominated by EfficientNet) + +Speed Winners (lowest latency): +1. MobileNetV3-Small: 67.4% @ 0.06 GFLOPs (50ms on mobile) +2. MobileNetV3-Large: 75.2% @ 0.2 GFLOPs (100ms on mobile) +3. EfficientNet-Lite0: 75.0% @ 0.4 GFLOPs +``` + +**Key Takeaway:** EfficientNet dominates ResNet on Pareto frontier (better accuracy at same compute). + + +### Parameters vs Accuracy + +**For same ~75% ImageNet accuracy:** +``` +VGG-16: 138M params (❌ terrible efficiency) +ResNet-50: 25M params +EfficientNet-B0: 5M params (✅ 5x fewer parameters!) +MobileNetV3-Large: 5M params (fast inference) +``` + +**Conclusion:** Modern architectures (EfficientNet, MobileNet) achieve same accuracy with far fewer parameters. + + +## Common Pitfalls + +### Pitfall 1: Defaulting to ResNet-50 +**Symptom:** Using ResNet-50 without considering alternatives + +**Why it's wrong:** EfficientNet-B0 matches ResNet-50 accuracy with 10x less compute + +**Fix:** Consider EfficientNet family first (better efficiency) + + +### Pitfall 2: Choosing Large Model for Small Dataset +**Symptom:** Using ResNet-101 with < 10k images + +**Why it's wrong:** Model will overfit (too many parameters for data) + +**Fix:** +- < 10k images → ResNet-18 or EfficientNet-B0 +- 10k-100k → ResNet-50 or EfficientNet-B2 +- > 100k → Can use larger models + + +### Pitfall 3: Using Desktop Model on Mobile +**Symptom:** Trying to run ResNet-50 on mobile device + +**Why it's wrong:** 2000ms inference time is unusable + +**Fix:** Use MobileNetV3 + quantization for mobile (15-30ms) + + +### Pitfall 4: Ignoring Task Type +**Symptom:** Using standard CNN for object detection without FPN + +**Why it's wrong:** Detection needs multi-scale features + +**Fix:** Use detection-specific frameworks (YOLOv8, Faster R-CNN) with appropriate backbone + + +### Pitfall 5: Believing "Bigger = Better" +**Symptom:** Choosing ResNet-152 over ResNet-50 without justification + +**Why it's wrong:** Diminishing returns - 3x compute for 1.3% accuracy, will overfit on small data + +**Fix:** Match model capacity to dataset size, consider efficiency + + +## Evolution and Historical Context + +**Why CNNs evolved the way they did:** + +``` +2012: AlexNet +→ Proved deep learning works for vision +→ 8 layers, 60M params + +2014: VGG +→ Deeper is better (16-19 layers) +→ But: 138M params (too many) + +2014: Inception/GoogLeNet +→ Multi-scale convolutions +→ More efficient than VGG + +2015: ResNet ★ +→ Skip connections enable very deep networks (152 layers) +→ Solved vanishing gradient problem +→ Became standard baseline + +2017: MobileNet +→ Mobile deployment needs +→ Depthwise separable convolutions (9x fewer ops) + +2017: DenseNet +→ Dense connections for feature reuse +→ Parameter efficient but slow inference + +2019: EfficientNet ★ +→ Compound scaling (depth + width + resolution) +→ Neural architecture search +→ Dominates Pareto frontier (best accuracy per FLOP) +→ New standard for efficiency + +2020: Vision Transformer +→ Attention-based (no convolutions) +→ Requires very large datasets (> 1M images) +→ For research/large-scale applications +``` + +**Current Recommendations (2025):** +- Cloud: **EfficientNet** (best efficiency) or ResNet (simplicity) +- Edge: **EfficientNet-Lite** or MobileNetV3 +- Mobile: **MobileNetV3** + quantization +- Detection: **EfficientDet** or YOLOv8 +- Baseline: **ResNet** (simple, well-tested) + + +## Decision Checklist + +Before choosing CNN, answer these: + +``` +☐ Deployment target? (cloud/edge/mobile) +☐ Latency requirement? (< 10ms / 10-100ms / > 100ms) +☐ Model size budget? (< 10M / 10-50M / unlimited params) +☐ Dataset size? (< 10k / 10k-100k / > 100k images) +☐ Accuracy priority? (maximum / production / fast iteration) +☐ Task type? (classification / detection / segmentation) +☐ Efficiency matters? (yes → EfficientNet, no → flexibility) + +Based on answers: +→ Mobile → MobileNetV3 +→ Edge → EfficientNet-Lite or MobileNetV3 +→ Cloud + efficiency → EfficientNet +→ Cloud + simplicity → ResNet +→ Maximum accuracy → EfficientNet-B7 or ViT +→ Small dataset → Small models (ResNet-18, EfficientNet-B0) +``` + + +## Integration with Other Skills + +**After selecting CNN architecture:** + +**Training the model:** +→ `yzmir/training-optimization/using-training-optimization` +- Optimizer selection (Adam, SGD, AdamW) +- Learning rate schedules +- Data augmentation + +**Implementing in PyTorch:** +→ `yzmir/pytorch-engineering/using-pytorch-engineering` +- Custom modifications to pre-trained models +- Multi-GPU training +- Performance optimization + +**Deploying to production:** +→ `yzmir/ml-production/using-ml-production` +- Quantization (INT8, FP16) +- Model serving (TorchServe, ONNX) +- Optimization for edge/mobile (TFLite, CoreML) + +**If architecture is unstable (very deep):** +→ `yzmir/neural-architectures/normalization-techniques` +- Normalization layers (BatchNorm, LayerNorm) +- Skip connections +- Initialization strategies + + +## Summary + +**CNN Selection in One Table:** + +| Scenario | Recommendation | Why | +|----------|----------------|-----| +| Cloud, balanced | EfficientNet-B2 | Best efficiency, 80% accuracy | +| Cloud, max accuracy | EfficientNet-B4 | 83% accuracy, reasonable compute | +| Cloud, simple baseline | ResNet-50 | Well-tested, widely used | +| Edge device | MobileNetV3-Large | Optimized for edge, 75% accuracy | +| Mobile app | MobileNetV3-Small + quantization | < 20ms inference | +| Small dataset (< 10k) | ResNet-18 or EfficientNet-B0 | Avoid overfitting | +| Object detection | ResNet-50 + FPN, EfficientDet | Multi-scale features | +| Segmentation | ResNet + U-Net, DeepLabV3 | Dense prediction | + +**Key Principles:** +1. **Match model capacity to dataset size** (small data → small model) +2. **EfficientNet dominates ResNet on efficiency** (same accuracy, less compute) +3. **Mobile needs mobile-specific architectures** (MobileNet, quantization) +4. **Task type matters** (detection/segmentation need multi-scale features) +5. **Bigger ≠ always better** (diminishing returns, overfitting risk) + +**When in doubt:** Start with **EfficientNet-B2** (cloud) or **MobileNetV3-Large** (edge/mobile). + + +**END OF SKILL** diff --git a/skills/using-neural-architectures/generative-model-families.md b/skills/using-neural-architectures/generative-model-families.md new file mode 100644 index 0000000..ca5e9fb --- /dev/null +++ b/skills/using-neural-architectures/generative-model-families.md @@ -0,0 +1,811 @@ + +# Generative Model Families + +## When to Use This Skill + +Use this skill when you need to: +- ✅ Select generative model for image/audio/video generation +- ✅ Understand VAE vs GAN vs Diffusion trade-offs +- ✅ Decide between training from scratch vs fine-tuning +- ✅ Address mode collapse in GANs +- ✅ Choose between quality, speed, and training stability +- ✅ Understand modern landscape (Stable Diffusion, StyleGAN, etc.) + +**Do NOT use this skill for:** +- ❌ Text generation (use `llm-specialist` pack) +- ❌ Architecture implementation details (use model-specific docs) +- ❌ High-level architecture selection (use `using-neural-architectures`) + + +## Core Principle + +**Generative models have fundamental trade-offs:** +- **Quality vs Stability**: GANs sharp but unstable, VAEs blurry but stable +- **Quality vs Speed**: Diffusion high-quality but slow, GANs fast +- **Explicitness vs Flexibility**: Autoregressive/Flow have likelihood, GANs don't + +**Modern default (2025):** Diffusion models (best quality + stability) + + +## Part 1: Model Family Overview + +### The Five Families + +**1. VAE (Variational Autoencoder)** +- **Approach**: Learn latent space with encoder-decoder +- **Quality**: Blurry (6/10) +- **Training**: Very stable +- **Use**: Latent space exploration, NOT high-quality generation + +**2. GAN (Generative Adversarial Network)** +- **Approach**: Adversarial game (generator vs discriminator) +- **Quality**: Sharp (9/10) +- **Training**: Unstable (adversarial dynamics) +- **Use**: High-quality generation, fast inference + +**3. Diffusion Models** +- **Approach**: Iterative denoising +- **Quality**: Very sharp (9.5/10) +- **Training**: Stable +- **Use**: Modern default for high-quality generation + +**4. Autoregressive Models** +- **Approach**: Sequential generation (pixel-by-pixel, token-by-token) +- **Quality**: Good (7-8/10) +- **Training**: Stable +- **Use**: Explicit likelihood, sequential data + +**5. Flow Models** +- **Approach**: Invertible transformations +- **Quality**: Good (7-8/10) +- **Training**: Stable +- **Use**: Exact likelihood, invertibility needed + +### Quick Comparison + +| Model | Quality | Training Stability | Inference Speed | Mode Collapse | Likelihood | +|-------|---------|-------------------|----------------|---------------|------------| +| VAE | 6/10 (blurry) | 10/10 | Fast | No | Approximate | +| GAN | 9/10 | 3/10 | Fast | Yes | No | +| Diffusion | 9.5/10 | 9/10 | Slow | No | Approximate | +| Autoregressive | 7-8/10 | 9/10 | Very slow | No | Exact | +| Flow | 7-8/10 | 8/10 | Fast (both ways) | No | Exact | + + +## Part 2: VAE (Variational Autoencoder) + +### Architecture + +**Components:** +1. **Encoder**: x → z (image to latent) +2. **Latent space**: z ~ N(μ, σ²) +3. **Decoder**: z → x' (latent to reconstruction) + +**Loss function:** +```python +# ELBO (Evidence Lower Bound) +loss = reconstruction_loss + KL_divergence + +# Reconstruction: How well decoder reconstructs input +reconstruction_loss = MSE(x, x_reconstructed) + +# KL: How close latent is to standard normal +KL_divergence = KL(q(z|x) || p(z)) +``` + +### Why VAE is Blurry + +**Problem**: MSE loss encourages pixel-wise averaging + +**Example:** +- Dataset: Faces with both smiles and no smiles +- VAE learns: "Average face has half-smile blur" +- Result: Blurry, hedges between modes + +**Mathematical reason:** +- MSE minimization = mean prediction +- Mean of sharp images = blurry image + +### When to Use VAE + +✅ **Use VAE for:** +- Latent space exploration (interpolation, arithmetic) +- Anomaly detection (reconstruction error) +- Disentangled representations (β-VAE) +- Compression (lossy, with latent codes) + +❌ **DON'T use VAE for:** +- High-quality image generation (use GAN or Diffusion!) +- Sharp, realistic outputs + +### Implementation + +```python +import torch +import torch.nn as nn + +class VAE(nn.Module): + def __init__(self, latent_dim=128): + super().__init__() + # Encoder + self.encoder = nn.Sequential( + nn.Conv2d(3, 32, 4, 2, 1), # 64x64 -> 32x32 + nn.ReLU(), + nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16 + nn.ReLU(), + nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8 + nn.ReLU(), + nn.Flatten() + ) + + self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim) + self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim) + + # Decoder + self.fc_decode = nn.Linear(latent_dim, 128 * 8 * 8) + self.decoder = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, 4, 2, 1), + nn.ReLU(), + nn.ConvTranspose2d(32, 3, 4, 2, 1), + nn.Sigmoid() + ) + + def reparameterize(self, mu, logvar): + # Reparameterization trick: z = μ + σ * ε + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + # Encode + h = self.encoder(x) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + + # Sample latent + z = self.reparameterize(mu, logvar) + + # Decode + h = self.fc_decode(z) + h = h.view(-1, 128, 8, 8) + x_recon = self.decoder(h) + + return x_recon, mu, logvar + + def loss_function(self, x, x_recon, mu, logvar): + # Reconstruction loss + recon_loss = F.mse_loss(x_recon, x, reduction='sum') + + # KL divergence + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + return recon_loss + kl_loss +``` + + +## Part 3: GAN (Generative Adversarial Network) + +### Architecture + +**Components:** +1. **Generator**: z → x (noise to image) +2. **Discriminator**: x → [0, 1] (image to real/fake probability) + +**Adversarial Training:** +```python +# Discriminator loss: Classify real as real, fake as fake +D_loss = -log(D(x_real)) - log(1 - D(G(z))) + +# Generator loss: Fool discriminator +G_loss = -log(D(G(z))) + +# Minimax game: +min_G max_D V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))] +``` + +### Training Instability + +**Problem**: Adversarial dynamics are unstable + +**Common issues:** +1. **Mode collapse**: Generator produces limited variety +2. **Non-convergence**: Oscillation, never settles +3. **Vanishing gradients**: Discriminator too strong, generator can't learn +4. **Hyperparameter sensitivity**: Learning rates critical + +**Solutions:** +- Spectral normalization (StyleGAN2) +- Progressive growing (start low-res, increase) +- Minibatch discrimination (penalize lack of diversity) +- Wasserstein loss (WGAN, more stable) + +### Mode Collapse + +**What is it?** +- Generator produces subset of distribution +- Example: Face GAN only generates 10 face types + +**Why it happens:** +- Generator exploits discriminator weaknesses +- Finds "easy" samples that fool discriminator +- Forgets other modes + +**Detection:** +```python +# Check diversity: Generate many samples +samples = generator.generate(n=1000) +diversity = compute_pairwise_distance(samples) +if diversity < threshold: + print("Mode collapse detected!") +``` + +**Solutions:** +- Minibatch discrimination +- Unrolled GANs (slow but helps) +- Switch to diffusion (no mode collapse by design!) + +### Modern GANs + +**StyleGAN2 (2020):** +- State-of-the-art for faces +- Style-based generator +- Spectral normalization for stability +- Resolution: 1024×1024 + +**StyleGAN3 (2021):** +- Alias-free architecture +- Better animation/video + +**When to use GAN:** +✅ Fast inference needed (50ms per image) +✅ Pretrained model available (StyleGAN2) +✅ Can tolerate training difficulty + +❌ Training instability unacceptable +❌ Mode collapse problematic +❌ Starting from scratch (use diffusion instead) + +### Implementation (Basic GAN) + +```python +class Generator(nn.Module): + def __init__(self, latent_dim=100, img_channels=3): + super().__init__() + self.model = nn.Sequential( + nn.Linear(latent_dim, 128 * 8 * 8), + nn.ReLU(), + nn.Unflatten(1, (128, 8, 8)), + nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8x8 -> 16x16 + nn.ReLU(), + nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16x16 -> 32x32 + nn.ReLU(), + nn.ConvTranspose2d(32, img_channels, 4, 2, 1), # 32x32 -> 64x64 + nn.Tanh() + ) + + def forward(self, z): + return self.model(z) + +class Discriminator(nn.Module): + def __init__(self, img_channels=3): + super().__init__() + self.model = nn.Sequential( + nn.Conv2d(img_channels, 32, 4, 2, 1), # 64x64 -> 32x32 + nn.LeakyReLU(0.2), + nn.Conv2d(32, 64, 4, 2, 1), # 32x32 -> 16x16 + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), # 16x16 -> 8x8 + nn.LeakyReLU(0.2), + nn.Flatten(), + nn.Linear(128 * 8 * 8, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.model(x) + +# Training loop +for real_images in dataloader: + # Train discriminator + fake_images = generator(noise) + D_real = discriminator(real_images) + D_fake = discriminator(fake_images.detach()) + + D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake)) + D_loss.backward() + optimizer_D.step() + + # Train generator + D_fake = discriminator(fake_images) + G_loss = -torch.mean(torch.log(D_fake)) + G_loss.backward() + optimizer_G.step() +``` + + +## Part 4: Diffusion Models (Modern Default) + +### Architecture + +**Concept**: Learn to reverse a diffusion (noising) process + +**Forward process** (fixed): +```python +# Gradually add noise to image +x_0 (original) → x_1 → x_2 → ... → x_T (pure noise) + +# At each step: +x_t = √(1 - β_t) * x_{t-1} + √β_t * ε +where ε ~ N(0, I), β_t = noise schedule +``` + +**Reverse process** (learned): +```python +# Model learns to denoise +x_T (noise) → x_{T-1} → ... → x_1 → x_0 (image) + +# Model predicts: ε_θ(x_t, t) +# Then: x_{t-1} = (x_t - √β_t * ε_θ(x_t, t)) / √(1 - β_t) +``` + +**Training:** +```python +# Simple loss: Predict the noise +loss = MSE(ε, ε_θ(x_t, t)) + +# x_t = noisy image at step t +# ε = actual noise added +# ε_θ(x_t, t) = model's noise prediction +``` + +### Why Diffusion is Excellent + +**Advantages:** +1. **High quality**: State-of-the-art (better than GAN) +2. **Stable training**: Standard MSE loss (no adversarial dynamics) +3. **No mode collapse**: By design, covers full distribution +4. **Controllable**: Easy to add conditioning (text, class, etc.) + +**Disadvantages:** +1. **Slow inference**: 50-1000 denoising steps (vs GAN's 1 step) +2. **Compute intensive**: T forward passes (T = 50-1000) + +**Speed comparison:** +``` +GAN: 1 forward pass = 50ms +Diffusion (T=50): 50 forward passes = 2.5 seconds +Diffusion (T=1000): 1000 forward passes = 50 seconds +``` + +**Speedup techniques:** +- DDIM (fewer steps, 10-50 instead of 1000) +- DPM-Solver (fast sampler) +- Latent diffusion (Stable Diffusion, denoise in latent space) + +### Modern Diffusion Models + +**Stable Diffusion (2022+):** +- Latent diffusion (denoise in VAE latent space) +- Text conditioning (CLIP text encoder) +- Pretrained on billions of images +- Fine-tunable + +**DALL-E 2 (2022):** +- Prior network (text → image embedding) +- Diffusion decoder (embedding → image) + +**Imagen (2022, Google):** +- Text conditioning with T5 encoder +- Cascaded diffusion (64×64 → 256×256 → 1024×1024) + +**When to use Diffusion:** +✅ High-quality generation (best quality) +✅ Stable training (standard loss) +✅ Diversity needed (no mode collapse) +✅ Conditioning (text-to-image, class-conditional) + +❌ Need fast inference (< 1 second) +❌ Real-time generation + +### Implementation (DDPM) + +```python +class DiffusionModel(nn.Module): + def __init__(self, img_channels=3): + super().__init__() + # U-Net architecture + self.model = UNet( + in_channels=img_channels, + out_channels=img_channels, + time_embedding_dim=256 + ) + + def forward(self, x_t, t): + # Predict noise ε at timestep t + return self.model(x_t, t) + +# Training +def train_step(model, x_0): + # Sample random timestep + t = torch.randint(0, T, (batch_size,)) + + # Sample noise + ε = torch.randn_like(x_0) + + # Create noisy image x_t + α_t = alpha_schedule[t] + x_t = torch.sqrt(α_t) * x_0 + torch.sqrt(1 - α_t) * ε + + # Predict noise + ε_pred = model(x_t, t) + + # Loss: MSE between actual and predicted noise + loss = F.mse_loss(ε_pred, ε) + return loss + +# Sampling (generation) +@torch.no_grad() +def sample(model, shape): + # Start from pure noise + x_t = torch.randn(shape) + + # Iteratively denoise + for t in reversed(range(T)): + # Predict noise + ε_pred = model(x_t, t) + + # Denoise one step + α_t = alpha_schedule[t] + x_t = (x_t - (1 - α_t) / torch.sqrt(1 - ᾱ_t) * ε_pred) / torch.sqrt(α_t) + + # Add noise (except last step) + if t > 0: + x_t += torch.sqrt(β_t) * torch.randn_like(x_t) + + return x_t # x_0 (generated image) +``` + + +## Part 5: Autoregressive Models + +### Concept + +**Idea**: Model probability as product of conditionals +``` +p(x) = p(x_1) * p(x_2|x_1) * p(x_3|x_1,x_2) * ... * p(x_n|x_1,...,x_{n-1}) +``` + +**For images**: Generate pixel-by-pixel (or patch-by-patch) + +**Architectures:** +- **PixelCNN**: Convolutional with masked kernels +- **PixelCNN++**: Improved with mixture of logistics +- **VQ-VAE + PixelCNN**: Two-stage (learn discrete codes, model codes) +- **ImageGPT**: GPT-style Transformer for images + +### Advantages + +✅ **Explicit likelihood**: Can compute p(x) exactly +✅ **Stable training**: Standard cross-entropy loss +✅ **Theoretical guarantees**: Proper probability model + +### Disadvantages + +❌ **Very slow generation**: Sequential (can't parallelize) +❌ **Limited quality**: Worse than GAN/Diffusion for high-res +❌ **Resolution scaling**: Impractical for 1024×1024 (1M pixels!) + +**Speed comparison:** +``` +GAN: Generate 1024×1024 in 50ms (parallel) +PixelCNN: Generate 32×32 in 5 seconds (sequential!) +ImageGPT: Generate 256×256 in 30 seconds + +For 1024×1024: 1M pixels × 5ms/pixel = 83 minutes! +``` + +### When to Use + +✅ **Use autoregressive for:** +- Explicit likelihood needed (compression, evaluation) +- Small images (32×32, 64×64) +- Two-stage models (VQ-VAE + Transformer) + +❌ **Don't use for:** +- High-resolution images (too slow) +- Real-time generation +- Quality-critical applications (use diffusion) + +### Modern Usage + +**Two-stage approach (DALL-E, VQ-GAN):** +1. **Stage 1**: VQ-VAE learns discrete codes + - Image → 32×32 grid of codes (instead of 1M pixels) +2. **Stage 2**: Autoregressive model (Transformer) on codes + - Much faster (32×32 = 1024 codes, not 1M pixels) + + +## Part 6: Flow Models + +### Concept + +**Idea**: Invertible transformations +``` +z ~ N(0, I) ←→ x ~ p_data + +f: z → x (forward) +f⁻¹: x → z (inverse) +``` + +**Requirement**: f must be invertible and differentiable + +**Advantage**: Exact likelihood via change-of-variables +``` +log p(x) = log p(z) + log |det(∂f⁻¹/∂x)| +``` + +### Architectures + +**RealNVP (2017):** +- Coupling layers (affine transformations) +- Invertible by design + +**Glow (2018, OpenAI):** +- Actnorm, invertible 1×1 convolutions +- Multi-scale architecture + +**When to use Flow:** +✅ Exact likelihood needed (better than VAE) +✅ Invertibility needed (both z→x and x→z) +✅ Stable training (standard loss) + +❌ Architecture constraints (must be invertible) +❌ Quality not as good as GAN/Diffusion + +### Modern Status + +**Mostly superseded by Diffusion:** +- Diffusion has better quality +- Diffusion more flexible (no invertibility constraint) +- Flow models still used in specialized applications + + +## Part 7: Decision Framework + +### By Primary Goal + +``` +Goal: High-quality images +→ Diffusion (modern default) +→ OR GAN if pretrained available + +Goal: Fast inference +→ GAN (50ms per image) +→ Avoid Diffusion (too slow for real-time) + +Goal: Training stability +→ Diffusion or VAE (standard loss) +→ Avoid GAN (adversarial training hard) + +Goal: Latent space exploration +→ VAE (smooth interpolation) +→ Avoid GAN (no encoder) + +Goal: Explicit likelihood +→ Autoregressive or Flow +→ For evaluation, compression + +Goal: Diversity (no mode collapse) +→ Diffusion (by design) +→ OR VAE (stable) +→ Avoid GAN (mode collapse common) +``` + +### By Data Type + +``` +Images (high-quality): +→ Diffusion (Stable Diffusion) +→ OR GAN (StyleGAN2) + +Images (small, 32×32): +→ Any model works +→ Try VAE first (simplest) + +Audio waveforms: +→ WaveGAN +→ OR Diffusion (WaveGrad) + +Video: +→ Video Diffusion (limited) +→ OR GAN (StyleGAN-V) + +Text: +→ Autoregressive (GPT) +→ NOT VAE/GAN/Diffusion (discrete tokens) +``` + +### By Training Budget + +``` +Large budget (millions $, pretrain from scratch): +→ Diffusion (Stable Diffusion scale) +→ Billions of images, weeks on cluster + +Medium budget (thousands $, train from scratch): +→ GAN or Diffusion +→ 10k-1M images, days on GPU + +Small budget (hundreds $, fine-tune): +→ Fine-tune Stable Diffusion (LoRA) +→ 1k-10k images, hours on consumer GPU + +Tiny budget (research, small scale): +→ VAE (simplest, most stable) +→ Few thousand images, CPU possible +``` + +### Modern Recommendations (2025) + +**For new projects:** +1. **Default: Diffusion** + - Fine-tune Stable Diffusion or train from scratch + - Best quality + stability + +2. **If need speed: GAN** + - Use pretrained StyleGAN2 if available + - Or train GAN (if can tolerate instability) + +3. **If need latent space: VAE** + - For interpolation, not generation quality + +**AVOID:** +- Training GAN from scratch (unless necessary) +- Using VAE for high-quality generation +- Autoregressive for high-res images + + +## Part 8: Training from Scratch vs Fine-Tuning + +### Stable Diffusion Example + +**Pretraining (what Stability AI did):** +- Dataset: LAION-5B (5 billion images) +- Compute: 150,000 A100 GPU hours +- Cost: ~$600,000 +- Time: Weeks on massive cluster +- **DON'T DO THIS!** + +**Fine-tuning (what users do):** +- Dataset: 10k-100k domain images +- Compute: 100-1000 GPU hours +- Cost: $100-1,000 +- Time: Days on single A100 +- **DO THIS!** + +**LoRA (Low-Rank Adaptation):** +- Efficient fine-tuning (fewer parameters) +- Dataset: 1k-5k images +- Compute: 10-100 GPU hours +- Cost: $10-100 +- Time: Hours on consumer GPU (RTX 3090) +- **Best for small budgets!** + +### Decision + +``` +Have pretrained model in your domain: +→ Fine-tune (don't retrain!) + +No pretrained model: +→ Train from scratch (small model) +→ OR find closest pretrained and fine-tune + +Budget < $1000: +→ LoRA fine-tuning +→ OR train small model (64×64) + +Budget < $100: +→ LoRA with free Colab +→ OR VAE from scratch (cheap) +``` + + +## Part 9: Common Mistakes + +### Mistake 1: VAE for High-Quality Generation + +**Symptom:** Blurry outputs +**Fix:** Use GAN or Diffusion for quality +**VAE is for:** Latent space, not generation + +### Mistake 2: Ignoring Mode Collapse + +**Symptom:** GAN generates same images +**Fix:** Spectral norm, minibatch discrimination +**Better:** Switch to Diffusion (no mode collapse) + +### Mistake 3: Training Stable Diffusion from Scratch + +**Symptom:** Burning money, poor results +**Fix:** Fine-tune pretrained model +**Reality:** Pretraining costs $600k+ + +### Mistake 4: Slow Inference with Diffusion + +**Symptom:** 50 seconds per image +**Fix:** Use DDIM (fewer steps, 10-50) +**OR:** Use GAN if speed critical + +### Mistake 5: Wrong Loss for GAN + +**Symptom:** Training diverges +**Fix:** Use Wasserstein loss (WGAN) +**OR:** Spectral normalization +**Better:** Switch to Diffusion (standard loss) + + +## Summary: Quick Reference + +### Model Selection + +``` +High quality + stable training: +→ Diffusion (modern default) + +Fast inference required: +→ GAN (if pretrained) or trained GAN + +Latent space exploration: +→ VAE + +Explicit likelihood: +→ Autoregressive or Flow + +Small images (< 64×64): +→ Any model (start with VAE) + +Large images (> 256×256): +→ Diffusion or GAN (avoid autoregressive) +``` + +### Quality Ranking + +``` +1. Diffusion (9.5/10) +2. GAN (9/10) +3. Autoregressive (7-8/10) +4. Flow (7-8/10) +5. VAE (6/10 - blurry) +``` + +### Training Stability Ranking + +``` +1. VAE (10/10) +2. Diffusion (9/10) +3. Autoregressive (9/10) +4. Flow (8/10) +5. GAN (3/10 - very unstable) +``` + +### Modern Stack (2025) + +``` +Image generation: Stable Diffusion (fine-tuned) +Fast inference: StyleGAN2 (if available) +Latent space: VAE +Research: Diffusion (easiest to train) +``` + + +## Next Steps + +After mastering this skill: +- `llm-specialist/llm-finetuning-strategies`: Apply to text generation +- `architecture-design-principles`: Understand design trade-offs +- `training-optimization`: Optimize training for your chosen model + +**Remember:** Diffusion models dominate in 2025. Use them unless you have specific reason not to (speed, latent space, likelihood). diff --git a/skills/using-neural-architectures/graph-neural-networks-basics.md b/skills/using-neural-architectures/graph-neural-networks-basics.md new file mode 100644 index 0000000..b809a4f --- /dev/null +++ b/skills/using-neural-architectures/graph-neural-networks-basics.md @@ -0,0 +1,625 @@ + +# Graph Neural Networks Basics + +## When to Use This Skill + +Use this skill when you need to: +- ✅ Work with graph-structured data (molecules, social networks, citations) +- ✅ Understand why CNN/RNN don't work on graphs +- ✅ Learn message passing framework +- ✅ Choose between GCN, GraphSAGE, GAT +- ✅ Decide if GNN is appropriate (vs simple model) +- ✅ Implement permutation-invariant aggregations + +**Do NOT use this skill for:** +- ❌ Sequential data (use RNN/Transformer) +- ❌ Grid data (use CNN) +- ❌ High-level architecture selection (use `using-neural-architectures`) + + +## Core Principle + +**Graphs have irregular structure.** CNN (grid) and RNN (sequence) don't work. + +**GNN solution:** Message passing +- Nodes aggregate information from neighbors +- Multiple layers = multi-hop neighborhoods +- Permutation invariant (order doesn't matter) + +**Critical question:** Does graph structure actually help? (Test: Compare with/without edges) + + +## Part 1: Why GNN (Not CNN/RNN) + +### Problem: Graph Structure + +**Graph components:** +- **Nodes**: Entities (atoms, users, papers) +- **Edges**: Relationships (bonds, friendships, citations) +- **Features**: Node/edge attributes + +**Key property:** Irregular structure +- Each node has variable number of neighbors +- No fixed spatial arrangement +- Permutation invariant (node order doesn't matter) + +### Why CNN Doesn't Work + +**CNN assumption:** Regular grid structure + +**Example:** Image (2D grid) +``` +Every pixel has exactly 8 neighbors: +[■][■][■] +[■][X][■] ← Center pixel has 8 neighbors (fixed!) +[■][■][■] + +CNN kernel: 3×3 (fixed size, fixed positions) +``` + +**Graph reality:** Irregular neighborhoods +``` +Node A: 2 neighbors (H, C) +Node B: 4 neighbors (C, C, C, H) +Node C: 1 neighbor (H) + +No fixed kernel size or position! +``` + +**CNN limitations:** +- Requires fixed-size neighborhoods → Graphs have variable-size +- Assumes spatial locality → Graphs have arbitrary connectivity +- Depends on node ordering → Should be permutation invariant + +### Why RNN Doesn't Work + +**RNN assumption:** Sequential structure + +**Example:** Text (1D sequence) +``` +"The cat sat" → [The] → [cat] → [sat] +Clear sequential order, temporal dependencies +``` + +**Graph reality:** No inherent sequence +``` +Social network: +A — B — C +| | +D ——————E + +What's the "sequence"? A→B→C? A→D→E? No natural ordering! +``` + +**RNN limitations:** +- Requires sequential order → Graphs have no natural order +- Processes one element at a time → Graphs have parallel connections +- Order-dependent → Should be permutation invariant + +### GNN Solution + +**Key innovation:** Message passing on graph structure +- Operate directly on nodes and edges +- Variable-size neighborhoods (handled naturally) +- Permutation invariant aggregations + + +## Part 2: Message Passing Framework + +### Core Mechanism + +**Message passing in 3 steps:** + +**1. Aggregate neighbor messages** +```python +# Node i aggregates from neighbors N(i) +messages = [h_j for j in neighbors(i)] +aggregated = aggregate(messages) # e.g., mean, sum, max +``` + +**2. Update node representation** +```python +# Combine own features with aggregated messages +h_i_new = update(h_i_old, aggregated) # e.g., neural network +``` + +**3. Repeat for L layers** +- Layer 1: Node sees 1-hop neighbors +- Layer 2: Node sees 2-hop neighbors +- Layer L: Node sees L-hop neighborhood + +### Concrete Example: Social Network + +**Task:** Predict user interests + +**Graph:** +``` + B (sports) + | +A ---+--- C (cooking) + | + D (music) +``` + +**Layer 1: 1-hop neighbors** +```python +# Node A aggregates from direct friends +h_A_layer1 = update( + h_A, + aggregate([h_B, h_C, h_D]) +) +# Now h_A includes friend interests +``` + +**Layer 2: 2-hop neighbors (friends of friends)** +```python +# B's friends: E, F +# C's friends: G, H +# D's friends: I + +h_A_layer2 = update( + h_A_layer1, + aggregate([h_B', h_C', h_D']) # h_B' includes E, F +) +# Now h_A includes friends-of-friends! +``` + +**Key insight:** More layers = larger receptive field (L-hop neighborhood) + +### Permutation Invariance + +**Critical property:** Same graph → same output (regardless of node ordering) + +**Example:** +```python +Graph: A-B, B-C + +Node list 1: [A, B, C] +Node list 2: [C, B, A] + +Output MUST be identical! (Same graph, different ordering) +``` + +**Invariant aggregations:** +- ✅ Mean: `mean([1, 2, 3]) == mean([3, 2, 1])` +- ✅ Sum: `sum([1, 2, 3]) == sum([3, 2, 1])` +- ✅ Max: `max([1, 2, 3]) == max([3, 2, 1])` + +**NOT invariant:** +- ❌ LSTM: `LSTM([1, 2, 3]) != LSTM([3, 2, 1])` +- ❌ Concatenate: `[1, 2, 3] != [3, 2, 1]` + +**Implementation:** +```python +# CORRECT: Permutation invariant +def aggregate(neighbor_features): + return torch.mean(neighbor_features, dim=0) + +# WRONG: Order-dependent! +def aggregate(neighbor_features): + return LSTM(neighbor_features) # Output depends on order +``` + + +## Part 3: GNN Architectures + +### Architecture 1: GCN (Graph Convolutional Network) + +**Key idea:** Spectral convolution on graphs (simplified) + +**Formula:** +```python +h_i^(l+1) = σ(∑_{j∈N(i)} W^(l) h_j^(l) / √(|N(i)| |N(j)|)) + +# Normalize by degree (√(deg(i) * deg(j))) +``` + +**Aggregation:** Weighted mean (degree-normalized) + +**Properties:** +- Transductive (needs full graph at training) +- Computationally efficient +- Good baseline + +**When to use:** +- Full graph available at training time +- Starting point (simplest GNN) +- Small to medium graphs + +**Implementation:** +```python +from torch_geometric.nn import GCNConv + +class GCN(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = GCNConv(in_channels, hidden_channels) + self.conv2 = GCNConv(hidden_channels, out_channels) + + def forward(self, x, edge_index): + # x: Node features (N, in_channels) + # edge_index: Graph connectivity (2, E) + + # Layer 1 + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + + # Layer 2 + x = self.conv2(x, edge_index) + + return x +``` + +### Architecture 2: GraphSAGE + +**Key idea:** Sample and aggregate (inductive learning) + +**Formula:** +```python +# Sample fixed-size neighborhood +neighbors_sampled = sample(neighbors(i), k=10) + +# Aggregate +h_N = aggregate({h_j for j in neighbors_sampled}) + +# Concatenate and transform +h_i^(l+1) = σ(W^(l) [h_i^(l); h_N]) +``` + +**Aggregation:** Mean, max, or LSTM (but mean/max preferred for invariance) + +**Key innovation:** Sampling +- Sample fixed number of neighbors (e.g., 10) +- Makes computation tractable for large graphs +- Enables inductive learning (generalizes to unseen nodes) + +**When to use:** +- Large graphs (millions of nodes) +- Need inductive capability (new nodes appear) +- Training on subset, testing on full graph + +**Implementation:** +```python +from torch_geometric.nn import SAGEConv + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = SAGEConv(in_channels, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, out_channels) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.conv2(x, edge_index) + return x +``` + +### Architecture 3: GAT (Graph Attention Network) + +**Key idea:** Learn attention weights for neighbors + +**Formula:** +```python +# Attention scores +α_ij = attention(h_i, h_j) # How important is neighbor j to node i? + +# Normalize (softmax) +α_ij = softmax_j(α_ij) + +# Weighted aggregation +h_i^(l+1) = σ(∑_{j∈N(i)} α_ij W h_j^(l)) +``` + +**Key innovation:** Learned neighbor importance +- Not all neighbors equally important +- Attention mechanism decides weights +- Multi-head attention (like Transformer) + +**When to use:** +- Neighbors have varying importance +- Need interpretability (attention weights) +- Have sufficient data (attention needs more data to learn) + +**Implementation:** +```python +from torch_geometric.nn import GATConv + +class GAT(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, heads=8): + super().__init__() + self.conv1 = GATConv(in_channels, hidden_channels, heads=heads) + self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.elu(x) + x = F.dropout(x, p=0.6, training=self.training) + x = self.conv2(x, edge_index) + return x +``` + +### Architecture Comparison + +| Feature | GCN | GraphSAGE | GAT | +|---------|-----|-----------|-----| +| Aggregation | Degree-weighted mean | Mean/max/LSTM | Attention-weighted | +| Neighbor weighting | Fixed (by degree) | Equal | Learned | +| Inductive | No | Yes | Yes | +| Scalability | Medium | High (sampling) | Medium | +| Interpretability | Low | Low | High (attention) | +| Complexity | Low | Medium | High | + +### Decision Tree + +``` +Starting out / Small graph: +→ GCN (simplest baseline) + +Large graph (millions of nodes): +→ GraphSAGE (sampling enables scalability) + +Need inductive learning (new nodes): +→ GraphSAGE or GAT + +Neighbors have different importance: +→ GAT (attention learns importance) + +Need interpretability: +→ GAT (attention weights explain predictions) + +Production deployment: +→ GraphSAGE (most robust and scalable) +``` + + +## Part 4: When NOT to Use GNN + +### Critical Question + +**Does graph structure actually help?** + +**Test:** Compare model with and without edges +```python +# Baseline: MLP on node features only +mlp_accuracy = train_mlp(node_features, labels) + +# GNN: Use node features + graph structure +gnn_accuracy = train_gnn(node_features, edges, labels) + +# Decision: +if gnn_accuracy - mlp_accuracy < 2%: + print("Graph structure doesn't help much") + print("Use simpler model (MLP or XGBoost)") +else: + print("Graph structure adds value") + print("Use GNN") +``` + +### Scenarios Where GNN Doesn't Help + +**1. Node features dominate** +``` +User churn prediction: +- Node features: Usage hours, demographics, subscription → Highly predictive +- Graph edges: Sparse user interactions → Weak signal +- Result: MLP 85%, GNN 86% (not worth complexity!) +``` + +**2. Sparse graphs** +``` +Graph with 1000 nodes, 100 edges (0.01% density): +- Most nodes have 0-1 neighbors +- No information to aggregate +- GNN reduces to MLP +``` + +**3. Random graph structure** +``` +If edges are random (no homophily): +- Neighbor labels uncorrelated +- Aggregation adds noise +- Simple model better +``` + +### When GNN DOES Help + +✅ **Molecular property prediction** +- Structure is PRIMARY signal +- Atom types + bonds determine properties +- GNN: Huge improvement over fingerprints + +✅ **Citation networks** +- Paper quality correlated with neighbors +- "You are what you cite" +- Clear homophily + +✅ **Social recommendation** +- Friends have similar preferences +- Graph structure informative +- GNN: Moderate to large improvement + +✅ **Knowledge graphs** +- Entities connected by relations +- Multi-hop reasoning valuable +- GNN captures complex patterns + +### Decision Framework + +``` +1. Start simple: + - Try MLP or XGBoost on node features + - Establish baseline performance + +2. Check graph structure value: + - Does edge information correlate with target? + - Is there homophily (similar nodes connected)? + - Test: Remove edges, compare performance + +3. Use GNN if: + - Graph structure adds >2-5% accuracy + - Structure is interpretable (not random) + - Have enough nodes for GNN to learn + +4. Stick with simple if: + - Node features alone sufficient + - Graph structure weak/random + - Small dataset (< 1000 nodes) +``` + + +## Part 5: Practical Implementation + +### Using PyTorch Geometric + +**Installation:** +```bash +pip install torch-geometric +``` + +**Basic workflow:** +```python +import torch +from torch_geometric.data import Data +from torch_geometric.nn import GCNConv + +# 1. Create graph data +x = torch.tensor([[feature1], [feature2], ...]) # Node features +edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) # Edges (COO format) +y = torch.tensor([label1, label2, ...]) # Node labels + +data = Data(x=x, edge_index=edge_index, y=y) + +# 2. Define model +class GNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(in_features, 64) + self.conv2 = GCNConv(64, num_classes) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + x = F.relu(self.conv1(x, edge_index)) + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index) + return F.log_softmax(x, dim=1) + +# 3. Train +model = GNN() +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +for epoch in range(200): + model.train() + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out[train_mask], data.y[train_mask]) + loss.backward() + optimizer.step() +``` + +### Edge Index Format + +**COO (Coordinate) format:** +```python +# Edge list: (0→1), (1→2), (2→0) +edge_index = torch.tensor([ + [0, 1, 2], # Source nodes + [1, 2, 0] # Target nodes +]) + +# For undirected graph, include both directions: +edge_index = torch.tensor([ + [0, 1, 1, 2, 2, 0], # Source + [1, 0, 2, 1, 0, 2] # Target +]) +``` + +### Mini-batching Graphs + +**Problem:** Graphs have different sizes + +**Solution:** Batch graphs as one large disconnected graph +```python +from torch_geometric.data import DataLoader + +# Create dataset +dataset = [Data(...), Data(...), ...] + +# DataLoader handles batching +loader = DataLoader(dataset, batch_size=32, shuffle=True) + +for batch in loader: + # batch contains multiple graphs as one large graph + # batch.batch: Indicator which nodes belong to which graph + out = model(batch.x, batch.edge_index) +``` + + +## Part 6: Common Mistakes + +### Mistake 1: LSTM Aggregation + +**Symptom:** Different outputs for same graph with reordered nodes +**Fix:** Use mean/sum/max aggregation (permutation invariant) + +### Mistake 2: Forgetting Edge Direction + +**Symptom:** Information flows wrong way +**Fix:** For undirected graphs, add edges in both directions + +### Mistake 3: Too Many Layers + +**Symptom:** Performance degrades, over-smoothing +**Fix:** Use 2-3 layers (most graphs have small diameter) +**Explanation:** Too many layers → all nodes converge to same representation + +### Mistake 4: Not Testing Simple Baseline + +**Symptom:** Complex GNN with minimal improvement +**Fix:** Always test MLP on node features first + +### Mistake 5: Using GNN on Euclidean Data + +**Symptom:** CNN/RNN would work better +**Fix:** Use GNN only for irregular graph structure (not grids/sequences) + + +## Part 7: Summary + +### Quick Reference + +**When to use GNN:** +- Graph-structured data (molecules, social networks, citations) +- Irregular neighborhoods (not grid/sequence) +- Graph structure informative (test this!) + +**Architecture selection:** +``` +Start: GCN (simplest) +Large graph: GraphSAGE (scalable) +Inductive learning: GraphSAGE or GAT +Neighbor importance: GAT (attention) +``` + +**Key principles:** +- Message passing: Aggregate neighbors + Update node +- Permutation invariance: Use mean/sum/max (not LSTM) +- Test baseline: MLP first, GNN if structure helps +- Layers: 2-3 sufficient (more = over-smoothing) + +**Implementation:** +- PyTorch Geometric: Standard library +- COO format: Edge index as 2×E tensor +- Batching: Merge graphs into one large graph + + +## Next Steps + +After mastering this skill: +- `transformer-architecture-deepdive`: Understand attention (used in GAT) +- `architecture-design-principles`: Design principles for graph architectures +- Advanced GNNs: Graph Transformers, Equivariant GNNs + +**Remember:** Not all graph data needs GNN. Test if graph structure actually helps! (Compare with MLP baseline) diff --git a/skills/using-neural-architectures/normalization-techniques.md b/skills/using-neural-architectures/normalization-techniques.md new file mode 100644 index 0000000..b7fe576 --- /dev/null +++ b/skills/using-neural-architectures/normalization-techniques.md @@ -0,0 +1,915 @@ + +# Normalization Techniques + +## Context + +You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering: +- **Batch size dependency**: BatchNorm fails with small batches (< 8) +- **Architecture mismatch**: BatchNorm breaks RNNs/Transformers (use LayerNorm) +- **Task-specific needs**: Style transfer needs InstanceNorm, not BatchNorm +- **Modern alternatives**: RMSNorm simpler and faster than LayerNorm for LLMs + +**This skill prevents normalization cargo-culting and provides architecture-specific selection.** + +## Why Normalization Matters + +**Problem: Internal Covariate Shift** +During training, layer input distributions shift as previous layers update. This causes: +- Vanishing/exploding gradients (deep networks) +- Slow convergence (small learning rates required) +- Training instability (loss spikes) + +**Solution: Normalization** +Normalize activations to have stable statistics (mean=0, std=1). Benefits: +- **10x faster convergence**: Can use larger learning rates +- **Better generalization**: Regularization effect +- **Enables deep networks**: 50+ layers without gradient issues +- **Less sensitive to initialization**: Weights can start further from optimal + +**Key insight**: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize. + + +## Normalization Families + +### 1. Batch Normalization (BatchNorm) + +**What it does:** +Normalizes across the batch dimension for each channel/feature. + +**Formula:** +``` +Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width + +For each channel c: + μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims + σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims + + x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε) + + # Learnable scale and shift + y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c +``` + +**When to use:** +- ✅ CNNs for classification (ResNet, EfficientNet) +- ✅ Large batch sizes (≥ 16) +- ✅ IID data (image classification, object detection) + +**When NOT to use:** +- ❌ Small batch sizes (< 8): Noisy statistics cause training failure +- ❌ RNNs/LSTMs: Breaks temporal dependencies +- ❌ Transformers: Batch dependency problematic for variable-length sequences +- ❌ Style transfer: Batch statistics erase style information + +**Batch size dependency:** +```python +batch_size = 32: # ✓ Works well (stable statistics) +batch_size = 16: # ✓ Acceptable +batch_size = 8: # ✓ Marginal (consider GroupNorm) +batch_size = 4: # ✗ Unstable (use GroupNorm) +batch_size = 2: # ✗ FAILS! (noisy statistics) +batch_size = 1: # ✗ Undefined (no batch to normalize over!) +``` + +**PyTorch example:** +```python +import torch.nn as nn + +# For Conv2d +bn = nn.BatchNorm2d(num_features=64) # 64 channels +x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64 +y = bn(x) + +# For Linear +bn = nn.BatchNorm1d(num_features=128) # 128 features +x = torch.randn(32, 128) # Batch=32, Features=128 +y = bn(x) +``` + +**Inference mode:** +```python +# Training: Uses batch statistics +model.train() +y = bn(x) # Normalizes using current batch mean/std + +# Inference: Uses running statistics (accumulated during training) +model.eval() +y = bn(x) # Normalizes using running_mean/running_std +``` + + +### 2. Layer Normalization (LayerNorm) + +**What it does:** +Normalizes across the feature dimension for each sample independently. + +**Formula:** +``` +Given input x with shape (B, C): # Batch, Features + +For each sample b: + μ_b = mean(x[b, :]) # Mean over features + σ_b = std(x[b, :]) # Std over features + + x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε) + + # Learnable scale and shift + y[b, :] = γ * x_norm[b, :] + β +``` + +**When to use:** +- ✅ Transformers (BERT, GPT, T5) +- ✅ RNNs/LSTMs (maintains temporal independence) +- ✅ Small batch sizes (batch-independent!) +- ✅ Variable-length sequences +- ✅ Reinforcement learning (batch_size=1 common) + +**Advantages over BatchNorm:** +- ✅ **Batch-independent**: Works with batch_size=1 +- ✅ **No running statistics**: Inference = training (no mode switching) +- ✅ **Sequence-friendly**: Doesn't mix information across timesteps + +**PyTorch example:** +```python +import torch.nn as nn + +# For Transformer +ln = nn.LayerNorm(normalized_shape=512) # d_model=512 +x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512 +y = ln(x) # Normalizes last dimension independently per (batch, position) + +# For RNN hidden states +ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256 +h = torch.randn(32, 256) # Batch=32, Hidden=256 +h_norm = ln(h) +``` + +**Key difference from BatchNorm:** +```python +# BatchNorm: Normalizes across batch dimension +# Given (B=32, C=64, H=28, W=28) +# Computes 64 means/stds (one per channel, across batch + spatial) + +# LayerNorm: Normalizes across feature dimension +# Given (B=32, L=128, D=512) +# Computes 32×128 means/stds (one per (batch, position), across features) +``` + + +### 3. Group Normalization (GroupNorm) + +**What it does:** +Normalizes channels in groups, batch-independent. + +**Formula:** +``` +Given input x with shape (B, C, H, W): +Divide C channels into G groups (C must be divisible by G) + +For each sample b and group g: + channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g + μ_{b,g} = mean(channels) # Mean over channels in group + spatial + σ_{b,g} = std(channels) # Std over channels in group + spatial + + x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε) +``` + +**When to use:** +- ✅ Small batch sizes (< 8) +- ✅ CNNs with batch_size=1 (style transfer, RL) +- ✅ Object detection/segmentation (often use small batches) +- ✅ When BatchNorm unstable but want spatial normalization + +**Group size selection:** +```python +# num_groups trade-off: +num_groups = 1: # = LayerNorm (all channels together) +num_groups = C: # = InstanceNorm (each channel separate) +num_groups = 32: # Standard choice (good balance) + +# Rule: C must be divisible by num_groups +channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group +channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group +channels = 64, num_groups = 30: # ✗ 64/30 not integer! +``` + +**PyTorch example:** +```python +import torch.nn as nn + +# For small batch CNN +gn = nn.GroupNorm(num_groups=32, num_channels=64) +x = torch.randn(2, 64, 28, 28) # Batch=2 (small!) +y = gn(x) # Works well even with batch=2 + +# Compare performance: +batch_sizes = [1, 2, 4, 8, 16, 32] + +bn = nn.BatchNorm2d(64) +gn = nn.GroupNorm(32, 64) + +for bs in batch_sizes: + x = torch.randn(bs, 64, 28, 28) + + # BatchNorm gets more stable with larger batch + # GroupNorm consistent across all batch sizes +``` + +**Empirical results (He et al. 2018):** +``` +ImageNet classification with ResNet-50: +batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie) +batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!) +batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!) +``` + + +### 4. Instance Normalization (InstanceNorm) + +**What it does:** +Normalizes each sample and channel independently (no batch mixing). + +**Formula:** +``` +Given input x with shape (B, C, H, W): + +For each sample b and channel c: + μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only + σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only + + x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε) +``` + +**When to use:** +- ✅ Style transfer (neural style, CycleGAN, pix2pix) +- ✅ Image-to-image translation +- ✅ When batch/channel mixing destroys information + +**Why for style transfer:** +```python +# Style transfer goal: Transfer style while preserving content +# BatchNorm: Mixes statistics across batch (erases individual style!) +# InstanceNorm: Per-image statistics (preserves each image's style) + +# Example: Neural style transfer +content_image = load_image("photo.jpg") +style_image = load_image("starry_night.jpg") + +# With BatchNorm: Output loses content image's unique characteristics +# With InstanceNorm: Content characteristics preserved, style applied +``` + +**PyTorch example:** +```python +import torch.nn as nn + +# For style transfer generator +in_norm = nn.InstanceNorm2d(num_features=64) +x = torch.randn(1, 64, 256, 256) # Single image +y = in_norm(x) # Normalizes each channel independently + +# CycleGAN generator architecture +class Generator(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, 7, padding=3) + self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm! + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.in1(x) # Per-image normalization + x = self.relu(x) + return x +``` + +**Relation to GroupNorm:** +```python +# InstanceNorm is GroupNorm with num_groups = num_channels +InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C) +``` + + +### 5. RMS Normalization (RMSNorm) + +**What it does:** +Simplified LayerNorm that only rescales (no recentering), faster and simpler. + +**Formula:** +``` +Given input x: + +# LayerNorm (2 steps): +x_centered = x - mean(x) # 1. Center +x_norm = x_centered / std(x) # 2. Scale + +# RMSNorm (1 step): +rms = sqrt(mean(x²)) # Root Mean Square +x_norm = x / rms # Only scale, no centering +``` + +**When to use:** +- ✅ Modern LLMs (LLaMA, Mistral, Gemma) +- ✅ When speed matters (15-20% faster than LayerNorm) +- ✅ Large Transformer models (billions of parameters) + +**Advantages:** +- ✅ **Simpler**: One operation instead of two +- ✅ **Faster**: ~15-20% speedup over LayerNorm +- ✅ **Numerically stable**: No subtraction (avoids catastrophic cancellation) +- ✅ **Same performance**: Empirically matches LayerNorm quality + +**PyTorch implementation:** +```python +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + # Compute RMS + rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) + + # Normalize + x_norm = x / rms + + # Scale (learnable) + return self.weight * x_norm + +# Usage in Transformer +rms = RMSNorm(dim=512) # d_model=512 +x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model +y = rms(x) +``` + +**Speed comparison (LLaMA-7B, A100 GPU):** +``` +LayerNorm: 1000 tokens/sec +RMSNorm: 1180 tokens/sec # 18% faster! + +# For large models, this adds up: +# 1 million tokens: 180 seconds saved +``` + +**Modern LLM adoption:** +```python +# LLaMA (Meta, 2023): RMSNorm +# Mistral (Mistral AI, 2023): RMSNorm +# Gemma (Google, 2024): RMSNorm +# PaLM (Google, 2022): RMSNorm + +# Older models: +# GPT-2/3 (OpenAI): LayerNorm +# BERT (Google, 2018): LayerNorm +``` + + +## Architecture-Specific Selection + +### CNN (Convolutional Neural Networks) + +**Default: BatchNorm** +```python +import torch.nn as nn + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.bn = nn.BatchNorm2d(out_channels) # After conv + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) # Normalize + x = self.relu(x) + return x +``` + +**Exception: Small batch sizes** +```python +# If batch_size < 8, use GroupNorm instead +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches + self.relu = nn.ReLU(inplace=True) +``` + +**Exception: Style transfer** +```python +# Use InstanceNorm for style transfer +class StyleConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization + self.relu = nn.ReLU(inplace=True) +``` + + +### RNN / LSTM (Recurrent Neural Networks) + +**Default: LayerNorm** +```python +import torch.nn as nn + +class NormalizedLSTM(nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states + + def forward(self, x): + # x: (batch, seq_len, input_size) + output, (h_n, c_n) = self.lstm(x) + # output: (batch, seq_len, hidden_size) + + # Normalize each timestep's output + output_norm = self.ln(output) # Applies independently per timestep + return output_norm, (h_n, c_n) +``` + +**Why NOT BatchNorm:** +```python +# BatchNorm in RNN mixes information across timesteps! +# Given (batch=32, seq_len=100, hidden=256) + +# BatchNorm would compute: +# mean/std over (batch × seq_len) = 3200 values +# This mixes t=0 with t=99 (destroys temporal structure!) + +# LayerNorm computes: +# mean/std over hidden_size = 256 values per (batch, timestep) +# Each timestep normalized independently (preserves temporal structure) +``` + +**Layer-wise normalization in stacked RNN:** +```python +class StackedNormalizedLSTM(nn.Module): + def __init__(self, input_size, hidden_size, num_layers): + super().__init__() + self.layers = nn.ModuleList() + + for i in range(num_layers): + in_size = input_size if i == 0 else hidden_size + self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True)) + self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer + + def forward(self, x): + for lstm, ln in zip(self.layers[::2], self.layers[1::2]): + x, _ = lstm(x) + x = ln(x) # Normalize between layers + return x +``` + + +### Transformer + +**Default: LayerNorm (or RMSNorm for modern/large models)** + +**Two placement options: Pre-norm vs Post-norm** + +**Post-norm (original Transformer, "Attention is All You Need"):** +```python +class TransformerLayerPostNorm(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, num_heads) + self.ffn = nn.Sequential( + nn.Linear(d_model, 4 * d_model), + nn.ReLU(), + nn.Linear(4 * d_model, d_model) + ) + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + + def forward(self, x): + # Post-norm: Apply normalization AFTER residual + x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding + x = self.ln2(x + self.ffn(x)) # Normalize after adding + return x +``` + +**Pre-norm (modern, more stable):** +```python +class TransformerLayerPreNorm(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, num_heads) + self.ffn = nn.Sequential( + nn.Linear(d_model, 4 * d_model), + nn.ReLU(), + nn.Linear(4 * d_model, d_model) + ) + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + + def forward(self, x): + # Pre-norm: Apply normalization BEFORE sublayer + x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention + x = x + self.ffn(self.ln2(x)) # Normalize before FFN + return x +``` + +**Pre-norm vs Post-norm comparison:** +```python +# Post-norm (original): +# - Less stable (requires careful initialization + warmup) +# - Slightly better performance IF training succeeds +# - Hard to train deep models (> 12 layers) + +# Pre-norm (modern): +# - More stable (easier to train deep models) +# - Standard for large models (GPT-3: 96 layers!) +# - Recommended default + +# Empirical: GPT-2, BERT (post-norm, ≤12 layers) +# GPT-3, T5, LLaMA (pre-norm, ≥24 layers) +``` + +**Using RMSNorm instead:** +```python +class TransformerLayerRMSNorm(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, num_heads) + self.ffn = nn.Sequential( + nn.Linear(d_model, 4 * d_model), + nn.ReLU(), + nn.Linear(4 * d_model, d_model) + ) + self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm + self.rms2 = RMSNorm(d_model) + + def forward(self, x): + # Pre-norm with RMSNorm (LLaMA style) + x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0] + x = x + self.ffn(self.rms2(x)) + return x +``` + + +### GAN (Generative Adversarial Network) + +**Generator: InstanceNorm or no normalization** +```python +class Generator(nn.Module): + def __init__(self): + super().__init__() + # Use InstanceNorm for image-to-image translation + self.conv1 = nn.Conv2d(3, 64, 7, padding=3) + self.in1 = nn.InstanceNorm2d(64) # Per-image normalization + + def forward(self, x): + x = self.conv1(x) + x = self.in1(x) # Preserves per-image characteristics + return x +``` + +**Discriminator: No normalization or LayerNorm** +```python +class Discriminator(nn.Module): + def __init__(self): + super().__init__() + # Often no normalization (BatchNorm can hurt GAN training) + self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1) + # No normalization here + + def forward(self, x): + x = self.conv1(x) + # Directly to activation (no norm) + return x +``` + +**Why avoid BatchNorm in GANs:** +```python +# BatchNorm in discriminator: +# - Mixes real and fake samples in batch +# - Leaks information (discriminator can detect batch composition) +# - Hurts training stability + +# Recommendation: +# Generator: InstanceNorm (for image translation) or no norm +# Discriminator: No normalization or LayerNorm +``` + + +## Decision Framework + +### Step 1: Check batch size + +```python +if batch_size >= 8: + consider_batchnorm = True +else: + use_groupnorm_or_layernorm = True # BatchNorm will be unstable +``` + +### Step 2: Check architecture + +```python +if architecture == "CNN": + if batch_size >= 8: + use_batchnorm() + else: + use_groupnorm(num_groups=32) + + # Exception: Style transfer + if task == "style_transfer": + use_instancenorm() + +elif architecture in ["RNN", "LSTM", "GRU"]: + use_layernorm() # NEVER BatchNorm! + +elif architecture == "Transformer": + if model_size == "large": # > 1B parameters + use_rmsnorm() # 15-20% faster + else: + use_layernorm() + + # Placement: Pre-norm (more stable) + use_prenorm_placement() + +elif architecture == "GAN": + if component == "generator": + if task == "image_translation": + use_instancenorm() + else: + use_no_norm() # Or InstanceNorm + elif component == "discriminator": + use_no_norm() # Or LayerNorm +``` + +### Step 3: Verify placement + +```python +# CNNs: After convolution, before activation +x = conv(x) +x = norm(x) # Here! +x = relu(x) + +# RNNs: After LSTM, normalize hidden states +output, (h, c) = lstm(x) +output = norm(output) # Here! + +# Transformers: Pre-norm (modern) or post-norm (original) +# Pre-norm (recommended): +x = x + sublayer(norm(x)) # Normalize before sublayer + +# Post-norm (original): +x = norm(x + sublayer(x)) # Normalize after residual +``` + + +## Implementation Checklist + +### Before adding normalization: + +1. ☐ **Check batch size**: If < 8, avoid BatchNorm +2. ☐ **Check architecture**: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm +3. ☐ **Check task**: Style transfer→InstanceNorm +4. ☐ **Verify placement**: After conv/linear, before activation (CNNs) +5. ☐ **Test training stability**: Loss should decrease smoothly + +### During training: + +6. ☐ **Monitor running statistics** (BatchNorm): Check running_mean/running_var are updating +7. ☐ **Test inference mode**: Verify model.eval() uses running stats correctly +8. ☐ **Check gradient flow**: Normalization should help, not hurt gradients + +### If training is unstable: + +9. ☐ **Try different normalization**: BatchNorm→GroupNorm, LayerNorm→RMSNorm +10. ☐ **Try pre-norm** (Transformers): More stable than post-norm +11. ☐ **Reduce learning rate**: Normalization allows larger LR, but start conservatively + + +## Common Mistakes + +### Mistake 1: BatchNorm with small batches + +```python +# WRONG: BatchNorm with batch_size=2 +model = ResNet50(norm_layer=nn.BatchNorm2d) +dataloader = DataLoader(dataset, batch_size=2) # Too small! + +# RIGHT: GroupNorm for small batches +model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels)) +dataloader = DataLoader(dataset, batch_size=2) # Works! +``` + +### Mistake 2: BatchNorm in RNN + +```python +# WRONG: BatchNorm in LSTM +class BadLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(100, 256) + self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps + + def forward(self, x): + output, _ = self.lstm(x) + output = output.permute(0, 2, 1) # (B, H, T) + output = self.bn(output) # Mixes timesteps! + return output + +# RIGHT: LayerNorm in LSTM +class GoodLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(100, 256) + self.ln = nn.LayerNorm(256) # Per-timestep normalization + + def forward(self, x): + output, _ = self.lstm(x) + output = self.ln(output) # Independent per timestep + return output +``` + +### Mistake 3: Forgetting model.eval() + +```python +# WRONG: Using training mode during inference +model.train() # BatchNorm uses batch statistics +predictions = model(test_data) # Batch statistics from test data (leakage!) + +# RIGHT: Use eval mode during inference +model.eval() # BatchNorm uses running statistics +with torch.no_grad(): + predictions = model(test_data) # Uses accumulated running stats +``` + +### Mistake 4: Post-norm for deep Transformers + +```python +# WRONG: Post-norm for 24-layer Transformer (unstable!) +class DeepTransformerPostNorm(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([ + TransformerLayerPostNorm(512, 8) for _ in range(24) + ]) # Hard to train! + +# RIGHT: Pre-norm for deep Transformers +class DeepTransformerPreNorm(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([ + TransformerLayerPreNorm(512, 8) for _ in range(24) + ]) # Stable training! +``` + +### Mistake 5: Wrong normalization for style transfer + +```python +# WRONG: BatchNorm for style transfer +class StyleGenerator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 7, padding=3) + self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch + +# RIGHT: InstanceNorm for style transfer +class StyleGenerator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 7, padding=3) + self.norm = nn.InstanceNorm2d(64) # Per-image normalization +``` + + +## Performance Impact + +### Training speed: + +```python +# Without normalization: 100 epochs to converge +# With normalization: 10 epochs to converge (10x faster!) + +# Reason: Larger learning rates possible +lr_no_norm = 0.001 # Must be small (unstable otherwise) +lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes) +``` + +### Inference speed: + +```python +# Normalization overhead (relative to no normalization): +BatchNorm: +2% (minimal, cached running stats) +LayerNorm: +3-5% (compute mean/std per forward pass) +RMSNorm: +2-3% (faster than LayerNorm) +GroupNorm: +5-8% (more computation than BatchNorm) +InstanceNorm: +3-5% (similar to LayerNorm) + +# For most models: Overhead is negligible compared to conv/linear layers +``` + +### Memory usage: + +```python +# Normalization memory (per layer): +BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β) +LayerNorm: 2 × normalized_shape (γ, β) +RMSNorm: 1 × normalized_shape (γ only, no β) + +# Example: 512 channels +BatchNorm: 4 × 512 = 2048 parameters +LayerNorm: 2 × 512 = 1024 parameters +RMSNorm: 1 × 512 = 512 parameters # Most efficient! +``` + + +## When NOT to Normalize + +**Case 1: Final output layer** +```python +# Don't normalize final predictions +class Classifier(nn.Module): + def __init__(self): + super().__init__() + self.backbone = ResNet50() # Normalization inside + self.fc = nn.Linear(2048, 1000) + # NO normalization here! (final logits should be unnormalized) + + def forward(self, x): + x = self.backbone(x) + x = self.fc(x) # Raw logits + return x # Don't normalize! +``` + +**Case 2: Very small networks** +```python +# Single-layer network: Normalization overkill +class TinyNet(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(784, 10) # MNIST classifier + # No normalization needed (network too simple) + + def forward(self, x): + return self.fc(x) +``` + +**Case 3: When debugging** +```python +# Remove normalization to isolate issues +# If training fails with normalization, try without to check if: +# - Initialization is correct +# - Loss function is correct +# - Data is correctly preprocessed +``` + + +## Modern Recommendations (2025) + +### CNNs: +- **Default**: BatchNorm (if batch_size ≥ 8) +- **Small batches**: GroupNorm (num_groups=32) +- **Style transfer**: InstanceNorm + +### RNNs/LSTMs: +- **Default**: LayerNorm +- **Never**: BatchNorm (breaks temporal structure) + +### Transformers: +- **Small models** (< 1B): LayerNorm + pre-norm +- **Large models** (≥ 1B): RMSNorm + pre-norm (15-20% faster) +- **Avoid**: Post-norm for deep models (> 12 layers) + +### GANs: +- **Generator**: InstanceNorm (image translation) or no norm +- **Discriminator**: No normalization or LayerNorm +- **Avoid**: BatchNorm (leaks information) + +### Emerging: +- **RMSNorm adoption increasing**: LLaMA, Mistral, Gemma all use RMSNorm +- **Pre-norm becoming standard**: More stable for deep networks +- **GroupNorm gaining traction**: Object detection, small-batch training + + +## Summary + +**Normalization is mandatory for modern deep learning.** The question is which normalization, not whether to normalize. + +**Quick decision tree:** +1. **Batch size ≥ 8?** → Consider BatchNorm (CNNs) +2. **Batch size < 8?** → Use GroupNorm (CNNs) or LayerNorm (all) +3. **RNN/LSTM?** → LayerNorm (never BatchNorm!) +4. **Transformer?** → LayerNorm or RMSNorm with pre-norm +5. **Style transfer?** → InstanceNorm +6. **GAN?** → InstanceNorm (generator) or no norm (discriminator) + +**Modern defaults:** +- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8) +- RNNs: LayerNorm +- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models) +- GANs: InstanceNorm (generator), no norm (discriminator) + +**Key insight**: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer. diff --git a/skills/using-neural-architectures/sequence-models-comparison.md b/skills/using-neural-architectures/sequence-models-comparison.md new file mode 100644 index 0000000..c7d0a3d --- /dev/null +++ b/skills/using-neural-architectures/sequence-models-comparison.md @@ -0,0 +1,714 @@ + +# Sequence Models Comparison: Choosing the Right Architecture for Sequential Data + + +Sequence modeling has evolved rapidly: +- 2014-2017: LSTM/GRU dominated +- 2017+: Transformers revolutionized the field +- 2018+: TCN emerged as efficient alternative +- 2021+: Sparse Transformers for very long sequences +- 2022+: State Space Models (S4) for extreme lengths + +Don't default to LSTM (outdated) or blindly use Transformers (not always appropriate). +Match architecture to your sequence characteristics. + + +## When to Use This Skill + +Use this skill when: +- ✅ Selecting model for sequential/temporal data +- ✅ Comparing RNN vs LSTM vs Transformer +- ✅ Deciding on sequence architecture for time series, text, audio +- ✅ Understanding modern alternatives to LSTM +- ✅ Optimizing for sequence length, speed, or accuracy + +DO NOT use for: +- ❌ Vision tasks (use cnn-families-and-selection) +- ❌ Graph-structured data (use graph-neural-networks-basics) +- ❌ LLM-specific questions (use llm-specialist pack) + +**When in doubt:** If data is sequential/temporal → this skill. + + +## Selection Framework + +### Step 1: Identify Key Characteristics + +**Before recommending, ask:** + +| Characteristic | Question | Impact | +|----------------|----------|--------| +| **Sequence Length** | Typical length? | Short (< 100) → LSTM/CNN, Medium (100-1k) → Transformer, Long (> 1k) → Sparse Transformer/S4 | +| **Data Type** | Language, time series, audio? | Language → Transformer, Time series → TCN/Transformer, Audio → Specialized | +| **Data Volume** | Training examples? | Small (< 10k) → LSTM/TCN, Large (> 100k) → Transformer | +| **Latency** | Real-time needed? | Yes → TCN/LSTM, No → Transformer | +| **Deployment** | Cloud/edge/mobile? | Edge → TCN/LSTM, Cloud → Any | + +### Step 2: Apply Decision Tree + +``` +START: What's your primary constraint? + +┌─ SEQUENCE LENGTH +│ ├─ Short (< 100 steps) +│ │ ├─ Language → BiLSTM or small Transformer +│ │ └─ Time series → TCN or LSTM +│ │ +│ ├─ Medium (100-1000 steps) +│ │ ├─ Language → Transformer (BERT-style) +│ │ └─ Time series → Transformer or TCN +│ │ +│ ├─ Long (1000-10000 steps) +│ │ ├─ Sparse Transformer (Longformer, BigBird) +│ │ └─ Hierarchical models +│ │ +│ └─ Very Long (> 10000 steps) +│ └─ State Space Models (S4) +│ +├─ DATA TYPE +│ ├─ Natural Language +│ │ ├─ < 50k data → BiLSTM or DistilBERT +│ │ └─ > 50k data → Transformer (BERT, RoBERTa) +│ │ +│ ├─ Time Series +│ │ ├─ Fast training → TCN +│ │ ├─ Long sequences → Transformer +│ │ └─ Multivariate → Transformer with cross-series attention +│ │ +│ └─ Audio +│ ├─ Waveform → WaveNet (TCN-based) +│ └─ Spectrograms → CNN + Transformer +│ +└─ COMPUTATIONAL CONSTRAINT + ├─ Edge device → TCN or small LSTM + ├─ Real-time latency → TCN (parallel inference) + └─ Cloud, no constraint → Transformer +``` + + +## Architecture Catalog + +### 1. RNN (Recurrent Neural Networks) - Legacy Foundation + +**Architecture:** Basic recurrent cell with hidden state + +**Status:** **OUTDATED** - don't use for new projects + +**Why it existed:** +- First neural approach to sequences +- Hidden state captures temporal information +- Theoretically can model any sequence + +**Why it failed:** +- Vanishing gradient (can't learn long dependencies) +- Very slow training (sequential processing) +- Replaced by LSTM in 2014 + +**When to mention:** +- Historical context only +- Teaching purposes +- Never recommend for production + +**Key Insight:** Proved neural nets could handle sequences, but impractical due to vanishing gradients + + +### 2. LSTM (Long Short-Term Memory) - Legacy Standard + +**Architecture:** Gated recurrent cell (forget, input, output gates) + +**Complexity:** O(n) memory, sequential processing + +**Strengths:** +- Solves vanishing gradient (gates maintain long-term info) +- Works well for short-medium sequences (< 500 steps) +- Small datasets (< 10k examples) +- Low memory footprint + +**Weaknesses:** +- Sequential processing (slow training, can't parallelize) +- Still struggles with very long sequences (> 1000 steps) +- Slow inference (especially bidirectional) +- Superseded by Transformers for most language tasks + +**When to Use:** +- ✅ Small datasets (< 10k examples) +- ✅ Short sequences (< 100 steps) +- ✅ Edge deployment (low memory) +- ✅ Baseline comparison + +**When NOT to Use:** +- ❌ Large datasets (Transformer better) +- ❌ Long sequences (> 500 steps) +- ❌ Modern NLP (Transformer standard) +- ❌ Fast training needed (TCN better) + +**Code Example:** +```python +class SeqLSTM(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, + num_layers=2, + batch_first=True, + bidirectional=True) + self.fc = nn.Linear(hidden_size * 2, num_classes) + + def forward(self, x): + # x: (batch, seq_len, features) + lstm_out, _ = self.lstm(x) + # Use last timestep + out = self.fc(lstm_out[:, -1, :]) + return out +``` + +**Status:** Legacy but still useful for specific cases (small data, edge deployment) + + +### 3. GRU (Gated Recurrent Unit) - Simplified LSTM + +**Architecture:** Simplified gating (2 gates instead of 3) + +**Advantages over LSTM:** +- Fewer parameters (faster training) +- Similar performance in many tasks +- Lower memory + +**Disadvantages:** +- Still sequential (same as LSTM) +- No major advantage over LSTM in practice +- Also superseded by Transformers + +**When to Use:** +- Same as LSTM, but prefer LSTM for slightly better performance +- Use if computational savings matter + +**Status:** Rarely recommended - if using recurrent, prefer LSTM or move to Transformer/TCN + + +### 4. Transformer - Modern Standard + +**Architecture:** Self-attention mechanism, parallel processing + +**Complexity:** +- Memory: O(n²) for sequence length n +- Compute: O(n²d) where d is embedding dimension + +**Strengths:** +- ✅ Parallel processing (fast training) +- ✅ Captures long-range dependencies (better than LSTM) +- ✅ State-of-the-art for language (BERT, GPT) +- ✅ Pre-trained models available +- ✅ Scales with data (more data = better performance) + +**Weaknesses:** +- ❌ Quadratic memory (struggles with sequences > 1000) +- ❌ Needs more data than LSTM (> 10k examples) +- ❌ Slower inference than TCN +- ❌ Harder to interpret than RNN + +**When to Use:** +- ✅ **Natural language** (current standard) +- ✅ Medium sequences (100-1000 tokens) +- ✅ Large datasets (> 50k examples) +- ✅ Pre-training available (BERT, GPT) +- ✅ Accuracy priority + +**When NOT to Use:** +- ❌ Short sequences (< 50 tokens) - LSTM/CNN competitive, simpler +- ❌ Very long sequences (> 2000) - quadratic memory explodes +- ❌ Small datasets (< 10k) - will overfit +- ❌ Edge deployment - large model size + +**Memory Analysis:** +```python +# Standard Transformer attention + +# For sequence length n=1000, batch_size=32, embedding_dim=512: +attention_weights = softmax(Q @ K^T / sqrt(d)) # Shape: (32, 1000, 1000) +# Memory: 32 * 1000 * 1000 * 4 bytes = 128 MB just for attention! + +# For n=5000: +# Memory: 32 * 5000 * 5000 * 4 bytes = 3.2 GB per batch! +# → Impossible on most GPUs +``` + +**Code Example:** +```python +from transformers import BertModel, BertTokenizer + +# Pre-trained BERT for text classification +class TransformerClassifier(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.bert = BertModel.from_pretrained('bert-base-uncased') + self.classifier = nn.Linear(768, num_classes) + + def forward(self, input_ids, attention_mask): + outputs = self.bert(input_ids=input_ids, + attention_mask=attention_mask) + # Use [CLS] token representation + pooled = outputs.pooler_output + return self.classifier(pooled) + +# Fine-tuning +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +model = TransformerClassifier(num_classes=2) +``` + +**Status:** **Current standard for NLP**, competitive for time series with large data + + +### 5. TCN (Temporal Convolutional Network) - Underrated Alternative + +**Architecture:** 1D convolutions with dilated causal convolutions + +**Complexity:** O(n) memory, fully parallel processing + +**Strengths:** +- ✅ **Parallel training** (much faster than LSTM) +- ✅ **Parallel inference** (faster than LSTM/Transformer) +- ✅ Linear memory (no quadratic blow-up) +- ✅ Large receptive field (dilation) +- ✅ Works well for time series +- ✅ Simple architecture + +**Weaknesses:** +- ❌ Less popular (fewer pre-trained models) +- ❌ Not standard for language (Transformer dominates) +- ❌ Fixed receptive field (vs adaptive attention) + +**When to Use:** +- ✅ **Time series forecasting** (often BETTER than LSTM) +- ✅ **Fast training needed** (2-3x faster than LSTM) +- ✅ **Fast inference** (real-time applications) +- ✅ Long sequences (linear memory) +- ✅ Audio processing (WaveNet is TCN-based) + +**When NOT to Use:** +- ❌ Natural language with pre-training available (use Transformer) +- ❌ Need very large receptive field (Transformer better) + +**Performance Comparison:** +``` +Time series forecasting (1000-step sequences): + +Training speed: +- LSTM: 100% (baseline, sequential) +- TCN: 35% (2.8x faster, parallel) +- Transformer: 45% (2.2x faster) + +Inference speed: +- LSTM: 100% (sequential) +- TCN: 20% (5x faster, parallel) +- Transformer: 60% (1.7x faster) + +Accuracy (similar across all three): +- LSTM: Baseline +- TCN: Equal or slightly better +- Transformer: Equal or slightly better (needs more data) + +Conclusion: TCN wins on speed, matches accuracy +``` + +**Code Example:** +```python +class TCN(nn.Module): + def __init__(self, input_channels, num_channels, kernel_size=3): + super().__init__() + + layers = [] + num_levels = len(num_channels) + + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = input_channels if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + + # Causal dilated convolution + layers.append( + nn.Conv1d(in_channels, out_channels, kernel_size, + padding=(kernel_size-1) * dilation_size, + dilation=dilation_size) + ) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(0.2)) + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + +# Usage for time series +# Input: (batch, channels, sequence_length) +model = TCN(input_channels=1, num_channels=[64, 128, 256]) +``` + +**Key Insight:** Dilated convolutions create exponentially large receptive field (2^k) while maintaining linear memory + +**Status:** **Excellent for time series**, underrated, should be considered before LSTM + + +### 6. Sparse Transformers - Long Sequence Specialists + +**Architecture:** Modified attention patterns to reduce complexity + +**Variants:** +- **Longformer**: Local + global attention +- **BigBird**: Random + local + global attention +- **Linformer**: Low-rank projection of keys/values +- **Performer**: Kernel approximation of attention + +**Complexity:** O(n log n) or O(n) depending on variant + +**When to Use:** +- ✅ **Long sequences** (1000-10000 tokens) +- ✅ Document processing (multi-page documents) +- ✅ Long-context language modeling +- ✅ When standard Transformer runs out of memory + +**Trade-offs:** +- Slightly lower accuracy than full attention (approximation) +- More complex implementation +- Fewer pre-trained models + +**Example Use Cases:** +- Legal document analysis (10k+ tokens) +- Scientific paper understanding +- Long-form text generation +- Time series with thousands of steps + +**Status:** Specialized for long sequences, active research area + + +### 7. State Space Models (S4) - Cutting Edge + +**Architecture:** Structured state space with efficient recurrence + +**Complexity:** O(n log n) training, O(n) inference + +**Strengths:** +- ✅ **Very long sequences** (10k-100k steps) +- ✅ Linear inference complexity +- ✅ Strong theoretical foundations +- ✅ Handles continuous-time sequences + +**Weaknesses:** +- ❌ Newer (less mature ecosystem) +- ❌ Complex mathematics +- ❌ Fewer pre-trained models +- ❌ Harder to implement + +**When to Use:** +- ✅ Extremely long sequences (> 10k steps) +- ✅ Audio (raw waveforms, 16kHz sampling) +- ✅ Medical signals (ECG, EEG) +- ✅ Research applications + +**Status:** **Cutting edge** (2022+), promising for very long sequences + + +## Practical Selection Guide + +### Scenario 1: Natural Language Processing + +**Short text (< 50 tokens, e.g., tweets, titles):** +``` +Small dataset (< 10k): +→ BiLSTM or 1D CNN (simple, effective) + +Large dataset (> 10k): +→ DistilBERT (smaller Transformer, 40M params) +→ Or BiLSTM if latency critical +``` + +**Medium text (50-512 tokens, e.g., reviews, articles):** +``` +Standard approach: +→ BERT, RoBERTa, or similar (110M params) +→ Fine-tune on task-specific data + +Small dataset: +→ DistilBERT (66M params, faster, similar accuracy) +``` + +**Long documents (> 512 tokens):** +``` +→ Longformer (4096 tokens max) +→ BigBird (4096 tokens max) +→ Hierarchical: Process in chunks, aggregate +``` + + +### Scenario 2: Time Series Forecasting + +**Short sequences (< 100 steps):** +``` +Fast training: +→ TCN (2-3x faster than LSTM) + +Small dataset: +→ LSTM or simple models (ARIMA, Prophet) + +Baseline: +→ LSTM (well-tested) +``` + +**Medium sequences (100-1000 steps):** +``` +Best accuracy: +→ Transformer (if data > 50k examples) + +Fast training/inference: +→ TCN (parallel processing) + +Multivariate: +→ Transformer with cross-series attention +``` + +**Long sequences (> 1000 steps):** +``` +→ Sparse Transformer (Informer for time series) +→ Hierarchical models (chunk + aggregate) +→ State Space Models (S4) +``` + + +### Scenario 3: Audio Processing + +**Waveform (raw audio, 16kHz):** +``` +→ WaveNet (TCN-based) +→ State Space Models (S4) +``` + +**Spectrograms (mel-spectrograms):** +``` +→ CNN + BiLSTM (traditional) +→ CNN + Transformer (modern) +``` + +**Speech recognition:** +``` +→ Transformer (Wav2Vec 2.0, Whisper) +→ Pre-trained models available +``` + + +## Trade-Off Analysis + +### Speed Comparison + +**Training speed (1000-step sequences):** +``` +LSTM: 100% (baseline, sequential) +GRU: 75% (simpler gates) +TCN: 35% (2.8x faster, parallel) +Transformer: 45% (2.2x faster, parallel) + +Conclusion: TCN fastest for training +``` + +**Inference speed:** +``` +LSTM: 100% (sequential) +BiLSTM: 200% (2x passes) +TCN: 20% (5x faster, parallel) +Transformer: 60% (faster, but attention overhead) + +Conclusion: TCN fastest for inference +``` + + +### Memory Comparison + +**Sequence length n=1000, batch=32:** +``` +LSTM: ~500 MB (linear in n) +Transformer: ~2 GB (quadratic in n) +TCN: ~400 MB (linear in n) +Sparse Transformer: ~800 MB (n log n) + +For n=5000: +LSTM: ~2 GB +Transformer: OUT OF MEMORY (50 GB needed!) +TCN: ~2 GB +Sparse Transformer: ~4 GB +``` + + +### Accuracy vs Data Size + +**Small dataset (< 10k examples):** +``` +LSTM: ★★★★☆ (works well with little data) +Transformer: ★★☆☆☆ (overfits, needs more data) +TCN: ★★★★☆ (similar to LSTM) + +Winner: LSTM or TCN +``` + +**Large dataset (> 100k examples):** +``` +LSTM: ★★★☆☆ (good but plateaus) +Transformer: ★★★★★ (best, scales with data) +TCN: ★★★★☆ (competitive) + +Winner: Transformer +``` + + +## Common Pitfalls + +### Pitfall 1: Using LSTM in 2025 Without Considering Modern Alternatives +**Symptom:** Defaulting to LSTM for all sequence tasks + +**Why it's wrong:** Transformers (language) and TCN (time series) often better + +**Fix:** Consider Transformer for language, TCN for time series, LSTM for small data/edge only + + +### Pitfall 2: Using Standard Transformer for Very Long Sequences +**Symptom:** Running out of memory on sequences > 1000 tokens + +**Why it's wrong:** O(n²) memory explodes + +**Fix:** Use Sparse Transformer (Longformer, BigBird) or hierarchical approach + + +### Pitfall 3: Not Trying TCN for Time Series +**Symptom:** Struggling with slow LSTM training + +**Why it's wrong:** TCN is 2-3x faster, often more accurate + +**Fix:** Try TCN before optimizing LSTM + + +### Pitfall 4: Using Transformer for Small Datasets +**Symptom:** Transformer overfits on < 10k examples + +**Why it's wrong:** Transformers need large datasets to work well + +**Fix:** Use LSTM or TCN for small datasets, or use pre-trained Transformer + + +### Pitfall 5: Ignoring Sequence Length Constraints +**Symptom:** Choosing architecture without considering typical sequence length + +**Why it's wrong:** Architecture effectiveness varies dramatically with length + +**Fix:** Match architecture to sequence length (short → LSTM/CNN, long → Sparse Transformer) + + +## Evolution Timeline + +**Understanding why architectures evolved:** + +``` +2010-2013: Basic RNN +→ Vanishing gradient problem +→ Can't learn long dependencies + +2014: LSTM (Hochreiter & Schmidhuber) +→ Gates solve vanishing gradient +→ Became standard for sequences + +2014: GRU +→ Simplified LSTM +→ Similar performance, fewer parameters + +2017: Transformer (Attention Is All You Need) +→ Self-attention replaces recurrence +→ Parallel processing (fast training) +→ Revolutionized NLP + +2018: TCN (Temporal Convolutional Networks) +→ Dilated convolutions for sequences +→ Often better than LSTM for time series +→ Underrated alternative + +2020: Sparse Transformers +→ Reduce quadratic complexity +→ Enable longer sequences + +2021: State Space Models (S4) +→ Very long sequences (10k-100k) +→ Theoretical foundations +→ Cutting edge research + +Current (2025): +- NLP: Transformer standard (BERT, GPT) +- Time Series: TCN or Transformer +- Audio: Specialized (WaveNet, Transformer) +- Edge: LSTM or TCN (low memory) +``` + + +## Decision Checklist + +Before choosing sequence model: + +``` +☐ Sequence length? (< 100 / 100-1k / > 1k) +☐ Data type? (language / time series / audio / other) +☐ Dataset size? (< 10k / 10k-100k / > 100k) +☐ Latency requirement? (real-time / batch / offline) +☐ Deployment target? (cloud / edge / mobile) +☐ Pre-trained models available? (yes / no) +☐ Training speed critical? (yes / no) + +Based on answers: +→ Language + large data → Transformer +→ Language + small data → BiLSTM or DistilBERT +→ Time series + speed → TCN +→ Time series + accuracy + large data → Transformer +→ Very long sequences → Sparse Transformer or S4 +→ Edge deployment → TCN or LSTM +→ Real-time latency → TCN +``` + + +## Integration with Other Skills + +**For language-specific questions:** +→ `yzmir/llm-specialist/using-llm-specialist` +- LLM-specific Transformers (GPT, BERT variants) +- Fine-tuning strategies +- Prompt engineering + +**For Transformer internals:** +→ `yzmir/neural-architectures/transformer-architecture-deepdive` +- Attention mechanisms +- Positional encoding +- Transformer variants + +**After selecting architecture:** +→ `yzmir/training-optimization/using-training-optimization` +- Optimizer selection +- Learning rate schedules +- Handling sequence-specific training issues + + +## Summary + +**Quick Reference Table:** + +| Use Case | Best Choice | Alternative | Avoid | +|----------|-------------|-------------|-------| +| Short text (< 50 tokens) | BiLSTM, DistilBERT | 1D CNN | Full BERT (overkill) | +| Long text (> 512 tokens) | Longformer, BigBird | Hierarchical | Standard BERT (memory) | +| Time series (< 1k steps) | TCN, Transformer | LSTM | Basic RNN | +| Time series (> 1k steps) | Sparse Transformer, S4 | Hierarchical | Standard Transformer | +| Small dataset (< 10k) | LSTM, TCN | Simple models | Transformer (overfits) | +| Large dataset (> 100k) | Transformer | TCN | LSTM (plateaus) | +| Edge deployment | TCN, LSTM | Quantized Transformer | Large Transformer | +| Real-time inference | TCN | Small LSTM | BiLSTM, Transformer | + +**Key Principles:** +1. **Don't default to LSTM** (outdated for most tasks) +2. **Transformer for language** (current standard, if data sufficient) +3. **TCN for time series** (fast, effective, underrated) +4. **Match to sequence length** (short → LSTM/CNN, long → Sparse Transformer) +5. **Consider modern alternatives** (don't stop at LSTM vs Transformer) + + +**END OF SKILL** diff --git a/skills/using-neural-architectures/transformer-architecture-deepdive.md b/skills/using-neural-architectures/transformer-architecture-deepdive.md new file mode 100644 index 0000000..b5e63ad --- /dev/null +++ b/skills/using-neural-architectures/transformer-architecture-deepdive.md @@ -0,0 +1,937 @@ + +# Transformer Architecture Deep Dive + +## When to Use This Skill + +Use this skill when you need to: +- ✅ Implement a Transformer from scratch +- ✅ Understand HOW and WHY self-attention works +- ✅ Choose between encoder, decoder, or encoder-decoder architectures +- ✅ Decide if Vision Transformer (ViT) is appropriate for your vision task +- ✅ Understand modern variants (RoPE, ALiBi, GQA, MQA) +- ✅ Debug Transformer implementation issues +- ✅ Optimize Transformer performance + +**Do NOT use this skill for:** +- ❌ High-level architecture selection (use `using-neural-architectures`) +- ❌ Attention mechanism comparison (use `attention-mechanisms-catalog`) +- ❌ LLM-specific topics like prompt engineering (use `llm-specialist` pack) + + +## Core Principle + +**Transformers are NOT magic.** They are: +1. Self-attention mechanism (information retrieval) +2. + Position encoding (break permutation invariance) +3. + Residual connections + Layer norm (training stability) +4. + Feed-forward networks (non-linearity) + +Understanding the mechanism beats cargo-culting implementations. + + +## Part 1: Self-Attention Mechanism Explained + +### The Information Retrieval Analogy + +**Self-attention = Querying a database:** +- **Query (Q)**: "What am I looking for?" +- **Key (K)**: "What do I contain?" +- **Value (V)**: "What information do I have?" + +**Process:** +1. Compare your query with all keys (compute similarity) +2. Weight values by similarity +3. Return weighted sum of values + +**Example:** Sentence: "The cat sat on the mat" +Token "sat" (verb): +- High attention to "cat" (subject) → Learns verb-subject relationship +- High attention to "mat" (object) → Learns verb-object relationship +- Low attention to "the", "on" (function words) + +### Mathematical Breakdown + +**Given input X:** (batch, seq_len, d_model) + +**Step 1: Project to Q, K, V** +```python +Q = X @ W_Q # (batch, seq_len, d_k) +K = X @ W_K # (batch, seq_len, d_k) +V = X @ W_V # (batch, seq_len, d_v) + +# Typically: d_k = d_v = d_model / num_heads +``` + +**Step 2: Compute attention scores** (similarity) +```python +scores = Q @ K.transpose(-2, -1) # (batch, seq_len, seq_len) +# scores[i, j] = similarity between query_i and key_j +``` + +**Geometric interpretation:** +- Dot product measures vector alignment +- q · k = ||q|| ||k|| cos(θ) +- Similar vectors → Large dot product → High attention +- Orthogonal vectors → Zero dot product → No attention + +**Step 3: Scale by √d_k** (CRITICAL!) +```python +scores = scores / math.sqrt(d_k) +``` + +**WHY scaling?** +- Dot products grow with dimension: Var(q · k) = d_k +- Example: d_k=64 → Random dot products ~ ±64 +- Large scores → Softmax saturates → Gradients vanish +- Scaling: Keep scores ~ O(1) regardless of dimension + +**Without scaling:** Softmax([30, 25, 20]) ≈ [0.99, 0.01, 0.00] (saturated!) +**With scaling:** Softmax([3, 2.5, 2]) ≈ [0.50, 0.30, 0.20] (healthy gradients) + +**Step 4: Softmax to get attention weights** +```python +attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len) +# Each row sums to 1 (probability distribution) +# attn_weights[i, j] = "how much token i attends to token j" +``` + +**Step 5: Weight values** +```python +output = attn_weights @ V # (batch, seq_len, d_v) +# Each token's output = weighted average of all values +``` + +**Complete formula:** +```python +Attention(Q, K, V) = softmax(Q K^T / √d_k) V +``` + +### Why Three Matrices (Q, K, V)? + +**Could we use just one?** Attention(X, X, X) +**Yes, but Q/K/V separation enables:** + +1. **Asymmetry**: Query can differ from key (search ≠ database) +2. **Decoupling**: What you search for (Q@K) ≠ what you retrieve (V) +3. **Cross-attention**: Q from one source, K/V from another + - Example: Decoder queries encoder (translation) + +**Modern optimization:** Multi-Query Attention (MQA), Grouped-Query Attention (GQA) +- Share K/V across heads (fewer parameters, faster inference) + +### Implementation Example + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class SelfAttention(nn.Module): + def __init__(self, d_model, d_k=None): + super().__init__() + self.d_k = d_k or d_model + + self.W_q = nn.Linear(d_model, self.d_k) + self.W_k = nn.Linear(d_model, self.d_k) + self.W_v = nn.Linear(d_model, self.d_k) + + def forward(self, x, mask=None): + # x: (batch, seq_len, d_model) + Q = self.W_q(x) # (batch, seq_len, d_k) + K = self.W_k(x) # (batch, seq_len, d_k) + V = self.W_v(x) # (batch, seq_len, d_k) + + # Attention scores + scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) + # scores: (batch, seq_len, seq_len) + + # Apply mask if provided (for causal attention) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + + # Attention weights + attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len) + + # Weighted sum of values + output = torch.matmul(attn_weights, V) # (batch, seq_len, d_k) + + return output, attn_weights +``` + +**Complexity:** O(n² · d) where n = seq_len, d = d_model +- **Quadratic in sequence length** (bottleneck for long sequences) +- For n=1000, d=512: 1000² × 512 = 512M operations + + +## Part 2: Multi-Head Attention + +### Why Multiple Heads? + +**Single-head attention** learns one attention pattern. +**Multi-head attention** learns multiple parallel patterns: +- Head 1: Syntactic relationships (subject-verb) +- Head 2: Semantic similarity +- Head 3: Positional proximity +- Head 4: Long-range dependencies + +**Analogy:** Ensemble of attention functions, each specializing in different patterns. + +### Head Dimension Calculation + +**CRITICAL CONSTRAINT:** num_heads must divide d_model evenly! + +```python +d_model = 512 +num_heads = 8 +d_k = d_model // num_heads # 512 / 8 = 64 + +# Each head operates in d_k dimensions +# Concatenate all heads → back to d_model dimensions +``` + +**Common configurations:** +- BERT-base: d_model=768, heads=12, d_k=64 +- GPT-2: d_model=768, heads=12, d_k=64 +- GPT-3 175B: d_model=12288, heads=96, d_k=128 +- LLaMA-2 70B: d_model=8192, heads=64, d_k=128 + +**Rule of thumb:** d_k (head dimension) should be 64-128 +- Too small (d_k < 32): Limited representational capacity +- Too large (d_k > 256): Redundant, wasteful + +### Multi-Head Implementation + +```python +class MultiHeadAttention(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.d_k = d_model // num_heads + + # Single linear layers for all heads (more efficient) + self.W_q = nn.Linear(d_model, d_model) + self.W_k = nn.Linear(d_model, d_model) + self.W_v = nn.Linear(d_model, d_model) + self.W_o = nn.Linear(d_model, d_model) # Output projection + + def split_heads(self, x): + # x: (batch, seq_len, d_model) + batch_size, seq_len, d_model = x.size() + # Reshape to (batch, seq_len, num_heads, d_k) + x = x.view(batch_size, seq_len, self.num_heads, self.d_k) + # Transpose to (batch, num_heads, seq_len, d_k) + return x.transpose(1, 2) + + def forward(self, x, mask=None): + batch_size = x.size(0) + + # Linear projections + Q = self.W_q(x) # (batch, seq_len, d_model) + K = self.W_k(x) + V = self.W_v(x) + + # Split into multiple heads + Q = self.split_heads(Q) # (batch, num_heads, seq_len, d_k) + K = self.split_heads(K) + V = self.split_heads(V) + + # Attention + scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + attn_weights = F.softmax(scores, dim=-1) + + # Weighted sum + attn_output = torch.matmul(attn_weights, V) + # attn_output: (batch, num_heads, seq_len, d_k) + + # Concatenate heads + attn_output = attn_output.transpose(1, 2).contiguous() + # (batch, seq_len, num_heads, d_k) + attn_output = attn_output.view(batch_size, -1, self.d_model) + # (batch, seq_len, d_model) + + # Final linear projection + output = self.W_o(attn_output) + + return output, attn_weights +``` + +### Modern Variants: GQA and MQA + +**Problem:** K/V caching during inference is memory-intensive +- LLaMA-2 70B: 8192 × 64 heads × 2 (K + V) = 1M parameters per token cached! + +**Solution 1: Multi-Query Attention (MQA)** +- **One** K/V head shared across **all** Q heads +- Benefit: Dramatically faster inference (smaller KV cache) +- Trade-off: ~1-2% accuracy loss + +```python +# MQA: Single K/V projection +self.W_k = nn.Linear(d_model, d_k) # Not d_model! +self.W_v = nn.Linear(d_model, d_k) +self.W_q = nn.Linear(d_model, d_model) # Multiple Q heads +``` + +**Solution 2: Grouped-Query Attention (GQA)** +- Middle ground: Group multiple Q heads per K/V head +- Example: 32 Q heads → 8 K/V heads (4 Q per K/V) +- Benefit: 4x smaller KV cache, minimal accuracy loss + +**Used in:** LLaMA-2, Mistral, Mixtral + + +## Part 3: Position Encoding + +### Why Position Encoding? + +**Problem:** Self-attention is **permutation-invariant** +- Attention("cat sat mat") = Attention("mat cat sat") +- No inherent notion of position or order! + +**Solution:** Add position information to embeddings + +### Strategy 1: Sinusoidal Position Encoding (Original) + +**Formula:** +```python +PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) +PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) +``` + +**Implementation:** +```python +def sinusoidal_position_encoding(seq_len, d_model): + pe = torch.zeros(seq_len, d_model) + position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + (-math.log(10000.0) / d_model)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + return pe + +# Usage: Add to input embeddings +x = token_embeddings + positional_encoding +``` + +**Properties:** +- Deterministic (no learned parameters) +- Extrapolates to unseen lengths (geometric properties) +- Relative positions: PE(pos+k) is linear function of PE(pos) + +**When to use:** Variable-length sequences in NLP + +### Strategy 2: Learned Position Embeddings + +```python +self.pos_embedding = nn.Embedding(max_seq_len, d_model) + +# Usage +positions = torch.arange(seq_len, device=x.device) +x = token_embeddings + self.pos_embedding(positions) +``` + +**Properties:** +- Learnable (adapts to data) +- Cannot extrapolate beyond max_seq_len + +**When to use:** +- Fixed-length sequences +- Vision Transformers (image patches) +- When training data covers all positions + +### Strategy 3: Rotary Position Embeddings (RoPE) ⭐ + +**Modern approach (2021+):** Rotate Q and K in complex plane + +**Key advantages:** +- Encodes **relative** positions naturally +- Better long-range decay properties +- No addition to embeddings (applied in attention) + +**Used in:** GPT-NeoX, PaLM, LLaMA, LLaMA-2, Mistral + +```python +def apply_rotary_pos_emb(x, cos, sin): + # x: (batch, num_heads, seq_len, d_k) + # Split into even/odd + x1, x2 = x[..., ::2], x[..., 1::2] + # Rotate + return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) +``` + +### Strategy 4: ALiBi (Attention with Linear Biases) ⭐ + +**Simplest modern approach:** Add bias to attention scores (no embeddings!) + +```python +# Bias matrix: -1 * distance +# [[0, -1, -2, -3], +# [0, 0, -1, -2], +# [0, 0, 0, -1], +# [0, 0, 0, 0]] + +scores = Q @ K^T / √d_k + alibi_bias +``` + +**Key advantages:** +- **Best extrapolation** to longer sequences +- No positional embeddings (simpler) +- Per-head slopes (different decay rates) + +**Used in:** BLOOM + +### Position Encoding Selection Guide + +| Use Case | Recommended | Why | +|----------|-------------|-----| +| NLP (variable length) | RoPE or ALiBi | Better extrapolation | +| NLP (fixed length) | Learned embeddings | Adapts to data | +| Vision (ViT) | 2D learned embeddings | Spatial structure | +| Long sequences (>2k) | ALiBi | Best extrapolation | +| Legacy/compatibility | Sinusoidal | Original Transformer | + +**Modern trend (2023+):** RoPE and ALiBi dominate over sinusoidal + + +## Part 4: Architecture Variants + +### Variant 1: Encoder-Only (Bidirectional) + +**Architecture:** +- Self-attention: Each token attends to **ALL** tokens (past + future) +- No masking (bidirectional context) + +**Examples:** BERT, RoBERTa, ELECTRA, DeBERTa + +**Use cases:** +- Text classification +- Named entity recognition +- Question answering (extract span from context) +- Sentence embeddings + +**Key property:** Sees full context → Good for **understanding** + +**Implementation:** +```python +class TransformerEncoder(nn.Module): + def __init__(self, d_model, num_heads, d_ff, num_layers): + super().__init__() + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, d_ff) + for _ in range(num_layers) + ]) + + def forward(self, x, mask=None): + for layer in self.layers: + x = layer(x, mask) # No causal mask! + return x + +class EncoderLayer(nn.Module): + def __init__(self, d_model, num_heads, d_ff): + super().__init__() + self.self_attn = MultiHeadAttention(d_model, num_heads) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, mask=None): + # Self-attention + residual + norm + attn_output, _ = self.self_attn(x, mask) + x = self.norm1(x + attn_output) + + # Feed-forward + residual + norm + ff_output = self.feed_forward(x) + x = self.norm2(x + ff_output) + + return x +``` + +### Variant 2: Decoder-Only (Autoregressive) + +**Architecture:** +- Self-attention with **causal masking** +- Each token attends ONLY to past tokens (not future) + +**Causal mask (lower triangular):** +```python +# mask[i, j] = 1 if j <= i else 0 +[[1, 0, 0, 0], # Token 0 sees only itself + [1, 1, 0, 0], # Token 1 sees tokens 0-1 + [1, 1, 1, 0], # Token 2 sees tokens 0-2 + [1, 1, 1, 1]] # Token 3 sees all +``` + +**Examples:** GPT, GPT-2, GPT-3, GPT-4, LLaMA, Mistral + +**Use cases:** +- Text generation +- Language modeling +- Code generation +- Autoregressive prediction + +**Key property:** Generates sequentially → Good for **generation** + +**Implementation:** +```python +def create_causal_mask(seq_len, device): + # Lower triangular matrix + mask = torch.tril(torch.ones(seq_len, seq_len, device=device)) + return mask + +class TransformerDecoder(nn.Module): + def __init__(self, d_model, num_heads, d_ff, num_layers): + super().__init__() + self.layers = nn.ModuleList([ + DecoderLayer(d_model, num_heads, d_ff) + for _ in range(num_layers) + ]) + + def forward(self, x): + seq_len = x.size(1) + causal_mask = create_causal_mask(seq_len, x.device) + + for layer in self.layers: + x = layer(x, causal_mask) # Apply causal mask! + return x +``` + +**Modern trend (2023+):** Decoder-only architectures dominate +- Can do both generation AND understanding (via prompting) +- Simpler than encoder-decoder (no cross-attention) +- Scales better to massive sizes + +### Variant 3: Encoder-Decoder (Seq2Seq) + +**Architecture:** +- **Encoder**: Bidirectional self-attention (understands input) +- **Decoder**: Causal self-attention (generates output) +- **Cross-attention**: Decoder queries encoder outputs + +**Cross-attention mechanism:** +```python +# Q from decoder, K and V from encoder +Q = decoder_hidden @ W_q +K = encoder_output @ W_k +V = encoder_output @ W_v + +cross_attn = softmax(Q K^T / √d_k) V +``` + +**Examples:** T5, BART, mT5, original Transformer (2017) + +**Use cases:** +- Translation (input ≠ output language) +- Summarization (long input → short output) +- Question answering (generate answer, not extract) + +**When to use:** Input and output are fundamentally different + +**Implementation:** +```python +class EncoderDecoderLayer(nn.Module): + def __init__(self, d_model, num_heads, d_ff): + super().__init__() + self.self_attn = MultiHeadAttention(d_model, num_heads) + self.cross_attn = MultiHeadAttention(d_model, num_heads) # NEW! + self.feed_forward = FeedForward(d_model, d_ff) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, decoder_input, encoder_output, causal_mask=None): + # 1. Self-attention (causal) + self_attn_out, _ = self.self_attn(decoder_input, causal_mask) + x = self.norm1(decoder_input + self_attn_out) + + # 2. Cross-attention (Q from decoder, K/V from encoder) + cross_attn_out, _ = self.cross_attn.forward_cross( + query=x, + key=encoder_output, + value=encoder_output + ) + x = self.norm2(x + cross_attn_out) + + # 3. Feed-forward + ff_out = self.feed_forward(x) + x = self.norm3(x + ff_out) + + return x +``` + +### Architecture Selection Guide + +| Task | Architecture | Why | +|------|--------------|-----| +| Classification | Encoder-only | Need full bidirectional context | +| Text generation | Decoder-only | Autoregressive generation | +| Translation | Encoder-decoder or Decoder-only | Different languages, or use prompting | +| Summarization | Encoder-decoder or Decoder-only | Length mismatch, or use prompting | +| Q&A (extract) | Encoder-only | Find span in context | +| Q&A (generate) | Decoder-only | Generate freeform answer | + +**2023+ trend:** Decoder-only can do everything via prompting (but less parameter-efficient for some tasks) + + +## Part 5: Vision Transformers (ViT) + +### From Images to Sequences + +**Key insight:** Treat image as sequence of patches + +**Process:** +1. Split image into patches (e.g., 16×16 pixels) +2. Flatten each patch → 1D vector +3. Linear projection → token embeddings +4. Add 2D positional embeddings +5. Prepend [CLS] token (for classification) +6. Feed to Transformer encoder + +**Example:** 224×224 image, 16×16 patches +- Number of patches: (224/16)² = 196 +- Each patch: 16 × 16 × 3 = 768 dimensions +- Transformer input: 197 tokens (196 patches + 1 [CLS]) + +### ViT Implementation + +```python +class VisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_channels=3, + d_model=768, num_heads=12, num_layers=12, num_classes=1000): + super().__init__() + self.patch_size = patch_size + num_patches = (img_size // patch_size) ** 2 + patch_dim = in_channels * patch_size ** 2 + + # Patch embedding (linear projection of flattened patches) + self.patch_embed = nn.Linear(patch_dim, d_model) + + # [CLS] token (learnable) + self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) + + # Position embeddings (learnable) + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model)) + + # Transformer encoder + self.encoder = TransformerEncoder(d_model, num_heads, + d_ff=4*d_model, num_layers=num_layers) + + # Classification head + self.head = nn.Linear(d_model, num_classes) + + def forward(self, x): + # x: (batch, channels, height, width) + batch_size = x.size(0) + + # Divide into patches and flatten + x = x.unfold(2, self.patch_size, self.patch_size) + x = x.unfold(3, self.patch_size, self.patch_size) + # (batch, channels, num_patches_h, num_patches_w, patch_size, patch_size) + + x = x.contiguous().view(batch_size, -1, self.patch_size ** 2 * 3) + # (batch, num_patches, patch_dim) + + # Linear projection + x = self.patch_embed(x) # (batch, num_patches, d_model) + + # Prepend [CLS] token + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) # (batch, num_patches+1, d_model) + + # Add positional embeddings + x = x + self.pos_embed + + # Transformer encoder + x = self.encoder(x) + + # Classification: Use [CLS] token + cls_output = x[:, 0] # (batch, d_model) + logits = self.head(cls_output) + + return logits +``` + +### ViT vs CNN: Critical Differences + +**1. Inductive Bias** + +| Property | CNN | ViT | +|----------|-----|-----| +| Locality | Strong (conv kernel) | Weak (global attention) | +| Translation invariance | Strong (weight sharing) | Weak (position embeddings) | +| Hierarchy | Strong (pooling layers) | None (flat patches) | + +**Implication:** CNN has strong priors, ViT learns from data + +**2. Data Requirements** + +| Dataset Size | CNN | ViT (from scratch) | ViT (pretrained) | +|--------------|-----|-------------------|------------------| +| Small (< 100k) | ✅ Good | ❌ Fails | ✅ Good | +| Medium (100k-1M) | ✅ Excellent | ⚠️ Poor | ✅ Good | +| Large (> 1M) | ✅ Excellent | ⚠️ OK | ✅ Excellent | +| Huge (> 100M) | ✅ Excellent | ✅ SOTA | N/A | + +**Key finding:** ViT needs 100M+ images to train from scratch +- Original ViT: Trained on JFT-300M (300 million images) +- Without massive data, ViT underperforms CNNs significantly + +**3. Computational Cost** + +**Example: 224×224 images** + +| Model | Parameters | GFLOPs | Inference (GPU) | +|-------|-----------|--------|-----------------| +| ResNet-50 | 25M | 4.1 | ~30ms | +| EfficientNet-B0 | 5M | 0.4 | ~10ms | +| ViT-B/16 | 86M | 17.6 | ~100ms | + +**Implication:** ViT is 40x more expensive than EfficientNet! + +### When to Use ViT + +✅ **Use ViT when:** +- Large dataset (> 1M images) OR using pretrained weights +- Computational cost acceptable (cloud, large GPU) +- Best possible accuracy needed +- Can fine-tune from ImageNet-21k checkpoint + +❌ **Use CNN when:** +- Small/medium dataset (< 1M images) and training from scratch +- Limited compute/memory +- Edge deployment (mobile, embedded) +- Need architectural inductive biases + +### Hybrid Approaches (2022-2023) + +**ConvNeXt:** CNN with ViT design choices +- Matches ViT accuracy with CNN efficiency +- Works better on small datasets + +**Swin Transformer:** Hierarchical ViT with local windows +- Shifted windows for efficiency +- O(n) complexity instead of O(n²) +- Better for dense prediction (segmentation) + +**CoAtNet:** Mix conv layers (early) + Transformer layers (late) +- Gets both inductive bias and global attention + + +## Part 6: Implementation Checklist + +### Critical Details + +**1. Layer Norm Placement** + +**Post-norm (original):** +```python +x = x + self_attn(x) +x = layer_norm(x) +``` + +**Pre-norm (modern, recommended):** +```python +x = x + self_attn(layer_norm(x)) +``` + +**Why pre-norm?** More stable training, less sensitive to learning rate + +**2. Attention Dropout** + +Apply dropout to **attention weights**, not Q/K/V! + +```python +attn_weights = F.softmax(scores, dim=-1) +attn_weights = F.dropout(attn_weights, p=0.1, training=self.training) # HERE! +output = torch.matmul(attn_weights, V) +``` + +**3. Feed-Forward Dimension** + +Typically: d_ff = 4 × d_model +- BERT: d_model=768, d_ff=3072 +- GPT-2: d_model=768, d_ff=3072 + +**4. Residual Connections** + +ALWAYS use residual connections (essential for training)! + +```python +x = x + self_attn(x) # Residual +x = x + feed_forward(x) # Residual +``` + +**5. Initialization** + +Use Xavier/Glorot initialization for attention weights: +```python +nn.init.xavier_uniform_(self.W_q.weight) +nn.init.xavier_uniform_(self.W_k.weight) +nn.init.xavier_uniform_(self.W_v.weight) +``` + + +## Part 7: When NOT to Use Transformers + +### Limitation 1: Small Datasets + +**Problem:** Transformers have weak inductive bias (learn from data) + +**Impact:** +- ViT: Fails on < 100k images without pretraining +- NLP: BERT needs 100M+ tokens for pretraining + +**Solution:** Use models with stronger priors (CNN for vision, smaller models for text) + +### Limitation 2: Long Sequences + +**Problem:** O(n²) memory complexity + +**Impact:** +- Standard Transformer: n=10k → 100M attention scores +- GPU memory: 10k² × 4 bytes = 400MB per sample! + +**Solution:** +- Sparse attention (Longformer, BigBird) +- Linear attention (Linformer, Performer) +- Flash Attention (memory-efficient kernel) +- State space models (S4, Mamba) + +### Limitation 3: Edge Deployment + +**Problem:** Large model size, high latency + +**Impact:** +- ViT-B: 86M parameters, ~100ms inference +- Mobile/embedded: Need < 10M parameters, < 50ms + +**Solution:** Efficient CNNs (MobileNet, EfficientNet) or distilled models + +### Limitation 4: Real-Time Processing + +**Problem:** Sequential generation in decoder (cannot parallelize at inference) + +**Impact:** GPT-style models generate one token at a time + +**Solution:** Non-autoregressive models, speculative decoding, or smaller models + + +## Part 8: Common Mistakes + +### Mistake 1: Forgetting Causal Mask + +**Symptom:** Decoder "cheats" by seeing future tokens + +**Fix:** Always apply causal mask to decoder self-attention! + +```python +causal_mask = torch.tril(torch.ones(seq_len, seq_len)) +``` + +### Mistake 2: Wrong Dimension for Multi-Head + +**Symptom:** Runtime error or dimension mismatch + +**Fix:** Ensure d_model % num_heads == 0 + +```python +assert d_model % num_heads == 0, "d_model must be divisible by num_heads" +``` + +### Mistake 3: Forgetting Position Encoding + +**Symptom:** Model ignores word order + +**Fix:** Always add position information! + +```python +x = token_embeddings + positional_encoding +``` + +### Mistake 4: Wrong Softmax Dimension + +**Symptom:** Attention weights don't sum to 1 per query + +**Fix:** Softmax over last dimension (keys) + +```python +attn_weights = F.softmax(scores, dim=-1) # Sum over keys for each query +``` + +### Mistake 5: No Residual Connections + +**Symptom:** Training diverges or converges very slowly + +**Fix:** Always add residual connections! + +```python +x = x + self_attn(x) +x = x + feed_forward(x) +``` + + +## Summary: Quick Reference + +### Architecture Selection + +``` +Classification/Understanding → Encoder-only (BERT-style) +Generation/Autoregressive → Decoder-only (GPT-style) +Seq2Seq (input ≠ output) → Encoder-decoder (T5-style) or Decoder-only with prompting +``` + +### Position Encoding Selection + +``` +NLP (variable length) → RoPE or ALiBi +NLP (fixed length) → Learned embeddings +Vision (ViT) → 2D learned embeddings +Long sequences (> 2k) → ALiBi (best extrapolation) +``` + +### Multi-Head Configuration + +``` +Small models (d_model < 512): 4-8 heads +Medium models (d_model 512-1024): 8-12 heads +Large models (d_model > 1024): 12-32 heads +Rule: d_k (head dimension) should be 64-128 +``` + +### ViT vs CNN + +``` +ViT: Large dataset (> 1M) OR pretrained weights +CNN: Small dataset (< 1M) OR edge deployment +``` + +### Implementation Essentials + +``` +✅ Pre-norm (more stable than post-norm) +✅ Residual connections (essential!) +✅ Causal mask for decoder +✅ Attention dropout (on weights, not Q/K/V) +✅ d_ff = 4 × d_model (feed-forward dimension) +✅ Check: d_model % num_heads == 0 +``` + + +## Next Steps + +After mastering this skill: +- `attention-mechanisms-catalog`: Explore attention variants (sparse, linear, Flash) +- `llm-specialist/llm-finetuning-strategies`: Apply to language models +- `architecture-design-principles`: Understand design trade-offs + +**Remember:** Transformers are NOT magic. Understanding the mechanism (information retrieval via Q/K/V) beats cargo-culting implementations.