1030 lines
31 KiB
Markdown
1030 lines
31 KiB
Markdown
|
|
# Tensor Operations and Memory Management
|
|
|
|
## Overview
|
|
|
|
**Core Principle:** PyTorch memory issues stem from understanding tensor lifecycle, operation efficiency, and device management. Fix at the operation level, not by adding more RAM.
|
|
|
|
Memory leaks, OOM errors, and slow operations are symptoms. Root causes are: gradient retention, inefficient operations, device inconsistency, or Python reference cycles. Systematic diagnosis beats guessing.
|
|
|
|
## When to Use
|
|
|
|
**Use this skill when:**
|
|
- "CUDA out of memory" or GPU memory growing over time
|
|
- Training/inference slower than expected on GPU
|
|
- "CUDA error: device-side assert triggered"
|
|
- Memory usage doesn't decrease after batch/epoch
|
|
- Tensor operations causing performance bottlenecks
|
|
- Mixed precision training causing crashes
|
|
|
|
**Don't use when:**
|
|
- Model architecture design (use neural-architectures)
|
|
- Training convergence issues (use training-optimization)
|
|
- Multi-GPU distributed training strategy (use distributed-training-strategies)
|
|
|
|
**Symptoms triggering this skill:**
|
|
- "Memory keeps growing each epoch"
|
|
- "GPU utilization low but training slow"
|
|
- "Random CUDA crashes"
|
|
- "Tensor operations taking too long"
|
|
|
|
|
|
## Memory Leak Diagnosis Methodology
|
|
|
|
### Systematic Debugging Steps
|
|
|
|
**1. Identify When Memory Grows**
|
|
```python
|
|
import torch
|
|
import gc
|
|
|
|
def diagnose_memory_growth():
|
|
"""Track memory at key points"""
|
|
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
|
print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
|
|
```
|
|
|
|
**Call after:**
|
|
- Each forward pass
|
|
- Each backward pass
|
|
- Each optimizer step
|
|
- Each epoch
|
|
|
|
**What to look for:**
|
|
- Reserved memory grows → fragmentation or retention
|
|
- Allocated grows → tensors not released
|
|
- Max allocated grows → peak usage increasing
|
|
|
|
|
|
**2. Check Gradient Accumulation**
|
|
|
|
```python
|
|
# ❌ WRONG: Gradients accumulate indefinitely
|
|
for epoch in range(num_epochs):
|
|
for batch in dataloader:
|
|
outputs = model(batch)
|
|
loss = criterion(outputs, targets)
|
|
loss.backward() # Gradients ACCUMULATE without clearing!
|
|
# Missing: optimizer.zero_grad()
|
|
|
|
# ✅ CORRECT: Clear gradients each iteration
|
|
for epoch in range(num_epochs):
|
|
for batch in dataloader:
|
|
optimizer.zero_grad() # or model.zero_grad(set_to_none=True)
|
|
outputs = model(batch)
|
|
loss = criterion(outputs, targets)
|
|
loss.backward()
|
|
optimizer.step()
|
|
```
|
|
|
|
**Why this matters:** Each `.backward()` call adds to existing gradients. Without clearing, memory grows unbounded.
|
|
|
|
**set_to_none=True advantage:** More memory efficient than zero-filling.
|
|
|
|
|
|
**3. Check Tensor Detachment**
|
|
|
|
```python
|
|
# ❌ WRONG: Retaining computation graph
|
|
train_losses = []
|
|
for batch in dataloader:
|
|
loss = compute_loss(batch)
|
|
train_losses.append(loss) # Keeps entire graph!
|
|
loss.backward()
|
|
|
|
# ✅ CORRECT: Detach from graph
|
|
train_losses = []
|
|
for batch in dataloader:
|
|
loss = compute_loss(batch)
|
|
train_losses.append(loss.item()) # .item() extracts scalar, breaks graph
|
|
loss.backward()
|
|
|
|
# ✅ ALSO CORRECT: Detach explicitly
|
|
intermediate_results = []
|
|
for batch in dataloader:
|
|
output = model(batch)
|
|
intermediate_results.append(output.detach()) # Breaks gradient tracking
|
|
```
|
|
|
|
**Why this matters:** Storing tensors with gradients keeps entire computation graph in memory. Use `.item()` for scalars, `.detach()` for tensors you need.
|
|
|
|
|
|
**4. Check Hidden State Retention (RNNs/Attention)**
|
|
|
|
```python
|
|
# ❌ WRONG: Hidden states retain gradients across batches
|
|
hidden = None
|
|
for batch in dataloader:
|
|
output, hidden = rnn(batch, hidden) # hidden retains graph from ALL previous batches!
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
|
|
# ✅ CORRECT: Detach hidden state each batch
|
|
hidden = None
|
|
for batch in dataloader:
|
|
if hidden is not None:
|
|
hidden = hidden.detach() # Break gradient chain
|
|
output, hidden = rnn(batch, hidden)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
```
|
|
|
|
**Why this matters:** RNN/LSTM/GRU hidden states chain gradients across batches, causing unbounded memory growth.
|
|
|
|
|
|
**5. Check Evaluation Context**
|
|
|
|
```python
|
|
# ❌ WRONG: Evaluation builds computation graphs
|
|
model.eval()
|
|
for batch in val_loader:
|
|
outputs = model(batch) # Still tracks gradients!
|
|
val_loss = criterion(outputs, targets)
|
|
|
|
# ✅ CORRECT: Disable gradient tracking
|
|
model.eval()
|
|
with torch.no_grad(): # Critical for memory efficiency
|
|
for batch in val_loader:
|
|
outputs = model(batch)
|
|
val_loss = criterion(outputs, targets)
|
|
|
|
# ✅ ALSO USEFUL: Inference mode (even more efficient)
|
|
with torch.inference_mode():
|
|
for batch in val_loader:
|
|
outputs = model(batch)
|
|
```
|
|
|
|
**Why this matters:** `torch.no_grad()` prevents graph building. `torch.inference_mode()` additionally disables autograd metadata for maximum efficiency.
|
|
|
|
|
|
**6. Check for Python Reference Cycles**
|
|
|
|
```python
|
|
# ❌ WRONG: Closure captures self
|
|
class Trainer:
|
|
def __init__(self):
|
|
self.callbacks = []
|
|
|
|
def add_callback(self):
|
|
# Lambda captures 'self', creates cycle
|
|
self.callbacks.append(lambda: self.model.train())
|
|
|
|
# ✅ CORRECT: Use weak references or explicit cleanup
|
|
class Trainer:
|
|
def __init__(self):
|
|
self.callbacks = []
|
|
|
|
def clear_callbacks(self):
|
|
self.callbacks.clear() # Explicit cleanup
|
|
gc.collect() # Force garbage collection
|
|
```
|
|
|
|
**When to check:** If memory not freed after training loop completes.
|
|
|
|
**Tools:**
|
|
```python
|
|
import gc
|
|
gc.collect() # Force collection
|
|
torch.cuda.empty_cache() # Release cached memory to OS
|
|
```
|
|
|
|
|
|
## Efficient Tensor Operations
|
|
|
|
### Memory-Efficient Operation Patterns
|
|
|
|
**1. In-Place Operations**
|
|
|
|
```python
|
|
# ❌ WRONG: Creates new tensor each time
|
|
x = torch.randn(1000, 1000)
|
|
x = x + 1 # Allocates new memory
|
|
x = x * 2 # Allocates new memory
|
|
x = torch.relu(x) # Allocates new memory
|
|
|
|
# ✅ CORRECT: In-place operations
|
|
x = torch.randn(1000, 1000)
|
|
x += 1 # In-place, no new allocation
|
|
x *= 2 # In-place
|
|
x.relu_() # In-place (note underscore)
|
|
|
|
# ⚠️ CAUTION: Don't use in-place on tensors needing gradients
|
|
x.requires_grad = True
|
|
x += 1 # ❌ Error! In-place on tensor with gradients
|
|
```
|
|
|
|
**When to use:** Loop iterations, activations in eval mode, preprocessing.
|
|
|
|
**When NOT to use:** Tensors in computation graph (breaks autograd).
|
|
|
|
|
|
**2. Contiguous Tensors**
|
|
|
|
```python
|
|
# Problem: Non-contiguous tensors slow down operations
|
|
x = torch.randn(100, 100)
|
|
x_t = x.t() # Transpose creates VIEW, not contiguous
|
|
|
|
# Check contiguity
|
|
print(x_t.is_contiguous()) # False
|
|
|
|
# ❌ SLOW: Operations on non-contiguous tensors
|
|
result = x_t + 1 # Has to handle strided memory access
|
|
|
|
# ✅ FAST: Make contiguous first
|
|
x_t = x_t.contiguous()
|
|
result = x_t + 1 # Sequential memory access, much faster
|
|
```
|
|
|
|
**Common sources of non-contiguous tensors:**
|
|
- `.transpose()`, `.permute()`, `.view()` (sometimes), indexing
|
|
|
|
**Rule of thumb:** If doing many operations on a tensor, call `.contiguous()` once upfront.
|
|
|
|
|
|
**3. Device Placement Efficiency**
|
|
|
|
```python
|
|
# ❌ VERY SLOW: Repeated CPU-GPU transfers
|
|
for batch in dataloader:
|
|
batch = batch.cuda() # Transfer every iteration
|
|
output = model(batch)
|
|
loss = criterion(output, target.cuda()) # Another transfer!
|
|
|
|
# ✅ FAST: Transfer once, keep on GPU
|
|
model = model.cuda()
|
|
criterion = criterion.cuda()
|
|
for batch in dataloader:
|
|
batch = batch.cuda() # Only data transfer needed
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
|
|
# ✅ EVEN BETTER: Pin memory for async transfer
|
|
dataloader = DataLoader(dataset, pin_memory=True, num_workers=4)
|
|
for batch in dataloader:
|
|
batch = batch.cuda(non_blocking=True) # Async transfer
|
|
output = model(batch)
|
|
```
|
|
|
|
**Why this matters:** CPU↔GPU transfers are slow (PCIe bandwidth). Minimize transfers.
|
|
|
|
|
|
**4. Broadcasting Awareness**
|
|
|
|
```python
|
|
# Broadcasting can create large intermediate tensors
|
|
x = torch.randn(1000, 1000, 100) # 400 MB
|
|
y = torch.randn(100) # 400 bytes
|
|
|
|
# ❌ MEMORY INEFFICIENT: Broadcasting creates full tensor
|
|
result = x + y # y broadcasts to (1000, 1000, 100) temporarily
|
|
|
|
# ✅ MORE EFFICIENT: Explicit broadcasting with memory awareness
|
|
result = x.add(y) # Same operation, PyTorch optimizes
|
|
|
|
# ✅ BEST: Fused operations when possible
|
|
result = torch.addcmul(x, y, value=1.0) # Fused multiply-add
|
|
```
|
|
|
|
**Profiling broadcasting:**
|
|
```python
|
|
import torch.utils.benchmark as benchmark
|
|
|
|
t = benchmark.Timer(
|
|
stmt='x + y',
|
|
globals={'x': x, 'y': y}
|
|
)
|
|
print(t.timeit(100))
|
|
```
|
|
|
|
|
|
**5. Memory Pooling and Allocation**
|
|
|
|
```python
|
|
# ❌ WRONG: Allocating inside loop
|
|
for epoch in range(100):
|
|
for batch in dataloader:
|
|
temp = torch.zeros(1024, 1024).cuda() # Allocate every iteration!
|
|
result = process(batch, temp)
|
|
|
|
# ✅ CORRECT: Pre-allocate reusable buffers
|
|
temp_buffer = torch.zeros(1024, 1024).cuda() # Allocate once
|
|
for epoch in range(100):
|
|
for batch in dataloader:
|
|
temp_buffer.zero_() # Reset in-place
|
|
result = process(batch, temp_buffer)
|
|
```
|
|
|
|
**Why this matters:** Memory allocation/deallocation has overhead. Reuse buffers when size is fixed.
|
|
|
|
|
|
## Device Management Best Practices
|
|
|
|
### Systematic Device Consistency
|
|
|
|
**1. Device Checking Methodology**
|
|
|
|
```python
|
|
def check_device_consistency(model, data, target):
|
|
"""Systematic device checking"""
|
|
print(f"Model on: {next(model.parameters()).device}")
|
|
print(f"Data on: {data.device}")
|
|
print(f"Target on: {target.device}")
|
|
|
|
# Check all model parameters on same device
|
|
devices = {p.device for p in model.parameters()}
|
|
if len(devices) > 1:
|
|
print(f"⚠️ Model parameters on multiple devices: {devices}")
|
|
|
|
# Check all buffers
|
|
buffer_devices = {b.device for b in model.buffers()}
|
|
if len(buffer_devices) > 1:
|
|
print(f"⚠️ Model buffers on multiple devices: {buffer_devices}")
|
|
|
|
# Use before training starts
|
|
check_device_consistency(model, batch['input'], batch['target'])
|
|
```
|
|
|
|
**When to check:**
|
|
- After model initialization
|
|
- After loading checkpoint
|
|
- Before training starts
|
|
- When debugging device-side asserts
|
|
|
|
|
|
**2. Mixed Precision Context Management**
|
|
|
|
```python
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
# ❌ WRONG: Inconsistent autocast usage
|
|
scaler = GradScaler()
|
|
for batch in dataloader:
|
|
with autocast():
|
|
output = model(batch)
|
|
loss = criterion(output, target) # ❌ Loss computed outside autocast!
|
|
scaler.scale(loss).backward()
|
|
|
|
# ✅ CORRECT: Consistent autocast context
|
|
scaler = GradScaler()
|
|
for batch in dataloader:
|
|
with autocast():
|
|
output = model(batch)
|
|
loss = criterion(output, target) # ✅ Loss inside autocast
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
```
|
|
|
|
**Critical rules:**
|
|
- Forward pass + loss computation in `autocast()` context
|
|
- `scaler.scale()` before `.backward()`
|
|
- `scaler.step()` instead of `optimizer.step()`
|
|
- `scaler.update()` after step
|
|
|
|
|
|
**3. Multi-GPU Device Placement**
|
|
|
|
```python
|
|
# ❌ WRONG: Implicit device assumptions
|
|
model = nn.DataParallel(model) # Wraps model
|
|
output = model(batch) # Output on device 0, but which one?
|
|
loss = criterion(output, target) # ❌ Target might be on wrong device!
|
|
|
|
# ✅ CORRECT: Explicit device management
|
|
device = torch.device("cuda:0")
|
|
model = nn.DataParallel(model).to(device)
|
|
for batch in dataloader:
|
|
batch = batch.to(device)
|
|
target = target.to(device)
|
|
output = model(batch) # Output on device 0
|
|
loss = criterion(output, target) # All on same device
|
|
```
|
|
|
|
**Device placement hierarchy:**
|
|
1. Pin device at start: `device = torch.device("cuda:0")`
|
|
2. Move model once: `model.to(device)`
|
|
3. Move data each batch: `batch.to(device)`
|
|
|
|
|
|
## Performance Profiling
|
|
|
|
### Identifying Bottlenecks
|
|
|
|
**1. Memory Profiling**
|
|
|
|
```python
|
|
import torch.cuda
|
|
|
|
# Profile memory usage
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
# Run operation
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
|
|
# Check memory stats
|
|
print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
|
|
print(f"Current memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
|
|
# Memory summary (detailed)
|
|
print(torch.cuda.memory_summary())
|
|
```
|
|
|
|
**2. Operation Profiling**
|
|
|
|
```python
|
|
from torch.profiler import profile, ProfilerActivity
|
|
|
|
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
|
|
# Print top operations by time
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
# Export for visualization
|
|
prof.export_chrome_trace("trace.json") # View in chrome://tracing
|
|
```
|
|
|
|
**What to look for:**
|
|
- Operations taking >10% of time
|
|
- Unexpected memory allocations
|
|
- CPU-GPU synchronization overhead
|
|
|
|
|
|
**3. Memory Snapshot (PyTorch 2.0+)**
|
|
|
|
```python
|
|
import torch.cuda
|
|
|
|
# Record memory snapshots
|
|
torch.cuda.memory._record_memory_history()
|
|
|
|
# Run training iteration
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
|
|
# Save snapshot
|
|
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
|
torch.cuda.memory._record_memory_history(enabled=None)
|
|
|
|
# Analyze with:
|
|
# python -m torch.cuda._memory_viz trace_plot memory_snapshot.pickle
|
|
```
|
|
|
|
|
|
## Red Flags - Stop and Diagnose
|
|
|
|
**If you catch yourself thinking ANY of these, STOP and follow systematic methodology:**
|
|
|
|
| Red Flag Thought | Reality | What to Do Instead |
|
|
|------------------|---------|-------------------|
|
|
| "I'll just reduce batch size" | Avoids root cause, wastes GPU capacity | Follow memory leak diagnosis first |
|
|
| "I'll add more GPU memory" | Expensive non-solution if there's a leak | Diagnose leak, don't throw hardware at it |
|
|
| "Memory leaks are normal in PyTorch" | FALSE - PyTorch has excellent memory management | There IS a bug in your code, find it |
|
|
| "Too complex to debug, I'll refactor" | Avoidance - same bug will appear in new code | Debug systematically, learn the issue |
|
|
| "Skip profiling, I know what's slow" | Guessing wastes time, profiling gives facts | Always profile before optimizing |
|
|
| "Mixed precision is broken" | AMP works when used correctly | Check autocast context boundaries |
|
|
| "Quick fix: just call empty_cache()" | Doesn't fix leaks, just masks symptoms | Find and fix the leak |
|
|
|
|
**Critical rule:** Memory issues have root causes. Systematic diagnosis ALWAYS faster than guessing.
|
|
|
|
|
|
## Common Rationalizations (Don't Do These)
|
|
|
|
| Excuse | What Really Happens | Correct Approach |
|
|
|--------|-------------------|------------------|
|
|
| "User seems rushed, skip methodology" | Guessing wastes MORE time than systematic diagnosis | 5 minutes of diagnosis saves hours of guessing |
|
|
| "I already tried profiling" | May have looked at wrong metrics or misinterpreted | Re-profile with specific focus from methodology |
|
|
| "This worked on smaller model" | Scaling exposes hidden issues | Same methodology applies, just reveals different bugs |
|
|
| "Documentation says to do X" | May be misunderstanding context or outdated | Check PyTorch version, verify applicability |
|
|
| "I'll optimize later" | Memory issues prevent finishing now | Fix memory first, then optimize if still needed |
|
|
| "It's a CUDA bug" | 99.9% of time it's your code | Assume your bug until proven otherwise |
|
|
|
|
|
|
## Common Pitfalls
|
|
|
|
### Consolidated Pitfall Table
|
|
|
|
| # | Pitfall | Symptom | Root Cause | Fix |
|
|
|---|---------|---------|------------|-----|
|
|
| 1 | Accumulating metrics without detachment | Memory grows linearly with iterations | Storing tensors retains computation graph | Use `.item()` for scalars, `.detach()` for tensors |
|
|
| 2 | Hidden state chaining (RNNs) | Memory grows across batches | Hidden states chain gradients indefinitely | Detach hidden states between batches |
|
|
| 3 | Missing `torch.no_grad()` in eval | High memory usage during validation | Evaluation builds unnecessary graphs | Wrap evaluation in `torch.no_grad()` |
|
|
| 4 | Repeated CPU-GPU transfers | Low GPU utilization, slow training | PCIe bandwidth bottleneck | Move to GPU once, keep there |
|
|
| 5 | Non-contiguous tensor operations | Unexpectedly slow operations | Strided memory access inefficiency | Call `.contiguous()` before repeated ops |
|
|
| 6 | Allocations in loops | Slow iterations, fragmentation | Memory allocation overhead | Pre-allocate and reuse buffers |
|
|
| 7 | Gradient accumulation without clearing | OOM after few iterations | Gradients accumulate unbounded | `optimizer.zero_grad()` every iteration |
|
|
| 8 | Mixed precision context boundaries | Intermittent crashes, NaN values | Loss computed outside autocast | Keep forward + loss inside `autocast()` |
|
|
| 9 | Device inconsistency | "device-side assert" errors | Tensors on different devices | Systematic device checking |
|
|
| 10 | Logging with tensors instead of scalars | Memory growth during training | Retaining graphs for logging | Always use `.item()` for logging |
|
|
|
|
|
|
### Memory Leak Pitfalls
|
|
|
|
❌ **Pitfall 1: Accumulating Metrics Without Detachment**
|
|
```python
|
|
# WRONG
|
|
losses = []
|
|
for batch in dataloader:
|
|
loss = criterion(output, target)
|
|
losses.append(loss) # Retains graph!
|
|
|
|
# CORRECT
|
|
losses = []
|
|
for batch in dataloader:
|
|
loss = criterion(output, target)
|
|
losses.append(loss.item()) # Extract scalar
|
|
```
|
|
**Symptom:** Memory grows linearly with iterations
|
|
**Fix:** Use `.item()` or `.detach()` before storing
|
|
|
|
|
|
❌ **Pitfall 2: Hidden State Chaining (RNNs)**
|
|
```python
|
|
# WRONG
|
|
hidden = None
|
|
for batch in dataloader:
|
|
output, hidden = lstm(batch, hidden)
|
|
|
|
# CORRECT
|
|
hidden = None
|
|
for batch in dataloader:
|
|
if hidden is not None:
|
|
hidden = tuple(h.detach() for h in hidden) # LSTM returns tuple
|
|
output, hidden = lstm(batch, hidden)
|
|
```
|
|
**Symptom:** Memory grows across batches in RNN training
|
|
**Fix:** Detach hidden states between batches
|
|
|
|
|
|
❌ **Pitfall 3: Missing torch.no_grad() in Evaluation**
|
|
```python
|
|
# WRONG
|
|
model.eval()
|
|
for batch in val_loader:
|
|
output = model(batch) # Builds graph!
|
|
|
|
# CORRECT
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
output = model(batch)
|
|
```
|
|
**Symptom:** High memory usage during evaluation
|
|
**Fix:** Always wrap evaluation in `torch.no_grad()`
|
|
|
|
|
|
### Performance Pitfalls
|
|
|
|
❌ **Pitfall 4: Repeated CPU-GPU Transfers**
|
|
```python
|
|
# WRONG
|
|
for batch in dataloader:
|
|
batch = batch.cpu().numpy() # Transfer to CPU
|
|
batch = preprocess(batch) # Process on CPU
|
|
batch = torch.from_numpy(batch).cuda() # Back to GPU
|
|
|
|
# CORRECT
|
|
for batch in dataloader:
|
|
batch = batch.cuda()
|
|
batch = preprocess_gpu(batch) # Keep on GPU
|
|
```
|
|
**Symptom:** Low GPU utilization, slow training
|
|
**Fix:** Minimize CPU↔GPU transfers, use GPU operations
|
|
|
|
|
|
❌ **Pitfall 5: Non-Contiguous Tensor Operations**
|
|
```python
|
|
# WRONG
|
|
x = x.transpose(0, 1) # Creates view
|
|
for _ in range(1000):
|
|
x = x + 1 # Slow on non-contiguous tensor
|
|
|
|
# CORRECT
|
|
x = x.transpose(0, 1).contiguous() # Make contiguous
|
|
for _ in range(1000):
|
|
x = x + 1 # Fast on contiguous tensor
|
|
```
|
|
**Symptom:** Unexpectedly slow operations
|
|
**Fix:** Call `.contiguous()` before repeated operations
|
|
|
|
|
|
❌ **Pitfall 6: Unnecessary Memory Allocations in Loops**
|
|
```python
|
|
# WRONG
|
|
for _ in range(1000):
|
|
temp = torch.zeros(1024, 1024).cuda()
|
|
result = process(temp)
|
|
|
|
# CORRECT
|
|
temp = torch.zeros(1024, 1024).cuda()
|
|
for _ in range(1000):
|
|
temp.zero_() # Reuse buffer
|
|
result = process(temp)
|
|
```
|
|
**Symptom:** Slow iteration, memory fragmentation
|
|
**Fix:** Pre-allocate and reuse buffers
|
|
|
|
|
|
## Debugging Methodology
|
|
|
|
### When You Get CUDA OOM
|
|
|
|
**Step 1: Get current memory state**
|
|
```python
|
|
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
|
print(torch.cuda.memory_summary())
|
|
```
|
|
|
|
**Step 2: Check the obvious**
|
|
- [ ] `optimizer.zero_grad()` called every iteration?
|
|
- [ ] Evaluation wrapped in `torch.no_grad()`?
|
|
- [ ] Storing tensors instead of `.item()`?
|
|
- [ ] RNN hidden states detached?
|
|
|
|
**Step 3: Binary search for leak**
|
|
```python
|
|
# Gradually disable parts of training loop
|
|
# 1. Just forward pass
|
|
# 2. Forward + backward
|
|
# 3. Forward + backward + optimizer step
|
|
# Find where memory grows
|
|
```
|
|
|
|
**Step 4: Profile memory**
|
|
```python
|
|
# Use memory profiler to find exact allocation
|
|
torch.cuda.reset_peak_memory_stats()
|
|
# Run one iteration
|
|
print(torch.cuda.memory_summary())
|
|
```
|
|
|
|
**Step 5: Check for fragmentation**
|
|
```python
|
|
# If reserved >> allocated, you have fragmentation
|
|
allocated = torch.cuda.memory_allocated()
|
|
reserved = torch.cuda.memory_reserved()
|
|
if reserved > allocated * 1.5:
|
|
print("⚠️ Memory fragmentation detected")
|
|
torch.cuda.empty_cache() # May help, but not guaranteed
|
|
```
|
|
|
|
|
|
### When Training is Slow
|
|
|
|
**Step 1: Profile CUDA time**
|
|
```python
|
|
import torch.utils.benchmark as benchmark
|
|
|
|
# Profile one iteration
|
|
def train_step():
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
t = benchmark.Timer(stmt='train_step()', globals=globals())
|
|
print(t.timeit(10))
|
|
```
|
|
|
|
**Step 2: Check GPU utilization**
|
|
```bash
|
|
# In terminal
|
|
nvidia-smi -l 1 # Update every second
|
|
```
|
|
If GPU utilization < 80%, bottleneck is likely:
|
|
- Data loading (use more `num_workers`)
|
|
- CPU preprocessing (move to GPU)
|
|
- CPU-GPU transfers (pin memory)
|
|
|
|
**Step 3: Profile operations**
|
|
```python
|
|
# Use PyTorch profiler to identify slow operations
|
|
with torch.profiler.profile() as prof:
|
|
train_step()
|
|
print(prof.key_averages().table(sort_by="cuda_time_total"))
|
|
```
|
|
|
|
**Step 4: Check for synchronization**
|
|
```python
|
|
# Frequent CPU-GPU synchronization kills performance
|
|
# Common causes:
|
|
# - .item() calls in training loop
|
|
# - .cpu() calls
|
|
# - print() statements with tensor values
|
|
# - Assertions on tensor values
|
|
```
|
|
|
|
|
|
## Edge Cases and Advanced Scenarios
|
|
|
|
### Edge Case 1: Gradient Checkpointing Interaction
|
|
|
|
**Scenario:** Using gradient checkpointing for large models but still getting OOM
|
|
|
|
```python
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
# Gradient checkpointing trades compute for memory
|
|
# But you can still leak memory!
|
|
|
|
# ❌ WRONG: Leaking even with checkpointing
|
|
class Model(nn.Module):
|
|
def forward(self, x):
|
|
# Checkpointing helps, but if you retain intermediate results...
|
|
intermediate = checkpoint(self.layer1, x)
|
|
self.cached_intermediate = intermediate # ❌ Retains graph!
|
|
return checkpoint(self.layer2, intermediate)
|
|
|
|
# ✅ CORRECT: Don't cache checkpointed results
|
|
class Model(nn.Module):
|
|
def forward(self, x):
|
|
intermediate = checkpoint(self.layer1, x)
|
|
return checkpoint(self.layer2, intermediate)
|
|
# No caching, memory saved
|
|
```
|
|
|
|
**Key insight:** Gradient checkpointing recomputes forward pass during backward. Caching checkpointed results defeats the purpose and leaks memory.
|
|
|
|
|
|
### Edge Case 2: Dynamic Computation Graphs
|
|
|
|
**Scenario:** Graph structure changes each iteration (e.g., different sequence lengths, variable branches)
|
|
|
|
```python
|
|
# Dynamic graphs can cause memory issues if not careful
|
|
|
|
# ❌ WRONG: Accumulating different graphs
|
|
graph_stats = []
|
|
for batch in dataloader:
|
|
# Sequence length varies each batch
|
|
output = model(batch) # Different graph each time
|
|
graph_stats.append(output.grad_fn) # ❌ Retains ALL graphs!
|
|
|
|
# ✅ CORRECT: Don't retain grad_fn, detach appropriately
|
|
for batch in dataloader:
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
# Don't store anything with .grad_fn
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
```
|
|
|
|
**Key insight:** Dynamic graphs are fine, but don't accumulate references to different graphs across iterations.
|
|
|
|
|
|
### Edge Case 3: DistributedDataParallel (DDP) Memory Management
|
|
|
|
**Scenario:** DDP training with memory issues
|
|
|
|
```python
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
# ❌ WRONG: Multiple issues with DDP
|
|
model = MyModel().cuda()
|
|
model = DDP(model) # Missing device_ids!
|
|
|
|
for batch in dataloader:
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
# Missing: optimizer.zero_grad() BEFORE backward in some DDP scenarios
|
|
|
|
# ✅ CORRECT: Proper DDP setup
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
|
|
model = MyModel().to(device)
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
|
|
|
for batch in dataloader:
|
|
batch = batch.to(device)
|
|
target = target.to(device)
|
|
|
|
optimizer.zero_grad(set_to_none=True) # Before forward in DDP
|
|
output = model(batch)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
```
|
|
|
|
**Key insights:**
|
|
- Specify `device_ids` and `output_device` explicitly
|
|
- `zero_grad()` placement critical in DDP
|
|
- Ensure all data on correct local device
|
|
|
|
|
|
### Edge Case 4: Custom CUDA Kernels Memory Management
|
|
|
|
**Scenario:** Using custom CUDA kernels or third-party extensions
|
|
|
|
```python
|
|
# Custom CUDA kernels may not play nice with PyTorch's memory management
|
|
|
|
import custom_cuda_kernel # Hypothetical extension
|
|
|
|
# ❌ WRONG: Not checking tensor lifetime
|
|
def forward(x):
|
|
y = custom_cuda_kernel.process(x) # Allocates CUDA memory
|
|
# If kernel doesn't register with PyTorch, memory not tracked!
|
|
return y
|
|
|
|
# ✅ CORRECT: Verify kernel registers memory with PyTorch
|
|
def forward(x):
|
|
y = custom_cuda_kernel.process(x)
|
|
|
|
# Check if memory is tracked
|
|
print(f"PyTorch tracked: {torch.cuda.memory_allocated()}")
|
|
# If custom kernel allocated memory not shown, manual cleanup needed
|
|
|
|
return y
|
|
```
|
|
|
|
**Key insight:** Custom CUDA code may bypass PyTorch's memory tracking. Use `torch.cuda.memory_allocated()` to verify, and ensure custom kernels use PyTorch's allocator.
|
|
|
|
|
|
### Edge Case 5: Nested Autocast Contexts
|
|
|
|
**Scenario:** Nested autocast contexts (e.g., custom training loop with autocast, calling library that also uses autocast)
|
|
|
|
```python
|
|
from torch.cuda.amp import autocast
|
|
|
|
# ❌ POTENTIAL ISSUE: Nested autocast with different settings
|
|
with autocast(): # Outer context
|
|
output1 = model1(x)
|
|
with autocast(enabled=False): # Inner context disables
|
|
output2 = model2(output1) # Back to float32
|
|
# output1 is float16, output2 is float32
|
|
loss = criterion(output1, output2) # Type mismatch possible!
|
|
|
|
# ✅ CORRECT: Be aware of autocast nesting
|
|
with autocast():
|
|
output1 = model1(x)
|
|
# If you need float32 for specific operation:
|
|
with autocast(enabled=False):
|
|
output2_float32 = model2(output1.float()) # Explicit cast
|
|
loss = criterion(output1, output2_float32.half()) # Explicit cast back
|
|
```
|
|
|
|
**Key insight:** Autocast contexts can nest. Be explicit about dtype when mixing precision contexts.
|
|
|
|
|
|
### Edge Case 6: Memory Fragmentation with Varying Batch Sizes
|
|
|
|
**Scenario:** Training with variable batch sizes causing fragmentation
|
|
|
|
```python
|
|
# Variable batch sizes can fragment memory over time
|
|
|
|
# ❌ WRONG: Varying allocations fragment memory pool
|
|
for batch in dataloader: # Batch sizes: 32, 64, 32, 128, 32...
|
|
output = model(batch) # Different allocations each time
|
|
# CUDA memory becomes fragmented
|
|
# Reserved >> Allocated
|
|
|
|
# ✅ BETTER: Use gradient accumulation with fixed effective batch size
|
|
accumulation_steps = 4
|
|
effective_batch_size = 32
|
|
|
|
for i, batch in enumerate(dataloader):
|
|
# Always process fixed size mini-batches
|
|
mini_batch = batch[:effective_batch_size] # Fixed size
|
|
output = model(mini_batch)
|
|
loss = criterion(output, target) / accumulation_steps
|
|
loss.backward()
|
|
|
|
if (i + 1) % accumulation_steps == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# ✅ OR: Periodically defragment
|
|
if epoch % 10 == 0:
|
|
torch.cuda.empty_cache() # Release fragmented memory
|
|
gc.collect()
|
|
```
|
|
|
|
**Key insight:** Variable batch sizes fragment CUDA memory pool. Use fixed sizes or periodic cleanup.
|
|
|
|
|
|
## Quick Reference: Memory & Performance Checklist
|
|
|
|
### Before Training Starts
|
|
- [ ] Model on correct device
|
|
- [ ] All data moved to device once
|
|
- [ ] Pin memory enabled in DataLoader
|
|
- [ ] Gradient checkpointing if needed for large models
|
|
|
|
### Training Loop Must-Haves
|
|
- [ ] `optimizer.zero_grad()` at start of iteration
|
|
- [ ] Only store `.item()` for scalars, not tensors
|
|
- [ ] Detach RNN/LSTM hidden states between batches
|
|
- [ ] Use `torch.no_grad()` for validation
|
|
|
|
### Performance Optimization
|
|
- [ ] Pre-allocate buffers for fixed-size tensors
|
|
- [ ] Call `.contiguous()` before repeated operations on views
|
|
- [ ] Use in-place operations where safe
|
|
- [ ] Minimize CPU-GPU transfers
|
|
- [ ] Use mixed precision (`autocast`) if appropriate
|
|
|
|
### When Debugging
|
|
- [ ] Check memory stats: `torch.cuda.memory_allocated()`
|
|
- [ ] Profile with `torch.profiler`
|
|
- [ ] Check GPU utilization: `nvidia-smi`
|
|
- [ ] Look for non-contiguous tensors: `.is_contiguous()`
|
|
- [ ] Check device consistency across all tensors
|
|
|
|
|
|
## Example: Complete Memory-Efficient Training Loop
|
|
|
|
```python
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
# Setup
|
|
device = torch.device("cuda:0")
|
|
model = MyModel().to(device)
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
scaler = GradScaler() # For mixed precision
|
|
dataloader = DataLoader(dataset, pin_memory=True, num_workers=4)
|
|
|
|
# Pre-allocate if using fixed-size buffers (example)
|
|
# temp_buffer = torch.zeros(batch_size, hidden_dim).to(device)
|
|
|
|
# Training loop
|
|
model.train()
|
|
for epoch in range(num_epochs):
|
|
for batch_idx, (data, target) in enumerate(dataloader):
|
|
# 1. Move data to device (async with pin_memory)
|
|
data = data.to(device, non_blocking=True)
|
|
target = target.to(device, non_blocking=True)
|
|
|
|
# 2. Zero gradients (set_to_none=True for efficiency)
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
# 3. Forward pass with mixed precision
|
|
with autocast():
|
|
output = model(data)
|
|
loss = criterion(output, target)
|
|
|
|
# 4. Backward pass with gradient scaling
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
# 5. Logging (use .item() to avoid retaining graph!)
|
|
if batch_idx % 100 == 0:
|
|
print(f"Loss: {loss.item():.4f}")
|
|
|
|
# 6. Validation (critical: no_grad context)
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for data, target in val_loader:
|
|
data = data.to(device, non_blocking=True)
|
|
target = target.to(device, non_blocking=True)
|
|
output = model(data)
|
|
val_loss += criterion(output, target).item()
|
|
|
|
model.train()
|
|
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader):.4f}")
|
|
|
|
# Optional: Check memory usage
|
|
if epoch % 10 == 0:
|
|
print(f"Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
```
|
|
|
|
**Why this is memory-efficient:**
|
|
1. ✅ Async data transfer with `pin_memory` and `non_blocking=True`
|
|
2. ✅ `set_to_none=True` for gradient zeroing (more efficient)
|
|
3. ✅ Mixed precision with `autocast` (reduces memory)
|
|
4. ✅ Only `.item()` stored for logging (no graph retention)
|
|
5. ✅ Validation wrapped in `torch.no_grad()`
|
|
6. ✅ All tensors on same device (no implicit transfers)
|
|
|
|
|
|
## References
|
|
|
|
**PyTorch Documentation:**
|
|
- Memory Management: https://pytorch.org/docs/stable/notes/cuda.html
|
|
- Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
|
- Mixed Precision: https://pytorch.org/docs/stable/amp.html
|
|
|
|
**Related Skills:**
|
|
- performance-profiling (deeper profiling techniques)
|
|
- distributed-training-strategies (multi-GPU memory management)
|
|
- mixed-precision-and-optimization (detailed autocast usage)
|
|
- debugging-techniques (systematic PyTorch debugging)
|