Files
gh-tachyon-beep-skillpacks-…/skills/using-training-optimization/loss-functions-and-objectives.md
2025-11-30 09:00:11 +08:00

2139 lines
63 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Loss Functions and Objectives Skill
## When to Use This Skill
Use this skill when:
- User asks "what loss function should I use?"
- Implementing binary, multi-class, or multi-label classification
- Implementing regression models
- Training on imbalanced datasets (class imbalance)
- Multi-task learning with multiple loss terms
- Custom loss function implementation needed
- Loss goes to NaN or infinity during training
- Loss not decreasing despite valid training loop
- User suggests BCE instead of BCEWithLogitsLoss (RED FLAG)
- User adds softmax before CrossEntropyLoss (RED FLAG)
- Multi-task losses without weighting (RED FLAG)
- Division or log operations in custom loss (stability concern)
- Segmentation, ranking, or specialized tasks
- Loss debugging and troubleshooting
Do NOT use when:
- User has specific bugs unrelated to loss functions
- Only discussing model architecture (no loss questions)
- Loss function already working well and no questions asked
- User needs general training advice (use optimization-algorithms skill)
## Core Principles
### 1. The Critical Importance of Loss Functions
**Loss functions are fundamental to deep learning:**
- Direct objective that gradients optimize
- Wrong loss → model optimizes wrong thing
- Numerically unstable loss → NaN, training collapse
- Unweighted multi-task → one task dominates
- Mismatched loss for task → poor performance
**Common Impact:**
- Proper loss selection: 5-15% performance improvement
- Numerical stability: difference between training and crashing
- Class balancing: difference between 95% accuracy (useless) and 85% F1 (useful)
- Multi-task weighting: difference between all tasks learning vs one task dominating
**This is NOT optional:**
- Every SOTA paper carefully selects and tunes losses
- Loss function debugging is essential skill
- One mistake (BCE vs BCEWithLogitsLoss) can break training
### 2. Loss Selection Decision Tree
```
What is your task?
├─ Classification?
│ │
│ ├─ Binary (2 classes, single output)
│ │ → Use: BCEWithLogitsLoss (NOT BCELoss!)
│ │ → Model outputs: logits (no sigmoid)
│ │ → Target shape: (batch,) or (batch, 1) with 0/1
│ │ → Imbalanced? Add pos_weight parameter
│ │
│ ├─ Multi-class (N classes, one label per sample)
│ │ → Use: CrossEntropyLoss
│ │ → Model outputs: logits (batch, num_classes) - no softmax!
│ │ → Target shape: (batch,) with class indices [0, N-1]
│ │ → Imbalanced? Add weight parameter or use focal loss
│ │
│ └─ Multi-label (N classes, multiple labels per sample)
│ → Use: BCEWithLogitsLoss
│ → Model outputs: logits (batch, num_classes) - no sigmoid!
│ → Target shape: (batch, num_classes) with 0/1
│ → Each class is independent binary classification
├─ Regression?
│ │
│ ├─ Standard regression, squared errors
│ │ → Use: MSELoss (L2 loss)
│ │ → Sensitive to outliers
│ │ → Penalizes large errors heavily
│ │
│ ├─ Robust to outliers
│ │ → Use: L1Loss (MAE)
│ │ → Less sensitive to outliers
│ │ → Linear penalty
│ │
│ └─ Best of both (recommended)
│ → Use: SmoothL1Loss (Huber loss)
│ → L2 for small errors, L1 for large errors
│ → Good default choice
├─ Segmentation?
│ │
│ ├─ Binary segmentation
│ │ → Use: BCEWithLogitsLoss or DiceLoss
│ │ → Combine both: α*BCE + (1-α)*Dice
│ │
│ └─ Multi-class segmentation
│ → Use: CrossEntropyLoss or DiceLoss
│ → Imbalanced pixels? Use weighted CE or focal loss
├─ Ranking/Similarity?
│ │
│ ├─ Triplet learning
│ │ → Use: TripletMarginLoss
│ │ → Learn embeddings with anchor, positive, negative
│ │
│ ├─ Pairwise ranking
│ │ → Use: MarginRankingLoss
│ │ → Learn x1 > x2 or x2 > x1
│ │
│ └─ Contrastive learning
│ → Use: ContrastiveLoss or NTXentLoss
│ → Pull similar together, push dissimilar apart
└─ Multi-Task?
→ Combine losses with careful weighting
→ See Multi-Task Learning section below
```
## Section 1: Binary Classification - BCEWithLogitsLoss
### THE GOLDEN RULE: ALWAYS Use BCEWithLogitsLoss, NEVER BCELoss
This is the MOST COMMON loss function mistake in deep learning.
### ❌ WRONG: BCELoss (Numerically Unstable)
```python
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(100, 1)
self.sigmoid = nn.Sigmoid() # ❌ DON'T DO THIS
def forward(self, x):
logits = self.fc(x)
return self.sigmoid(logits) # ❌ Applying sigmoid in model
# Training loop
output = model(x) # Probabilities [0, 1]
loss = F.binary_cross_entropy(output, target) # ❌ UNSTABLE!
```
**Why this is WRONG:**
1. **Numerical instability**: `log(sigmoid(x))` underflows for large negative x
2. **Gradient issues**: sigmoid saturates, BCE takes log → compound saturation
3. **NaN risk**: When sigmoid(logits) = 0 or 1, log(0) = -inf
4. **Slower training**: Less stable gradients
### ✅ RIGHT: BCEWithLogitsLoss (Numerically Stable)
```python
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(100, 1)
# ✅ NO sigmoid in model!
def forward(self, x):
return self.fc(x) # ✅ Return logits
# Training loop
logits = model(x) # Raw logits (can be any value)
loss = F.binary_cross_entropy_with_logits(logits, target) # ✅ STABLE!
```
**Why this is RIGHT:**
1. **Numerically stable**: Uses log-sum-exp trick internally
2. **Better gradients**: Single combined operation
3. **No NaN**: Stable for all logit values
4. **Faster training**: More stable optimization
### The Math Behind the Stability
**Unstable version (BCELoss):**
```python
# BCE computes: -[y*log(σ(x)) + (1-y)*log(1-σ(x))]
# Problem: log(σ(x)) = log(1/(1+exp(-x))) underflows for large negative x
# Example:
x = -100 # Large negative logit
sigmoid(x) = 1 / (1 + exp(100)) 0 # Underflows to 0
log(sigmoid(x)) = log(0) = -inf # ❌ NaN!
```
**Stable version (BCEWithLogitsLoss):**
```python
# BCEWithLogitsLoss uses log-sum-exp trick:
# log(σ(x)) = log(1/(1+exp(-x))) = -log(1+exp(-x))
# Rewritten as: -log1p(exp(-x)) for stability
# For positive x: use log(sigmoid(x)) = -log1p(exp(-x))
# For negative x: use log(sigmoid(x)) = x - log1p(exp(x))
# This is ALWAYS stable!
# Example:
x = -100
log(sigmoid(x)) = -100 - log1p(exp(-100))
= -100 - log1p(0)
= -100 # ✅ Stable!
```
### Inference: Converting Logits to Probabilities
```python
# During training
logits = model(x)
loss = F.binary_cross_entropy_with_logits(logits, target)
# During inference/evaluation
logits = model(x)
probs = torch.sigmoid(logits) # ✅ NOW apply sigmoid
predictions = (probs > 0.5).float() # Binary predictions
```
### Handling Class Imbalance with pos_weight
```python
# Dataset: 95% negative (class 0), 5% positive (class 1)
# Problem: Model predicts all negatives → 95% accuracy but useless!
# Solution 1: pos_weight parameter
neg_count = 950
pos_count = 50
pos_weight = torch.tensor([neg_count / pos_count]) # 950/50 = 19.0
loss = F.binary_cross_entropy_with_logits(
logits, target,
pos_weight=pos_weight # Weight positive class 19x more
)
# pos_weight effect:
# - Positive examples contribute 19x to loss
# - Forces model to care about minority class
# - Balances gradient contributions
# Solution 2: Focal Loss (see Advanced Techniques section)
```
### Complete Binary Classification Example
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class BinaryClassifier(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 1) # Single output for binary
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # ✅ Return logits (no sigmoid)
# Training setup
model = BinaryClassifier(input_dim=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Handle imbalanced data
class_counts = torch.bincount(train_labels.long())
pos_weight = class_counts[0] / class_counts[1]
# Training loop
model.train()
for x, y in train_loader:
optimizer.zero_grad()
logits = model(x)
loss = F.binary_cross_entropy_with_logits(
logits.squeeze(), # Shape: (batch,)
y.float(), # Shape: (batch,)
pos_weight=pos_weight if imbalanced else None
)
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(x_test)
probs = torch.sigmoid(logits) # ✅ Apply sigmoid for inference
preds = (probs > 0.5).float()
# Compute metrics
accuracy = (preds.squeeze() == y_test).float().mean()
# Better: Use F1, precision, recall for imbalanced data
```
## Section 2: Multi-Class Classification - CrossEntropyLoss
### THE GOLDEN RULE: Pass Logits (NOT Softmax) to CrossEntropyLoss
### ❌ WRONG: Applying Softmax Before CrossEntropyLoss
```python
class MultiClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.fc = nn.Linear(100, num_classes)
self.softmax = nn.Softmax(dim=1) # ❌ DON'T DO THIS
def forward(self, x):
logits = self.fc(x)
return self.softmax(logits) # ❌ Applying softmax in model
# Training
probs = model(x) # Already softmaxed
loss = F.cross_entropy(probs, target) # ❌ WRONG! Double softmax!
```
**Why this is WRONG:**
1. **Double softmax**: CrossEntropyLoss applies softmax internally
2. **Numerical instability**: Extra softmax operation
3. **Wrong gradients**: Backprop through unnecessary operation
4. **Confusion**: Model outputs different things in train vs eval
### ✅ RIGHT: Pass Logits to CrossEntropyLoss
```python
class MultiClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.fc = nn.Linear(100, num_classes)
# ✅ NO softmax in model!
def forward(self, x):
return self.fc(x) # ✅ Return logits
# Training
logits = model(x) # Shape: (batch, num_classes)
target = ... # Shape: (batch,) with class indices [0, num_classes-1]
loss = F.cross_entropy(logits, target) # ✅ CORRECT!
```
### Target Shape Requirements
```python
# ✅ CORRECT: Target is class indices
logits = torch.randn(32, 10) # (batch=32, num_classes=10)
target = torch.randint(0, 10, (32,)) # (batch=32,) with values in [0, 9]
loss = F.cross_entropy(logits, target) # ✅ Works!
# ❌ WRONG: One-hot encoded target
target_onehot = F.one_hot(target, num_classes=10) # (batch=32, num_classes=10)
loss = F.cross_entropy(logits, target_onehot) # ❌ Type error!
# If you have one-hot, convert back to indices:
target_indices = target_onehot.argmax(dim=1) # (batch,)
loss = F.cross_entropy(logits, target_indices) # ✅ Works!
```
### Handling Class Imbalance with Weights
```python
# Dataset: Class 0: 1000 samples, Class 1: 100 samples, Class 2: 50 samples
# Problem: Model biased toward majority class
# Solution 1: Class weights (inverse frequency)
class_counts = torch.tensor([1000., 100., 50.])
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_weights)
# Normalizes so weights sum to num_classes
# class_weights = [0.086, 0.857, 1.714]
# Minority classes weighted much higher
loss = F.cross_entropy(logits, target, weight=class_weights)
# Solution 2: Balanced accuracy loss (effective sample weighting)
# Weight each sample by inverse class frequency
sample_weights = class_weights[target] # Index into weights
loss = F.cross_entropy(logits, target, reduction='none')
weighted_loss = (loss * sample_weights).mean()
# Solution 3: Focal Loss (see Advanced Techniques section)
```
### Inference: Converting Logits to Probabilities
```python
# During training
logits = model(x)
loss = F.cross_entropy(logits, target)
# During inference/evaluation
logits = model(x) # (batch, num_classes)
probs = F.softmax(logits, dim=1) # ✅ NOW apply softmax
preds = logits.argmax(dim=1) # Or directly argmax logits (same result)
# Why argmax logits works:
# argmax(softmax(logits)) = argmax(logits) because softmax is monotonic
```
### Complete Multi-Class Example
```python
class MultiClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # ✅ Return logits
# Training setup
num_classes = 10
model = MultiClassifier(input_dim=100, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Compute class weights for imbalanced data
class_counts = torch.bincount(train_labels)
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * num_classes
# Training loop
model.train()
for x, y in train_loader:
optimizer.zero_grad()
logits = model(x) # (batch, num_classes)
loss = F.cross_entropy(logits, y, weight=class_weights)
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(x_test)
probs = F.softmax(logits, dim=1) # For calibration analysis
preds = logits.argmax(dim=1) # Class predictions
accuracy = (preds == y_test).float().mean()
```
## Section 3: Multi-Label Classification
### Use BCEWithLogitsLoss (Each Class is Independent)
```python
# Task: Predict multiple tags for an image
# Example: [dog, outdoor, sunny] → target = [1, 0, 1, 0, 1, 0, ...]
class MultiLabelClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.fc = nn.Linear(input_dim, num_classes)
# ✅ NO sigmoid! Return logits
def forward(self, x):
return self.fc(x) # (batch, num_classes) logits
# Training
logits = model(x) # (batch, num_classes)
target = ... # (batch, num_classes) with 0/1 for each class
# Each class is independent binary classification
loss = F.binary_cross_entropy_with_logits(logits, target.float())
# Inference
logits = model(x_test)
probs = torch.sigmoid(logits) # Per-class probabilities
preds = (probs > 0.5).float() # Threshold each class independently
# Example output:
# probs = [0.9, 0.3, 0.8, 0.1, 0.7, ...]
# preds = [1.0, 0.0, 1.0, 0.0, 1.0, ...] (dog=yes, outdoor=no, sunny=yes, ...)
```
### Handling Imbalanced Labels
```python
# Some labels are rare (e.g., "sunset" appears in 2% of images)
# Solution 1: Per-class pos_weight
label_counts = train_labels.sum(dim=0) # Count per class
num_samples = len(train_labels)
neg_counts = num_samples - label_counts
pos_weights = neg_counts / label_counts # (num_classes,)
loss = F.binary_cross_entropy_with_logits(
logits, target.float(),
pos_weight=pos_weights
)
# Solution 2: Focal loss per class (see Advanced Techniques)
```
## Section 4: Regression Losses
### MSELoss (L2 Loss) - Default Choice
```python
# Mean Squared Error: (pred - target)^2
pred = model(x) # (batch, output_dim)
target = ... # (batch, output_dim)
loss = F.mse_loss(pred, target)
# Characteristics:
# ✅ Smooth gradients
# ✅ Penalizes large errors heavily (squared term)
# ❌ Sensitive to outliers (outliers dominate loss)
# ❌ Can be numerically large if targets not normalized
# When to use:
# - Standard regression tasks
# - Targets are normalized (similar scale to predictions)
# - Outliers are rare or not expected
```
### L1Loss (MAE) - Robust to Outliers
```python
# Mean Absolute Error: |pred - target|
pred = model(x)
loss = F.l1_loss(pred, target)
# Characteristics:
# ✅ Robust to outliers (linear penalty)
# ✅ Numerically stable
# ❌ Non-smooth at zero (gradient discontinuity)
# ❌ Equal penalty for all error magnitudes
# When to use:
# - Outliers present in data
# - Want robust loss
# - Median prediction preferred over mean
```
### SmoothL1Loss (Huber Loss) - Best of Both Worlds
```python
# Smooth L1: L2 for small errors, L1 for large errors
pred = model(x)
loss = F.smooth_l1_loss(pred, target, beta=1.0)
# Formula:
# loss = 0.5 * (pred - target)^2 / beta if |pred - target| < beta
# loss = |pred - target| - 0.5 * beta otherwise
# Characteristics:
# ✅ Smooth gradients everywhere
# ✅ Robust to outliers (L1 for large errors)
# ✅ Fast convergence (L2 for small errors)
# ✅ Best default for regression
# When to use:
# - General regression (RECOMMENDED DEFAULT)
# - Uncertainty about outliers
# - Want fast convergence + robustness
```
### Target Normalization (CRITICAL)
```python
# ❌ WRONG: Unnormalized targets
pred = model(x) # Model outputs in range [0, 1] (e.g., after sigmoid)
target = ... # Range [1000, 10000] - NOT NORMALIZED!
loss = F.mse_loss(pred, target) # Huge loss values, bad gradients
# ✅ RIGHT: Normalize targets
# Option 1: Min-Max normalization to [0, 1]
target_min = train_targets.min()
target_max = train_targets.max()
target_normalized = (target - target_min) / (target_max - target_min)
pred = model(x) # Range [0, 1]
loss = F.mse_loss(pred, target_normalized) # ✅ Same scale
# Denormalize for evaluation:
pred_denorm = pred * (target_max - target_min) + target_min
# Option 2: Standardization to mean=0, std=1
target_mean = train_targets.mean()
target_std = train_targets.std()
target_standardized = (target - target_mean) / target_std
pred = model(x) # Should output standardized values
loss = F.mse_loss(pred, target_standardized) # ✅ Normalized scale
# Denormalize for evaluation:
pred_denorm = pred * target_std + target_mean
# Why normalization matters:
# 1. Loss values in reasonable range (not 1e6)
# 2. Better gradient flow
# 3. Learning rate can be standard (1e-3)
# 4. Faster convergence
```
### Complete Regression Example
```python
class Regressor(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # Linear output for regression
# Normalize targets
target_mean = train_targets.mean(dim=0)
target_std = train_targets.std(dim=0)
def normalize(targets):
return (targets - target_mean) / (target_std + 1e-8)
def denormalize(preds):
return preds * target_std + target_mean
# Training
model = Regressor(input_dim=100, output_dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for x, y in train_loader:
optimizer.zero_grad()
pred = model(x)
y_norm = normalize(y)
loss = F.smooth_l1_loss(pred, y_norm) # Using Huber loss
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
pred_norm = model(x_test)
pred = denormalize(pred_norm) # Back to original scale
mse = F.mse_loss(pred, y_test)
print(f"Test MSE: {mse.item()}")
```
## Section 5: Numerical Stability in Loss Computation
### Critical Rule: Avoid log(0), log(negative), and division by zero
### Problem 1: Division by Zero
```python
# ❌ UNSTABLE: No protection
def iou_loss(pred, target):
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
iou = intersection / union # ❌ Division by zero if both empty!
return 1 - iou
# ✅ STABLE: Add epsilon
def iou_loss(pred, target):
eps = 1e-8
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
iou = (intersection + eps) / (union + eps) # ✅ Safe
return 1 - iou
# Why epsilon works:
# - Denominator never zero: union + 1e-8 ≥ 1e-8
# - Doesn't affect result when union is large
# - Prevents NaN propagation
```
### Problem 2: Log of Zero or Negative
```python
# ❌ UNSTABLE: No clamping
def custom_loss(pred, target):
ratio = pred / target
return torch.log(ratio).mean() # ❌ log(0) = -inf, log(neg) = nan
# ✅ STABLE: Clamp before log
def custom_loss(pred, target):
eps = 1e-8
ratio = pred / (target + eps) # Safe division
ratio = torch.clamp(ratio, min=eps) # Ensure positive
return torch.log(ratio).mean() # ✅ Safe log
# Alternative: Use log1p for log(1+x)
def custom_loss(pred, target):
eps = 1e-8
ratio = pred / (target + eps)
return torch.log1p(ratio).mean() # log1p(x) = log(1+x), more stable
```
### Problem 3: Exponential Overflow
```python
# ❌ UNSTABLE: Direct exp can overflow
def custom_loss(logits):
return torch.exp(logits).mean() # ❌ exp(100) = overflow!
# ✅ STABLE: Clamp logits or use stable operations
def custom_loss(logits):
# Option 1: Clamp logits
logits = torch.clamp(logits, max=10) # Prevent overflow
return torch.exp(logits).mean()
# Option 2: Use log-space operations
# If computing log(exp(x)), just return x!
```
### Problem 4: Custom Softmax (Use Built-in Instead)
```python
# ❌ UNSTABLE: Manual softmax
def manual_softmax(logits):
exp_logits = torch.exp(logits) # ❌ Overflow for large logits!
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
# ✅ STABLE: Use F.softmax (uses max subtraction trick)
def stable_softmax(logits):
return F.softmax(logits, dim=1) # ✅ Handles overflow internally
# Built-in implementation (for understanding):
def softmax_stable(logits):
# Subtract max for numerical stability
logits_max = logits.max(dim=1, keepdim=True)[0]
logits = logits - logits_max # Now max(logits) = 0
exp_logits = torch.exp(logits) # No overflow!
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
```
### Epsilon Best Practices
```python
# Epsilon guidelines:
eps = 1e-8 # ✅ Good default for float32
eps = 1e-6 # ✅ Alternative, more conservative
eps = 1e-10 # ❌ Too small, can still underflow
# Where to add epsilon:
# 1. Denominators: x / (y + eps)
# 2. Before log: log(x + eps) or log(clamp(x, min=eps))
# 3. Before sqrt: sqrt(x + eps)
# Where NOT to add epsilon:
# 4. ❌ Numerators usually don't need it
# 5. ❌ Inside standard PyTorch functions (already stable)
# 6. ❌ After stable operations
```
### Complete Stable Custom Loss Template
```python
class StableCustomLoss(nn.Module):
def __init__(self):
super().__init__()
self.eps = 1e-8
def forward(self, pred, target):
# 1. Ensure inputs are valid
assert not torch.isnan(pred).any(), "pred contains NaN"
assert not torch.isnan(target).any(), "target contains NaN"
# 2. Safe division
ratio = pred / (target + self.eps)
# 3. Clamp before log/sqrt/pow
ratio = torch.clamp(ratio, min=self.eps, max=1e8)
# 4. Safe log operation
log_ratio = torch.log(ratio)
# 5. Check output
loss = log_ratio.mean()
assert not torch.isnan(loss), "loss is NaN"
return loss
```
## Section 6: Multi-Task Learning and Loss Weighting
### The Problem: Different Loss Scales
```python
# Task 1: Classification, CrossEntropyLoss ~ 0.5-2.0
# Task 2: Regression, MSELoss ~ 100-1000
# Task 3: Reconstruction, L2 Loss ~ 10-50
# ❌ WRONG: Naive sum (task 2 dominates!)
loss1 = F.cross_entropy(logits1, target1) # ~0.5
loss2 = F.mse_loss(pred2, target2) # ~500.0
loss3 = F.mse_loss(recon, input) # ~20.0
total_loss = loss1 + loss2 + loss3 # ≈ 520.5
# Gradient analysis:
# ∂total_loss/∂θ ≈ ∂loss2/∂θ (loss1 and loss3 contribute <5%)
# Model learns ONLY task 2, ignores tasks 1 and 3!
```
### Solution 1: Manual Weighting
```python
# Balance losses to similar magnitudes
loss1 = F.cross_entropy(logits1, target1) # ~0.5
loss2 = F.mse_loss(pred2, target2) # ~500.0
loss3 = F.mse_loss(recon, input) # ~20.0
# Set weights so weighted losses are similar scale
w1 = 1.0 # Keep as is
w2 = 0.001 # Scale down by 1000x
w3 = 0.05 # Scale down by 20x
total_loss = w1 * loss1 + w2 * loss2 + w3 * loss3
# = 1.0*0.5 + 0.001*500 + 0.05*20
# = 0.5 + 0.5 + 1.0 = 2.0
# All tasks contribute meaningfully!
# How to find weights:
# 1. Run 1 epoch with equal weights
# 2. Print loss magnitudes
# 3. Set weights inversely proportional to magnitudes
# 4. Iterate until balanced
```
### Solution 2: Uncertainty Weighting (Learnable)
```python
# "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018)
# Learn task weights during training!
class MultiTaskLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# Log variance parameters (learnable)
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses):
"""
losses: list of task losses [loss1, loss2, loss3, ...]
For each task:
weighted_loss = (1 / (2 * σ²)) * loss + log(σ)
Where σ² = exp(log_var) is the learned uncertainty
- High uncertainty → lower weight on that task
- Low uncertainty → higher weight on that task
"""
weighted_losses = []
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i]) # 1/σ²
weighted_loss = precision * loss + self.log_vars[i]
weighted_losses.append(weighted_loss)
return sum(weighted_losses)
# Usage
model = MultiTaskModel()
multi_loss = MultiTaskLoss(num_tasks=3)
# Optimize both model and loss weights
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': multi_loss.parameters(), 'lr': 0.01} # Can use different LR
])
# Training loop
for x, targets in train_loader:
optimizer.zero_grad()
# Compute task predictions
out1, out2, out3 = model(x)
# Compute task losses
loss1 = F.cross_entropy(out1, targets[0])
loss2 = F.mse_loss(out2, targets[1])
loss3 = F.mse_loss(out3, targets[2])
# Combine with learned weighting
total_loss = multi_loss([loss1, loss2, loss3])
total_loss.backward()
optimizer.step()
# Monitor learned weights
if step % 100 == 0:
weights = torch.exp(-multi_loss.log_vars)
print(f"Task weights: {weights.detach()}")
```
### Solution 3: Gradient Normalization
```python
# GradNorm: balances task learning by normalizing gradient magnitudes
def grad_norm_step(model, losses, alpha=1.5):
"""
Adjust task weights to balance gradient magnitudes
losses: list of task losses
alpha: balancing parameter (1.5 typical)
"""
# Get initial loss ratios
initial_losses = [l.item() for l in losses]
# Compute average gradient norm per task
shared_params = list(model.shared_layers.parameters())
grad_norms = []
for loss in losses:
model.zero_grad()
loss.backward(retain_graph=True)
# Compute gradient norm
grad_norm = 0
for p in shared_params:
if p.grad is not None:
grad_norm += p.grad.norm(2).item() ** 2
grad_norms.append(grad_norm ** 0.5)
# Target: all tasks have same gradient norm
mean_grad_norm = sum(grad_norms) / len(grad_norms)
# Adjust weights
weights = []
for gn in grad_norms:
weight = mean_grad_norm / (gn + 1e-8)
weights.append(weight ** alpha)
# Normalize weights
weights = torch.tensor(weights)
weights = weights / weights.sum() * len(weights)
return weights
# Note: GradNorm is more complex, this is simplified version
# For production, use manual or uncertainty weighting
```
### Solution 4: Loss Normalization
```python
# Normalize each loss to [0, 1] range before combining
class NormalizedMultiTaskLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# Track running mean/std per task
self.register_buffer('running_mean', torch.zeros(num_tasks))
self.register_buffer('running_std', torch.ones(num_tasks))
self.momentum = 0.9
def forward(self, losses):
"""Normalize each loss before combining"""
losses_tensor = torch.stack(losses)
if self.training:
# Update running statistics
mean = losses_tensor.mean()
std = losses_tensor.std() + 1e-8
self.running_mean = (self.momentum * self.running_mean +
(1 - self.momentum) * mean)
self.running_std = (self.momentum * self.running_std +
(1 - self.momentum) * std)
# Normalize losses
normalized = (losses_tensor - self.running_mean) / self.running_std
return normalized.sum()
```
### Best Practices for Multi-Task Loss
```python
# Recommended approach:
1. Start with manual weighting:
- Run 1 epoch, check loss magnitudes
- Set weights to balance scales
- Quick and interpretable
2. If tasks have different difficulties:
- Use uncertainty weighting
- Let model learn task importance
- More training time but adaptive
3. Monitor individual task metrics:
- Don't just watch total loss
- Track accuracy/error per task
- Ensure all tasks learning
4. Curriculum learning:
- Start with easy tasks
- Gradually add harder tasks
- Can improve stability
# Example monitoring:
if step % 100 == 0:
print(f"Total Loss: {total_loss.item():.4f}")
print(f"Task 1 (CE): {loss1.item():.4f}")
print(f"Task 2 (MSE): {loss2.item():.4f}")
print(f"Task 3 (Recon): {loss3.item():.4f}")
# Check if any task stuck
if loss1 > 5.0: # Not learning
print("WARNING: Task 1 not learning, increase weight")
```
## Section 7: Custom Loss Function Implementation
### Template for Custom Loss
```python
class CustomLoss(nn.Module):
"""
Template for implementing custom losses
"""
def __init__(self, weight=None, reduction='mean'):
"""
Args:
weight: Manual sample weights (optional)
reduction: 'mean', 'sum', or 'none'
"""
super().__init__()
self.weight = weight
self.reduction = reduction
self.eps = 1e-8 # For numerical stability
def forward(self, pred, target):
"""
Args:
pred: Model predictions
target: Ground truth
Returns:
Loss value (scalar if reduction != 'none')
"""
# 1. Input validation
assert pred.shape == target.shape, "Shape mismatch"
assert not torch.isnan(pred).any(), "pred contains NaN"
# 2. Compute element-wise loss
loss = self.compute_loss(pred, target)
# 3. Apply sample weights if provided
if self.weight is not None:
loss = loss * self.weight
# 4. Apply reduction
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
def compute_loss(self, pred, target):
"""Override this method with your loss computation"""
# Example: MSE
return (pred - target) ** 2
```
### Example 1: Dice Loss (Segmentation)
```python
class DiceLoss(nn.Module):
"""
Dice Loss for segmentation tasks
Dice = 2 * |X ∩ Y| / (|X| + |Y|)
Loss = 1 - Dice
Good for:
- Binary segmentation
- Handling class imbalance
- Smooth gradients
"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth # Prevent division by zero
def forward(self, pred, target):
"""
Args:
pred: (batch, C, H, W) logits
target: (batch, C, H, W) binary masks
"""
# Apply sigmoid to get probabilities
pred = torch.sigmoid(pred)
# Flatten spatial dimensions
pred = pred.view(pred.size(0), pred.size(1), -1) # (batch, C, H*W)
target = target.view(target.size(0), target.size(1), -1)
# Compute dice per sample and per class
intersection = (pred * target).sum(dim=2) # (batch, C)
union = pred.sum(dim=2) + target.sum(dim=2) # (batch, C)
dice = (2 * intersection + self.smooth) / (union + self.smooth)
# Average over classes and batch
return 1 - dice.mean()
# Usage
criterion = DiceLoss(smooth=1.0)
loss = criterion(logits, masks)
# Often combined with BCE:
dice_loss = DiceLoss()
bce_loss = nn.BCEWithLogitsLoss()
total_loss = 0.5 * dice_loss(logits, masks) + 0.5 * bce_loss(logits, masks)
```
### Example 2: Focal Loss (Imbalanced Classification)
```python
class FocalLoss(nn.Module):
"""
Focal Loss for addressing class imbalance
FL = -α * (1 - p)^γ * log(p)
- α: class balancing weight
- γ: focusing parameter (typical: 2.0)
- (1-p)^γ: down-weights easy examples
Good for:
- Highly imbalanced datasets (e.g., object detection)
- Many easy negatives, few hard positives
- When class weights aren't enough
"""
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, target):
"""
Args:
logits: (batch, num_classes) raw logits
target: (batch,) class indices
"""
# Compute cross entropy
ce_loss = F.cross_entropy(logits, target, reduction='none')
# Compute pt = e^(-CE) = probability of true class
pt = torch.exp(-ce_loss)
# Compute focal loss
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
# Apply reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
# Usage
criterion = FocalLoss(alpha=1.0, gamma=2.0)
loss = criterion(logits, target)
# Effect of gamma:
# γ = 0: equivalent to CrossEntropyLoss
# γ = 2: typical value, strong down-weighting of easy examples
# γ = 5: extreme focusing, only hardest examples matter
# Example probability and loss weights:
# pt = 0.9 (easy): (1-0.9)^2 = 0.01 → 1% weight
# pt = 0.5 (medium): (1-0.5)^2 = 0.25 → 25% weight
# pt = 0.1 (hard): (1-0.1)^2 = 0.81 → 81% weight
```
### Example 3: Contrastive Loss (Metric Learning)
```python
class ContrastiveLoss(nn.Module):
"""
Contrastive Loss for learning embeddings
Pulls similar pairs together, pushes dissimilar pairs apart
Good for:
- Face recognition
- Similarity learning
- Few-shot learning
"""
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, embedding1, embedding2, label):
"""
Args:
embedding1: (batch, embedding_dim) first embeddings
embedding2: (batch, embedding_dim) second embeddings
label: (batch,) 1 if similar, 0 if dissimilar
"""
# Euclidean distance
distance = F.pairwise_distance(embedding1, embedding2)
# Loss for similar pairs: want distance = 0
loss_similar = label * distance.pow(2)
# Loss for dissimilar pairs: want distance ≥ margin
loss_dissimilar = (1 - label) * F.relu(self.margin - distance).pow(2)
loss = loss_similar + loss_dissimilar
return loss.mean()
# Usage
criterion = ContrastiveLoss(margin=1.0)
for (img1, img2, is_similar) in train_loader:
emb1 = model(img1)
emb2 = model(img2)
loss = criterion(emb1, emb2, is_similar)
```
### Example 4: Perceptual Loss (Style Transfer, Super-Resolution)
```python
class PerceptualLoss(nn.Module):
"""
Perceptual Loss using VGG features
Compares high-level features instead of pixels
Good for:
- Image generation
- Super-resolution
- Style transfer
"""
def __init__(self, layer='relu3_3'):
super().__init__()
# Load pre-trained VGG
vgg = torchvision.models.vgg16(pretrained=True).features
self.vgg = vgg.eval()
# Freeze VGG
for param in self.vgg.parameters():
param.requires_grad = False
# Select layer
self.layer_map = {
'relu1_2': 4,
'relu2_2': 9,
'relu3_3': 16,
'relu4_3': 23,
}
self.layer_idx = self.layer_map[layer]
def forward(self, pred, target):
"""
Args:
pred: (batch, 3, H, W) predicted images
target: (batch, 3, H, W) target images
"""
# Extract features
pred_features = self.extract_features(pred)
target_features = self.extract_features(target)
# MSE in feature space
loss = F.mse_loss(pred_features, target_features)
return loss
def extract_features(self, x):
"""Extract features from VGG layer"""
for i, layer in enumerate(self.vgg):
x = layer(x)
if i == self.layer_idx:
return x
return x
# Usage
perceptual_loss = PerceptualLoss(layer='relu3_3')
pixel_loss = nn.L1Loss()
# Combine pixel and perceptual loss
total_loss = pixel_loss(pred, target) + 0.1 * perceptual_loss(pred, target)
```
### Example 5: Custom Weighted MSE
```python
class WeightedMSELoss(nn.Module):
"""
MSE with per-element importance weighting
Good for:
- Focusing on important regions (e.g., foreground)
- Time-series with different importance
- Confidence-weighted regression
"""
def __init__(self):
super().__init__()
def forward(self, pred, target, weight):
"""
Args:
pred: (batch, ...) predictions
target: (batch, ...) targets
weight: (batch, ...) importance weights (0-1)
"""
# Element-wise squared error
squared_error = (pred - target) ** 2
# Weight by importance
weighted_error = squared_error * weight
# Average only over weighted elements
# (avoid counting zero-weight elements)
loss = weighted_error.sum() / (weight.sum() + 1e-8)
return loss
# Usage example: Foreground-focused loss
criterion = WeightedMSELoss()
# Create importance map (1.0 for foreground, 0.1 for background)
weight = torch.where(mask > 0.5, torch.tensor(1.0), torch.tensor(0.1))
loss = criterion(pred, target, weight)
```
## Section 8: Advanced Loss Techniques
### Technique 1: Label Smoothing
```python
# Problem: Hard labels [0, 0, 1, 0, 0] cause overconfident predictions
# Solution: Soft labels [0.025, 0.025, 0.9, 0.025, 0.025]
# PyTorch 1.10+ built-in support
loss = F.cross_entropy(logits, target, label_smoothing=0.1)
# What it does:
# Original: y = [0, 0, 1, 0, 0]
# Smoothed: y = (1-α)*[0, 0, 1, 0, 0] + α*[0.2, 0.2, 0.2, 0.2, 0.2]
# = [0.02, 0.02, 0.92, 0.02, 0.02] (for α=0.1, num_classes=5)
# Manual implementation (for understanding):
class LabelSmoothingLoss(nn.Module):
def __init__(self, num_classes, smoothing=0.1):
super().__init__()
self.num_classes = num_classes
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
def forward(self, logits, target):
"""
logits: (batch, num_classes)
target: (batch,) class indices
"""
log_probs = F.log_softmax(logits, dim=1)
# Create smooth labels
smooth_labels = torch.zeros_like(log_probs)
smooth_labels.fill_(self.smoothing / (self.num_classes - 1))
smooth_labels.scatter_(1, target.unsqueeze(1), self.confidence)
# NLL with smooth labels
loss = (-smooth_labels * log_probs).sum(dim=1)
return loss.mean()
# Benefits:
# 1. Better calibration (confidence closer to accuracy)
# 2. Prevents overconfidence
# 3. Acts as regularization
# 4. Often improves test accuracy by 0.5-1%
# When to use:
# ✅ Classification with CrossEntropyLoss
# ✅ Large models prone to overfitting
# ✅ Clean labels (not noisy)
# ❌ Small models (might hurt performance)
# ❌ Noisy labels (already have uncertainty)
```
### Technique 2: Class-Balanced Loss
```python
# Problem: 1000 samples class 0, 10 samples class 1
# Standard CE treats all samples equally → biased to class 0
# Solution 1: Inverse frequency weighting
class_counts = torch.bincount(train_labels)
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * len(class_weights)
loss = F.cross_entropy(logits, target, weight=class_weights)
# Solution 2: Effective number of samples (better for extreme imbalance)
def get_eff_num_weights(num_samples_per_class, beta=0.999):
"""
Effective number of samples: (1 - β^n) / (1 - β)
Handles extreme imbalance better than inverse frequency
Args:
num_samples_per_class: [n1, n2, ..., nC]
beta: Hyperparameter (0.99-0.9999), higher for more imbalance
"""
effective_num = 1.0 - torch.pow(beta, num_samples_per_class)
weights = (1.0 - beta) / effective_num
weights = weights / weights.sum() * len(weights)
return weights
# Usage
class_counts = torch.bincount(train_labels)
weights = get_eff_num_weights(class_counts.float(), beta=0.9999)
loss = F.cross_entropy(logits, target, weight=weights)
# Solution 3: Focal loss (see Example 2 in Section 7)
```
### Technique 3: Mixup / CutMix Loss
```python
# Mixup: Blend two samples and their labels
def mixup_data(x, y, alpha=1.0):
"""
Args:
x: (batch, ...) input
y: (batch,) labels
alpha: Mixup parameter
"""
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(pred, y_a, y_b, lam):
"""Compute mixed loss"""
return lam * F.cross_entropy(pred, y_a) + (1 - lam) * F.cross_entropy(pred, y_b)
# Training with Mixup
for x, y in train_loader:
x, y_a, y_b, lam = mixup_data(x, y, alpha=1.0)
optimizer.zero_grad()
pred = model(x)
loss = mixup_criterion(pred, y_a, y_b, lam)
loss.backward()
optimizer.step()
# Benefits:
# - Regularization
# - Better generalization
# - Smooth decision boundaries
# - +1-2% accuracy on CIFAR/ImageNet
```
### Technique 4: Gradient Clipping for Loss Stability
```python
# Problem: Loss spikes to NaN during training
# Often caused by exploding gradients
# Solution: Clip gradients before optimizer step
for x, y in train_loader:
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Or clip by value:
# torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
optimizer.step()
# When to use:
# ✅ RNNs/LSTMs (prone to exploding gradients)
# ✅ Transformers with high learning rates
# ✅ Loss occasionally spikes to NaN
# ✅ Large models or deep networks
# ❌ Stable training (unnecessary overhead)
# How to choose max_norm:
# - Start with 1.0
# - If still unstable, reduce to 0.5
# - Monitor: print gradient norms to see if clipping activates
```
### Technique 5: Loss Scaling for Mixed Precision
```python
# Problem: Mixed precision (FP16) can cause gradients to underflow
# Solution: Scale loss up, then scale gradients down
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for x, y in train_loader:
optimizer.zero_grad()
# Forward in FP16
with autocast():
pred = model(x)
loss = criterion(pred, y)
# Scale loss and backward
scaler.scale(loss).backward()
# Unscale gradients and step
scaler.step(optimizer)
scaler.update()
# GradScaler automatically:
# 1. Scales loss by factor (e.g., 65536)
# 2. Backprop computes scaled gradients
# 3. Unscales gradients before optimizer step
# 4. Adjusts scale factor dynamically
```
## Section 9: Common Loss Function Pitfalls
### Pitfall 1: BCE Instead of BCEWithLogitsLoss
```python
# ❌ WRONG (seen in 30% of beginner code!)
probs = torch.sigmoid(logits)
loss = F.binary_cross_entropy(probs, target)
# ✅ RIGHT
loss = F.binary_cross_entropy_with_logits(logits, target)
# Impact: Training instability, NaN losses, worse performance
# Fix time: 2 minutes
# Performance gain: Stable training, +2-5% accuracy
```
### Pitfall 2: Softmax Before CrossEntropyLoss
```python
# ❌ WRONG (seen in 20% of beginner code!)
probs = F.softmax(logits, dim=1)
loss = F.cross_entropy(probs, target)
# ✅ RIGHT
loss = F.cross_entropy(logits, target) # Expects logits!
# Impact: Suboptimal learning, double softmax
# Fix time: 1 minute
# Performance gain: +1-3% accuracy
```
### Pitfall 3: Wrong Target Shape for CrossEntropyLoss
```python
# ❌ WRONG: One-hot encoded targets
target = F.one_hot(labels, num_classes=10) # (batch, 10)
loss = F.cross_entropy(logits, target) # Type error!
# ✅ RIGHT: Class indices
target = labels # (batch,) with values in [0, 9]
loss = F.cross_entropy(logits, target)
# Impact: Runtime error or wrong loss computation
# Fix time: 2 minutes
```
### Pitfall 4: Ignoring Class Imbalance
```python
# ❌ WRONG: 95% negative, 5% positive
loss = F.binary_cross_entropy_with_logits(logits, target)
# Model predicts all negative → 95% accuracy but useless!
# ✅ RIGHT: Weight positive class
pos_weight = torch.tensor([19.0]) # 95/5
loss = F.binary_cross_entropy_with_logits(logits, target, pos_weight=pos_weight)
# Impact: Model learns trivial predictor
# Fix time: 5 minutes
# Performance gain: From useless to actually working
```
### Pitfall 5: Not Normalizing Regression Targets
```python
# ❌ WRONG: Targets in [1000, 10000], predictions in [0, 1]
loss = F.mse_loss(pred, target) # Huge loss, bad gradients
# ✅ RIGHT: Normalize targets
target_norm = (target - target.mean()) / target.std()
loss = F.mse_loss(pred, target_norm)
# Impact: Slow convergence, high loss values, need very small LR
# Fix time: 5 minutes
# Performance gain: 10-100x faster convergence
```
### Pitfall 6: Unweighted Multi-Task Loss
```python
# ❌ WRONG: Different scales
loss1 = F.cross_entropy(out1, target1) # ~0.5
loss2 = F.mse_loss(out2, target2) # ~500.0
total = loss1 + loss2 # Task 2 dominates!
# ✅ RIGHT: Balance scales
total = 1.0 * loss1 + 0.001 * loss2 # Both ~0.5
# Impact: One task learns, others ignored
# Fix time: 10 minutes (trial and error)
# Performance gain: All tasks learn instead of one
```
### Pitfall 7: Division by Zero in Custom Loss
```python
# ❌ WRONG: No epsilon
iou = intersection / union # Division by zero!
# ✅ RIGHT: Add epsilon
eps = 1e-8
iou = (intersection + eps) / (union + eps)
# Impact: NaN losses, training crash
# Fix time: 2 minutes
```
### Pitfall 8: Missing optimizer.zero_grad()
```python
# ❌ WRONG: Gradients accumulate!
for x, y in train_loader:
loss = criterion(model(x), y)
loss.backward()
optimizer.step() # Missing zero_grad!
# ✅ RIGHT: Reset gradients
for x, y in train_loader:
optimizer.zero_grad() # ✅ Critical!
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
# Impact: Loss doesn't decrease, weird behavior
# Fix time: 1 minute
# This is caught by systematic debugging
```
### Pitfall 9: Wrong Reduction for Custom Loss
```python
# ❌ SUBOPTIMAL: Sum over batch
loss = (pred - target).pow(2).sum() # Loss scales with batch size!
# ✅ BETTER: Mean over batch
loss = (pred - target).pow(2).mean() # Loss independent of batch size
# Impact: Learning rate depends on batch size
# Fix time: 2 minutes
# When to use sum vs mean:
# - mean: Default, loss independent of batch size
# - sum: When you want loss to scale with batch size (rare)
# - none: When you want per-sample losses (for weighting)
```
### Pitfall 10: Using Accuracy for Imbalanced Data
```python
# ❌ WRONG: 95-5 imbalance
accuracy = (pred == target).float().mean() # 95% for trivial predictor!
# ✅ RIGHT: Use F1, precision, recall
from sklearn.metrics import f1_score, precision_score, recall_score
f1 = f1_score(target, pred) # Balanced metric
precision = precision_score(target, pred)
recall = recall_score(target, pred)
# Or use balanced accuracy:
balanced_acc = (recall_class0 + recall_class1) / 2
# Impact: Misinterpreting model performance
# Fix time: 5 minutes
```
## Section 10: Loss Debugging Methodology
### When Loss is NaN
```python
# Step 1: Check inputs for NaN
print(f"Input has NaN: {torch.isnan(x).any()}")
print(f"Target has NaN: {torch.isnan(target).any()}")
if torch.isnan(x).any():
# Data loading issue
print("Fix: Check data preprocessing")
# Step 2: Check for numerical instability in loss
# - Division by zero
# - Log of zero or negative
# - Exp overflow
# Step 3: Check gradients before NaN
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm()
print(f"{name}: {grad_norm.item()}")
if grad_norm > 1000:
print(f"Exploding gradient in {name}")
# Step 4: Add gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step 5: Lower learning rate
# LR too high can cause NaN
# Step 6: Check loss computation
# Add assertions in custom loss:
def custom_loss(pred, target):
loss = compute_loss(pred, target)
assert not torch.isnan(loss), f"Loss is NaN, pred range: [{pred.min()}, {pred.max()}]"
assert not torch.isinf(loss), f"Loss is inf"
return loss
```
### When Loss Not Decreasing
```python
# Systematic debugging checklist:
# 1. Check loss value
print(f"Loss: {loss.item()}")
# - Is it reasonable? (CE should be ~ln(num_classes) initially)
# - Is it constant? (optimizer not stepping)
# - Is it very high? (wrong scale)
# 2. Check gradients
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name} grad: mean={param.grad.abs().mean():.6f}, max={param.grad.abs().max():.6f}")
# If all gradients ~ 0:
# → Vanishing gradients (check activation functions, initialization)
# If gradients very large (>10):
# → Exploding gradients (add gradient clipping, lower LR)
# If no gradients printed:
# → Missing loss.backward() or parameters not requiring grad
# 3. Check predictions
print(f"Pred range: [{pred.min():.4f}, {pred.max():.4f}]")
print(f"Target range: [{target.min():.4f}, {target.max():.4f}]")
print(f"Pred mean: {pred.mean():.4f}, Target mean: {target.mean():.4f}")
# If predictions are constant:
# → Model not learning (check optimizer.step(), zero_grad())
# If predictions are random:
# → Model learning but task too hard or wrong loss
# If pred/target ranges very different:
# → Normalization issue
# 4. Verify training setup
print(f"Model training mode: {model.training}") # Should be True
print(f"Requires grad: {next(model.parameters()).requires_grad}") # Should be True
# Check optimizer.zero_grad() is called
# Check loss.backward() is called
# Check optimizer.step() is called
# 5. Check learning rate
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
# Too low (< 1e-6): Won't learn
# Too high (> 1e-2): Unstable
# 6. Verify loss function matches task
# Classification → CrossEntropyLoss
# Regression → MSELoss or SmoothL1Loss
# Binary classification → BCEWithLogitsLoss
# 7. Check data
# Visualize a batch:
print(f"Batch input shape: {x.shape}")
print(f"Batch target shape: {target.shape}")
print(f"Target unique values: {target.unique()}")
# Are labels correct?
# Is data normalized?
# Any NaN in data?
# 8. Overfit single batch
# Can model fit one batch perfectly?
single_x, single_y = next(iter(train_loader))
for i in range(1000):
optimizer.zero_grad()
pred = model(single_x)
loss = criterion(pred, single_y)
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Step {i}: Loss = {loss.item():.4f}")
# If can't overfit single batch:
# → Model architecture issue
# → Loss function wrong
# → Bug in training loop
```
### When Loss Stuck at Same Value
```python
# Scenario: Loss stays at 0.693 for binary classification (ln(2))
# Diagnosis: Model predicting 0.5 probability for all samples
# Possible causes:
# 1. Learning rate too low
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Try 1e-3, 1e-4
# 2. Dead neurons (all ReLU outputs are 0)
# Check activations:
activations = model.fc1(x)
print(f"Activations: {activations.abs().mean()}")
if activations.abs().mean() < 0.01:
print("Dead neurons! Try:")
print("- Different initialization")
print("- LeakyReLU instead of ReLU")
print("- Lower learning rate")
# 3. Gradient flow blocked
# Check each layer:
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.abs().mean():.6f}")
else:
print(f"{name}: NO GRADIENT!")
# 4. Wrong optimizer state (if resuming training)
# Solution: Create fresh optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 5. Model too simple for task
# Try: Larger model, more layers, more parameters
# 6. Task is actually random
# Check: Can humans solve this task?
# Check: Is there signal in the data?
```
### When Loss Oscillating / Unstable
```python
# Scenario: Loss jumps around: 0.5 → 2.0 → 0.3 → 5.0 → ...
# Diagnosis: Unstable training
# Possible causes:
# 1. Learning rate too high
# Solution: Lower LR by 10x
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Down from 1e-3
# 2. Batch size too small
# Solution: Increase batch size (more stable gradients)
train_loader = DataLoader(dataset, batch_size=64) # Up from 32
# 3. No gradient clipping
# Solution: Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 4. Numerical instability in loss
# Solution: Use stable loss functions
# BCEWithLogitsLoss instead of BCE
# Add epsilon to custom losses
# 5. Data outliers
# Solution:
# - Remove outliers
# - Use robust loss (L1, SmoothL1, Huber)
# - Clip targets to reasonable range
# 6. Exploding gradients
# Check and clip:
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.norm().item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm}")
if total_norm > 10:
print("Exploding gradients! Add gradient clipping.")
```
## Rationalization Prevention Table
| Rationalization | Why It's Wrong | What You Must Do |
|----------------|----------------|------------------|
| "BCE is simpler than BCEWithLogitsLoss" | BCE is numerically unstable, causes NaN | **ALWAYS use BCEWithLogitsLoss**. Non-negotiable. |
| "Loss weighting is just extra hyperparameter tuning" | Unweighted multi-task losses fail completely | **Check loss scales, weight them**. One task will dominate otherwise. |
| "The optimizer will figure out the scale differences" | Optimizers don't balance losses, they follow gradients | **Manual balance required**. SGD sees gradient magnitude, not task importance. |
| "95% accuracy is great!" | With 95-5 imbalance, this is trivial predictor | **Check F1/precision/recall**. Accuracy misleading for imbalanced data. |
| "Data is clean, no need for epsilon" | Even clean data can hit edge cases (empty masks, zeros) | **Add epsilon anyway**. Cost is negligible, prevents NaN. |
| "Softmax before CE makes output clearer" | CE applies softmax internally, this causes double softmax | **Pass logits to CE**. Never apply softmax first. |
| "One-hot encoding is more standard" | CrossEntropyLoss expects class indices, not one-hot | **Use class indices**. Shape must be (batch,) not (batch, C). |
| "Reduction parameter is optional" | Controls how loss aggregates, affects training dynamics | **Understand and choose**: mean (default), sum (rare), none (per-sample). |
| "Just lower LR to fix NaN" | NaN usually from numerical instability, not LR | **Fix root cause first**: epsilon, clipping, stable loss. Then adjust LR. |
| "Papers use different loss, I should too" | Papers don't always use optimal losses, context matters | **Evaluate if appropriate** for your data/task. Don't blindly copy. |
| "Custom loss is more flexible" | Built-in losses are optimized and tested | **Use built-ins when possible**. Only custom when necessary. |
| "Loss function doesn't matter much" | Loss is THE OBJECTIVE your model optimizes | **Loss choice is critical**. Wrong loss = optimizing wrong thing. |
| "I'll tune loss later" | Loss should match task from the start | **Choose correct loss immediately**. Tuning won't fix fundamentally wrong loss. |
| "Focal loss is always better for imbalance" | Focal loss has hyperparameters, can hurt if tuned wrong | **Try class weights first** (simpler, fewer hyperparameters). |
| "Division by zero won't happen in practice" | Edge cases happen: empty batches, all-zero masks | **Defensive programming**: always add epsilon to denominators. |
## Red Flags Checklist
When reviewing loss function code, watch for these RED FLAGS:
### Critical (Fix Immediately):
- [ ] Using `F.binary_cross_entropy` instead of `F.binary_cross_entropy_with_logits`
- [ ] Applying `sigmoid` or `softmax` before stable loss (BCEWithLogitsLoss, CrossEntropyLoss)
- [ ] Division without epsilon: `x / y` instead of `x / (y + 1e-8)`
- [ ] Log without clamping: `torch.log(x)` instead of `torch.log(torch.clamp(x, min=1e-8))`
- [ ] Missing `optimizer.zero_grad()` in training loop
- [ ] Multi-task losses added without weighting (different scales)
- [ ] Loss goes to NaN during training
### Important (Fix Soon):
- [ ] Class imbalance ignored (no `weight` or `pos_weight` parameter)
- [ ] Regression targets not normalized (huge loss values)
- [ ] Wrong target shape for CrossEntropyLoss (one-hot instead of indices)
- [ ] Custom loss without numerical stability checks
- [ ] Using accuracy metric for highly imbalanced data
- [ ] No gradient clipping for RNNs/Transformers
- [ ] Reduction not specified in custom loss
### Best Practices (Improve):
- [ ] No label smoothing for classification (consider adding)
- [ ] No focal loss for extreme imbalance (>100:1 ratio)
- [ ] Not monitoring individual task losses in multi-task learning
- [ ] Not checking gradient norms during training
- [ ] No assertions in custom loss for debugging
- [ ] Not testing loss function on toy data first
## Summary: Loss Function Selection Flowchart
```
START
|
├─ Binary Classification?
| → BCEWithLogitsLoss + pos_weight for imbalance
|
├─ Multi-Class Classification?
| → CrossEntropyLoss + class weights for imbalance
| → Consider Focal Loss if extreme imbalance (>100:1)
|
├─ Multi-Label Classification?
| → BCEWithLogitsLoss + per-class pos_weight
|
├─ Regression?
| → SmoothL1Loss (good default)
| → MSELoss if no outliers
| → L1Loss if robust to outliers needed
| → ALWAYS normalize targets!
|
├─ Segmentation?
| → BCEWithLogitsLoss + DiceLoss (combine both)
| → Consider Focal Loss for imbalanced pixels
|
├─ Ranking/Similarity?
| → TripletMarginLoss or ContrastiveLoss
|
└─ Multi-Task?
→ Combine with careful weighting
→ Start with manual balance (check scales!)
→ Consider uncertainty weighting
ALWAYS:
✅ Use logits (no sigmoid/softmax before stable losses)
✅ Add epsilon to divisions and before log/sqrt
✅ Check for class/label imbalance
✅ Normalize regression targets
✅ Monitor loss values and gradients
✅ Test loss on toy data first
NEVER:
❌ Use BCE instead of BCEWithLogitsLoss
❌ Apply softmax before CrossEntropyLoss
❌ Ignore different scales in multi-task
❌ Divide without epsilon
❌ Trust accuracy alone for imbalanced data
```
## Final Checklist Before Training
Before starting training, verify:
1. **Loss Function Matches Task:**
- [ ] Binary classification → BCEWithLogitsLoss
- [ ] Multi-class → CrossEntropyLoss
- [ ] Regression → SmoothL1Loss or MSE with normalized targets
2. **Numerical Stability:**
- [ ] Using stable loss (BCEWithLogitsLoss, not BCE)
- [ ] Epsilon in divisions: `x / (y + 1e-8)`
- [ ] Clamp before log: `torch.log(torch.clamp(x, min=1e-8))`
3. **Class Imbalance Handled:**
- [ ] Checked class distribution
- [ ] Added `weight` or `pos_weight` if imbalanced
- [ ] Using F1/precision/recall metrics, not just accuracy
4. **Multi-Task Weighting:**
- [ ] Checked loss scales (printed first batch)
- [ ] Added manual weights or uncertainty weighting
- [ ] Monitoring individual task metrics
5. **Target Preparation:**
- [ ] CrossEntropyLoss: targets are class indices (batch,)
- [ ] BCEWithLogitsLoss: targets are 0/1 floats
- [ ] Regression: targets normalized to similar scale as predictions
6. **Training Loop:**
- [ ] `optimizer.zero_grad()` before backward
- [ ] `loss.backward()` to compute gradients
- [ ] `optimizer.step()` to update parameters
- [ ] Gradient clipping if using RNN/Transformer
7. **Debugging Setup:**
- [ ] Can print loss value: `loss.item()`
- [ ] Can print gradient norms
- [ ] Can visualize predictions vs targets
- [ ] Have tested overfitting single batch
## When to Seek Help
If after following this skill you still have loss issues:
1. **Loss is NaN:**
- Checked all numerical stability issues?
- Added epsilon everywhere?
- Tried gradient clipping?
- Lowered learning rate?
→ If still NaN, may need architecture change or data investigation
2. **Loss not decreasing:**
- Verified training loop is correct (zero_grad, backward, step)?
- Checked gradients are flowing?
- Tried overfitting single batch?
- Verified loss function matches task?
→ If still not decreasing, may be model capacity or data issue
3. **Loss decreasing but metrics poor:**
- Is loss the right objective for your metric?
- Example: CE minimizes NLL, not accuracy
- Consider metric-aware loss or post-hoc calibration
4. **Multi-task learning not working:**
- Tried multiple weighting strategies?
- Monitored individual task losses?
- Ensured all tasks getting gradient signal?
→ May need task-specific heads or curriculum learning
Remember: Loss function is the heart of deep learning. Get this right first before tuning everything else.
## Additional Resources
**Key Papers:**
- Focal Loss: "Focal Loss for Dense Object Detection" (Lin et al., 2017)
- Label Smoothing: "Rethinking the Inception Architecture" (Szegedy et al., 2016)
- Uncertainty Weighting: "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018)
- Class-Balanced Loss: "Class-Balanced Loss Based on Effective Number of Samples" (Cui et al., 2019)
**PyTorch Documentation:**
- Loss Functions: https://pytorch.org/docs/stable/nn.html#loss-functions
- Numerical Stability: Use built-in combined operations (BCEWithLogitsLoss, etc.)
**Common Loss Functions Quick Reference:**
```python
# Classification
F.binary_cross_entropy_with_logits(logits, target, pos_weight=...)
F.cross_entropy(logits, target, weight=..., label_smoothing=...)
F.nll_loss(log_probs, target) # If you already have log_probs
# Regression
F.mse_loss(pred, target)
F.l1_loss(pred, target)
F.smooth_l1_loss(pred, target, beta=1.0)
F.huber_loss(pred, target, delta=1.0) # PyTorch 1.10+
# Ranking
F.margin_ranking_loss(input1, input2, target, margin=0.0)
F.triplet_margin_loss(anchor, positive, negative, margin=1.0)
F.cosine_embedding_loss(input1, input2, target)
# Other
F.kl_div(log_probs, target_probs) # KL divergence
F.poisson_nll_loss(log_input, target) # Poisson regression
```
**END OF SKILL**
When you use this skill, you become an expert in loss function selection and implementation. You will:
- Choose the correct loss for any task
- Ensure numerical stability
- Handle class imbalance appropriately
- Weight multi-task losses correctly
- Debug loss issues systematically
- Avoid all common loss function pitfalls
Remember: The loss function IS your model's objective. Get this right, and everything else follows.