Files
gh-tachyon-beep-skillpacks-…/skills/using-neural-architectures/attention-mechanisms-catalog.md
2025-11-30 09:00:00 +08:00

825 lines
21 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Attention Mechanisms Catalog
## When to Use This Skill
Use this skill when you need to:
- ✅ Select attention mechanism for long sequences (> 2k tokens)
- ✅ Optimize memory usage (GPU OOM errors)
- ✅ Speed up training or inference
- ✅ Understand exact vs approximate attention trade-offs
- ✅ Choose between Flash, sparse, or linear attention
- ✅ Implement cross-attention for multimodal models
**Do NOT use this skill for:**
- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`)
- ❌ High-level architecture selection (use `using-neural-architectures`)
- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`)
## Core Principle
**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve:
- **O(n²) with less memory**: Flash Attention (exact, 4x less memory)
- **O(n × w)**: Sparse attention (exact, sliding window)
- **O(n)**: Linear attention (approximate, 1-3% accuracy loss)
**Default recommendation:** Flash Attention (exact + fast + memory-efficient)
## Part 1: Complexity Hierarchy
### Standard Self-Attention (Baseline)
**Formula:**
```python
Attention(Q, K, V) = softmax(Q K^T / d_k) V
```
**Complexity:**
- Time: O(n² · d) where n = seq_len, d = d_model
- Memory: O(n²) for attention matrix
- Exact: Yes (no approximation)
**Memory breakdown (4k tokens, d=768):**
```
Attention scores: 4096² × 4 bytes = 64MB per layer
Multi-head (12 heads): 64MB × 12 = 768MB per layer
16 layers: 768MB × 16 = 12GB just for attention!
Batch size 8: 12GB × 8 = 96GB (impossible on single GPU)
```
**When to use:**
- Sequence length < 2k tokens
- Standard use case (most models)
- Pair with Flash Attention optimization
**Limitations:**
- Memory explosion for long sequences
- Quadratic scaling impractical beyond 4k tokens
## Part 2: Flash Attention ⭐ (Modern Default)
### What is Flash Attention?
**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster
**Key insight:**
- Standard attention is **memory-bound** (not compute-bound)
- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s)
- Bottleneck: Moving n² attention matrix to/from HBM
**Solution:**
- Tile attention computation
- Recompute instead of store intermediate values
- Fuse operations (reduce memory transfers)
- Result: Same O(n²) compute, O(n) memory
### Algorithm
```
Standard attention (3 memory operations):
1. Compute scores: S = Q K^T (store n² matrix)
2. Softmax: P = softmax(S) (store n² matrix)
3. Output: O = P V (store n×d matrix)
Flash Attention (tiled):
1. Divide Q, K, V into blocks
2. For each Q block:
- Load block to SRAM (fast memory)
- For each K, V block:
- Compute attention for this tile
- Update output incrementally
- Never materialize full n² matrix!
3. Result: Same output, O(n) memory
```
### Performance
**Benchmarks (A100 GPU, 2k tokens):**
Standard attention:
- Memory: 4GB for batch_size=8
- Speed: 150ms/batch
- Max batch size: 16
Flash Attention:
- Memory: 1GB for batch_size=8 **(4x reduction)**
- Speed: 75ms/batch **(2x faster)**
- Max batch size: 64 **(4x larger)**
**Flash Attention 2 (2023 update):**
- Further optimized: 2-3x faster than Flash Attention 1
- Better parallelism
- Supports more head dimensions
### When to Use
**ALWAYS use Flash Attention when:**
- Sequence length < 16k tokens
- Need exact attention (no approximation)
- Available in your framework
**It's a FREE LUNCH:**
- No accuracy loss (mathematically exact)
- Faster training AND inference
- Less memory usage
- Drop-in replacement
### Implementation
**PyTorch 2.0+ (built-in):**
```python
import torch.nn.functional as F
# Automatic Flash Attention (if available)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
# PyTorch automatically uses Flash Attention if:
# - CUDA available
# - Sequence length suitable
# - No attention mask (or causal mask)
```
**HuggingFace Transformers:**
```python
from transformers import AutoModel
# Enable Flash Attention 2
model = AutoModel.from_pretrained(
"bert-base-uncased",
attn_implementation="flash_attention_2", # Requires flash-attn package
torch_dtype=torch.float16
)
```
**Manual installation:**
```bash
pip install flash-attn --no-build-isolation
```
### Limitations
**Flash Attention NOT suitable when:**
- Sequence length > 16k (memory still grows quadratically)
- Custom attention masks (complex patterns not supported)
- Inference on CPU (CUDA-only)
**For > 16k tokens:** Use sparse or linear attention
## Part 3: Sparse Attention (Exact for Long Sequences)
### Concept
**Idea:** Each token attends to subset of tokens (not all)
- Sliding window: Local context
- Global tokens: Long-range connections
- Result: O(n × window_size) instead of O(n²)
**Key property:** Still EXACT attention (not approximate)
- Just more structured attention pattern
- No accuracy loss if pattern matches task
### Variant 1: Longformer
**Pattern:** Sliding window + global attention
```
Attention pattern (window=2, global=[0]):
0 1 2 3 4 5
0 [ 1 1 1 1 1 1 ] ← Global token (attends to all)
1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2
2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3
3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4
4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5
5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5
Complexity: O(n × (window + num_global))
```
**Components:**
1. **Sliding window**: Each token attends to w/2 tokens before and after
2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens
3. **Dilated windows**: Optional (stride > 1 for longer context)
**Implementation:**
```python
from transformers import LongformerModel
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask (shape: batch, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token
output = model(input_ids, attention_mask=attention_mask)
```
**Memory comparison (4k tokens, window=512):**
```
Standard: 4096² = 16M elements → 64MB
Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!)
```
**When to use:**
- Documents: 4k-16k tokens (legal, scientific papers)
- Need full context but can't fit O(n²)
- Task has local + global structure
**Pretrained models:**
- `allenai/longformer-base-4096`: Max 4096 tokens
- `allenai/longformer-large-4096`: Larger version
### Variant 2: BigBird
**Pattern:** Random + window + global
```
Attention pattern:
- Sliding window: Like Longformer
- Random connections: Each token attends to r random tokens
- Global tokens: Special tokens attend to all
Complexity: O(n × (window + r + num_global))
```
**Key difference from Longformer:**
- Random connections help information flow
- Theoretically proven to approximate full attention
**When to use:**
- Similar to Longformer
- Slightly better for tasks needing long-range
- Less widely adopted than Longformer
**Implementation:**
```python
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained(
"google/bigbird-roberta-base",
attention_type="block_sparse" # or "original_full"
)
```
### Sparse Attention Decision
```
Sequence length < 4k:
→ Flash Attention (exact, no pattern needed)
Sequence length 4k-16k:
→ Longformer (sliding window + global)
→ Best for: Documents, long-form text
Sequence length > 16k:
→ Longformer if possible
→ Linear attention if Longformer too slow
```
## Part 4: Linear Attention (Approximate for Very Long)
### Concept
**Idea:** Approximate softmax attention with linear operations
- Complexity: O(n × k) where k << n
- Trade-off: 1-3% accuracy loss
- Benefit: Can handle very long sequences (> 16k)
**Key property:** APPROXIMATE (not exact)
- Do NOT use if accuracy critical
- Good for extremely long sequences where exact is impossible
### Variant 1: Performer
**Method:** Random Fourier Features to approximate softmax(Q K^T)
**Formula:**
```python
# Standard attention
Attention(Q, K, V) = softmax(Q K^T) V
# Performer approximation
φ(Q) φ(K)^T softmax(Q K^T)
Attention(Q, K, V) φ(Q) (φ(K)^T V)
# Complexity: O(n × k) where k = feature dimension
```
**Key trick:**
- Compute φ(K)^T V first: (k × d) matrix (small!)
- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d)
- Never materialize n² attention matrix
**Implementation:**
```python
# From performer-pytorch library
from performer_pytorch import Performer
model = Performer(
dim=512,
depth=6,
heads=8,
dim_head=64,
causal=False,
nb_features=256 # k = number of random features
)
```
**Accuracy:**
- Typical loss: 1-2% vs standard attention
- Depends on nb_features (more features = better approximation)
- k=256 usually sufficient
**When to use:**
- Sequence length > 16k tokens
- Accuracy loss acceptable (not critical task)
- Need better than sparse attention (no structure assumptions)
### Variant 2: Linformer
**Method:** Project K and V to lower dimension
**Formula:**
```python
# Standard attention (n × n attention matrix)
Attention(Q, K, V) = softmax(Q K^T / d_k) V
# Linformer (project K, V to n × k where k << n)
K_proj = E K # E: (k × n) projection matrix
V_proj = F V # F: (k × n) projection matrix
Attention(Q, K, V) softmax(Q K_proj^T / d_k) V_proj
# Attention matrix: (n × k) instead of (n × n)
```
**Complexity:**
- Time: O(n × k × d) where k << n
- Memory: O(n × k) instead of O(n²)
**Implementation:**
```python
# From linformer library
from linformer import Linformer
model = Linformer(
dim=512,
seq_len=8192,
depth=12,
heads=8,
k=256 # Projected dimension
)
```
**Accuracy:**
- Typical loss: 1-3% vs standard attention
- More loss than Performer
- Fixed sequence length (k is tied to max_seq_len)
**When to use:**
- Fixed-length long sequences
- Memory more critical than speed
- Accuracy loss OK (2-3%)
### Linear Attention Decision
```
Need exact attention:
→ Flash Attention or Sparse Attention (NOT linear)
Sequence > 16k, accuracy critical:
→ Sparse Attention (Longformer)
Sequence > 16k, accuracy loss OK:
→ Performer (better) or Linformer
Sequence > 100k:
→ State space models (S4, Mamba, not attention)
```
## Part 5: Cross-Attention (Multimodal)
### Concept
**Self-attention:** Q, K, V from same source
**Cross-attention:** Q from one source, K/V from another
**Use cases:**
- Multimodal: vision → language (image captioning)
- Seq2seq: source language → target language (translation)
- RAG: query → document retrieval
- Conditioning: generation conditioned on context
### Architecture
```python
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
def forward(self, query_source, key_value_source, mask=None):
# query_source: (batch, n_q, d_model) - e.g., text tokens
# key_value_source: (batch, n_kv, d_model) - e.g., image patches
# Q from query source
Q = self.W_q(query_source)
# K, V from key-value source
K = self.W_k(key_value_source)
V = self.W_v(key_value_source)
# Attention: (batch, n_q, d_model)
output = attention(Q, K, V, mask)
return output
```
### Example: Image Captioning
**Task:** Generate caption from image
**Architecture:**
1. **Image Encoder:** ViT processes image → image features (n_patches × d)
2. **Text Decoder:** Autoregressive text generation
3. **Cross-Attention:** Text queries image features
```python
class ImageCaptioningDecoder(nn.Module):
def forward(self, text_tokens, image_features):
# 1. Self-attention on text (causal)
text = self.text_self_attention(
query=text,
key=text,
value=text,
causal_mask=True # Don't see future words
)
# 2. Cross-attention (text queries image)
text = self.cross_attention(
query=text, # From text decoder
key=image_features, # From image encoder
value=image_features # From image encoder
# No causal mask! Can attend to all image patches
)
# 3. Feed-forward
text = self.feed_forward(text)
return text
```
**Attention flow:**
- Text token "cat" → High attention to cat region in image
- Text token "sitting" → High attention to posture in image
### Example: Retrieval-Augmented Generation (RAG)
**Task:** Generate answer using retrieved documents
```python
class RAGDecoder(nn.Module):
def forward(self, query_tokens, document_embeddings):
# 1. Self-attention on query
query = self.query_self_attention(query, query, query)
# 2. Cross-attention (query → documents)
query = self.cross_attention(
query=query, # What we're generating
key=document_embeddings, # Retrieved docs
value=document_embeddings # Retrieved docs
)
# Query learns to extract relevant info from docs
return query
```
### When to Use Cross-Attention
**Use cross-attention when:**
- Two different modalities (vision + language)
- Conditioning generation on context (RAG)
- Seq2seq with different input/output (translation)
- Query-document matching
**Don't use cross-attention when:**
- Same modality (use self-attention)
- No clear query vs key-value separation
## Part 6: Other Attention Variants
### Axial Attention (2D Images)
**Idea:** For 2D data (images), attend along each axis separately
```
Standard 2D attention: H×W tokens → (HW)² attention matrix
Axial attention:
- Row attention: Each row attends to itself (H × W²)
- Column attention: Each column attends to itself (W × H²)
- Total: O(HW × (H + W)) << O((HW)²)
```
**When to use:**
- High-resolution images
- 2D positional structure important
### Block-Sparse Attention
**Idea:** Divide attention into blocks, attend only within/across blocks
**Pattern:**
```
Block size = 64 tokens
- Local block: Attend within same block
- Vertical stripe: Attend to corresponding position in other blocks
```
**Used in:** Sparse Transformer (OpenAI), GPT-3
### Multi-Query Attention (MQA)
**Idea:** One K/V head shared across all Q heads
**Benefit:**
- Smaller KV cache during inference
- Much faster decoding (4-8x)
- Trade-off: ~1% accuracy loss
**Used in:** PaLM, Falcon
### Grouped-Query Attention (GQA)
**Idea:** Middle ground between multi-head and multi-query
- Group Q heads share K/V heads
- Example: 32 Q heads → 8 K/V heads (4:1 ratio)
**Benefit:**
- 4x smaller KV cache
- Minimal accuracy loss (< 0.5%)
**Used in:** LLaMA-2, Mistral
## Part 7: Decision Framework
### By Sequence Length
```
< 2k tokens:
→ Flash Attention
Exact, fast, standard
2k-4k tokens:
→ Flash Attention
Still manageable with modern GPUs
4k-16k tokens:
→ Sparse Attention (Longformer, BigBird)
Exact, designed for documents
→ OR Flash Attention if batch size = 1
> 16k tokens:
→ Sparse Attention
If task has local structure
→ Linear Attention (Performer)
If accuracy loss OK (1-2%)
→ State Space Models (S4, Mamba)
If sequence > 100k
```
### By Memory Constraints
```
GPU OOM with standard attention:
1. Try Flash Attention (4x less memory, free lunch)
2. If still OOM, reduce batch size
3. If batch size = 1 and still OOM, use sparse attention
4. Last resort: Linear attention (if accuracy loss OK)
DON'T:
- Gradient checkpointing (slower, use Flash Attention instead)
- Throwing more GPUs (algorithmic problem, not hardware)
```
### By Accuracy Requirements
```
Must be exact (no approximation):
→ Flash Attention or Sparse Attention
Never use linear attention!
Accuracy loss acceptable (1-3%):
→ Linear Attention (Performer, Linformer)
Only for very long sequences (> 16k)
Critical task (medical, legal):
→ Exact attention only
Flash Attention or Sparse Attention
```
### By Task Type
```
Classification / Understanding:
→ Standard + Flash Attention
Sequence usually < 2k
Document processing:
→ Longformer (4096 tokens)
Designed for documents
Generation (LLM):
→ Flash Attention for training
→ + GQA/MQA for inference (faster decoding)
Multimodal (vision + language):
→ Cross-attention for modality fusion
→ Self-attention within each modality
Retrieval-augmented:
→ Cross-attention (query → documents)
```
## Part 8: Implementation Checklist
### Using Flash Attention
**PyTorch 2.0+:**
```python
# Automatic (recommended)
output = F.scaled_dot_product_attention(query, key, value)
# Verify Flash Attention is used
import torch.backends.cuda
print(torch.backends.cuda.flash_sdp_enabled()) # Should be True
```
**HuggingFace:**
```python
model = AutoModel.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16 # Flash Attention needs fp16/bf16
)
```
**Requirements:**
- CUDA GPU (not CPU)
- PyTorch >= 2.0 OR flash-attn package
- fp16 or bf16 dtype (not fp32)
### Using Sparse Attention
**Longformer:**
```python
from transformers import LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask
# 0 = no attention, 1 = local attention, 2 = global attention
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # [CLS] token gets global attention
outputs = model(input_ids, attention_mask=attention_mask)
```
**Custom sparse pattern:**
```python
# Create custom block-sparse mask
def create_block_sparse_mask(seq_len, block_size):
num_blocks = seq_len // block_size
mask = torch.zeros(seq_len, seq_len)
for i in range(num_blocks):
start = i * block_size
end = start + block_size
mask[start:end, start:end] = 1 # Local block
return mask
```
### Using Cross-Attention
```python
class DecoderWithCrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
def forward(self, decoder_input, encoder_output, causal_mask=None):
# Self-attention (causal)
x = self.self_attn(
query=decoder_input,
key=decoder_input,
value=decoder_input,
mask=causal_mask
)
# Cross-attention (Q from decoder, K/V from encoder)
x = self.cross_attn(
query=x, # From decoder
key=encoder_output, # From encoder
value=encoder_output, # From encoder
mask=None # No causal mask for cross-attention!
)
return x
```
## Part 9: Common Mistakes
### Mistake 1: Ignoring Flash Attention
**Symptom:** Training slow, high memory usage
**Fix:** Always use Flash Attention for < 16k tokens
### Mistake 2: Using Linear Attention Unnecessarily
**Symptom:** 1-3% accuracy loss for no reason
**Fix:** Use Flash Attention (exact) unless sequence > 16k
### Mistake 3: Gradient Checkpointing Instead of Flash Attention
**Symptom:** Training 20% slower
**Fix:** Flash Attention gives memory savings AND speed
### Mistake 4: Cross-Attention with Causal Mask
**Symptom:** Decoder can't attend to encoder properly
**Fix:** Causal mask only for self-attention, NOT cross-attention
### Mistake 5: Accepting O(n²) Memory
**Symptom:** GPU OOM for > 4k tokens
**Fix:** Use sparse or Flash Attention, don't just add GPUs
## Summary: Quick Reference
### Attention Selection
```
Sequence length:
< 2k → Flash Attention (default)
2-4k → Flash Attention
4-16k → Longformer (documents) or Flash Attention (batch=1)
> 16k → Sparse or Linear Attention
Memory constrained:
First: Try Flash Attention (4x less memory)
Still OOM: Use sparse attention (Longformer)
Last resort: Linear attention (accuracy loss)
Speed critical:
Training: Flash Attention (2x faster)
Inference: Flash Attention + GQA/MQA
Accuracy critical:
Use exact attention only (Flash or Sparse)
NEVER linear attention
Multimodal:
Cross-attention for modality fusion
```
### Implementation
```
PyTorch 2.0+:
F.scaled_dot_product_attention() # Auto Flash Attention
HuggingFace:
attn_implementation="flash_attention_2"
Longformer:
LongformerModel.from_pretrained("allenai/longformer-base-4096")
Custom:
Inherit from nn.Module, implement forward()
```
## Next Steps
After mastering this skill:
- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference
- `llm-specialist/context-window-management`: Manage long contexts in LLMs
- `architecture-design-principles`: Understand broader design trade-offs
**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns).