Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 09:00:00 +08:00
commit 955d5c6743
12 changed files with 6996 additions and 0 deletions

View File

@@ -0,0 +1,824 @@
# Attention Mechanisms Catalog
## When to Use This Skill
Use this skill when you need to:
- ✅ Select attention mechanism for long sequences (> 2k tokens)
- ✅ Optimize memory usage (GPU OOM errors)
- ✅ Speed up training or inference
- ✅ Understand exact vs approximate attention trade-offs
- ✅ Choose between Flash, sparse, or linear attention
- ✅ Implement cross-attention for multimodal models
**Do NOT use this skill for:**
- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`)
- ❌ High-level architecture selection (use `using-neural-architectures`)
- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`)
## Core Principle
**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve:
- **O(n²) with less memory**: Flash Attention (exact, 4x less memory)
- **O(n × w)**: Sparse attention (exact, sliding window)
- **O(n)**: Linear attention (approximate, 1-3% accuracy loss)
**Default recommendation:** Flash Attention (exact + fast + memory-efficient)
## Part 1: Complexity Hierarchy
### Standard Self-Attention (Baseline)
**Formula:**
```python
Attention(Q, K, V) = softmax(Q K^T / d_k) V
```
**Complexity:**
- Time: O(n² · d) where n = seq_len, d = d_model
- Memory: O(n²) for attention matrix
- Exact: Yes (no approximation)
**Memory breakdown (4k tokens, d=768):**
```
Attention scores: 4096² × 4 bytes = 64MB per layer
Multi-head (12 heads): 64MB × 12 = 768MB per layer
16 layers: 768MB × 16 = 12GB just for attention!
Batch size 8: 12GB × 8 = 96GB (impossible on single GPU)
```
**When to use:**
- Sequence length < 2k tokens
- Standard use case (most models)
- Pair with Flash Attention optimization
**Limitations:**
- Memory explosion for long sequences
- Quadratic scaling impractical beyond 4k tokens
## Part 2: Flash Attention ⭐ (Modern Default)
### What is Flash Attention?
**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster
**Key insight:**
- Standard attention is **memory-bound** (not compute-bound)
- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s)
- Bottleneck: Moving n² attention matrix to/from HBM
**Solution:**
- Tile attention computation
- Recompute instead of store intermediate values
- Fuse operations (reduce memory transfers)
- Result: Same O(n²) compute, O(n) memory
### Algorithm
```
Standard attention (3 memory operations):
1. Compute scores: S = Q K^T (store n² matrix)
2. Softmax: P = softmax(S) (store n² matrix)
3. Output: O = P V (store n×d matrix)
Flash Attention (tiled):
1. Divide Q, K, V into blocks
2. For each Q block:
- Load block to SRAM (fast memory)
- For each K, V block:
- Compute attention for this tile
- Update output incrementally
- Never materialize full n² matrix!
3. Result: Same output, O(n) memory
```
### Performance
**Benchmarks (A100 GPU, 2k tokens):**
Standard attention:
- Memory: 4GB for batch_size=8
- Speed: 150ms/batch
- Max batch size: 16
Flash Attention:
- Memory: 1GB for batch_size=8 **(4x reduction)**
- Speed: 75ms/batch **(2x faster)**
- Max batch size: 64 **(4x larger)**
**Flash Attention 2 (2023 update):**
- Further optimized: 2-3x faster than Flash Attention 1
- Better parallelism
- Supports more head dimensions
### When to Use
**ALWAYS use Flash Attention when:**
- Sequence length < 16k tokens
- Need exact attention (no approximation)
- Available in your framework
**It's a FREE LUNCH:**
- No accuracy loss (mathematically exact)
- Faster training AND inference
- Less memory usage
- Drop-in replacement
### Implementation
**PyTorch 2.0+ (built-in):**
```python
import torch.nn.functional as F
# Automatic Flash Attention (if available)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
# PyTorch automatically uses Flash Attention if:
# - CUDA available
# - Sequence length suitable
# - No attention mask (or causal mask)
```
**HuggingFace Transformers:**
```python
from transformers import AutoModel
# Enable Flash Attention 2
model = AutoModel.from_pretrained(
"bert-base-uncased",
attn_implementation="flash_attention_2", # Requires flash-attn package
torch_dtype=torch.float16
)
```
**Manual installation:**
```bash
pip install flash-attn --no-build-isolation
```
### Limitations
**Flash Attention NOT suitable when:**
- Sequence length > 16k (memory still grows quadratically)
- Custom attention masks (complex patterns not supported)
- Inference on CPU (CUDA-only)
**For > 16k tokens:** Use sparse or linear attention
## Part 3: Sparse Attention (Exact for Long Sequences)
### Concept
**Idea:** Each token attends to subset of tokens (not all)
- Sliding window: Local context
- Global tokens: Long-range connections
- Result: O(n × window_size) instead of O(n²)
**Key property:** Still EXACT attention (not approximate)
- Just more structured attention pattern
- No accuracy loss if pattern matches task
### Variant 1: Longformer
**Pattern:** Sliding window + global attention
```
Attention pattern (window=2, global=[0]):
0 1 2 3 4 5
0 [ 1 1 1 1 1 1 ] ← Global token (attends to all)
1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2
2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3
3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4
4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5
5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5
Complexity: O(n × (window + num_global))
```
**Components:**
1. **Sliding window**: Each token attends to w/2 tokens before and after
2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens
3. **Dilated windows**: Optional (stride > 1 for longer context)
**Implementation:**
```python
from transformers import LongformerModel
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask (shape: batch, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token
output = model(input_ids, attention_mask=attention_mask)
```
**Memory comparison (4k tokens, window=512):**
```
Standard: 4096² = 16M elements → 64MB
Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!)
```
**When to use:**
- Documents: 4k-16k tokens (legal, scientific papers)
- Need full context but can't fit O(n²)
- Task has local + global structure
**Pretrained models:**
- `allenai/longformer-base-4096`: Max 4096 tokens
- `allenai/longformer-large-4096`: Larger version
### Variant 2: BigBird
**Pattern:** Random + window + global
```
Attention pattern:
- Sliding window: Like Longformer
- Random connections: Each token attends to r random tokens
- Global tokens: Special tokens attend to all
Complexity: O(n × (window + r + num_global))
```
**Key difference from Longformer:**
- Random connections help information flow
- Theoretically proven to approximate full attention
**When to use:**
- Similar to Longformer
- Slightly better for tasks needing long-range
- Less widely adopted than Longformer
**Implementation:**
```python
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained(
"google/bigbird-roberta-base",
attention_type="block_sparse" # or "original_full"
)
```
### Sparse Attention Decision
```
Sequence length < 4k:
→ Flash Attention (exact, no pattern needed)
Sequence length 4k-16k:
→ Longformer (sliding window + global)
→ Best for: Documents, long-form text
Sequence length > 16k:
→ Longformer if possible
→ Linear attention if Longformer too slow
```
## Part 4: Linear Attention (Approximate for Very Long)
### Concept
**Idea:** Approximate softmax attention with linear operations
- Complexity: O(n × k) where k << n
- Trade-off: 1-3% accuracy loss
- Benefit: Can handle very long sequences (> 16k)
**Key property:** APPROXIMATE (not exact)
- Do NOT use if accuracy critical
- Good for extremely long sequences where exact is impossible
### Variant 1: Performer
**Method:** Random Fourier Features to approximate softmax(Q K^T)
**Formula:**
```python
# Standard attention
Attention(Q, K, V) = softmax(Q K^T) V
# Performer approximation
φ(Q) φ(K)^T softmax(Q K^T)
Attention(Q, K, V) φ(Q) (φ(K)^T V)
# Complexity: O(n × k) where k = feature dimension
```
**Key trick:**
- Compute φ(K)^T V first: (k × d) matrix (small!)
- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d)
- Never materialize n² attention matrix
**Implementation:**
```python
# From performer-pytorch library
from performer_pytorch import Performer
model = Performer(
dim=512,
depth=6,
heads=8,
dim_head=64,
causal=False,
nb_features=256 # k = number of random features
)
```
**Accuracy:**
- Typical loss: 1-2% vs standard attention
- Depends on nb_features (more features = better approximation)
- k=256 usually sufficient
**When to use:**
- Sequence length > 16k tokens
- Accuracy loss acceptable (not critical task)
- Need better than sparse attention (no structure assumptions)
### Variant 2: Linformer
**Method:** Project K and V to lower dimension
**Formula:**
```python
# Standard attention (n × n attention matrix)
Attention(Q, K, V) = softmax(Q K^T / d_k) V
# Linformer (project K, V to n × k where k << n)
K_proj = E K # E: (k × n) projection matrix
V_proj = F V # F: (k × n) projection matrix
Attention(Q, K, V) softmax(Q K_proj^T / d_k) V_proj
# Attention matrix: (n × k) instead of (n × n)
```
**Complexity:**
- Time: O(n × k × d) where k << n
- Memory: O(n × k) instead of O(n²)
**Implementation:**
```python
# From linformer library
from linformer import Linformer
model = Linformer(
dim=512,
seq_len=8192,
depth=12,
heads=8,
k=256 # Projected dimension
)
```
**Accuracy:**
- Typical loss: 1-3% vs standard attention
- More loss than Performer
- Fixed sequence length (k is tied to max_seq_len)
**When to use:**
- Fixed-length long sequences
- Memory more critical than speed
- Accuracy loss OK (2-3%)
### Linear Attention Decision
```
Need exact attention:
→ Flash Attention or Sparse Attention (NOT linear)
Sequence > 16k, accuracy critical:
→ Sparse Attention (Longformer)
Sequence > 16k, accuracy loss OK:
→ Performer (better) or Linformer
Sequence > 100k:
→ State space models (S4, Mamba, not attention)
```
## Part 5: Cross-Attention (Multimodal)
### Concept
**Self-attention:** Q, K, V from same source
**Cross-attention:** Q from one source, K/V from another
**Use cases:**
- Multimodal: vision → language (image captioning)
- Seq2seq: source language → target language (translation)
- RAG: query → document retrieval
- Conditioning: generation conditioned on context
### Architecture
```python
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
def forward(self, query_source, key_value_source, mask=None):
# query_source: (batch, n_q, d_model) - e.g., text tokens
# key_value_source: (batch, n_kv, d_model) - e.g., image patches
# Q from query source
Q = self.W_q(query_source)
# K, V from key-value source
K = self.W_k(key_value_source)
V = self.W_v(key_value_source)
# Attention: (batch, n_q, d_model)
output = attention(Q, K, V, mask)
return output
```
### Example: Image Captioning
**Task:** Generate caption from image
**Architecture:**
1. **Image Encoder:** ViT processes image → image features (n_patches × d)
2. **Text Decoder:** Autoregressive text generation
3. **Cross-Attention:** Text queries image features
```python
class ImageCaptioningDecoder(nn.Module):
def forward(self, text_tokens, image_features):
# 1. Self-attention on text (causal)
text = self.text_self_attention(
query=text,
key=text,
value=text,
causal_mask=True # Don't see future words
)
# 2. Cross-attention (text queries image)
text = self.cross_attention(
query=text, # From text decoder
key=image_features, # From image encoder
value=image_features # From image encoder
# No causal mask! Can attend to all image patches
)
# 3. Feed-forward
text = self.feed_forward(text)
return text
```
**Attention flow:**
- Text token "cat" → High attention to cat region in image
- Text token "sitting" → High attention to posture in image
### Example: Retrieval-Augmented Generation (RAG)
**Task:** Generate answer using retrieved documents
```python
class RAGDecoder(nn.Module):
def forward(self, query_tokens, document_embeddings):
# 1. Self-attention on query
query = self.query_self_attention(query, query, query)
# 2. Cross-attention (query → documents)
query = self.cross_attention(
query=query, # What we're generating
key=document_embeddings, # Retrieved docs
value=document_embeddings # Retrieved docs
)
# Query learns to extract relevant info from docs
return query
```
### When to Use Cross-Attention
**Use cross-attention when:**
- Two different modalities (vision + language)
- Conditioning generation on context (RAG)
- Seq2seq with different input/output (translation)
- Query-document matching
**Don't use cross-attention when:**
- Same modality (use self-attention)
- No clear query vs key-value separation
## Part 6: Other Attention Variants
### Axial Attention (2D Images)
**Idea:** For 2D data (images), attend along each axis separately
```
Standard 2D attention: H×W tokens → (HW)² attention matrix
Axial attention:
- Row attention: Each row attends to itself (H × W²)
- Column attention: Each column attends to itself (W × H²)
- Total: O(HW × (H + W)) << O((HW)²)
```
**When to use:**
- High-resolution images
- 2D positional structure important
### Block-Sparse Attention
**Idea:** Divide attention into blocks, attend only within/across blocks
**Pattern:**
```
Block size = 64 tokens
- Local block: Attend within same block
- Vertical stripe: Attend to corresponding position in other blocks
```
**Used in:** Sparse Transformer (OpenAI), GPT-3
### Multi-Query Attention (MQA)
**Idea:** One K/V head shared across all Q heads
**Benefit:**
- Smaller KV cache during inference
- Much faster decoding (4-8x)
- Trade-off: ~1% accuracy loss
**Used in:** PaLM, Falcon
### Grouped-Query Attention (GQA)
**Idea:** Middle ground between multi-head and multi-query
- Group Q heads share K/V heads
- Example: 32 Q heads → 8 K/V heads (4:1 ratio)
**Benefit:**
- 4x smaller KV cache
- Minimal accuracy loss (< 0.5%)
**Used in:** LLaMA-2, Mistral
## Part 7: Decision Framework
### By Sequence Length
```
< 2k tokens:
→ Flash Attention
Exact, fast, standard
2k-4k tokens:
→ Flash Attention
Still manageable with modern GPUs
4k-16k tokens:
→ Sparse Attention (Longformer, BigBird)
Exact, designed for documents
→ OR Flash Attention if batch size = 1
> 16k tokens:
→ Sparse Attention
If task has local structure
→ Linear Attention (Performer)
If accuracy loss OK (1-2%)
→ State Space Models (S4, Mamba)
If sequence > 100k
```
### By Memory Constraints
```
GPU OOM with standard attention:
1. Try Flash Attention (4x less memory, free lunch)
2. If still OOM, reduce batch size
3. If batch size = 1 and still OOM, use sparse attention
4. Last resort: Linear attention (if accuracy loss OK)
DON'T:
- Gradient checkpointing (slower, use Flash Attention instead)
- Throwing more GPUs (algorithmic problem, not hardware)
```
### By Accuracy Requirements
```
Must be exact (no approximation):
→ Flash Attention or Sparse Attention
Never use linear attention!
Accuracy loss acceptable (1-3%):
→ Linear Attention (Performer, Linformer)
Only for very long sequences (> 16k)
Critical task (medical, legal):
→ Exact attention only
Flash Attention or Sparse Attention
```
### By Task Type
```
Classification / Understanding:
→ Standard + Flash Attention
Sequence usually < 2k
Document processing:
→ Longformer (4096 tokens)
Designed for documents
Generation (LLM):
→ Flash Attention for training
→ + GQA/MQA for inference (faster decoding)
Multimodal (vision + language):
→ Cross-attention for modality fusion
→ Self-attention within each modality
Retrieval-augmented:
→ Cross-attention (query → documents)
```
## Part 8: Implementation Checklist
### Using Flash Attention
**PyTorch 2.0+:**
```python
# Automatic (recommended)
output = F.scaled_dot_product_attention(query, key, value)
# Verify Flash Attention is used
import torch.backends.cuda
print(torch.backends.cuda.flash_sdp_enabled()) # Should be True
```
**HuggingFace:**
```python
model = AutoModel.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16 # Flash Attention needs fp16/bf16
)
```
**Requirements:**
- CUDA GPU (not CPU)
- PyTorch >= 2.0 OR flash-attn package
- fp16 or bf16 dtype (not fp32)
### Using Sparse Attention
**Longformer:**
```python
from transformers import LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
# Attention mask
# 0 = no attention, 1 = local attention, 2 = global attention
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 0] = 2 # [CLS] token gets global attention
outputs = model(input_ids, attention_mask=attention_mask)
```
**Custom sparse pattern:**
```python
# Create custom block-sparse mask
def create_block_sparse_mask(seq_len, block_size):
num_blocks = seq_len // block_size
mask = torch.zeros(seq_len, seq_len)
for i in range(num_blocks):
start = i * block_size
end = start + block_size
mask[start:end, start:end] = 1 # Local block
return mask
```
### Using Cross-Attention
```python
class DecoderWithCrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
def forward(self, decoder_input, encoder_output, causal_mask=None):
# Self-attention (causal)
x = self.self_attn(
query=decoder_input,
key=decoder_input,
value=decoder_input,
mask=causal_mask
)
# Cross-attention (Q from decoder, K/V from encoder)
x = self.cross_attn(
query=x, # From decoder
key=encoder_output, # From encoder
value=encoder_output, # From encoder
mask=None # No causal mask for cross-attention!
)
return x
```
## Part 9: Common Mistakes
### Mistake 1: Ignoring Flash Attention
**Symptom:** Training slow, high memory usage
**Fix:** Always use Flash Attention for < 16k tokens
### Mistake 2: Using Linear Attention Unnecessarily
**Symptom:** 1-3% accuracy loss for no reason
**Fix:** Use Flash Attention (exact) unless sequence > 16k
### Mistake 3: Gradient Checkpointing Instead of Flash Attention
**Symptom:** Training 20% slower
**Fix:** Flash Attention gives memory savings AND speed
### Mistake 4: Cross-Attention with Causal Mask
**Symptom:** Decoder can't attend to encoder properly
**Fix:** Causal mask only for self-attention, NOT cross-attention
### Mistake 5: Accepting O(n²) Memory
**Symptom:** GPU OOM for > 4k tokens
**Fix:** Use sparse or Flash Attention, don't just add GPUs
## Summary: Quick Reference
### Attention Selection
```
Sequence length:
< 2k → Flash Attention (default)
2-4k → Flash Attention
4-16k → Longformer (documents) or Flash Attention (batch=1)
> 16k → Sparse or Linear Attention
Memory constrained:
First: Try Flash Attention (4x less memory)
Still OOM: Use sparse attention (Longformer)
Last resort: Linear attention (accuracy loss)
Speed critical:
Training: Flash Attention (2x faster)
Inference: Flash Attention + GQA/MQA
Accuracy critical:
Use exact attention only (Flash or Sparse)
NEVER linear attention
Multimodal:
Cross-attention for modality fusion
```
### Implementation
```
PyTorch 2.0+:
F.scaled_dot_product_attention() # Auto Flash Attention
HuggingFace:
attn_implementation="flash_attention_2"
Longformer:
LongformerModel.from_pretrained("allenai/longformer-base-4096")
Custom:
Inherit from nn.Module, implement forward()
```
## Next Steps
After mastering this skill:
- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference
- `llm-specialist/context-window-management`: Manage long contexts in LLMs
- `architecture-design-principles`: Understand broader design trade-offs
**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns).