Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 09:00:00 +08:00
commit 955d5c6743
12 changed files with 6996 additions and 0 deletions

View File

@@ -0,0 +1,915 @@
# Normalization Techniques
## Context
You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering:
- **Batch size dependency**: BatchNorm fails with small batches (< 8)
- **Architecture mismatch**: BatchNorm breaks RNNs/Transformers (use LayerNorm)
- **Task-specific needs**: Style transfer needs InstanceNorm, not BatchNorm
- **Modern alternatives**: RMSNorm simpler and faster than LayerNorm for LLMs
**This skill prevents normalization cargo-culting and provides architecture-specific selection.**
## Why Normalization Matters
**Problem: Internal Covariate Shift**
During training, layer input distributions shift as previous layers update. This causes:
- Vanishing/exploding gradients (deep networks)
- Slow convergence (small learning rates required)
- Training instability (loss spikes)
**Solution: Normalization**
Normalize activations to have stable statistics (mean=0, std=1). Benefits:
- **10x faster convergence**: Can use larger learning rates
- **Better generalization**: Regularization effect
- **Enables deep networks**: 50+ layers without gradient issues
- **Less sensitive to initialization**: Weights can start further from optimal
**Key insight**: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize.
## Normalization Families
### 1. Batch Normalization (BatchNorm)
**What it does:**
Normalizes across the batch dimension for each channel/feature.
**Formula:**
```
Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width
For each channel c:
μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims
σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims
x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε)
# Learnable scale and shift
y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c
```
**When to use:**
- ✅ CNNs for classification (ResNet, EfficientNet)
- ✅ Large batch sizes (≥ 16)
- ✅ IID data (image classification, object detection)
**When NOT to use:**
- ❌ Small batch sizes (< 8): Noisy statistics cause training failure
- ❌ RNNs/LSTMs: Breaks temporal dependencies
- ❌ Transformers: Batch dependency problematic for variable-length sequences
- ❌ Style transfer: Batch statistics erase style information
**Batch size dependency:**
```python
batch_size = 32: # ✓ Works well (stable statistics)
batch_size = 16: # ✓ Acceptable
batch_size = 8: # ✓ Marginal (consider GroupNorm)
batch_size = 4: # ✗ Unstable (use GroupNorm)
batch_size = 2: # ✗ FAILS! (noisy statistics)
batch_size = 1: # ✗ Undefined (no batch to normalize over!)
```
**PyTorch example:**
```python
import torch.nn as nn
# For Conv2d
bn = nn.BatchNorm2d(num_features=64) # 64 channels
x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64
y = bn(x)
# For Linear
bn = nn.BatchNorm1d(num_features=128) # 128 features
x = torch.randn(32, 128) # Batch=32, Features=128
y = bn(x)
```
**Inference mode:**
```python
# Training: Uses batch statistics
model.train()
y = bn(x) # Normalizes using current batch mean/std
# Inference: Uses running statistics (accumulated during training)
model.eval()
y = bn(x) # Normalizes using running_mean/running_std
```
### 2. Layer Normalization (LayerNorm)
**What it does:**
Normalizes across the feature dimension for each sample independently.
**Formula:**
```
Given input x with shape (B, C): # Batch, Features
For each sample b:
μ_b = mean(x[b, :]) # Mean over features
σ_b = std(x[b, :]) # Std over features
x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε)
# Learnable scale and shift
y[b, :] = γ * x_norm[b, :] + β
```
**When to use:**
- ✅ Transformers (BERT, GPT, T5)
- ✅ RNNs/LSTMs (maintains temporal independence)
- ✅ Small batch sizes (batch-independent!)
- ✅ Variable-length sequences
- ✅ Reinforcement learning (batch_size=1 common)
**Advantages over BatchNorm:**
-**Batch-independent**: Works with batch_size=1
-**No running statistics**: Inference = training (no mode switching)
-**Sequence-friendly**: Doesn't mix information across timesteps
**PyTorch example:**
```python
import torch.nn as nn
# For Transformer
ln = nn.LayerNorm(normalized_shape=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512
y = ln(x) # Normalizes last dimension independently per (batch, position)
# For RNN hidden states
ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256
h = torch.randn(32, 256) # Batch=32, Hidden=256
h_norm = ln(h)
```
**Key difference from BatchNorm:**
```python
# BatchNorm: Normalizes across batch dimension
# Given (B=32, C=64, H=28, W=28)
# Computes 64 means/stds (one per channel, across batch + spatial)
# LayerNorm: Normalizes across feature dimension
# Given (B=32, L=128, D=512)
# Computes 32×128 means/stds (one per (batch, position), across features)
```
### 3. Group Normalization (GroupNorm)
**What it does:**
Normalizes channels in groups, batch-independent.
**Formula:**
```
Given input x with shape (B, C, H, W):
Divide C channels into G groups (C must be divisible by G)
For each sample b and group g:
channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g
μ_{b,g} = mean(channels) # Mean over channels in group + spatial
σ_{b,g} = std(channels) # Std over channels in group + spatial
x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε)
```
**When to use:**
- ✅ Small batch sizes (< 8)
- ✅ CNNs with batch_size=1 (style transfer, RL)
- ✅ Object detection/segmentation (often use small batches)
- ✅ When BatchNorm unstable but want spatial normalization
**Group size selection:**
```python
# num_groups trade-off:
num_groups = 1: # = LayerNorm (all channels together)
num_groups = C: # = InstanceNorm (each channel separate)
num_groups = 32: # Standard choice (good balance)
# Rule: C must be divisible by num_groups
channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group
channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group
channels = 64, num_groups = 30: # ✗ 64/30 not integer!
```
**PyTorch example:**
```python
import torch.nn as nn
# For small batch CNN
gn = nn.GroupNorm(num_groups=32, num_channels=64)
x = torch.randn(2, 64, 28, 28) # Batch=2 (small!)
y = gn(x) # Works well even with batch=2
# Compare performance:
batch_sizes = [1, 2, 4, 8, 16, 32]
bn = nn.BatchNorm2d(64)
gn = nn.GroupNorm(32, 64)
for bs in batch_sizes:
x = torch.randn(bs, 64, 28, 28)
# BatchNorm gets more stable with larger batch
# GroupNorm consistent across all batch sizes
```
**Empirical results (He et al. 2018):**
```
ImageNet classification with ResNet-50:
batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie)
batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!)
batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!)
```
### 4. Instance Normalization (InstanceNorm)
**What it does:**
Normalizes each sample and channel independently (no batch mixing).
**Formula:**
```
Given input x with shape (B, C, H, W):
For each sample b and channel c:
μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only
σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only
x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε)
```
**When to use:**
- ✅ Style transfer (neural style, CycleGAN, pix2pix)
- ✅ Image-to-image translation
- ✅ When batch/channel mixing destroys information
**Why for style transfer:**
```python
# Style transfer goal: Transfer style while preserving content
# BatchNorm: Mixes statistics across batch (erases individual style!)
# InstanceNorm: Per-image statistics (preserves each image's style)
# Example: Neural style transfer
content_image = load_image("photo.jpg")
style_image = load_image("starry_night.jpg")
# With BatchNorm: Output loses content image's unique characteristics
# With InstanceNorm: Content characteristics preserved, style applied
```
**PyTorch example:**
```python
import torch.nn as nn
# For style transfer generator
in_norm = nn.InstanceNorm2d(num_features=64)
x = torch.randn(1, 64, 256, 256) # Single image
y = in_norm(x) # Normalizes each channel independently
# CycleGAN generator architecture
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm!
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Per-image normalization
x = self.relu(x)
return x
```
**Relation to GroupNorm:**
```python
# InstanceNorm is GroupNorm with num_groups = num_channels
InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C)
```
### 5. RMS Normalization (RMSNorm)
**What it does:**
Simplified LayerNorm that only rescales (no recentering), faster and simpler.
**Formula:**
```
Given input x:
# LayerNorm (2 steps):
x_centered = x - mean(x) # 1. Center
x_norm = x_centered / std(x) # 2. Scale
# RMSNorm (1 step):
rms = sqrt(mean(x²)) # Root Mean Square
x_norm = x / rms # Only scale, no centering
```
**When to use:**
- ✅ Modern LLMs (LLaMA, Mistral, Gemma)
- ✅ When speed matters (15-20% faster than LayerNorm)
- ✅ Large Transformer models (billions of parameters)
**Advantages:**
-**Simpler**: One operation instead of two
-**Faster**: ~15-20% speedup over LayerNorm
-**Numerically stable**: No subtraction (avoids catastrophic cancellation)
-**Same performance**: Empirically matches LayerNorm quality
**PyTorch implementation:**
```python
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize
x_norm = x / rms
# Scale (learnable)
return self.weight * x_norm
# Usage in Transformer
rms = RMSNorm(dim=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model
y = rms(x)
```
**Speed comparison (LLaMA-7B, A100 GPU):**
```
LayerNorm: 1000 tokens/sec
RMSNorm: 1180 tokens/sec # 18% faster!
# For large models, this adds up:
# 1 million tokens: 180 seconds saved
```
**Modern LLM adoption:**
```python
# LLaMA (Meta, 2023): RMSNorm
# Mistral (Mistral AI, 2023): RMSNorm
# Gemma (Google, 2024): RMSNorm
# PaLM (Google, 2022): RMSNorm
# Older models:
# GPT-2/3 (OpenAI): LayerNorm
# BERT (Google, 2018): LayerNorm
```
## Architecture-Specific Selection
### CNN (Convolutional Neural Networks)
**Default: BatchNorm**
```python
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels) # After conv
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # Normalize
x = self.relu(x)
return x
```
**Exception: Small batch sizes**
```python
# If batch_size < 8, use GroupNorm instead
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches
self.relu = nn.ReLU(inplace=True)
```
**Exception: Style transfer**
```python
# Use InstanceNorm for style transfer
class StyleConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization
self.relu = nn.ReLU(inplace=True)
```
### RNN / LSTM (Recurrent Neural Networks)
**Default: LayerNorm**
```python
import torch.nn as nn
class NormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states
def forward(self, x):
# x: (batch, seq_len, input_size)
output, (h_n, c_n) = self.lstm(x)
# output: (batch, seq_len, hidden_size)
# Normalize each timestep's output
output_norm = self.ln(output) # Applies independently per timestep
return output_norm, (h_n, c_n)
```
**Why NOT BatchNorm:**
```python
# BatchNorm in RNN mixes information across timesteps!
# Given (batch=32, seq_len=100, hidden=256)
# BatchNorm would compute:
# mean/std over (batch × seq_len) = 3200 values
# This mixes t=0 with t=99 (destroys temporal structure!)
# LayerNorm computes:
# mean/std over hidden_size = 256 values per (batch, timestep)
# Each timestep normalized independently (preserves temporal structure)
```
**Layer-wise normalization in stacked RNN:**
```python
class StackedNormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_size = input_size if i == 0 else hidden_size
self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True))
self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer
def forward(self, x):
for lstm, ln in zip(self.layers[::2], self.layers[1::2]):
x, _ = lstm(x)
x = ln(x) # Normalize between layers
return x
```
### Transformer
**Default: LayerNorm (or RMSNorm for modern/large models)**
**Two placement options: Pre-norm vs Post-norm**
**Post-norm (original Transformer, "Attention is All You Need"):**
```python
class TransformerLayerPostNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Post-norm: Apply normalization AFTER residual
x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding
x = self.ln2(x + self.ffn(x)) # Normalize after adding
return x
```
**Pre-norm (modern, more stable):**
```python
class TransformerLayerPreNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm: Apply normalization BEFORE sublayer
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention
x = x + self.ffn(self.ln2(x)) # Normalize before FFN
return x
```
**Pre-norm vs Post-norm comparison:**
```python
# Post-norm (original):
# - Less stable (requires careful initialization + warmup)
# - Slightly better performance IF training succeeds
# - Hard to train deep models (> 12 layers)
# Pre-norm (modern):
# - More stable (easier to train deep models)
# - Standard for large models (GPT-3: 96 layers!)
# - Recommended default
# Empirical: GPT-2, BERT (post-norm, ≤12 layers)
# GPT-3, T5, LLaMA (pre-norm, ≥24 layers)
```
**Using RMSNorm instead:**
```python
class TransformerLayerRMSNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm
self.rms2 = RMSNorm(d_model)
def forward(self, x):
# Pre-norm with RMSNorm (LLaMA style)
x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0]
x = x + self.ffn(self.rms2(x))
return x
```
### GAN (Generative Adversarial Network)
**Generator: InstanceNorm or no normalization**
```python
class Generator(nn.Module):
def __init__(self):
super().__init__()
# Use InstanceNorm for image-to-image translation
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # Per-image normalization
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Preserves per-image characteristics
return x
```
**Discriminator: No normalization or LayerNorm**
```python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
# Often no normalization (BatchNorm can hurt GAN training)
self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
# No normalization here
def forward(self, x):
x = self.conv1(x)
# Directly to activation (no norm)
return x
```
**Why avoid BatchNorm in GANs:**
```python
# BatchNorm in discriminator:
# - Mixes real and fake samples in batch
# - Leaks information (discriminator can detect batch composition)
# - Hurts training stability
# Recommendation:
# Generator: InstanceNorm (for image translation) or no norm
# Discriminator: No normalization or LayerNorm
```
## Decision Framework
### Step 1: Check batch size
```python
if batch_size >= 8:
consider_batchnorm = True
else:
use_groupnorm_or_layernorm = True # BatchNorm will be unstable
```
### Step 2: Check architecture
```python
if architecture == "CNN":
if batch_size >= 8:
use_batchnorm()
else:
use_groupnorm(num_groups=32)
# Exception: Style transfer
if task == "style_transfer":
use_instancenorm()
elif architecture in ["RNN", "LSTM", "GRU"]:
use_layernorm() # NEVER BatchNorm!
elif architecture == "Transformer":
if model_size == "large": # > 1B parameters
use_rmsnorm() # 15-20% faster
else:
use_layernorm()
# Placement: Pre-norm (more stable)
use_prenorm_placement()
elif architecture == "GAN":
if component == "generator":
if task == "image_translation":
use_instancenorm()
else:
use_no_norm() # Or InstanceNorm
elif component == "discriminator":
use_no_norm() # Or LayerNorm
```
### Step 3: Verify placement
```python
# CNNs: After convolution, before activation
x = conv(x)
x = norm(x) # Here!
x = relu(x)
# RNNs: After LSTM, normalize hidden states
output, (h, c) = lstm(x)
output = norm(output) # Here!
# Transformers: Pre-norm (modern) or post-norm (original)
# Pre-norm (recommended):
x = x + sublayer(norm(x)) # Normalize before sublayer
# Post-norm (original):
x = norm(x + sublayer(x)) # Normalize after residual
```
## Implementation Checklist
### Before adding normalization:
1.**Check batch size**: If < 8, avoid BatchNorm
2.**Check architecture**: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm
3.**Check task**: Style transfer→InstanceNorm
4.**Verify placement**: After conv/linear, before activation (CNNs)
5.**Test training stability**: Loss should decrease smoothly
### During training:
6.**Monitor running statistics** (BatchNorm): Check running_mean/running_var are updating
7.**Test inference mode**: Verify model.eval() uses running stats correctly
8.**Check gradient flow**: Normalization should help, not hurt gradients
### If training is unstable:
9.**Try different normalization**: BatchNorm→GroupNorm, LayerNorm→RMSNorm
10.**Try pre-norm** (Transformers): More stable than post-norm
11.**Reduce learning rate**: Normalization allows larger LR, but start conservatively
## Common Mistakes
### Mistake 1: BatchNorm with small batches
```python
# WRONG: BatchNorm with batch_size=2
model = ResNet50(norm_layer=nn.BatchNorm2d)
dataloader = DataLoader(dataset, batch_size=2) # Too small!
# RIGHT: GroupNorm for small batches
model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels))
dataloader = DataLoader(dataset, batch_size=2) # Works!
```
### Mistake 2: BatchNorm in RNN
```python
# WRONG: BatchNorm in LSTM
class BadLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps
def forward(self, x):
output, _ = self.lstm(x)
output = output.permute(0, 2, 1) # (B, H, T)
output = self.bn(output) # Mixes timesteps!
return output
# RIGHT: LayerNorm in LSTM
class GoodLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.ln = nn.LayerNorm(256) # Per-timestep normalization
def forward(self, x):
output, _ = self.lstm(x)
output = self.ln(output) # Independent per timestep
return output
```
### Mistake 3: Forgetting model.eval()
```python
# WRONG: Using training mode during inference
model.train() # BatchNorm uses batch statistics
predictions = model(test_data) # Batch statistics from test data (leakage!)
# RIGHT: Use eval mode during inference
model.eval() # BatchNorm uses running statistics
with torch.no_grad():
predictions = model(test_data) # Uses accumulated running stats
```
### Mistake 4: Post-norm for deep Transformers
```python
# WRONG: Post-norm for 24-layer Transformer (unstable!)
class DeepTransformerPostNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPostNorm(512, 8) for _ in range(24)
]) # Hard to train!
# RIGHT: Pre-norm for deep Transformers
class DeepTransformerPreNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPreNorm(512, 8) for _ in range(24)
]) # Stable training!
```
### Mistake 5: Wrong normalization for style transfer
```python
# WRONG: BatchNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch
# RIGHT: InstanceNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.InstanceNorm2d(64) # Per-image normalization
```
## Performance Impact
### Training speed:
```python
# Without normalization: 100 epochs to converge
# With normalization: 10 epochs to converge (10x faster!)
# Reason: Larger learning rates possible
lr_no_norm = 0.001 # Must be small (unstable otherwise)
lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes)
```
### Inference speed:
```python
# Normalization overhead (relative to no normalization):
BatchNorm: +2% (minimal, cached running stats)
LayerNorm: +3-5% (compute mean/std per forward pass)
RMSNorm: +2-3% (faster than LayerNorm)
GroupNorm: +5-8% (more computation than BatchNorm)
InstanceNorm: +3-5% (similar to LayerNorm)
# For most models: Overhead is negligible compared to conv/linear layers
```
### Memory usage:
```python
# Normalization memory (per layer):
BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β)
LayerNorm: 2 × normalized_shape (γ, β)
RMSNorm: 1 × normalized_shape (γ only, no β)
# Example: 512 channels
BatchNorm: 4 × 512 = 2048 parameters
LayerNorm: 2 × 512 = 1024 parameters
RMSNorm: 1 × 512 = 512 parameters # Most efficient!
```
## When NOT to Normalize
**Case 1: Final output layer**
```python
# Don't normalize final predictions
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = ResNet50() # Normalization inside
self.fc = nn.Linear(2048, 1000)
# NO normalization here! (final logits should be unnormalized)
def forward(self, x):
x = self.backbone(x)
x = self.fc(x) # Raw logits
return x # Don't normalize!
```
**Case 2: Very small networks**
```python
# Single-layer network: Normalization overkill
class TinyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10) # MNIST classifier
# No normalization needed (network too simple)
def forward(self, x):
return self.fc(x)
```
**Case 3: When debugging**
```python
# Remove normalization to isolate issues
# If training fails with normalization, try without to check if:
# - Initialization is correct
# - Loss function is correct
# - Data is correctly preprocessed
```
## Modern Recommendations (2025)
### CNNs:
- **Default**: BatchNorm (if batch_size ≥ 8)
- **Small batches**: GroupNorm (num_groups=32)
- **Style transfer**: InstanceNorm
### RNNs/LSTMs:
- **Default**: LayerNorm
- **Never**: BatchNorm (breaks temporal structure)
### Transformers:
- **Small models** (< 1B): LayerNorm + pre-norm
- **Large models** (≥ 1B): RMSNorm + pre-norm (15-20% faster)
- **Avoid**: Post-norm for deep models (> 12 layers)
### GANs:
- **Generator**: InstanceNorm (image translation) or no norm
- **Discriminator**: No normalization or LayerNorm
- **Avoid**: BatchNorm (leaks information)
### Emerging:
- **RMSNorm adoption increasing**: LLaMA, Mistral, Gemma all use RMSNorm
- **Pre-norm becoming standard**: More stable for deep networks
- **GroupNorm gaining traction**: Object detection, small-batch training
## Summary
**Normalization is mandatory for modern deep learning.** The question is which normalization, not whether to normalize.
**Quick decision tree:**
1. **Batch size ≥ 8?** → Consider BatchNorm (CNNs)
2. **Batch size < 8?** → Use GroupNorm (CNNs) or LayerNorm (all)
3. **RNN/LSTM?** → LayerNorm (never BatchNorm!)
4. **Transformer?** → LayerNorm or RMSNorm with pre-norm
5. **Style transfer?** → InstanceNorm
6. **GAN?** → InstanceNorm (generator) or no norm (discriminator)
**Modern defaults:**
- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8)
- RNNs: LayerNorm
- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models)
- GANs: InstanceNorm (generator), no norm (discriminator)
**Key insight**: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer.