825 lines
21 KiB
Markdown
825 lines
21 KiB
Markdown
|
||
# Attention Mechanisms Catalog
|
||
|
||
## When to Use This Skill
|
||
|
||
Use this skill when you need to:
|
||
- ✅ Select attention mechanism for long sequences (> 2k tokens)
|
||
- ✅ Optimize memory usage (GPU OOM errors)
|
||
- ✅ Speed up training or inference
|
||
- ✅ Understand exact vs approximate attention trade-offs
|
||
- ✅ Choose between Flash, sparse, or linear attention
|
||
- ✅ Implement cross-attention for multimodal models
|
||
|
||
**Do NOT use this skill for:**
|
||
- ❌ Basic Transformer understanding (use `transformer-architecture-deepdive`)
|
||
- ❌ High-level architecture selection (use `using-neural-architectures`)
|
||
- ❌ LLM-specific optimization (use `llm-specialist/llm-inference-optimization`)
|
||
|
||
|
||
## Core Principle
|
||
|
||
**Not all attention is O(n²).** Standard self-attention has quadratic complexity, but modern variants achieve:
|
||
- **O(n²) with less memory**: Flash Attention (exact, 4x less memory)
|
||
- **O(n × w)**: Sparse attention (exact, sliding window)
|
||
- **O(n)**: Linear attention (approximate, 1-3% accuracy loss)
|
||
|
||
**Default recommendation:** Flash Attention (exact + fast + memory-efficient)
|
||
|
||
|
||
## Part 1: Complexity Hierarchy
|
||
|
||
### Standard Self-Attention (Baseline)
|
||
|
||
**Formula:**
|
||
```python
|
||
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
|
||
```
|
||
|
||
**Complexity:**
|
||
- Time: O(n² · d) where n = seq_len, d = d_model
|
||
- Memory: O(n²) for attention matrix
|
||
- Exact: Yes (no approximation)
|
||
|
||
**Memory breakdown (4k tokens, d=768):**
|
||
```
|
||
Attention scores: 4096² × 4 bytes = 64MB per layer
|
||
Multi-head (12 heads): 64MB × 12 = 768MB per layer
|
||
16 layers: 768MB × 16 = 12GB just for attention!
|
||
Batch size 8: 12GB × 8 = 96GB (impossible on single GPU)
|
||
```
|
||
|
||
**When to use:**
|
||
- Sequence length < 2k tokens
|
||
- Standard use case (most models)
|
||
- Pair with Flash Attention optimization
|
||
|
||
**Limitations:**
|
||
- Memory explosion for long sequences
|
||
- Quadratic scaling impractical beyond 4k tokens
|
||
|
||
|
||
## Part 2: Flash Attention ⭐ (Modern Default)
|
||
|
||
### What is Flash Attention?
|
||
|
||
**Breakthrough (2022):** Exact attention with 4x less memory, 2-3x faster
|
||
|
||
**Key insight:**
|
||
- Standard attention is **memory-bound** (not compute-bound)
|
||
- GPUs: Fast compute (TFLOPS), slow memory bandwidth (GB/s)
|
||
- Bottleneck: Moving n² attention matrix to/from HBM
|
||
|
||
**Solution:**
|
||
- Tile attention computation
|
||
- Recompute instead of store intermediate values
|
||
- Fuse operations (reduce memory transfers)
|
||
- Result: Same O(n²) compute, O(n) memory
|
||
|
||
### Algorithm
|
||
|
||
```
|
||
Standard attention (3 memory operations):
|
||
1. Compute scores: S = Q K^T (store n² matrix)
|
||
2. Softmax: P = softmax(S) (store n² matrix)
|
||
3. Output: O = P V (store n×d matrix)
|
||
|
||
Flash Attention (tiled):
|
||
1. Divide Q, K, V into blocks
|
||
2. For each Q block:
|
||
- Load block to SRAM (fast memory)
|
||
- For each K, V block:
|
||
- Compute attention for this tile
|
||
- Update output incrementally
|
||
- Never materialize full n² matrix!
|
||
3. Result: Same output, O(n) memory
|
||
```
|
||
|
||
### Performance
|
||
|
||
**Benchmarks (A100 GPU, 2k tokens):**
|
||
|
||
Standard attention:
|
||
- Memory: 4GB for batch_size=8
|
||
- Speed: 150ms/batch
|
||
- Max batch size: 16
|
||
|
||
Flash Attention:
|
||
- Memory: 1GB for batch_size=8 **(4x reduction)**
|
||
- Speed: 75ms/batch **(2x faster)**
|
||
- Max batch size: 64 **(4x larger)**
|
||
|
||
**Flash Attention 2 (2023 update):**
|
||
- Further optimized: 2-3x faster than Flash Attention 1
|
||
- Better parallelism
|
||
- Supports more head dimensions
|
||
|
||
### When to Use
|
||
|
||
✅ **ALWAYS use Flash Attention when:**
|
||
- Sequence length < 16k tokens
|
||
- Need exact attention (no approximation)
|
||
- Available in your framework
|
||
|
||
**It's a FREE LUNCH:**
|
||
- No accuracy loss (mathematically exact)
|
||
- Faster training AND inference
|
||
- Less memory usage
|
||
- Drop-in replacement
|
||
|
||
### Implementation
|
||
|
||
**PyTorch 2.0+ (built-in):**
|
||
```python
|
||
import torch.nn.functional as F
|
||
|
||
# Automatic Flash Attention (if available)
|
||
output = F.scaled_dot_product_attention(
|
||
query, key, value,
|
||
attn_mask=None,
|
||
dropout_p=0.0,
|
||
is_causal=False
|
||
)
|
||
# PyTorch automatically uses Flash Attention if:
|
||
# - CUDA available
|
||
# - Sequence length suitable
|
||
# - No attention mask (or causal mask)
|
||
```
|
||
|
||
**HuggingFace Transformers:**
|
||
```python
|
||
from transformers import AutoModel
|
||
|
||
# Enable Flash Attention 2
|
||
model = AutoModel.from_pretrained(
|
||
"bert-base-uncased",
|
||
attn_implementation="flash_attention_2", # Requires flash-attn package
|
||
torch_dtype=torch.float16
|
||
)
|
||
```
|
||
|
||
**Manual installation:**
|
||
```bash
|
||
pip install flash-attn --no-build-isolation
|
||
```
|
||
|
||
### Limitations
|
||
|
||
❌ **Flash Attention NOT suitable when:**
|
||
- Sequence length > 16k (memory still grows quadratically)
|
||
- Custom attention masks (complex patterns not supported)
|
||
- Inference on CPU (CUDA-only)
|
||
|
||
**For > 16k tokens:** Use sparse or linear attention
|
||
|
||
|
||
## Part 3: Sparse Attention (Exact for Long Sequences)
|
||
|
||
### Concept
|
||
|
||
**Idea:** Each token attends to subset of tokens (not all)
|
||
- Sliding window: Local context
|
||
- Global tokens: Long-range connections
|
||
- Result: O(n × window_size) instead of O(n²)
|
||
|
||
**Key property:** Still EXACT attention (not approximate)
|
||
- Just more structured attention pattern
|
||
- No accuracy loss if pattern matches task
|
||
|
||
### Variant 1: Longformer
|
||
|
||
**Pattern:** Sliding window + global attention
|
||
|
||
```
|
||
Attention pattern (window=2, global=[0]):
|
||
0 1 2 3 4 5
|
||
0 [ 1 1 1 1 1 1 ] ← Global token (attends to all)
|
||
1 [ 1 1 1 0 0 0 ] ← Window: tokens 0-2
|
||
2 [ 1 1 1 1 0 0 ] ← Window: tokens 1-3
|
||
3 [ 1 0 1 1 1 0 ] ← Window: tokens 2-4
|
||
4 [ 1 0 0 1 1 1 ] ← Window: tokens 3-5
|
||
5 [ 1 0 0 0 1 1 ] ← Window: tokens 4-5
|
||
|
||
Complexity: O(n × (window + num_global))
|
||
```
|
||
|
||
**Components:**
|
||
1. **Sliding window**: Each token attends to w/2 tokens before and after
|
||
2. **Global tokens**: Special tokens (like [CLS]) attend to all tokens
|
||
3. **Dilated windows**: Optional (stride > 1 for longer context)
|
||
|
||
**Implementation:**
|
||
```python
|
||
from transformers import LongformerModel
|
||
|
||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||
|
||
# Attention mask (shape: batch, seq_len)
|
||
attention_mask = torch.ones(batch_size, seq_len)
|
||
attention_mask[:, 0] = 2 # 2 = global attention for [CLS] token
|
||
|
||
output = model(input_ids, attention_mask=attention_mask)
|
||
```
|
||
|
||
**Memory comparison (4k tokens, window=512):**
|
||
```
|
||
Standard: 4096² = 16M elements → 64MB
|
||
Longformer: 4096 × 512 = 2M elements → 8MB (8x reduction!)
|
||
```
|
||
|
||
**When to use:**
|
||
- Documents: 4k-16k tokens (legal, scientific papers)
|
||
- Need full context but can't fit O(n²)
|
||
- Task has local + global structure
|
||
|
||
**Pretrained models:**
|
||
- `allenai/longformer-base-4096`: Max 4096 tokens
|
||
- `allenai/longformer-large-4096`: Larger version
|
||
|
||
### Variant 2: BigBird
|
||
|
||
**Pattern:** Random + window + global
|
||
|
||
```
|
||
Attention pattern:
|
||
- Sliding window: Like Longformer
|
||
- Random connections: Each token attends to r random tokens
|
||
- Global tokens: Special tokens attend to all
|
||
|
||
Complexity: O(n × (window + r + num_global))
|
||
```
|
||
|
||
**Key difference from Longformer:**
|
||
- Random connections help information flow
|
||
- Theoretically proven to approximate full attention
|
||
|
||
**When to use:**
|
||
- Similar to Longformer
|
||
- Slightly better for tasks needing long-range
|
||
- Less widely adopted than Longformer
|
||
|
||
**Implementation:**
|
||
```python
|
||
from transformers import BigBirdModel
|
||
|
||
model = BigBirdModel.from_pretrained(
|
||
"google/bigbird-roberta-base",
|
||
attention_type="block_sparse" # or "original_full"
|
||
)
|
||
```
|
||
|
||
### Sparse Attention Decision
|
||
|
||
```
|
||
Sequence length < 4k:
|
||
→ Flash Attention (exact, no pattern needed)
|
||
|
||
Sequence length 4k-16k:
|
||
→ Longformer (sliding window + global)
|
||
→ Best for: Documents, long-form text
|
||
|
||
Sequence length > 16k:
|
||
→ Longformer if possible
|
||
→ Linear attention if Longformer too slow
|
||
```
|
||
|
||
|
||
## Part 4: Linear Attention (Approximate for Very Long)
|
||
|
||
### Concept
|
||
|
||
**Idea:** Approximate softmax attention with linear operations
|
||
- Complexity: O(n × k) where k << n
|
||
- Trade-off: 1-3% accuracy loss
|
||
- Benefit: Can handle very long sequences (> 16k)
|
||
|
||
**Key property:** APPROXIMATE (not exact)
|
||
- Do NOT use if accuracy critical
|
||
- Good for extremely long sequences where exact is impossible
|
||
|
||
### Variant 1: Performer
|
||
|
||
**Method:** Random Fourier Features to approximate softmax(Q K^T)
|
||
|
||
**Formula:**
|
||
```python
|
||
# Standard attention
|
||
Attention(Q, K, V) = softmax(Q K^T) V
|
||
|
||
# Performer approximation
|
||
φ(Q) ≈ φ(K)^T ≈ softmax(Q K^T)
|
||
Attention(Q, K, V) ≈ φ(Q) (φ(K)^T V)
|
||
|
||
# Complexity: O(n × k) where k = feature dimension
|
||
```
|
||
|
||
**Key trick:**
|
||
- Compute φ(K)^T V first: (k × d) matrix (small!)
|
||
- Then multiply by φ(Q): O(n × k × d) instead of O(n² × d)
|
||
- Never materialize n² attention matrix
|
||
|
||
**Implementation:**
|
||
```python
|
||
# From performer-pytorch library
|
||
from performer_pytorch import Performer
|
||
|
||
model = Performer(
|
||
dim=512,
|
||
depth=6,
|
||
heads=8,
|
||
dim_head=64,
|
||
causal=False,
|
||
nb_features=256 # k = number of random features
|
||
)
|
||
```
|
||
|
||
**Accuracy:**
|
||
- Typical loss: 1-2% vs standard attention
|
||
- Depends on nb_features (more features = better approximation)
|
||
- k=256 usually sufficient
|
||
|
||
**When to use:**
|
||
- Sequence length > 16k tokens
|
||
- Accuracy loss acceptable (not critical task)
|
||
- Need better than sparse attention (no structure assumptions)
|
||
|
||
### Variant 2: Linformer
|
||
|
||
**Method:** Project K and V to lower dimension
|
||
|
||
**Formula:**
|
||
```python
|
||
# Standard attention (n × n attention matrix)
|
||
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
|
||
|
||
# Linformer (project K, V to n × k where k << n)
|
||
K_proj = E K # E: (k × n) projection matrix
|
||
V_proj = F V # F: (k × n) projection matrix
|
||
|
||
Attention(Q, K, V) ≈ softmax(Q K_proj^T / √d_k) V_proj
|
||
# Attention matrix: (n × k) instead of (n × n)
|
||
```
|
||
|
||
**Complexity:**
|
||
- Time: O(n × k × d) where k << n
|
||
- Memory: O(n × k) instead of O(n²)
|
||
|
||
**Implementation:**
|
||
```python
|
||
# From linformer library
|
||
from linformer import Linformer
|
||
|
||
model = Linformer(
|
||
dim=512,
|
||
seq_len=8192,
|
||
depth=12,
|
||
heads=8,
|
||
k=256 # Projected dimension
|
||
)
|
||
```
|
||
|
||
**Accuracy:**
|
||
- Typical loss: 1-3% vs standard attention
|
||
- More loss than Performer
|
||
- Fixed sequence length (k is tied to max_seq_len)
|
||
|
||
**When to use:**
|
||
- Fixed-length long sequences
|
||
- Memory more critical than speed
|
||
- Accuracy loss OK (2-3%)
|
||
|
||
### Linear Attention Decision
|
||
|
||
```
|
||
Need exact attention:
|
||
→ Flash Attention or Sparse Attention (NOT linear)
|
||
|
||
Sequence > 16k, accuracy critical:
|
||
→ Sparse Attention (Longformer)
|
||
|
||
Sequence > 16k, accuracy loss OK:
|
||
→ Performer (better) or Linformer
|
||
|
||
Sequence > 100k:
|
||
→ State space models (S4, Mamba, not attention)
|
||
```
|
||
|
||
|
||
## Part 5: Cross-Attention (Multimodal)
|
||
|
||
### Concept
|
||
|
||
**Self-attention:** Q, K, V from same source
|
||
**Cross-attention:** Q from one source, K/V from another
|
||
|
||
**Use cases:**
|
||
- Multimodal: vision → language (image captioning)
|
||
- Seq2seq: source language → target language (translation)
|
||
- RAG: query → document retrieval
|
||
- Conditioning: generation conditioned on context
|
||
|
||
### Architecture
|
||
|
||
```python
|
||
class CrossAttention(nn.Module):
|
||
def __init__(self, d_model, num_heads):
|
||
super().__init__()
|
||
self.mha = MultiHeadAttention(d_model, num_heads)
|
||
|
||
def forward(self, query_source, key_value_source, mask=None):
|
||
# query_source: (batch, n_q, d_model) - e.g., text tokens
|
||
# key_value_source: (batch, n_kv, d_model) - e.g., image patches
|
||
|
||
# Q from query source
|
||
Q = self.W_q(query_source)
|
||
|
||
# K, V from key-value source
|
||
K = self.W_k(key_value_source)
|
||
V = self.W_v(key_value_source)
|
||
|
||
# Attention: (batch, n_q, d_model)
|
||
output = attention(Q, K, V, mask)
|
||
return output
|
||
```
|
||
|
||
### Example: Image Captioning
|
||
|
||
**Task:** Generate caption from image
|
||
|
||
**Architecture:**
|
||
1. **Image Encoder:** ViT processes image → image features (n_patches × d)
|
||
2. **Text Decoder:** Autoregressive text generation
|
||
3. **Cross-Attention:** Text queries image features
|
||
|
||
```python
|
||
class ImageCaptioningDecoder(nn.Module):
|
||
def forward(self, text_tokens, image_features):
|
||
# 1. Self-attention on text (causal)
|
||
text = self.text_self_attention(
|
||
query=text,
|
||
key=text,
|
||
value=text,
|
||
causal_mask=True # Don't see future words
|
||
)
|
||
|
||
# 2. Cross-attention (text queries image)
|
||
text = self.cross_attention(
|
||
query=text, # From text decoder
|
||
key=image_features, # From image encoder
|
||
value=image_features # From image encoder
|
||
# No causal mask! Can attend to all image patches
|
||
)
|
||
|
||
# 3. Feed-forward
|
||
text = self.feed_forward(text)
|
||
|
||
return text
|
||
```
|
||
|
||
**Attention flow:**
|
||
- Text token "cat" → High attention to cat region in image
|
||
- Text token "sitting" → High attention to posture in image
|
||
|
||
### Example: Retrieval-Augmented Generation (RAG)
|
||
|
||
**Task:** Generate answer using retrieved documents
|
||
|
||
```python
|
||
class RAGDecoder(nn.Module):
|
||
def forward(self, query_tokens, document_embeddings):
|
||
# 1. Self-attention on query
|
||
query = self.query_self_attention(query, query, query)
|
||
|
||
# 2. Cross-attention (query → documents)
|
||
query = self.cross_attention(
|
||
query=query, # What we're generating
|
||
key=document_embeddings, # Retrieved docs
|
||
value=document_embeddings # Retrieved docs
|
||
)
|
||
|
||
# Query learns to extract relevant info from docs
|
||
|
||
return query
|
||
```
|
||
|
||
### When to Use Cross-Attention
|
||
|
||
✅ **Use cross-attention when:**
|
||
- Two different modalities (vision + language)
|
||
- Conditioning generation on context (RAG)
|
||
- Seq2seq with different input/output (translation)
|
||
- Query-document matching
|
||
|
||
❌ **Don't use cross-attention when:**
|
||
- Same modality (use self-attention)
|
||
- No clear query vs key-value separation
|
||
|
||
|
||
## Part 6: Other Attention Variants
|
||
|
||
### Axial Attention (2D Images)
|
||
|
||
**Idea:** For 2D data (images), attend along each axis separately
|
||
|
||
```
|
||
Standard 2D attention: H×W tokens → (HW)² attention matrix
|
||
Axial attention:
|
||
- Row attention: Each row attends to itself (H × W²)
|
||
- Column attention: Each column attends to itself (W × H²)
|
||
- Total: O(HW × (H + W)) << O((HW)²)
|
||
```
|
||
|
||
**When to use:**
|
||
- High-resolution images
|
||
- 2D positional structure important
|
||
|
||
### Block-Sparse Attention
|
||
|
||
**Idea:** Divide attention into blocks, attend only within/across blocks
|
||
|
||
**Pattern:**
|
||
```
|
||
Block size = 64 tokens
|
||
- Local block: Attend within same block
|
||
- Vertical stripe: Attend to corresponding position in other blocks
|
||
```
|
||
|
||
**Used in:** Sparse Transformer (OpenAI), GPT-3
|
||
|
||
### Multi-Query Attention (MQA)
|
||
|
||
**Idea:** One K/V head shared across all Q heads
|
||
|
||
**Benefit:**
|
||
- Smaller KV cache during inference
|
||
- Much faster decoding (4-8x)
|
||
- Trade-off: ~1% accuracy loss
|
||
|
||
**Used in:** PaLM, Falcon
|
||
|
||
### Grouped-Query Attention (GQA)
|
||
|
||
**Idea:** Middle ground between multi-head and multi-query
|
||
- Group Q heads share K/V heads
|
||
- Example: 32 Q heads → 8 K/V heads (4:1 ratio)
|
||
|
||
**Benefit:**
|
||
- 4x smaller KV cache
|
||
- Minimal accuracy loss (< 0.5%)
|
||
|
||
**Used in:** LLaMA-2, Mistral
|
||
|
||
|
||
## Part 7: Decision Framework
|
||
|
||
### By Sequence Length
|
||
|
||
```
|
||
< 2k tokens:
|
||
→ Flash Attention
|
||
Exact, fast, standard
|
||
|
||
2k-4k tokens:
|
||
→ Flash Attention
|
||
Still manageable with modern GPUs
|
||
|
||
4k-16k tokens:
|
||
→ Sparse Attention (Longformer, BigBird)
|
||
Exact, designed for documents
|
||
→ OR Flash Attention if batch size = 1
|
||
|
||
> 16k tokens:
|
||
→ Sparse Attention
|
||
If task has local structure
|
||
→ Linear Attention (Performer)
|
||
If accuracy loss OK (1-2%)
|
||
→ State Space Models (S4, Mamba)
|
||
If sequence > 100k
|
||
```
|
||
|
||
### By Memory Constraints
|
||
|
||
```
|
||
GPU OOM with standard attention:
|
||
1. Try Flash Attention (4x less memory, free lunch)
|
||
2. If still OOM, reduce batch size
|
||
3. If batch size = 1 and still OOM, use sparse attention
|
||
4. Last resort: Linear attention (if accuracy loss OK)
|
||
|
||
DON'T:
|
||
- Gradient checkpointing (slower, use Flash Attention instead)
|
||
- Throwing more GPUs (algorithmic problem, not hardware)
|
||
```
|
||
|
||
### By Accuracy Requirements
|
||
|
||
```
|
||
Must be exact (no approximation):
|
||
→ Flash Attention or Sparse Attention
|
||
Never use linear attention!
|
||
|
||
Accuracy loss acceptable (1-3%):
|
||
→ Linear Attention (Performer, Linformer)
|
||
Only for very long sequences (> 16k)
|
||
|
||
Critical task (medical, legal):
|
||
→ Exact attention only
|
||
Flash Attention or Sparse Attention
|
||
```
|
||
|
||
### By Task Type
|
||
|
||
```
|
||
Classification / Understanding:
|
||
→ Standard + Flash Attention
|
||
Sequence usually < 2k
|
||
|
||
Document processing:
|
||
→ Longformer (4096 tokens)
|
||
Designed for documents
|
||
|
||
Generation (LLM):
|
||
→ Flash Attention for training
|
||
→ + GQA/MQA for inference (faster decoding)
|
||
|
||
Multimodal (vision + language):
|
||
→ Cross-attention for modality fusion
|
||
→ Self-attention within each modality
|
||
|
||
Retrieval-augmented:
|
||
→ Cross-attention (query → documents)
|
||
```
|
||
|
||
|
||
## Part 8: Implementation Checklist
|
||
|
||
### Using Flash Attention
|
||
|
||
**PyTorch 2.0+:**
|
||
```python
|
||
# Automatic (recommended)
|
||
output = F.scaled_dot_product_attention(query, key, value)
|
||
|
||
# Verify Flash Attention is used
|
||
import torch.backends.cuda
|
||
print(torch.backends.cuda.flash_sdp_enabled()) # Should be True
|
||
```
|
||
|
||
**HuggingFace:**
|
||
```python
|
||
model = AutoModel.from_pretrained(
|
||
"model-name",
|
||
attn_implementation="flash_attention_2",
|
||
torch_dtype=torch.float16 # Flash Attention needs fp16/bf16
|
||
)
|
||
```
|
||
|
||
**Requirements:**
|
||
- CUDA GPU (not CPU)
|
||
- PyTorch >= 2.0 OR flash-attn package
|
||
- fp16 or bf16 dtype (not fp32)
|
||
|
||
### Using Sparse Attention
|
||
|
||
**Longformer:**
|
||
```python
|
||
from transformers import LongformerModel, LongformerTokenizer
|
||
|
||
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
|
||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||
|
||
# Attention mask
|
||
# 0 = no attention, 1 = local attention, 2 = global attention
|
||
attention_mask = torch.ones(batch_size, seq_len)
|
||
attention_mask[:, 0] = 2 # [CLS] token gets global attention
|
||
|
||
outputs = model(input_ids, attention_mask=attention_mask)
|
||
```
|
||
|
||
**Custom sparse pattern:**
|
||
```python
|
||
# Create custom block-sparse mask
|
||
def create_block_sparse_mask(seq_len, block_size):
|
||
num_blocks = seq_len // block_size
|
||
mask = torch.zeros(seq_len, seq_len)
|
||
|
||
for i in range(num_blocks):
|
||
start = i * block_size
|
||
end = start + block_size
|
||
mask[start:end, start:end] = 1 # Local block
|
||
|
||
return mask
|
||
```
|
||
|
||
### Using Cross-Attention
|
||
|
||
```python
|
||
class DecoderWithCrossAttention(nn.Module):
|
||
def __init__(self, d_model, num_heads):
|
||
super().__init__()
|
||
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
||
self.cross_attn = MultiHeadAttention(d_model, num_heads)
|
||
|
||
def forward(self, decoder_input, encoder_output, causal_mask=None):
|
||
# Self-attention (causal)
|
||
x = self.self_attn(
|
||
query=decoder_input,
|
||
key=decoder_input,
|
||
value=decoder_input,
|
||
mask=causal_mask
|
||
)
|
||
|
||
# Cross-attention (Q from decoder, K/V from encoder)
|
||
x = self.cross_attn(
|
||
query=x, # From decoder
|
||
key=encoder_output, # From encoder
|
||
value=encoder_output, # From encoder
|
||
mask=None # No causal mask for cross-attention!
|
||
)
|
||
|
||
return x
|
||
```
|
||
|
||
|
||
## Part 9: Common Mistakes
|
||
|
||
### Mistake 1: Ignoring Flash Attention
|
||
|
||
**Symptom:** Training slow, high memory usage
|
||
**Fix:** Always use Flash Attention for < 16k tokens
|
||
|
||
### Mistake 2: Using Linear Attention Unnecessarily
|
||
|
||
**Symptom:** 1-3% accuracy loss for no reason
|
||
**Fix:** Use Flash Attention (exact) unless sequence > 16k
|
||
|
||
### Mistake 3: Gradient Checkpointing Instead of Flash Attention
|
||
|
||
**Symptom:** Training 20% slower
|
||
**Fix:** Flash Attention gives memory savings AND speed
|
||
|
||
### Mistake 4: Cross-Attention with Causal Mask
|
||
|
||
**Symptom:** Decoder can't attend to encoder properly
|
||
**Fix:** Causal mask only for self-attention, NOT cross-attention
|
||
|
||
### Mistake 5: Accepting O(n²) Memory
|
||
|
||
**Symptom:** GPU OOM for > 4k tokens
|
||
**Fix:** Use sparse or Flash Attention, don't just add GPUs
|
||
|
||
|
||
## Summary: Quick Reference
|
||
|
||
### Attention Selection
|
||
|
||
```
|
||
Sequence length:
|
||
< 2k → Flash Attention (default)
|
||
2-4k → Flash Attention
|
||
4-16k → Longformer (documents) or Flash Attention (batch=1)
|
||
> 16k → Sparse or Linear Attention
|
||
|
||
Memory constrained:
|
||
First: Try Flash Attention (4x less memory)
|
||
Still OOM: Use sparse attention (Longformer)
|
||
Last resort: Linear attention (accuracy loss)
|
||
|
||
Speed critical:
|
||
Training: Flash Attention (2x faster)
|
||
Inference: Flash Attention + GQA/MQA
|
||
|
||
Accuracy critical:
|
||
Use exact attention only (Flash or Sparse)
|
||
NEVER linear attention
|
||
|
||
Multimodal:
|
||
Cross-attention for modality fusion
|
||
```
|
||
|
||
### Implementation
|
||
|
||
```
|
||
PyTorch 2.0+:
|
||
F.scaled_dot_product_attention() # Auto Flash Attention
|
||
|
||
HuggingFace:
|
||
attn_implementation="flash_attention_2"
|
||
|
||
Longformer:
|
||
LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||
|
||
Custom:
|
||
Inherit from nn.Module, implement forward()
|
||
```
|
||
|
||
|
||
## Next Steps
|
||
|
||
After mastering this skill:
|
||
- `llm-specialist/llm-inference-optimization`: Apply attention optimizations to inference
|
||
- `llm-specialist/context-window-management`: Manage long contexts in LLMs
|
||
- `architecture-design-principles`: Understand broader design trade-offs
|
||
|
||
**Remember:** Flash Attention is the modern default. Use it unless you have a specific reason not to (> 16k tokens, custom patterns).
|