2443 lines
78 KiB
Markdown
2443 lines
78 KiB
Markdown
|
||
# Gradient Management Skill
|
||
|
||
## When to Use This Skill
|
||
|
||
Use this skill when:
|
||
- Loss becomes NaN or Inf during training
|
||
- Training is unstable with loss spikes
|
||
- User asks about gradient clipping
|
||
- User wants larger batch size but has OOM issues
|
||
- User mentions "exploding gradients" or "vanishing gradients"
|
||
- Gradients are very large (>100) or very small (<1e-8)
|
||
- Implementing gradient accumulation
|
||
- Using mixed precision (AMP) with gradient clipping
|
||
- User asks "why is my training unstable?"
|
||
- Training Transformers, RNNs, or very deep networks
|
||
- User implements gradient accumulation without loss scaling (RED FLAG)
|
||
- User clips gradients after optimizer.step() (RED FLAG)
|
||
- User doesn't unscale before clipping with AMP (RED FLAG)
|
||
- Reinforcement learning (policy gradients often explode)
|
||
- Distributed training with gradient synchronization questions
|
||
- User says "just lower learning rate" for NaN loss (may need clipping)
|
||
|
||
Do NOT use when:
|
||
- Training is stable with no gradient issues
|
||
- User has architecture questions unrelated to gradients
|
||
- User only asks about learning rate (use learning-rate-scheduling skill)
|
||
- User asks about data issues (different problem space)
|
||
|
||
|
||
## Core Principles
|
||
|
||
### 1. The Critical Importance of Gradient Management
|
||
|
||
**Gradients are the foundation of neural network training:**
|
||
- Backpropagation computes gradients of loss w.r.t. parameters
|
||
- Optimizer uses gradients to update parameters
|
||
- Gradient magnitude determines update size
|
||
- Gradient stability determines training stability
|
||
- Wrong gradient handling → training failure (NaN, no convergence)
|
||
|
||
**Common Impact:**
|
||
- Gradient clipping: Difference between training and NaN loss
|
||
- Gradient accumulation: Train with 8x larger effective batch size on same hardware
|
||
- Proper diagnosis: 1-2 hours to fix vs days of confusion
|
||
- Mixed precision integration: 2x speedup without breaking training
|
||
|
||
**This is NOT optional:**
|
||
- Every Transformer paper uses gradient clipping
|
||
- Gradient accumulation is standard for large models
|
||
- Production training code always monitors gradients
|
||
- Ignoring gradients → fragile, unreliable training
|
||
|
||
|
||
### 2. Gradient Flow in Training
|
||
|
||
**Understanding the training loop gradient flow:**
|
||
|
||
```python
|
||
# Step 1: Zero gradients from previous iteration
|
||
optimizer.zero_grad()
|
||
|
||
# Step 2: Forward pass (compute loss)
|
||
output = model(input)
|
||
loss = criterion(output, target)
|
||
|
||
# Step 3: Backward pass (compute gradients)
|
||
# This computes: param.grad = ∂loss/∂param for all parameters
|
||
loss.backward()
|
||
|
||
# Step 4: [OPTIONAL] Modify gradients (clipping, scaling, etc.)
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
# Step 5: Optimizer step (update parameters using gradients)
|
||
# This does: param = param - lr * param.grad (simplified)
|
||
optimizer.step()
|
||
```
|
||
|
||
**Critical ordering:**
|
||
1. Gradients are computed by `backward()`
|
||
2. Gradients can be modified between `backward()` and `step()`
|
||
3. Gradients are consumed by `step()` to update parameters
|
||
4. Gradient modifications MUST happen after `backward()`, before `step()`
|
||
|
||
**Mental model:**
|
||
- `backward()` produces gradients
|
||
- Your code can inspect/modify gradients
|
||
- `step()` consumes gradients to update parameters
|
||
- Modifications after `step()` are useless (gradients already consumed)
|
||
- Modifications before `backward()` are useless (gradients don't exist yet)
|
||
|
||
|
||
## Gradient Clipping
|
||
|
||
### Why Gradient Clipping Matters
|
||
|
||
**The exploding gradients problem:**
|
||
- Deep networks multiply gradients through chain rule
|
||
- Each layer multiplies gradient by weights and activation derivatives
|
||
- If these multiplications are >1, gradients grow exponentially
|
||
- Large gradients → large parameter updates → training instability
|
||
- Extremely large gradients → NaN or Inf loss
|
||
|
||
**Real-world symptoms:**
|
||
- Loss suddenly jumps to NaN after normal training
|
||
- Loss oscillates wildly between iterations
|
||
- Training is stable initially, then diverges
|
||
- Parameters become NaN or Inf
|
||
- Gradient norms >100 or >1000
|
||
|
||
**Why it happens:**
|
||
- Transformers: Attention mechanism can amplify gradients
|
||
- RNNs: Backpropagation through time multiplies gradients across timesteps
|
||
- Very deep networks: Many layers multiply gradients
|
||
- Poor initialization: Large initial weights amplify gradients
|
||
- High learning rates: Amplify already-large gradients
|
||
|
||
### Norm-Based Gradient Clipping (Primary Method)
|
||
|
||
**The standard solution:**
|
||
|
||
```python
|
||
# Clip gradients by global norm
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
# Complete training loop:
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
```
|
||
|
||
**What it does:**
|
||
1. Computes total gradient norm: `total_norm = sqrt(sum(g^2 for g in all gradients))`
|
||
2. If `total_norm > max_norm`:
|
||
- Scaling factor = `max_norm / total_norm`
|
||
- All gradients multiplied by this factor
|
||
3. Result: Gradient direction preserved, magnitude limited
|
||
|
||
**Why this is good:**
|
||
- Preserves gradient direction (doesn't distort signal)
|
||
- Only scales when needed (if total_norm ≤ max_norm, no change)
|
||
- Global view (considers all parameters together)
|
||
- Mathematically elegant (scales gradient vector to unit ball)
|
||
|
||
**Typical values for max_norm:**
|
||
|
||
```python
|
||
# Transformers (BERT, GPT, T5)
|
||
max_norm = 1.0 # Most common
|
||
max_norm = 5.0 # Sometimes used for very large models
|
||
|
||
# RNNs/LSTMs
|
||
max_norm = 0.5 # More aggressive clipping
|
||
max_norm = 1.0 # Also common
|
||
|
||
# Reinforcement Learning (policy gradients)
|
||
max_norm = 0.5 # RL gradients are particularly unstable
|
||
|
||
# CNNs (ResNets, etc.)
|
||
# Usually DON'T clip - residual connections provide stability
|
||
# Only clip if you observe instability
|
||
|
||
# Very deep networks (>100 layers)
|
||
max_norm = 1.0 # Helps with stability
|
||
```
|
||
|
||
**When to use norm-based clipping:**
|
||
✅ Training Transformers (almost always needed)
|
||
✅ Training RNNs/LSTMs (essential for long sequences)
|
||
✅ Reinforcement learning (policy gradients)
|
||
✅ Any time you see loss → NaN
|
||
✅ Loss spikes or wild oscillations
|
||
✅ Very deep networks (>50 layers)
|
||
|
||
**When NOT to use:**
|
||
❌ Stable CNN training (ResNet on ImageNet)
|
||
❌ Training is already stable with no issues
|
||
❌ As a preemptive measure without evidence of need
|
||
|
||
### Value-Based Gradient Clipping (Rare)
|
||
|
||
**Clips each gradient element individually:**
|
||
|
||
```python
|
||
# Clip each gradient value to [-clip_value, +clip_value]
|
||
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
|
||
|
||
# What it does:
|
||
for param in model.parameters():
|
||
if param.grad is not None:
|
||
param.grad.clamp_(-clip_value, clip_value)
|
||
```
|
||
|
||
**Difference from norm-based:**
|
||
- Norm-based: Scales entire gradient vector to limit total magnitude
|
||
- Value-based: Clamps each gradient element independently
|
||
- Value-based is MORE aggressive (can change gradient direction)
|
||
- Value-based treats all parameters equally (ignores scale differences)
|
||
|
||
**When to use value-based clipping:**
|
||
- Debugging: Identify which specific parameters have large gradients
|
||
- Extreme outliers: Some parameters have huge gradients while others are normal
|
||
- Legacy code: Some old papers use this
|
||
|
||
**Usually prefer norm-based:**
|
||
- Norm-based is standard in modern deep learning
|
||
- Preserves gradient direction
|
||
- Better theoretical properties
|
||
- Used in all major Transformer implementations
|
||
|
||
### Complete Clipping Implementation
|
||
|
||
```python
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# Model and optimizer
|
||
model = TransformerModel()
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
|
||
# Training loop with gradient clipping
|
||
for epoch in range(num_epochs):
|
||
for batch in train_loader:
|
||
# 1. Zero gradients
|
||
optimizer.zero_grad()
|
||
|
||
# 2. Forward pass
|
||
output = model(batch['input'])
|
||
loss = criterion(output, batch['target'])
|
||
|
||
# 3. Backward pass (compute gradients)
|
||
loss.backward()
|
||
|
||
# 4. Clip gradients (CRITICAL: after backward, before step)
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
# 5. Optimizer step (update parameters)
|
||
optimizer.step()
|
||
```
|
||
|
||
**Common mistakes - WRONG ORDER:**
|
||
|
||
```python
|
||
# WRONG: Clipping after optimizer.step()
|
||
loss.backward()
|
||
optimizer.step()
|
||
clip_grad_norm_(model.parameters(), 1.0) # ❌ Too late! Already updated.
|
||
|
||
# WRONG: Clipping before backward()
|
||
optimizer.zero_grad()
|
||
clip_grad_norm_(model.parameters(), 1.0) # ❌ No gradients exist yet!
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# RIGHT: Clipping after backward(), before step()
|
||
loss.backward() # Compute gradients
|
||
clip_grad_norm_(model.parameters(), 1.0) # Modify gradients
|
||
optimizer.step() # Use modified gradients
|
||
```
|
||
|
||
### How to Choose max_norm Value
|
||
|
||
**Start with standard values:**
|
||
|
||
```python
|
||
# Default starting point for Transformers
|
||
max_norm = 1.0
|
||
|
||
# If still unstable (loss spikes)
|
||
max_norm = 0.5 # More aggressive clipping
|
||
|
||
# If training seems too constrained (slow convergence)
|
||
max_norm = 2.0 # Less aggressive clipping
|
||
```
|
||
|
||
**Systematic tuning:**
|
||
|
||
1. **Monitor gradient norms WITHOUT clipping:**
|
||
```python
|
||
# Check typical gradient magnitudes
|
||
total_norm = 0.0
|
||
for p in model.parameters():
|
||
if p.grad is not None:
|
||
param_norm = p.grad.data.norm(2)
|
||
total_norm += param_norm.item() ** 2
|
||
total_norm = total_norm ** 0.5
|
||
print(f"Gradient norm: {total_norm:.4f}")
|
||
```
|
||
|
||
2. **Set max_norm based on typical norms:**
|
||
- If typical norms are 0.5-2.0, set max_norm=2.0 or 3.0
|
||
- If typical norms are 5-10, set max_norm=5.0 or 10.0
|
||
- Goal: Clip outliers without affecting normal gradients
|
||
|
||
3. **Verify clipping is helping:**
|
||
```python
|
||
# Log how often clipping activates
|
||
grad_norm_before = compute_grad_norm(model)
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
grad_norm_after = compute_grad_norm(model)
|
||
|
||
if grad_norm_before > max_norm:
|
||
print(f"Clipped: {grad_norm_before:.4f} -> {grad_norm_after:.4f}")
|
||
```
|
||
|
||
**Signs you need clipping:**
|
||
- Gradient norms occasionally >10 or >100
|
||
- Loss occasionally spikes or becomes NaN
|
||
- Training is initially stable then diverges
|
||
- Gradient norms grow over time
|
||
|
||
**Signs your max_norm is too low:**
|
||
- Clipping activates on EVERY iteration
|
||
- Training converges very slowly
|
||
- Gradient norm is always exactly max_norm (always clipping)
|
||
|
||
**Signs your max_norm is too high:**
|
||
- Still getting NaN or loss spikes
|
||
- Clipping never activates
|
||
- Not solving the stability problem
|
||
|
||
|
||
## Gradient Accumulation
|
||
|
||
### Why Gradient Accumulation Matters
|
||
|
||
**The memory vs batch size problem:**
|
||
- Larger batch sizes often improve training (more stable gradients)
|
||
- Larger batches require more GPU memory
|
||
- Memory is limited (GPU VRAM)
|
||
- Can't always fit desired batch size in memory
|
||
|
||
**Example scenario:**
|
||
- Want batch size 256 for stable training
|
||
- Only fit batch size 32 in GPU memory
|
||
- Can't afford bigger GPU
|
||
- Solution: Gradient accumulation
|
||
|
||
**What gradient accumulation does:**
|
||
- Accumulate gradients over multiple small batches
|
||
- Update parameters once with accumulated gradients
|
||
- Equivalent to training with one large batch
|
||
- Same results, but fits in memory
|
||
|
||
**Real-world impact:**
|
||
- Train models 4-8x larger batch size on same hardware
|
||
- Standard technique in production training
|
||
- Used in all large model training (GPT, BERT, etc.)
|
||
- Essential for competitive performance on limited hardware
|
||
|
||
### Correct Gradient Accumulation Implementation
|
||
|
||
**The critical implementation:**
|
||
|
||
```python
|
||
# Want effective batch size 256, but can only fit 64 in memory
|
||
# Solution: Accumulate over 4 steps (256 = 64 * 4)
|
||
|
||
accumulation_steps = 4
|
||
optimizer.zero_grad()
|
||
|
||
for i, (data, target) in enumerate(train_loader):
|
||
# Forward pass
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
|
||
# Backward pass with CRITICAL loss scaling
|
||
# MUST divide loss by accumulation_steps!
|
||
(loss / accumulation_steps).backward()
|
||
|
||
# Update weights every accumulation_steps
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
**Why scale loss by accumulation_steps?**
|
||
|
||
```python
|
||
# Without scaling:
|
||
loss.backward() # Adds gradients: param.grad += ∂loss/∂param
|
||
|
||
# After 4 accumulation steps:
|
||
# param.grad = ∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param
|
||
# This is 4x larger than single batch!
|
||
|
||
# With scaling:
|
||
(loss / 4).backward() # Adds: param.grad += (∂loss/∂param) / 4
|
||
|
||
# After 4 accumulation steps:
|
||
# param.grad = (∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param) / 4
|
||
# This is the AVERAGE gradient - equivalent to single large batch!
|
||
```
|
||
|
||
**Mathematical equivalence:**
|
||
- Large batch loss: `L = (l1 + l2 + l3 + l4) / 4` (mean over samples)
|
||
- Large batch gradient: `∂L/∂param = (∂l1/∂param + ∂l2/∂param + ∂l3/∂param + ∂l4/∂param) / 4`
|
||
- Accumulated gradient: Same result!
|
||
|
||
### Common Gradient Accumulation Mistakes
|
||
|
||
**WRONG: Not scaling loss**
|
||
|
||
```python
|
||
# ❌ WRONG - Gradients are accumulation_steps times too large!
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
loss = criterion(model(batch), target)
|
||
loss.backward() # ❌ Not scaled!
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Result: Equivalent to learning_rate * accumulation_steps
|
||
# Acts like LR is 4x too high → unstable training
|
||
```
|
||
|
||
**WRONG: Scaling gradients instead of loss**
|
||
|
||
```python
|
||
# ❌ WRONG - Inefficient and error-prone!
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
loss = criterion(model(batch), target)
|
||
loss.backward()
|
||
|
||
# Manually scale gradients
|
||
for param in model.parameters():
|
||
if param.grad is not None:
|
||
param.grad /= accumulation_steps # ❌ Inefficient!
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Why wrong:
|
||
# - More code, more error-prone
|
||
# - Less efficient (iterates all parameters)
|
||
# - Easy to forget or do incorrectly
|
||
# - Scaling loss is cleaner and standard
|
||
```
|
||
|
||
**WRONG: Forgetting to zero_grad() after update**
|
||
|
||
```python
|
||
# ❌ WRONG - Gradients keep accumulating forever!
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
loss = criterion(model(batch), target)
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
# ❌ Missing optimizer.zero_grad()!
|
||
# Next accumulation will add to these gradients!
|
||
|
||
# Result: Gradients never reset, accumulate across updates
|
||
# Acts like accumulation_steps grows over time
|
||
```
|
||
|
||
**WRONG: Zeroing gradients inside accumulation loop**
|
||
|
||
```python
|
||
# ❌ WRONG - Clears gradients before accumulating!
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
optimizer.zero_grad() # ❌ Clears previous accumulation!
|
||
|
||
loss = criterion(model(batch), target)
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
|
||
# Result: Only last batch's gradients are used (no accumulation!)
|
||
```
|
||
|
||
### Complete Gradient Accumulation Implementation
|
||
|
||
```python
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# Configuration
|
||
batch_size_per_step = 64 # What fits in memory
|
||
accumulation_steps = 4 # Accumulate over 4 steps
|
||
effective_batch_size = batch_size_per_step * accumulation_steps # = 256
|
||
|
||
# DataLoader with smaller batch size
|
||
train_loader = DataLoader(dataset, batch_size=batch_size_per_step)
|
||
|
||
# Model and optimizer
|
||
model = TransformerModel()
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
|
||
# Training loop
|
||
optimizer.zero_grad() # Zero once before accumulation loop
|
||
|
||
for epoch in range(num_epochs):
|
||
for i, (data, target) in enumerate(train_loader):
|
||
# Forward pass
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
|
||
# Backward pass with scaled loss
|
||
(loss / accumulation_steps).backward()
|
||
|
||
# Update every accumulation_steps
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Handle remaining batches at end of epoch
|
||
# (if total batches not divisible by accumulation_steps)
|
||
if len(train_loader) % accumulation_steps != 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
### Gradient Accumulation with Gradient Clipping
|
||
|
||
**Correct order:**
|
||
|
||
```python
|
||
accumulation_steps = 4
|
||
optimizer.zero_grad()
|
||
|
||
for i, (data, target) in enumerate(train_loader):
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
|
||
# Scale loss and backward
|
||
(loss / accumulation_steps).backward()
|
||
|
||
# Update every accumulation_steps
|
||
if (i + 1) % accumulation_steps == 0:
|
||
# Clip BEFORE optimizer step (on accumulated gradients)
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
**Why this order?**
|
||
- Gradients accumulate over `accumulation_steps` iterations
|
||
- After accumulation, gradients are ready for clipping
|
||
- Clip once on the full accumulated gradients
|
||
- Then update parameters with clipped gradients
|
||
|
||
**WRONG: Clipping on each accumulation step:**
|
||
|
||
```python
|
||
# ❌ WRONG - Clips partial gradients!
|
||
for i, (data, target) in enumerate(train_loader):
|
||
(loss / accumulation_steps).backward()
|
||
|
||
# ❌ Clipping partial gradients!
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Why wrong:
|
||
# - Clipping partial gradients distorts accumulation
|
||
# - Each partial gradient is ~1/4 of final gradient
|
||
# - Clipping these small gradients has wrong threshold
|
||
# - Clip ONCE on final accumulated gradient
|
||
```
|
||
|
||
### Gradient Accumulation with Learning Rate Scheduling
|
||
|
||
**Correct implementation:**
|
||
|
||
```python
|
||
accumulation_steps = 4
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
|
||
|
||
optimizer.zero_grad()
|
||
|
||
for i, (data, target) in enumerate(train_loader):
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Step scheduler AFTER optimizer step (once per update)
|
||
scheduler.step()
|
||
```
|
||
|
||
**Key points:**
|
||
- Scheduler steps once per parameter update (not per batch)
|
||
- Matches the effective batch size timing
|
||
- Scheduler sees `num_batches / accumulation_steps` total steps
|
||
|
||
|
||
## Gradient Diagnosis
|
||
|
||
### Why Diagnosis Matters
|
||
|
||
**Don't guess - measure:**
|
||
- "Training isn't working" could be many issues
|
||
- Gradient issues have specific symptoms
|
||
- Measuring gradients identifies the problem
|
||
- Diagnosis guides the solution
|
||
|
||
**What to diagnose:**
|
||
1. Gradient magnitudes (too large? too small?)
|
||
2. Gradient distribution across layers (vanishing in early layers?)
|
||
3. NaN or Inf gradients (numerical issues?)
|
||
4. Gradient patterns over time (growing? shrinking?)
|
||
|
||
### Checking Gradient Magnitudes
|
||
|
||
**Basic gradient checking:**
|
||
|
||
```python
|
||
def check_gradients(model):
|
||
"""Check gradient magnitudes for all parameters"""
|
||
total_norm = 0.0
|
||
param_norms = {}
|
||
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
# Compute gradient norm for this parameter
|
||
param_norm = param.grad.data.norm(2).item()
|
||
param_norms[name] = param_norm
|
||
total_norm += param_norm ** 2
|
||
|
||
total_norm = total_norm ** 0.5
|
||
|
||
print(f"Total gradient norm: {total_norm:.4f}")
|
||
|
||
# Show top 5 largest gradients
|
||
print("\nLargest gradients:")
|
||
for name, norm in sorted(param_norms.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||
print(f" {name}: {norm:.4f}")
|
||
|
||
# Show top 5 smallest gradients
|
||
print("\nSmallest gradients:")
|
||
for name, norm in sorted(param_norms.items(), key=lambda x: x[1])[:5]:
|
||
print(f" {name}: {norm:.4e}")
|
||
|
||
return total_norm
|
||
|
||
# Usage in training loop:
|
||
loss.backward()
|
||
grad_norm = check_gradients(model)
|
||
optimizer.step()
|
||
```
|
||
|
||
**What to look for:**
|
||
|
||
```python
|
||
# Healthy gradients:
|
||
# Total norm: 0.1 to 10
|
||
# Layer norms: Similar order of magnitude across layers
|
||
# No NaN or Inf values
|
||
|
||
# Exploding gradients:
|
||
# Total norm: >100 or >1000
|
||
# Some layers have huge gradients (>10)
|
||
# → Solution: Gradient clipping
|
||
|
||
# Vanishing gradients:
|
||
# Total norm: <1e-6
|
||
# Early layers have much smaller gradients than late layers
|
||
# → Solution: Better activation/initialization/architecture
|
||
|
||
# NaN gradients:
|
||
# Any gradient is NaN or Inf
|
||
# → Solution: Check for numerical instability in loss or model
|
||
```
|
||
|
||
### Comprehensive Gradient Diagnostics
|
||
|
||
```python
|
||
def diagnose_gradients(model, threshold_low=1e-8, threshold_high=100):
|
||
"""
|
||
Comprehensive gradient diagnostics with automatic issue detection
|
||
|
||
Args:
|
||
model: PyTorch model
|
||
threshold_low: Threshold for vanishing gradients
|
||
threshold_high: Threshold for exploding gradients
|
||
|
||
Returns:
|
||
dict with diagnostic information
|
||
"""
|
||
diagnostics = {
|
||
'total_norm': 0.0,
|
||
'param_norms': {},
|
||
'has_nan': False,
|
||
'has_inf': False,
|
||
'vanishing': [],
|
||
'exploding': [],
|
||
}
|
||
|
||
total_norm = 0.0
|
||
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
grad = param.grad.data
|
||
|
||
# Check for NaN or Inf
|
||
if torch.isnan(grad).any():
|
||
diagnostics['has_nan'] = True
|
||
print(f"⚠️ NaN gradient detected in {name}")
|
||
|
||
if torch.isinf(grad).any():
|
||
diagnostics['has_inf'] = True
|
||
print(f"⚠️ Inf gradient detected in {name}")
|
||
|
||
# Compute norm
|
||
param_norm = grad.norm(2).item()
|
||
diagnostics['param_norms'][name] = param_norm
|
||
total_norm += param_norm ** 2
|
||
|
||
# Check for vanishing
|
||
if param_norm < threshold_low:
|
||
diagnostics['vanishing'].append((name, param_norm))
|
||
|
||
# Check for exploding
|
||
if param_norm > threshold_high:
|
||
diagnostics['exploding'].append((name, param_norm))
|
||
|
||
total_norm = total_norm ** 0.5
|
||
diagnostics['total_norm'] = total_norm
|
||
|
||
# Print diagnosis
|
||
print(f"\n{'='*60}")
|
||
print(f"GRADIENT DIAGNOSTICS")
|
||
print(f"{'='*60}")
|
||
print(f"Total gradient norm: {total_norm:.4f}")
|
||
|
||
if diagnostics['has_nan']:
|
||
print("\n🚨 CRITICAL: NaN gradients detected!")
|
||
print(" Possible causes:")
|
||
print(" - Division by zero in loss or model")
|
||
print(" - Log of zero or negative number")
|
||
print(" - Numerical overflow")
|
||
print(" - Already-NaN parameters or inputs")
|
||
|
||
if diagnostics['has_inf']:
|
||
print("\n🚨 CRITICAL: Inf gradients detected!")
|
||
print(" Possible causes:")
|
||
print(" - Numerical overflow (very large values)")
|
||
print(" - Division by very small number")
|
||
print(" - Exponential of very large number")
|
||
|
||
if total_norm > threshold_high:
|
||
print(f"\n⚠️ EXPLODING GRADIENTS: Total norm {total_norm:.2f} > {threshold_high}")
|
||
print(" Solution: Add gradient clipping")
|
||
print(f" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm={threshold_high/10:.1f})")
|
||
if diagnostics['exploding']:
|
||
print(f"\n Top exploding layers:")
|
||
for name, norm in sorted(diagnostics['exploding'], key=lambda x: x[1], reverse=True)[:5]:
|
||
print(f" - {name}: {norm:.2f}")
|
||
|
||
if total_norm < threshold_low:
|
||
print(f"\n⚠️ VANISHING GRADIENTS: Total norm {total_norm:.2e} < {threshold_low}")
|
||
print(" Possible solutions:")
|
||
print(" - Use ReLU/GELU instead of sigmoid/tanh")
|
||
print(" - Check weight initialization (use He/Xavier)")
|
||
print(" - Add batch normalization")
|
||
print(" - Add residual connections")
|
||
print(" - Increase learning rate (after other fixes)")
|
||
if diagnostics['vanishing']:
|
||
print(f"\n Layers with vanishing gradients:")
|
||
for name, norm in sorted(diagnostics['vanishing'], key=lambda x: x[1])[:5]:
|
||
print(f" - {name}: {norm:.2e}")
|
||
|
||
print(f"{'='*60}\n")
|
||
|
||
return diagnostics
|
||
|
||
# Usage:
|
||
loss.backward()
|
||
diagnostics = diagnose_gradients(model)
|
||
|
||
if diagnostics['has_nan'] or diagnostics['has_inf']:
|
||
# Stop training, fix the issue
|
||
raise RuntimeError("NaN or Inf gradients detected!")
|
||
```
|
||
|
||
### Gradient Monitoring and Logging
|
||
|
||
**Log gradient statistics during training:**
|
||
|
||
```python
|
||
import wandb # or tensorboard
|
||
|
||
def log_gradient_stats(model, logger, step):
|
||
"""Log gradient statistics for monitoring"""
|
||
|
||
total_norm = 0.0
|
||
layer_norms = {}
|
||
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
# Gradient norm
|
||
grad_norm = param.grad.data.norm(2).item()
|
||
layer_norms[name] = grad_norm
|
||
total_norm += grad_norm ** 2
|
||
|
||
# Parameter norm (for ratio calculation)
|
||
param_norm = param.data.norm(2).item()
|
||
|
||
# Log individual layer stats
|
||
logger.log({
|
||
f"gradients/{name}/norm": grad_norm,
|
||
f"gradients/{name}/mean": param.grad.data.mean().item(),
|
||
f"gradients/{name}/std": param.grad.data.std().item(),
|
||
f"gradients/{name}/max": param.grad.data.abs().max().item(),
|
||
}, step=step)
|
||
|
||
# Log ratio of gradient norm to parameter norm
|
||
# Healthy ratio is typically 0.001 to 0.01
|
||
if param_norm > 0:
|
||
ratio = grad_norm / param_norm
|
||
logger.log({f"gradients/{name}/ratio": ratio}, step=step)
|
||
|
||
total_norm = total_norm ** 0.5
|
||
|
||
# Log total gradient norm
|
||
logger.log({"gradients/total_norm": total_norm}, step=step)
|
||
|
||
return total_norm
|
||
|
||
# Usage in training loop:
|
||
for step, batch in enumerate(train_loader):
|
||
optimizer.zero_grad()
|
||
loss = model(batch)
|
||
loss.backward()
|
||
|
||
# Log gradients (before clipping to see true magnitudes)
|
||
grad_norm = log_gradient_stats(model, wandb, step)
|
||
|
||
# Clip and update
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
```
|
||
|
||
**What to watch in gradient logs:**
|
||
|
||
```python
|
||
# Healthy training:
|
||
# - Total gradient norm: Relatively stable (0.1 to 10)
|
||
# - Layer norms: Similar across layers (no huge disparities)
|
||
# - Ratios: ~0.001 (gradients much smaller than parameters)
|
||
# - No sudden spikes or drops to zero
|
||
|
||
# Warning signs:
|
||
# - Total norm suddenly spikes (>100) → exploding gradients
|
||
# - Total norm gradually decreases to near-zero → vanishing gradients
|
||
# - Early layers have much smaller norms than late layers → vanishing
|
||
# - Ratios > 0.1 → updates are too large relative to parameters
|
||
# - Sudden drop to zero → dead neurons or broken gradient flow
|
||
```
|
||
|
||
|
||
## Vanishing Gradients
|
||
|
||
### Recognizing Vanishing Gradients
|
||
|
||
**Symptoms:**
|
||
1. Training loss decreases very slowly or not at all
|
||
2. Validation metrics don't improve
|
||
3. Gradient norms are extremely small (<1e-6)
|
||
4. Early layers have much smaller gradients than later layers
|
||
5. Training seems "stuck" after initialization
|
||
|
||
**How to confirm:**
|
||
|
||
```python
|
||
# Check gradient magnitudes by layer depth
|
||
loss.backward()
|
||
|
||
print("Layer-wise gradient norms:")
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
norm = param.grad.norm(2).item()
|
||
print(f"{name}: {norm:.2e}")
|
||
|
||
# Example output showing vanishing gradients:
|
||
# layer1.weight: 1.23e-02 ← Early layer
|
||
# layer5.weight: 3.45e-04
|
||
# layer10.weight: 8.91e-06
|
||
# layer15.weight: 2.34e-07
|
||
# layer20.weight: 5.67e-09 ← Late layer
|
||
|
||
# Pattern: Gradients shrink exponentially with depth
|
||
# This is vanishing gradients!
|
||
```
|
||
|
||
### Causes of Vanishing Gradients
|
||
|
||
**1. Too many layers (very deep networks):**
|
||
- Each layer multiplies gradient by weights during backprop
|
||
- If multiplication factor <1, gradients shrink exponentially
|
||
- More layers = more multiplication = smaller gradients
|
||
|
||
**2. Saturating activation functions:**
|
||
- Sigmoid: `σ'(x) ≈ 0` when `|x|` is large (saturates)
|
||
- Tanh: `tanh'(x) ≈ 0` when `|x|` is large
|
||
- Gradient flows through: `grad = grad * activation'(x)`
|
||
- If `activation'(x) ≈ 0`, gradient vanishes
|
||
|
||
**3. Poor weight initialization:**
|
||
- Weights too small → activations too small → gradients too small
|
||
- Weights initialized uniformly → improper scaling across layers
|
||
|
||
**4. Learning rate too low:**
|
||
- Not a root cause, but can make problem worse
|
||
- Tiny gradients * tiny LR = no learning
|
||
|
||
### Solutions for Vanishing Gradients
|
||
|
||
**Solution 1: Use Better Activation Functions**
|
||
|
||
```python
|
||
# AVOID: Sigmoid and Tanh (saturate easily)
|
||
class BadModel(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(100, 100),
|
||
nn.Sigmoid(), # ❌ Saturates, kills gradients
|
||
nn.Linear(100, 100),
|
||
nn.Sigmoid(), # ❌ Even worse with depth
|
||
nn.Linear(100, 10)
|
||
)
|
||
|
||
# PREFER: ReLU, GELU, or other non-saturating activations
|
||
class GoodModel(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(100, 100),
|
||
nn.ReLU(), # ✅ Doesn't saturate (for x>0)
|
||
nn.Linear(100, 100),
|
||
nn.GELU(), # ✅ Smooth, non-saturating
|
||
nn.Linear(100, 10)
|
||
)
|
||
|
||
# Why it helps:
|
||
# ReLU: grad = 1 for x>0, doesn't shrink gradient
|
||
# GELU: Smooth version of ReLU, widely used in Transformers
|
||
# Both avoid saturation that kills gradients
|
||
```
|
||
|
||
**Solution 2: Proper Weight Initialization**
|
||
|
||
```python
|
||
# Use He initialization for ReLU networks
|
||
def init_weights(m):
|
||
if isinstance(m, nn.Linear):
|
||
# He initialization: optimal for ReLU
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
||
if m.bias is not None:
|
||
nn.init.zeros_(m.bias)
|
||
elif isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
if m.bias is not None:
|
||
nn.init.zeros_(m.bias)
|
||
|
||
model = GoodModel()
|
||
model.apply(init_weights)
|
||
|
||
# Use Xavier initialization for Tanh/Sigmoid (if you must use them)
|
||
def init_weights_xavier(m):
|
||
if isinstance(m, nn.Linear):
|
||
nn.init.xavier_normal_(m.weight)
|
||
if m.bias is not None:
|
||
nn.init.zeros_(m.bias)
|
||
|
||
# Why it helps:
|
||
# Proper initialization ensures gradients have appropriate scale
|
||
# He init accounts for ReLU's effect on variance
|
||
# Xavier init maintains variance across layers for symmetric activations
|
||
```
|
||
|
||
**Solution 3: Batch Normalization**
|
||
|
||
```python
|
||
# Add BatchNorm between layers
|
||
class ModelWithBatchNorm(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(100, 100),
|
||
nn.BatchNorm1d(100), # ✅ Normalizes activations
|
||
nn.ReLU(),
|
||
nn.Linear(100, 100),
|
||
nn.BatchNorm1d(100), # ✅ Helps gradient flow
|
||
nn.ReLU(),
|
||
nn.Linear(100, 10)
|
||
)
|
||
|
||
# For CNNs:
|
||
class CNNWithBatchNorm(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Conv2d(3, 64, 3, padding=1),
|
||
nn.BatchNorm2d(64), # ✅ After conv, before activation
|
||
nn.ReLU(),
|
||
nn.Conv2d(64, 128, 3, padding=1),
|
||
nn.BatchNorm2d(128),
|
||
nn.ReLU(),
|
||
)
|
||
|
||
# Why it helps:
|
||
# BatchNorm normalizes activations to have mean=0, std=1
|
||
# Prevents activations from getting too small or too large
|
||
# Helps maintain gradient scale through network
|
||
# Widely used in modern architectures
|
||
```
|
||
|
||
**Solution 4: Residual Connections (Skip Connections)**
|
||
|
||
```python
|
||
# Add skip connections (ResNet-style)
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(dim, dim),
|
||
nn.ReLU(),
|
||
nn.Linear(dim, dim)
|
||
)
|
||
|
||
def forward(self, x):
|
||
# Skip connection: add input to output
|
||
return x + self.layers(x) # ✅ Gradient flows through skip connection
|
||
|
||
class ResidualNetwork(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.blocks = nn.Sequential(
|
||
ResidualBlock(100),
|
||
ResidualBlock(100),
|
||
ResidualBlock(100),
|
||
# Can stack many blocks without vanishing gradients!
|
||
)
|
||
self.output = nn.Linear(100, 10)
|
||
|
||
def forward(self, x):
|
||
x = self.blocks(x)
|
||
return self.output(x)
|
||
|
||
# Why it helps:
|
||
# Gradients can flow directly through skip connections
|
||
# Backprop path: grad flows through addition (no multiplication)
|
||
# Allows training very deep networks (ResNet-152, ResNet-200)
|
||
# Essential for modern deep architectures
|
||
```
|
||
|
||
**Solution 5: Layer Normalization (for Transformers)**
|
||
|
||
```python
|
||
# Transformers use Layer Normalization
|
||
class TransformerBlock(nn.Module):
|
||
def __init__(self, d_model):
|
||
super().__init__()
|
||
self.attention = MultiHeadAttention(d_model)
|
||
self.norm1 = nn.LayerNorm(d_model) # ✅ Layer norm
|
||
self.ffn = FeedForward(d_model)
|
||
self.norm2 = nn.LayerNorm(d_model) # ✅ Layer norm
|
||
|
||
def forward(self, x):
|
||
# Pre-norm architecture (modern standard)
|
||
x = x + self.attention(self.norm1(x))
|
||
x = x + self.ffn(self.norm2(x))
|
||
return x
|
||
|
||
# Why Layer Norm:
|
||
# BatchNorm doesn't work well for sequences (different lengths)
|
||
# LayerNorm normalizes across features (not batch)
|
||
# Standard in Transformers (BERT, GPT, etc.)
|
||
```
|
||
|
||
**Solution 6: Gradient Checkpointing (if memory-constrained)**
|
||
|
||
```python
|
||
# Trade computation for memory (from pytorch-engineering pack)
|
||
from torch.utils.checkpoint import checkpoint
|
||
|
||
class DeepModel(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.blocks = nn.ModuleList([
|
||
ResidualBlock(100) for _ in range(50) # Very deep!
|
||
])
|
||
|
||
def forward(self, x):
|
||
for block in self.blocks:
|
||
# Use checkpointing to save memory
|
||
x = checkpoint(block, x, use_reentrant=False)
|
||
return x
|
||
|
||
# Why it helps:
|
||
# Allows training deeper networks in same memory
|
||
# Doesn't directly solve vanishing gradients
|
||
# But removes memory constraint that prevents using deeper models
|
||
# Compatible with all other solutions (BN, residuals, etc.)
|
||
```
|
||
|
||
### Systematic Approach to Vanishing Gradients
|
||
|
||
**Step 1: Confirm diagnosis**
|
||
```python
|
||
# Check gradient magnitudes
|
||
loss.backward()
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
print(f"{name}: {param.grad.norm(2).item():.2e}")
|
||
|
||
# Look for: Early layers << Late layers
|
||
```
|
||
|
||
**Step 2: Apply architectural fixes (priority order)**
|
||
1. Switch to ReLU/GELU activations (highest impact)
|
||
2. Add proper weight initialization (He/Xavier)
|
||
3. Add BatchNorm or LayerNorm
|
||
4. Add residual connections if very deep (>20 layers)
|
||
|
||
**Step 3: Verify improvement**
|
||
```python
|
||
# After fixes, check gradients again
|
||
# Should see more uniform gradient magnitudes across layers
|
||
```
|
||
|
||
**Step 4: Adjust learning rate if needed**
|
||
```python
|
||
# Only AFTER architectural fixes
|
||
# May need slightly higher LR with better gradient flow
|
||
```
|
||
|
||
**IMPORTANT NOTE: When Small Gradients Are Actually OK**
|
||
|
||
Don't blindly "fix" small gradients if training is working well:
|
||
|
||
```python
|
||
# Scenario: Gradients are small (1e-7) but training is progressing
|
||
# Epoch 1: Loss 2.34, Grad norm: 3.45e-07
|
||
# Epoch 2: Loss 1.89, Grad norm: 2.91e-07 ← Loss decreasing!
|
||
# Epoch 3: Loss 1.52, Grad norm: 2.34e-07 ← Still improving!
|
||
|
||
# This is OK! Don't fix what isn't broken.
|
||
```
|
||
|
||
**Healthy small gradients:**
|
||
- Training progressing (loss decreasing, metrics improving) ✓
|
||
- Gradients relatively uniform across layers
|
||
- Gradients stable over time
|
||
|
||
**Unhealthy vanishing gradients:**
|
||
- Training stuck (loss not decreasing)
|
||
- Early layers << late layers (1000x difference)
|
||
- Gradients decreasing over time
|
||
|
||
**Key insight:** Absolute gradient magnitude depends on parameter scale, loss scale, and learning rate. What matters is: **Is the model learning?**
|
||
|
||
```python
|
||
# Better diagnostic: Check relative gradients across layers
|
||
grad_norms = {}
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None:
|
||
grad_norms[name] = param.grad.norm(2).item()
|
||
|
||
# Check ratio: Are early layers much smaller than late layers?
|
||
early_layers = [v for k, v in grad_norms.items() if 'layer0' in k or 'layer1' in k]
|
||
late_layers = [v for k, v in grad_norms.items() if 'layer19' in k or 'layer20' in k]
|
||
|
||
if early_layers and late_layers:
|
||
ratio = np.mean(late_layers) / np.mean(early_layers)
|
||
if ratio > 1000:
|
||
print(f"⚠️ Vanishing gradients: late/early ratio = {ratio:.0f}")
|
||
else:
|
||
print(f"✅ Gradient flow OK: late/early ratio = {ratio:.0f}")
|
||
```
|
||
|
||
**Decision rule:**
|
||
- Training working well + gradients stable → No action needed
|
||
- Training stuck + early << late → Apply architectural fixes
|
||
- Training working + improving over time → Monitor but don't change
|
||
|
||
|
||
## Exploding Gradients
|
||
|
||
### Recognizing Exploding Gradients
|
||
|
||
**Symptoms:**
|
||
1. Loss suddenly becomes NaN or Inf during training
|
||
2. Loss oscillates wildly (jumps up and down)
|
||
3. Parameters become very large or NaN
|
||
4. Gradient norms >100 or >1000
|
||
5. Training is stable initially then suddenly diverges
|
||
|
||
**How to confirm:**
|
||
|
||
```python
|
||
# Check gradient magnitudes
|
||
loss.backward()
|
||
|
||
total_norm = 0.0
|
||
for param in model.parameters():
|
||
if param.grad is not None:
|
||
param_norm = param.grad.data.norm(2)
|
||
total_norm += param_norm.item() ** 2
|
||
|
||
total_norm = total_norm ** 0.5
|
||
print(f"Total gradient norm: {total_norm:.4f}")
|
||
|
||
# If total_norm > 100: Exploding gradients!
|
||
# If any parameter grad norm > 100: Exploding gradients!
|
||
```
|
||
|
||
### Causes of Exploding Gradients
|
||
|
||
**1. Learning rate too high:**
|
||
- Large gradients * large LR = huge parameter updates
|
||
- Updates overshoot optimal values
|
||
- Can cause oscillation or divergence
|
||
|
||
**2. Poor weight initialization:**
|
||
- Weights too large → activations too large → gradients too large
|
||
- Random initialization without proper scaling
|
||
|
||
**3. Lack of gradient clipping:**
|
||
- Occasional gradient spikes are normal in some architectures
|
||
- Without clipping, one spike can break training
|
||
|
||
**4. Numerical instability in model:**
|
||
- Division by very small numbers
|
||
- Exponential of large numbers
|
||
- Log of numbers close to zero
|
||
|
||
**5. Architecture-specific issues:**
|
||
- Transformers: Attention mechanism can amplify gradients
|
||
- RNNs: Backprop through time multiplies gradients across timesteps
|
||
- Very deep networks: Many layers multiply gradients
|
||
|
||
### Solutions for Exploding Gradients
|
||
|
||
**Solution 1: Gradient Clipping (Primary Solution)**
|
||
|
||
```python
|
||
# Add gradient clipping - THE solution for exploding gradients
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
|
||
# Clip gradients to maximum norm
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
optimizer.step()
|
||
|
||
# Why this works:
|
||
# Limits gradient magnitude while preserving direction
|
||
# Prevents huge parameter updates
|
||
# Standard practice for Transformers, RNNs, RL
|
||
```
|
||
|
||
**Solution 2: Lower Learning Rate**
|
||
|
||
```python
|
||
# If gradients are consistently large, try lower LR
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Was 1e-3
|
||
|
||
# But NOTE:
|
||
# Gradient clipping is usually BETTER than just lowering LR
|
||
# Clipping handles occasional spikes without limiting normal gradients
|
||
# Lowering LR slows down ALL learning, even when gradients are normal
|
||
|
||
# Best approach: Use both
|
||
# - Gradient clipping for stability (handles spikes)
|
||
# - Reasonable learning rate for speed (not too high or too low)
|
||
```
|
||
|
||
**Solution 3: Better Weight Initialization**
|
||
|
||
```python
|
||
# Use proper initialization
|
||
def init_weights(m):
|
||
if isinstance(m, nn.Linear):
|
||
# He initialization for ReLU
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
||
nn.init.zeros_(m.bias)
|
||
elif isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
nn.init.zeros_(m.bias)
|
||
|
||
model.apply(init_weights)
|
||
|
||
# Why it helps:
|
||
# Proper initialization ensures weights are appropriate scale
|
||
# Prevents initial gradients from being too large
|
||
# Particularly important for very deep networks
|
||
```
|
||
|
||
**Solution 4: Batch Normalization**
|
||
|
||
```python
|
||
# Add BatchNorm to stabilize training
|
||
class StableModel(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(100, 100),
|
||
nn.BatchNorm1d(100), # ✅ Stabilizes gradients
|
||
nn.ReLU(),
|
||
nn.Linear(100, 100),
|
||
nn.BatchNorm1d(100),
|
||
nn.ReLU(),
|
||
nn.Linear(100, 10)
|
||
)
|
||
|
||
# Why it helps:
|
||
# Normalizes activations, which stabilizes gradients
|
||
# Reduces internal covariate shift
|
||
# Makes training more robust to hyperparameter choices
|
||
```
|
||
|
||
**Solution 5: Check for Numerical Issues**
|
||
|
||
```python
|
||
# AVOID: Operations that can cause numerical instability
|
||
|
||
# ❌ Division by small numbers
|
||
loss = 1.0 / (predictions + eps) # If predictions ≈ 0, loss explodes
|
||
|
||
# ✅ Add epsilon for stability
|
||
eps = 1e-8
|
||
loss = 1.0 / (predictions + eps)
|
||
|
||
# ❌ Log of values close to zero
|
||
loss = -torch.log(predictions) # If predictions ≈ 0, loss → -∞
|
||
|
||
# ✅ Add epsilon
|
||
loss = -torch.log(predictions + eps)
|
||
|
||
# ❌ Exp of large values
|
||
loss = torch.exp(logits) # If logits are large, exp explodes
|
||
|
||
# ✅ Use log-sum-exp trick or built-in stable functions
|
||
loss = F.cross_entropy(logits, targets) # Handles numerics internally
|
||
|
||
# ❌ Custom loss without stability
|
||
def unstable_loss(pred, target):
|
||
return ((pred - target) / pred).pow(2).mean() # Division can explode
|
||
|
||
# ✅ Add stability
|
||
def stable_loss(pred, target):
|
||
return ((pred - target) / (pred.abs() + eps)).pow(2).mean()
|
||
```
|
||
|
||
**Solution 6: Use Residual Connections**
|
||
|
||
```python
|
||
# Residual connections help stability
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(dim, dim),
|
||
nn.ReLU(),
|
||
nn.Linear(dim, dim)
|
||
)
|
||
|
||
def forward(self, x):
|
||
return x + self.layers(x) # ✅ Skip connection provides stable path
|
||
|
||
# Why it helps:
|
||
# Gradients can flow through skip connections
|
||
# Prevents gradients from exploding through many layers
|
||
# Used in all modern deep architectures (ResNet, Transformer, etc.)
|
||
```
|
||
|
||
### Systematic Approach to Exploding Gradients
|
||
|
||
**Step 1: Confirm diagnosis**
|
||
```python
|
||
# Monitor gradient norms
|
||
loss.backward()
|
||
total_norm = sum(p.grad.data.norm(2).item() ** 2
|
||
for p in model.parameters() if p.grad is not None) ** 0.5
|
||
print(f"Gradient norm: {total_norm:.4f}")
|
||
|
||
# If norm > 100 or training diverges: Exploding gradients
|
||
```
|
||
|
||
**Step 2: Apply fixes (priority order)**
|
||
1. **Add gradient clipping** (highest priority, most effective)
|
||
```python
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
```
|
||
|
||
2. **Check learning rate** (if still unstable after clipping)
|
||
```python
|
||
optimizer = Adam(model.parameters(), lr=1e-4) # Try lower
|
||
```
|
||
|
||
3. **Verify initialization** (if problems from start of training)
|
||
```python
|
||
model.apply(init_weights) # Use He/Xavier init
|
||
```
|
||
|
||
4. **Check for numerical issues** (if NaN appears)
|
||
```python
|
||
# Add epsilon to divisions, logs, etc.
|
||
```
|
||
|
||
**Step 3: Verify improvement**
|
||
```python
|
||
# Monitor gradient norms during training
|
||
# Should stay in reasonable range (0.1 to 10)
|
||
# No sudden spikes to >100
|
||
# No NaN or Inf
|
||
```
|
||
|
||
### When Clipping Doesn't Fix NaN
|
||
|
||
**If you've added gradient clipping but still get NaN loss:**
|
||
|
||
The problem may be in your loss function, not gradients. Diagnose systematically:
|
||
|
||
```python
|
||
# Step 1: Check if loss is NaN BEFORE backward()
|
||
optimizer.zero_grad()
|
||
output = model(batch)
|
||
loss = custom_loss(output, target)
|
||
|
||
# Check loss BEFORE backward
|
||
if torch.isnan(loss):
|
||
print("❌ Loss is NaN BEFORE backward - problem is in loss function!")
|
||
print(f" Output range: {output.min():.4f} to {output.max():.4f}")
|
||
print(f" Target range: {target.min():.4f} to {target.max():.4f}")
|
||
# Don't proceed with backward - fix loss function first
|
||
else:
|
||
print("✅ Loss is valid before backward")
|
||
loss.backward()
|
||
|
||
# Check gradients after backward
|
||
for name, param in model.named_parameters():
|
||
if param.grad is not None and torch.isnan(param.grad).any():
|
||
print(f"❌ NaN gradient in {name} - gradient issue")
|
||
```
|
||
|
||
**Common loss function numerical issues:**
|
||
|
||
```python
|
||
# ❌ UNSTABLE: Log of zero or negative
|
||
def bad_loss(pred, target):
|
||
return -torch.log(pred).mean() # NaN if pred <= 0!
|
||
|
||
# ✅ STABLE: Add epsilon
|
||
def good_loss(pred, target):
|
||
eps = 1e-8
|
||
return -torch.log(pred + eps).mean()
|
||
|
||
|
||
# ❌ UNSTABLE: Division by zero or very small number
|
||
def bad_loss2(pred, target):
|
||
return (target / pred).mean() # Explodes if pred ≈ 0
|
||
|
||
# ✅ STABLE: Add epsilon
|
||
def good_loss2(pred, target):
|
||
eps = 1e-8
|
||
return (target / (pred + eps)).mean()
|
||
|
||
|
||
# ❌ UNSTABLE: Sqrt of negative (can happen with numerical errors)
|
||
def bad_loss3(pred, target):
|
||
diff = pred - target
|
||
return torch.sqrt(diff ** 2).mean() # Can get negative from rounding
|
||
|
||
# ✅ STABLE: Use abs or clamp
|
||
def good_loss3(pred, target):
|
||
diff = pred - target
|
||
return torch.sqrt(torch.clamp(diff ** 2, min=0)).mean()
|
||
|
||
|
||
# ❌ UNSTABLE: Exp of large values
|
||
def bad_loss4(logits):
|
||
return torch.exp(logits).sum() # Explodes if logits > 100
|
||
|
||
# ✅ STABLE: Use built-in stable functions
|
||
def good_loss4(logits, targets):
|
||
return F.cross_entropy(logits, targets) # Handles log-sum-exp internally
|
||
```
|
||
|
||
**Diagnostic order when NaN appears:**
|
||
|
||
1. **Check loss before backward()**: `if torch.isnan(loss): ...`
|
||
- If NaN here → fix loss function (add epsilon, clamp, use stable functions)
|
||
- If not NaN → gradient issue
|
||
|
||
2. **Check gradients after backward()**:
|
||
- If gradients are NaN → clipping placement correct? Unscaling (AMP)?
|
||
- If gradients OK → parameters NaN from previous update?
|
||
|
||
3. **Check parameters**:
|
||
```python
|
||
for name, param in model.named_parameters():
|
||
if torch.isnan(param).any():
|
||
print(f"❌ NaN in parameter {name} - previous update caused NaN")
|
||
```
|
||
|
||
**Summary decision tree:**
|
||
|
||
```
|
||
Loss becomes NaN
|
||
│
|
||
├─ Check: Is loss NaN before backward()?
|
||
│ │
|
||
│ ├─ YES → Problem in loss function
|
||
│ │ • Add epsilon to divisions
|
||
│ │ • Add epsilon to logs
|
||
│ │ • Clamp inputs to sqrt
|
||
│ │ • Use stable built-in functions
|
||
│ │
|
||
│ └─ NO → Problem in backward/gradients
|
||
│ • Check gradient clipping is correctly placed
|
||
│ • Check unscaling if using AMP
|
||
│ • Check for numerical instability in model
|
||
│ • Verify proper initialization
|
||
```
|
||
|
||
|
||
## Mixed Precision Training Integration
|
||
|
||
### Gradient Clipping with AMP
|
||
|
||
**The critical interaction:**
|
||
|
||
```python
|
||
from torch.cuda.amp import autocast, GradScaler
|
||
|
||
scaler = GradScaler()
|
||
model = TransformerModel().cuda()
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
|
||
for batch in train_loader:
|
||
optimizer.zero_grad()
|
||
|
||
# Forward pass with autocast (mixed precision)
|
||
with autocast():
|
||
output = model(batch['input'])
|
||
loss = criterion(output, batch['target'])
|
||
|
||
# Backward pass (gradients are SCALED)
|
||
scaler.scale(loss).backward()
|
||
|
||
# CRITICAL: Unscale before clipping!
|
||
scaler.unscale_(optimizer)
|
||
|
||
# Now clip (on unscaled gradients)
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
# Optimizer step (scaler handles it)
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
```
|
||
|
||
**Why unscale before clipping?**
|
||
|
||
```python
|
||
# Understanding the problem:
|
||
|
||
# GradScaler multiplies gradients by large factor (e.g., 2^16 = 65536)
|
||
# This prevents underflow in fp16 gradients
|
||
# But clipping should happen on TRUE gradient values, not scaled values
|
||
|
||
# WITHOUT unscaling:
|
||
scaler.scale(loss).backward() # Gradients are now 65536x larger
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Clips at 1.0
|
||
# But gradients are scaled! Effective clip threshold is 65536, not 1.0
|
||
# Clipping does nothing - gradients are rarely >65536
|
||
|
||
# WITH unscaling:
|
||
scaler.scale(loss).backward() # Gradients are 65536x larger
|
||
scaler.unscale_(optimizer) # Gradients back to true values
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Clips at true 1.0
|
||
# Clipping works correctly on true gradient magnitudes
|
||
```
|
||
|
||
**The flow:**
|
||
|
||
```
|
||
1. Forward pass with autocast() → activations in fp16
|
||
2. Compute loss (in fp16 or fp32 depending on operation)
|
||
3. scaler.scale(loss).backward() → multiply gradients by scale factor
|
||
4. scaler.unscale_(optimizer) → divide gradients by scale factor (back to true values)
|
||
5. clip_grad_norm_() → clip true gradient values
|
||
6. scaler.step(optimizer) → check for inf/NaN, update parameters if safe
|
||
7. scaler.update() → adjust scale factor for next iteration
|
||
```
|
||
|
||
**Complete AMP + Clipping + Accumulation:**
|
||
|
||
```python
|
||
from torch.cuda.amp import autocast, GradScaler
|
||
|
||
scaler = GradScaler()
|
||
accumulation_steps = 4
|
||
|
||
optimizer.zero_grad()
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
# Forward pass with autocast
|
||
with autocast():
|
||
output = model(batch['input'])
|
||
loss = criterion(output, batch['target'])
|
||
|
||
# Scale loss for accumulation
|
||
scaled_loss = loss / accumulation_steps
|
||
|
||
# Backward pass (scaled)
|
||
scaler.scale(scaled_loss).backward()
|
||
|
||
# Update every accumulation_steps
|
||
if (i + 1) % accumulation_steps == 0:
|
||
# Unscale before clipping
|
||
scaler.unscale_(optimizer)
|
||
|
||
# Clip gradients
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
# Optimizer step with scaler
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
### Common AMP + Gradient Mistakes
|
||
|
||
**WRONG: Not unscaling before clipping**
|
||
|
||
```python
|
||
# ❌ WRONG - Clipping scaled gradients
|
||
scaler.scale(loss).backward()
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ On scaled gradients!
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
|
||
# Result: Clipping doesn't work, training may diverge
|
||
```
|
||
|
||
**WRONG: Unscaling multiple times**
|
||
|
||
```python
|
||
# ❌ WRONG - Unscaling twice
|
||
scaler.scale(loss).backward()
|
||
scaler.unscale_(optimizer) # Unscale once
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
scaler.unscale_(optimizer) # ❌ Unscale again! Gradients now too small
|
||
scaler.step(optimizer)
|
||
|
||
# Result: Gradients become too small, slow training
|
||
```
|
||
|
||
**WRONG: Calling step() directly instead of scaler.step()**
|
||
|
||
```python
|
||
# ❌ WRONG - Bypassing scaler
|
||
scaler.scale(loss).backward()
|
||
scaler.unscale_(optimizer)
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step() # ❌ Should use scaler.step()!
|
||
scaler.update()
|
||
|
||
# Result: Scaler can't skip updates when inf/NaN detected
|
||
# Training may diverge from inf/NaN gradients
|
||
```
|
||
|
||
|
||
## Advanced Topics
|
||
|
||
### Per-Layer Gradient Clipping
|
||
|
||
**When global clipping isn't enough:**
|
||
|
||
```python
|
||
def clip_grad_norm_per_layer(model, max_norm):
|
||
"""
|
||
Clip each layer's gradients independently
|
||
|
||
Use when some layers have much larger gradients than others
|
||
and global clipping is too aggressive or not aggressive enough
|
||
"""
|
||
for name, module in model.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)):
|
||
# Get parameters for this layer
|
||
params = [p for p in module.parameters() if p.grad is not None]
|
||
|
||
if params:
|
||
# Clip this layer's gradients
|
||
layer_norm = torch.nn.utils.clip_grad_norm_(params, max_norm)
|
||
|
||
# Log if clipping was applied
|
||
if layer_norm > max_norm:
|
||
print(f"Clipped {name}: {layer_norm:.4f} -> {max_norm}")
|
||
|
||
# Usage:
|
||
loss.backward()
|
||
clip_grad_norm_per_layer(model, max_norm=1.0)
|
||
optimizer.step()
|
||
|
||
# When to use:
|
||
# - Attention layers have much larger gradients than FFN layers
|
||
# - Some task heads have huge gradients while backbone is normal
|
||
# - Global clipping clips too much for some layers, too little for others
|
||
|
||
# Trade-off:
|
||
# ✅ More fine-grained control
|
||
# ❌ More complex, harder to tune
|
||
# ❌ Less common in literature (harder to compare)
|
||
```
|
||
|
||
### Gradient Noise and Stability
|
||
|
||
**Adding noise to gradients (advanced technique):**
|
||
|
||
```python
|
||
def add_gradient_noise(model, noise_scale=1e-3):
|
||
"""
|
||
Add Gaussian noise to gradients
|
||
|
||
Can help with:
|
||
- Escaping sharp minima (better generalization)
|
||
- Privacy (differential privacy)
|
||
- Exploration in RL
|
||
"""
|
||
for param in model.parameters():
|
||
if param.grad is not None:
|
||
noise = torch.randn_like(param.grad) * noise_scale
|
||
param.grad.add_(noise)
|
||
|
||
# Usage:
|
||
loss.backward()
|
||
add_gradient_noise(model, noise_scale=1e-3)
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # Clip after adding noise
|
||
optimizer.step()
|
||
|
||
# When to use:
|
||
# - Research setting (exploring new techniques)
|
||
# - Differential privacy requirements
|
||
# - NOT recommended for standard training (adds complexity)
|
||
```
|
||
|
||
### Gradient Checkpointing Interaction
|
||
|
||
**Gradient checkpointing compatibility:**
|
||
|
||
```python
|
||
from torch.utils.checkpoint import checkpoint
|
||
|
||
# Gradient checkpointing (from pytorch-engineering pack)
|
||
# Trades computation for memory by recomputing activations during backward
|
||
|
||
class CheckpointedModel(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.blocks = nn.ModuleList([
|
||
TransformerBlock(dim=512) for _ in range(24)
|
||
])
|
||
|
||
def forward(self, x):
|
||
for block in self.blocks:
|
||
# Checkpoint each block
|
||
x = checkpoint(block, x, use_reentrant=False)
|
||
return x
|
||
|
||
# Training with checkpointing + clipping + accumulation:
|
||
model = CheckpointedModel()
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
accumulation_steps = 4
|
||
|
||
optimizer.zero_grad()
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
output = model(batch) # Uses checkpointing internally
|
||
loss = criterion(output, target)
|
||
(loss / accumulation_steps).backward() # Recomputes activations
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
# Clipping works normally (no special handling needed)
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Compatibility:
|
||
# ✅ Gradient clipping: Works normally after backward()
|
||
# ✅ Gradient accumulation: No special handling needed
|
||
# ✅ Mixed precision: Combine with AMP as usual
|
||
# ✅ All gradient management techniques: Fully compatible
|
||
|
||
# Performance note:
|
||
# Checkpointing increases backward pass time by ~30-50%
|
||
# But enables training much larger models or batch sizes
|
||
# Trade computation for memory
|
||
```
|
||
|
||
### Distributed Training Considerations
|
||
|
||
**Gradient clipping in DDP (DistributedDataParallel):**
|
||
|
||
```python
|
||
import torch.distributed as dist
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
|
||
# Setup DDP
|
||
model = TransformerModel().cuda()
|
||
model = DDP(model, device_ids=[local_rank])
|
||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
|
||
for batch in train_loader:
|
||
optimizer.zero_grad()
|
||
|
||
output = model(batch)
|
||
loss = criterion(output, target)
|
||
loss.backward()
|
||
|
||
# Gradient clipping in DDP
|
||
# IMPORTANT: Clip AFTER backward() (gradients are already synchronized)
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
||
optimizer.step()
|
||
|
||
# How DDP works:
|
||
# 1. Forward pass: Each GPU computes independently
|
||
# 2. Backward pass: Gradients computed on each GPU
|
||
# 3. Gradient synchronization: DDP averages gradients across GPUs (automatic)
|
||
# 4. Clipping: Happens AFTER synchronization (on averaged gradients)
|
||
# 5. Optimizer step: Each GPU updates identically (same gradients)
|
||
|
||
# Key points:
|
||
# ✅ Clip after backward() as usual - DDP handles synchronization automatically
|
||
# ✅ All GPUs see same averaged gradients, so clipping is consistent
|
||
# ❌ DON'T manually synchronize gradients (DDP does this)
|
||
# ❌ DON'T clip before backward() (gradients don't exist yet)
|
||
```
|
||
|
||
**Gradient accumulation with DDP (Optimized):**
|
||
|
||
**IMPORTANT:** DDP synchronizes gradients on every backward() by default. With accumulation, this is wasteful - we only need to sync ONCE per update. Use `no_sync()` to optimize.
|
||
|
||
```python
|
||
from contextlib import nullcontext
|
||
import torch.distributed as dist
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
|
||
# Setup DDP
|
||
model = TransformerModel().cuda()
|
||
model = DDP(model, device_ids=[local_rank])
|
||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
accumulation_steps = 4
|
||
|
||
optimizer.zero_grad()
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
# Disable gradient synchronization for accumulation steps
|
||
# Only sync on the last accumulation step
|
||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||
|
||
# Context manager: no_sync() when accumulating, normal when updating
|
||
with model.no_sync() if is_accumulation_step else nullcontext():
|
||
output = model(batch)
|
||
loss = criterion(output, target)
|
||
(loss / accumulation_steps).backward()
|
||
|
||
# Update on last accumulation step (gradients are now synchronized)
|
||
if (i + 1) % accumulation_steps == 0:
|
||
# Gradients are synchronized across all GPUs
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
**How this works:**
|
||
|
||
```
|
||
WITHOUT no_sync() (inefficient):
|
||
Step 1: backward() → sync gradients across GPUs (communication!)
|
||
Step 2: backward() → sync gradients across GPUs (communication!)
|
||
Step 3: backward() → sync gradients across GPUs (communication!)
|
||
Step 4: backward() → sync gradients across GPUs (communication!)
|
||
optimizer.step() → update parameters
|
||
Total: 4 synchronizations per update
|
||
|
||
WITH no_sync() (optimized):
|
||
Step 1: backward() with no_sync() → no communication
|
||
Step 2: backward() with no_sync() → no communication
|
||
Step 3: backward() with no_sync() → no communication
|
||
Step 4: backward() without no_sync() → sync accumulated gradients (communication!)
|
||
optimizer.step() → update parameters
|
||
Total: 1 synchronization per update
|
||
|
||
Performance improvement: 3x less communication overhead
|
||
```
|
||
|
||
**Why no_sync() is necessary:**
|
||
- DDP normally synchronizes gradients on every backward() (default behavior)
|
||
- With accumulation, we only want to sync ONCE (on last step)
|
||
- no_sync() temporarily disables DDP's all-reduce operation
|
||
- On last step (without no_sync()), DDP performs normal synchronization
|
||
- Result: Accumulated gradients are synchronized once and correctly averaged
|
||
|
||
**Complete DDP + Accumulation + Clipping + AMP:**
|
||
|
||
```python
|
||
from torch.cuda.amp import autocast, GradScaler
|
||
from contextlib import nullcontext
|
||
|
||
model = DDP(model, device_ids=[local_rank])
|
||
scaler = GradScaler()
|
||
accumulation_steps = 4
|
||
|
||
optimizer.zero_grad()
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||
|
||
# Disable sync on accumulation steps
|
||
with model.no_sync() if is_accumulation_step else nullcontext():
|
||
# Mixed precision forward
|
||
with autocast():
|
||
output = model(batch)
|
||
loss = criterion(output, target)
|
||
|
||
# Scale and backward
|
||
scaled_loss = loss / accumulation_steps
|
||
scaler.scale(scaled_loss).backward()
|
||
|
||
# Update after accumulation
|
||
if (i + 1) % accumulation_steps == 0:
|
||
# Gradients now synchronized across GPUs
|
||
scaler.unscale_(optimizer) # Unscale for clipping
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
optimizer.zero_grad()
|
||
|
||
# This combines ALL techniques correctly:
|
||
# ✅ DDP distributed training
|
||
# ✅ Gradient accumulation (with loss scaling)
|
||
# ✅ Mixed precision (with proper unscaling)
|
||
# ✅ Gradient clipping (on correct values)
|
||
# ✅ Optimized communication (no_sync())
|
||
```
|
||
|
||
**Performance comparison:**
|
||
|
||
```python
|
||
# Measure with and without no_sync()
|
||
|
||
# WITHOUT no_sync(): ~40 seconds per epoch (excessive communication)
|
||
# WITH no_sync(): ~12 seconds per epoch (optimized communication)
|
||
# Speedup: 3.3x faster with accumulation_steps=4
|
||
|
||
# The more GPUs you have, the more important no_sync() becomes
|
||
# 2 GPUs: ~2x speedup
|
||
# 4 GPUs: ~3x speedup
|
||
# 8 GPUs: ~4x speedup
|
||
```
|
||
|
||
**Common mistake:**
|
||
|
||
```python
|
||
# ❌ WRONG - Synchronizing on every step (slow!)
|
||
model = DDP(model)
|
||
accumulation_steps = 4
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
(loss / accumulation_steps).backward() # Syncs every time!
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# Result: Correct results but 3-4x slower than necessary
|
||
```
|
||
|
||
|
||
## Common Gradient Pitfalls
|
||
|
||
### Pitfall 1: Not Clipping When Needed
|
||
|
||
**Symptom:** Training becomes NaN after few epochs, loss spikes
|
||
|
||
**WRONG:**
|
||
```python
|
||
# User sees NaN loss and thinks: "Must be learning rate"
|
||
optimizer = Adam(model.parameters(), lr=1e-5) # ❌ Lower LR to "fix" it
|
||
|
||
# Result: Training is slow and may still diverge
|
||
# Root cause (exploding gradients) not addressed
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# Recognize exploding gradients, add clipping
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
|
||
# Result: Training is stable, no NaN
|
||
# This is THE solution for exploding gradients
|
||
```
|
||
|
||
### Pitfall 2: Wrong Gradient Accumulation Scaling
|
||
|
||
**Symptom:** Gradient accumulation gives worse results than small batch
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ Not scaling loss
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
loss = criterion(model(batch), target)
|
||
loss.backward() # ❌ Gradients are 4x too large!
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Scale loss by accumulation_steps
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
loss = criterion(model(batch), target)
|
||
(loss / accumulation_steps).backward() # ✅ Correct scaling
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
### Pitfall 3: Clipping After optimizer.step()
|
||
|
||
**Symptom:** Clipping doesn't help, training still unstable
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ Clipping after step (useless!)
|
||
loss.backward()
|
||
optimizer.step()
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Too late!
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Clipping after backward, before step
|
||
loss.backward()
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Correct timing
|
||
optimizer.step()
|
||
```
|
||
|
||
### Pitfall 4: Not Unscaling Before Clipping (AMP)
|
||
|
||
**Symptom:** Mixed precision training diverges, regular training works
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ Clipping scaled gradients
|
||
scaler.scale(loss).backward()
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Wrong scale!
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Unscale before clipping
|
||
scaler.scale(loss).backward()
|
||
scaler.unscale_(optimizer) # ✅ Unscale first!
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
```
|
||
|
||
### Pitfall 5: Forgetting to zero_grad() After Accumulation
|
||
|
||
**Symptom:** Loss decreases then increases, training unstable
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ Missing zero_grad() after update
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
# ❌ Missing optimizer.zero_grad()!
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Zero gradients after update
|
||
accumulation_steps = 4
|
||
for i, batch in enumerate(train_loader):
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad() # ✅ Clear gradients for next accumulation
|
||
```
|
||
|
||
### Pitfall 6: Using Value Clipping Instead of Norm Clipping
|
||
|
||
**Symptom:** Training works but slower convergence than expected
|
||
|
||
**SUBOPTIMAL:**
|
||
```python
|
||
# Value clipping changes gradient direction
|
||
clip_grad_value_(model.parameters(), clip_value=0.5) # Can distort gradients
|
||
```
|
||
|
||
**BETTER:**
|
||
```python
|
||
# Norm clipping preserves direction
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # Preferred method
|
||
```
|
||
|
||
### Pitfall 7: Applying Clipping to All Models
|
||
|
||
**Symptom:** Unnecessarily slow training, limiting gradient flow
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ Clipping when not needed (ResNet on ImageNet)
|
||
model = ResNet50()
|
||
optimizer = SGD(model.parameters(), lr=0.1)
|
||
|
||
for batch in train_loader:
|
||
loss.backward()
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Not needed!
|
||
optimizer.step()
|
||
|
||
# Result: Limits gradient flow, may slow convergence
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Only clip when needed (training is unstable)
|
||
model = ResNet50()
|
||
optimizer = SGD(model.parameters(), lr=0.1)
|
||
|
||
for batch in train_loader:
|
||
loss.backward()
|
||
# No clipping - ResNets are naturally stable
|
||
optimizer.step()
|
||
|
||
# Only add clipping if you observe:
|
||
# - Loss becomes NaN
|
||
# - Loss spikes
|
||
# - Training instability
|
||
```
|
||
|
||
### Pitfall 8: Not Monitoring Gradients
|
||
|
||
**Symptom:** Training fails, no visibility into why
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ No gradient monitoring
|
||
for batch in train_loader:
|
||
loss = train_step(batch)
|
||
# Training fails, no idea why
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Monitor gradient norms
|
||
for step, batch in enumerate(train_loader):
|
||
optimizer.zero_grad()
|
||
loss = criterion(model(batch), target)
|
||
loss.backward()
|
||
|
||
# Monitor gradients
|
||
if step % 100 == 0:
|
||
total_norm = 0.0
|
||
for p in model.parameters():
|
||
if p.grad is not None:
|
||
total_norm += p.grad.data.norm(2).item() ** 2
|
||
total_norm = total_norm ** 0.5
|
||
print(f"Step {step}, Loss: {loss.item():.4f}, Grad norm: {total_norm:.4f}")
|
||
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
|
||
# Now you can see:
|
||
# - When gradients explode (norm suddenly large)
|
||
# - When gradients vanish (norm goes to zero)
|
||
# - How clipping affects training
|
||
```
|
||
|
||
### Pitfall 9: Wrong DDP Gradient Synchronization
|
||
|
||
**Symptom:** DDP with accumulation slower than expected or wrong results
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ DDP synchronizes on every backward (wasteful with accumulation)
|
||
model = DDP(model)
|
||
accumulation_steps = 4
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
(loss / accumulation_steps).backward() # ❌ Syncs every time!
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Disable sync except on last accumulation step
|
||
model = DDP(model)
|
||
accumulation_steps = 4
|
||
|
||
for i, batch in enumerate(train_loader):
|
||
is_accumulation_step = (i + 1) % accumulation_steps != 0
|
||
|
||
with model.no_sync() if is_accumulation_step else nullcontext():
|
||
(loss / accumulation_steps).backward()
|
||
|
||
if (i + 1) % accumulation_steps == 0:
|
||
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
```
|
||
|
||
### Pitfall 10: Clipping Too Aggressively
|
||
|
||
**Symptom:** Training converges very slowly, gradient norm always at max_norm
|
||
|
||
**WRONG:**
|
||
```python
|
||
# ❌ max_norm too low, clipping every iteration
|
||
clip_grad_norm_(model.parameters(), max_norm=0.01) # Way too aggressive!
|
||
|
||
# Result: All gradients clipped, learning very slow
|
||
```
|
||
|
||
**RIGHT:**
|
||
```python
|
||
# ✅ Monitor and tune max_norm appropriately
|
||
# Check typical gradient norms without clipping
|
||
total_norm = compute_grad_norm(model)
|
||
print(f"Gradient norm: {total_norm:.4f}")
|
||
|
||
# Set max_norm to clip outliers, not normal gradients
|
||
# If typical norms are 0.5-2.0, set max_norm=5.0
|
||
clip_grad_norm_(model.parameters(), max_norm=5.0) # Clips outliers only
|
||
```
|
||
|
||
|
||
## Rationalization Prevention Table
|
||
|
||
| When Agent Wants To Say | STOP - Say This Instead |
|
||
|-------------------------|-------------------------|
|
||
| "Just lower the learning rate" | "This is likely exploding gradients. Add gradient clipping: `clip_grad_norm_(model.parameters(), max_norm=1.0)` BEFORE optimizer.step(). Then adjust LR if still needed." |
|
||
| "Try a smaller model to save memory" | "Use gradient accumulation to train with larger effective batch size: Scale loss by `accumulation_steps` and update every N batches. This is standard practice." |
|
||
| "Gradient accumulation is complicated" | "It's actually simple: `(loss / accumulation_steps).backward()` to accumulate, `optimizer.step()` every N batches. MUST scale loss - this is critical." |
|
||
| "Mixed precision doesn't work with clipping" | "AMP + clipping work together perfectly. You MUST unscale before clipping: `scaler.unscale_(optimizer)` then `clip_grad_norm_()`. This is documented and standard." |
|
||
| "Your gradients are too small, just increase LR" | "This is vanishing gradients. Architectural fixes are needed: Use ReLU/GELU activations, proper initialization (He/Xavier), BatchNorm, and residual connections. Increasing LR alone won't fix it." |
|
||
| "Clipping is a hack, don't use it" | "Clipping is standard practice in Transformers, RNNs, and RL. Every major paper (BERT, GPT, etc.) uses gradient clipping. It's essential for training stability, not a hack." |
|
||
| "The paper didn't use clipping, so you shouldn't" | "Papers don't always document all techniques. Clipping may have been used but not mentioned. If you observe instability (NaN, spikes), add clipping regardless of what paper says." |
|
||
| "Try different optimizer, maybe SGD works better" | "Switching optimizer doesn't fix exploding gradients. Add gradient clipping first, then compare optimizers. Clipping works with any optimizer." |
|
||
| "Gradient issues are mysterious and hard to debug" | "Gradient issues are systematic: Check gradient norms. >100 = exploding (clip). <1e-6 = vanishing (fix activations/init). NaN = numerical instability (check loss/model). Clear diagnosis → clear solution." |
|
||
| "You can clip anytime in the training loop" | "Clipping MUST happen after backward(), before step(). Timing is critical: backward() creates gradients, clip() modifies them, step() consumes them. Wrong order = useless clipping." |
|
||
| "Scale gradients instead of loss for accumulation" | "Scale LOSS, not gradients: `(loss / accumulation_steps).backward()`. Scaling gradients manually is error-prone and inefficient. Loss scaling is the standard, correct way." |
|
||
| "Batch norm is optional for deep networks" | "BatchNorm is critical for very deep networks (>20 layers). It normalizes activations and stabilizes gradients. Essential for training stability. Use unless you have specific reason not to." |
|
||
| "Residual connections are just a fancy trick" | "Residual connections are fundamental to training deep networks (>50 layers). They provide direct gradient path, preventing vanishing gradients. ResNet, Transformer - all use residuals." |
|
||
| "Just clip more aggressively (max_norm=0.01)" | "Too-aggressive clipping limits all gradients, slowing learning. Monitor typical gradient norms. Set max_norm to clip outliers (>100) without affecting normal gradients (1-10)." |
|
||
| "DDP handles everything automatically" | "DDP synchronizes gradients on backward(). For accumulation, use `model.no_sync()` on intermediate steps to avoid unnecessary synchronization. Only sync on final accumulation step." |
|
||
| "Your model is too complex, that's why training fails" | "Model complexity alone doesn't cause training failure. Gradient issues do. Diagnose gradients first. Most complex models (GPT-3, etc.) train successfully with proper gradient management." |
|
||
| "Checkpointing and clipping don't work together" | "They're fully compatible. Checkpoint affects forward/backward computation. Clipping affects gradients after backward(). No interaction - use both together freely." |
|
||
| "You need expensive GPUs for large batches" | "Use gradient accumulation for larger effective batches on any GPU. Accumulate over N steps = N× batch size, same memory. Standard technique for training large models on consumer hardware." |
|
||
| "Loss → NaN means your data has NaN" | "Usually exploding gradients, not data. Check gradient norms. If >100, add clipping. Data NaN would cause issues immediately, not after several epochs." |
|
||
| "Transformers just don't train stably" | "Transformers train extremely well with proper gradient management. BERT, GPT, T5 - all trained successfully. Use gradient clipping (max_norm=1.0), proper LR, and you'll have stable training." |
|
||
| "Clipping is expensive, will slow training significantly" | "Clipping overhead is <1%, not 5-10%. It's computing gradient norms (one pass) then scaling. Much cheaper than backward pass. 1% cost to prevent catastrophic training failure is excellent trade-off." |
|
||
| "I added clipping but still get NaN, it doesn't work" | "Check if loss is NaN BEFORE backward(). If yes, problem is in loss function (add epsilon to divisions/logs), not gradients. If no, check clipping placement and AMP unscaling." |
|
||
| "Accumulation scaling depends on batch content" | "NO. Accumulation scaling is ALWAYS `accumulation_steps` (constant). Sample weighting is separate concern (handled in loss function). Don't confuse these two independent concepts." |
|
||
| "Paper doesn't mention clipping, so I shouldn't use it" | "Papers don't document all implementation details. Clipping may have been used but not mentioned. Check official code if available. If your training is unstable, add clipping - stable training is prerequisite for valid comparison." |
|
||
| "Different sources give conflicting advice on clipping" | "Context matters. Transformers/RNNs usually need clipping. CNNs usually don't. Decide based on YOUR architecture and stability. Monitor gradient norms. If you see spikes >100 or NaN, add clipping. Empiricism over dogma." |
|
||
| "Use PyTorch Lightning so I need to manually add clipping" | "Lightning has built-in clipping: `Trainer(gradient_clip_val=1.0, gradient_clip_algorithm='norm')`. No manual code needed. Check your framework docs - most have built-in gradient management features." |
|
||
| "My model is complex so gradients will always be problematic" | "Model complexity doesn't determine gradient behavior. GPT-3 (175B parameters) trains successfully. Proper gradient management (clipping, architecture, initialization) enables training ANY size model. Complexity is not the issue." |
|
||
| "Small gradients mean training is broken" | "Only if training is stuck. If loss is decreasing and metrics improving, small absolute gradient values are OK. What matters: relative gradients across layers and whether learning is happening. Don't fix what isn't broken." |
|
||
|
||
|
||
## Red Flags Checklist
|
||
|
||
When you see these patterns, IMMEDIATELY suggest gradient management solutions:
|
||
|
||
### Critical Red Flags (Fix Immediately)
|
||
|
||
🚨 **Loss becomes NaN during training**
|
||
- Cause: Exploding gradients or numerical instability
|
||
- Solution: Add gradient clipping + check for numerical issues in loss/model
|
||
|
||
🚨 **User implements gradient accumulation without scaling loss**
|
||
```python
|
||
# ❌ RED FLAG
|
||
loss.backward() # Should be: (loss / accumulation_steps).backward()
|
||
```
|
||
- Impact: Gradients are accumulation_steps times too large
|
||
- Solution: Scale loss by accumulation_steps
|
||
|
||
🚨 **User clips gradients after optimizer.step()**
|
||
```python
|
||
# ❌ RED FLAG
|
||
optimizer.step()
|
||
clip_grad_norm_(...) # Too late!
|
||
```
|
||
- Impact: Clipping does nothing (gradients already consumed)
|
||
- Solution: Move clipping between backward() and step()
|
||
|
||
🚨 **User uses AMP + clipping without unscaling**
|
||
```python
|
||
# ❌ RED FLAG
|
||
scaler.scale(loss).backward()
|
||
clip_grad_norm_(...) # Should unscale first!
|
||
```
|
||
- Impact: Clipping wrong magnitude (on scaled gradients)
|
||
- Solution: Add scaler.unscale_(optimizer) before clipping
|
||
|
||
### Warning Signs (Suggest Improvements)
|
||
|
||
⚠️ **Training transformers/RNNs without gradient clipping**
|
||
- Likely to hit exploding gradients eventually
|
||
- Suggest preemptive clipping (max_norm=1.0)
|
||
|
||
⚠️ **Very deep network (>20 layers) with sigmoid/tanh activations**
|
||
- Vanishing gradients likely
|
||
- Suggest ReLU/GELU + BatchNorm + residual connections
|
||
|
||
⚠️ **User says "want larger batch but OOM"**
|
||
- Perfect use case for gradient accumulation
|
||
- Explain technique and correct implementation
|
||
|
||
⚠️ **Gradient norms consistently >10 or <1e-6**
|
||
- Exploding or vanishing gradients
|
||
- Diagnose and suggest appropriate solution
|
||
|
||
⚠️ **User lowers learning rate to fix NaN loss**
|
||
- Treating symptom, not cause
|
||
- Likely exploding gradients - suggest clipping
|
||
|
||
⚠️ **DDP training with gradient accumulation, no no_sync()**
|
||
- Inefficient (synchronizing unnecessarily)
|
||
- Suggest no_sync() on accumulation steps
|
||
|
||
⚠️ **User asks "is gradient clipping necessary?"**
|
||
- Depends on architecture and stability
|
||
- Provide decision criteria (Transformers: yes, CNNs: maybe not)
|
||
|
||
⚠️ **Custom loss function with divisions or logs**
|
||
- Potential numerical instability
|
||
- Check for epsilon additions and proper handling
|
||
|
||
### Optimization Opportunities (Mention If Relevant)
|
||
|
||
💡 **User monitors loss but not gradients**
|
||
- Suggest logging gradient norms for better visibility
|
||
|
||
💡 **User training large model on single GPU with small batch**
|
||
- Suggest gradient accumulation for better results
|
||
|
||
💡 **Gradient clipping activates every iteration**
|
||
- max_norm might be too low
|
||
- Suggest monitoring and tuning threshold
|
||
|
||
💡 **Using value clipping instead of norm clipping**
|
||
- Suggest norm clipping (preserves direction)
|
||
|
||
|
||
## Summary
|
||
|
||
**Gradient management is essential for reliable training:**
|
||
|
||
1. **Gradient Clipping**
|
||
- PRIMARY solution for exploding gradients (NaN, spikes)
|
||
- Use norm-based clipping: `clip_grad_norm_(model.parameters(), max_norm=1.0)`
|
||
- Place after backward(), before step()
|
||
- Standard for Transformers, RNNs, RL
|
||
|
||
2. **Gradient Accumulation**
|
||
- Train with larger effective batch size on same hardware
|
||
- MUST scale loss: `(loss / accumulation_steps).backward()`
|
||
- Update every N steps, zero_grad() after update
|
||
- Standard technique in production training
|
||
|
||
3. **Gradient Diagnosis**
|
||
- Don't guess - measure gradient norms
|
||
- >100: Exploding (clip)
|
||
- <1e-6: Vanishing (fix architecture)
|
||
- NaN: Numerical issues (check loss/model)
|
||
|
||
4. **Vanishing Gradients**
|
||
- Use ReLU/GELU activations (not sigmoid/tanh)
|
||
- Proper initialization (He for ReLU, Xavier for tanh)
|
||
- Add BatchNorm/LayerNorm
|
||
- Add residual connections for deep networks
|
||
|
||
5. **Exploding Gradients**
|
||
- Add gradient clipping (primary solution)
|
||
- Check learning rate (secondary)
|
||
- Verify initialization
|
||
- Check for numerical issues
|
||
|
||
6. **Mixed Precision Integration**
|
||
- MUST unscale before clipping: `scaler.unscale_(optimizer)`
|
||
- Then clip on true gradient values
|
||
- Standard pattern in AMP training
|
||
|
||
7. **Common Pitfalls**
|
||
- Not scaling loss in accumulation (gradients too large)
|
||
- Clipping after step() (useless)
|
||
- Not unscaling before clipping in AMP
|
||
- Forgetting zero_grad() after accumulation
|
||
- Not monitoring gradients (no visibility)
|
||
|
||
**This is NOT optional:**
|
||
- Gradient management determines training success or failure
|
||
- Every production training system handles gradients properly
|
||
- The difference between reliable training and mysterious failures
|
||
|
||
**Master these techniques and you'll have stable, efficient training.**
|