26 KiB
Normalization Techniques
Context
You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering:
- Batch size dependency: BatchNorm fails with small batches (< 8)
- Architecture mismatch: BatchNorm breaks RNNs/Transformers (use LayerNorm)
- Task-specific needs: Style transfer needs InstanceNorm, not BatchNorm
- Modern alternatives: RMSNorm simpler and faster than LayerNorm for LLMs
This skill prevents normalization cargo-culting and provides architecture-specific selection.
Why Normalization Matters
Problem: Internal Covariate Shift During training, layer input distributions shift as previous layers update. This causes:
- Vanishing/exploding gradients (deep networks)
- Slow convergence (small learning rates required)
- Training instability (loss spikes)
Solution: Normalization Normalize activations to have stable statistics (mean=0, std=1). Benefits:
- 10x faster convergence: Can use larger learning rates
- Better generalization: Regularization effect
- Enables deep networks: 50+ layers without gradient issues
- Less sensitive to initialization: Weights can start further from optimal
Key insight: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize.
Normalization Families
1. Batch Normalization (BatchNorm)
What it does: Normalizes across the batch dimension for each channel/feature.
Formula:
Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width
For each channel c:
μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims
σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims
x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε)
# Learnable scale and shift
y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c
When to use:
- ✅ CNNs for classification (ResNet, EfficientNet)
- ✅ Large batch sizes (≥ 16)
- ✅ IID data (image classification, object detection)
When NOT to use:
- ❌ Small batch sizes (< 8): Noisy statistics cause training failure
- ❌ RNNs/LSTMs: Breaks temporal dependencies
- ❌ Transformers: Batch dependency problematic for variable-length sequences
- ❌ Style transfer: Batch statistics erase style information
Batch size dependency:
batch_size = 32: # ✓ Works well (stable statistics)
batch_size = 16: # ✓ Acceptable
batch_size = 8: # ✓ Marginal (consider GroupNorm)
batch_size = 4: # ✗ Unstable (use GroupNorm)
batch_size = 2: # ✗ FAILS! (noisy statistics)
batch_size = 1: # ✗ Undefined (no batch to normalize over!)
PyTorch example:
import torch.nn as nn
# For Conv2d
bn = nn.BatchNorm2d(num_features=64) # 64 channels
x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64
y = bn(x)
# For Linear
bn = nn.BatchNorm1d(num_features=128) # 128 features
x = torch.randn(32, 128) # Batch=32, Features=128
y = bn(x)
Inference mode:
# Training: Uses batch statistics
model.train()
y = bn(x) # Normalizes using current batch mean/std
# Inference: Uses running statistics (accumulated during training)
model.eval()
y = bn(x) # Normalizes using running_mean/running_std
2. Layer Normalization (LayerNorm)
What it does: Normalizes across the feature dimension for each sample independently.
Formula:
Given input x with shape (B, C): # Batch, Features
For each sample b:
μ_b = mean(x[b, :]) # Mean over features
σ_b = std(x[b, :]) # Std over features
x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε)
# Learnable scale and shift
y[b, :] = γ * x_norm[b, :] + β
When to use:
- ✅ Transformers (BERT, GPT, T5)
- ✅ RNNs/LSTMs (maintains temporal independence)
- ✅ Small batch sizes (batch-independent!)
- ✅ Variable-length sequences
- ✅ Reinforcement learning (batch_size=1 common)
Advantages over BatchNorm:
- ✅ Batch-independent: Works with batch_size=1
- ✅ No running statistics: Inference = training (no mode switching)
- ✅ Sequence-friendly: Doesn't mix information across timesteps
PyTorch example:
import torch.nn as nn
# For Transformer
ln = nn.LayerNorm(normalized_shape=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512
y = ln(x) # Normalizes last dimension independently per (batch, position)
# For RNN hidden states
ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256
h = torch.randn(32, 256) # Batch=32, Hidden=256
h_norm = ln(h)
Key difference from BatchNorm:
# BatchNorm: Normalizes across batch dimension
# Given (B=32, C=64, H=28, W=28)
# Computes 64 means/stds (one per channel, across batch + spatial)
# LayerNorm: Normalizes across feature dimension
# Given (B=32, L=128, D=512)
# Computes 32×128 means/stds (one per (batch, position), across features)
3. Group Normalization (GroupNorm)
What it does: Normalizes channels in groups, batch-independent.
Formula:
Given input x with shape (B, C, H, W):
Divide C channels into G groups (C must be divisible by G)
For each sample b and group g:
channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g
μ_{b,g} = mean(channels) # Mean over channels in group + spatial
σ_{b,g} = std(channels) # Std over channels in group + spatial
x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε)
When to use:
- ✅ Small batch sizes (< 8)
- ✅ CNNs with batch_size=1 (style transfer, RL)
- ✅ Object detection/segmentation (often use small batches)
- ✅ When BatchNorm unstable but want spatial normalization
Group size selection:
# num_groups trade-off:
num_groups = 1: # = LayerNorm (all channels together)
num_groups = C: # = InstanceNorm (each channel separate)
num_groups = 32: # Standard choice (good balance)
# Rule: C must be divisible by num_groups
channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group
channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group
channels = 64, num_groups = 30: # ✗ 64/30 not integer!
PyTorch example:
import torch.nn as nn
# For small batch CNN
gn = nn.GroupNorm(num_groups=32, num_channels=64)
x = torch.randn(2, 64, 28, 28) # Batch=2 (small!)
y = gn(x) # Works well even with batch=2
# Compare performance:
batch_sizes = [1, 2, 4, 8, 16, 32]
bn = nn.BatchNorm2d(64)
gn = nn.GroupNorm(32, 64)
for bs in batch_sizes:
x = torch.randn(bs, 64, 28, 28)
# BatchNorm gets more stable with larger batch
# GroupNorm consistent across all batch sizes
Empirical results (He et al. 2018):
ImageNet classification with ResNet-50:
batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie)
batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!)
batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!)
4. Instance Normalization (InstanceNorm)
What it does: Normalizes each sample and channel independently (no batch mixing).
Formula:
Given input x with shape (B, C, H, W):
For each sample b and channel c:
μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only
σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only
x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε)
When to use:
- ✅ Style transfer (neural style, CycleGAN, pix2pix)
- ✅ Image-to-image translation
- ✅ When batch/channel mixing destroys information
Why for style transfer:
# Style transfer goal: Transfer style while preserving content
# BatchNorm: Mixes statistics across batch (erases individual style!)
# InstanceNorm: Per-image statistics (preserves each image's style)
# Example: Neural style transfer
content_image = load_image("photo.jpg")
style_image = load_image("starry_night.jpg")
# With BatchNorm: Output loses content image's unique characteristics
# With InstanceNorm: Content characteristics preserved, style applied
PyTorch example:
import torch.nn as nn
# For style transfer generator
in_norm = nn.InstanceNorm2d(num_features=64)
x = torch.randn(1, 64, 256, 256) # Single image
y = in_norm(x) # Normalizes each channel independently
# CycleGAN generator architecture
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm!
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Per-image normalization
x = self.relu(x)
return x
Relation to GroupNorm:
# InstanceNorm is GroupNorm with num_groups = num_channels
InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C)
5. RMS Normalization (RMSNorm)
What it does: Simplified LayerNorm that only rescales (no recentering), faster and simpler.
Formula:
Given input x:
# LayerNorm (2 steps):
x_centered = x - mean(x) # 1. Center
x_norm = x_centered / std(x) # 2. Scale
# RMSNorm (1 step):
rms = sqrt(mean(x²)) # Root Mean Square
x_norm = x / rms # Only scale, no centering
When to use:
- ✅ Modern LLMs (LLaMA, Mistral, Gemma)
- ✅ When speed matters (15-20% faster than LayerNorm)
- ✅ Large Transformer models (billions of parameters)
Advantages:
- ✅ Simpler: One operation instead of two
- ✅ Faster: ~15-20% speedup over LayerNorm
- ✅ Numerically stable: No subtraction (avoids catastrophic cancellation)
- ✅ Same performance: Empirically matches LayerNorm quality
PyTorch implementation:
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize
x_norm = x / rms
# Scale (learnable)
return self.weight * x_norm
# Usage in Transformer
rms = RMSNorm(dim=512) # d_model=512
x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model
y = rms(x)
Speed comparison (LLaMA-7B, A100 GPU):
LayerNorm: 1000 tokens/sec
RMSNorm: 1180 tokens/sec # 18% faster!
# For large models, this adds up:
# 1 million tokens: 180 seconds saved
Modern LLM adoption:
# LLaMA (Meta, 2023): RMSNorm
# Mistral (Mistral AI, 2023): RMSNorm
# Gemma (Google, 2024): RMSNorm
# PaLM (Google, 2022): RMSNorm
# Older models:
# GPT-2/3 (OpenAI): LayerNorm
# BERT (Google, 2018): LayerNorm
Architecture-Specific Selection
CNN (Convolutional Neural Networks)
Default: BatchNorm
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels) # After conv
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # Normalize
x = self.relu(x)
return x
Exception: Small batch sizes
# If batch_size < 8, use GroupNorm instead
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches
self.relu = nn.ReLU(inplace=True)
Exception: Style transfer
# Use InstanceNorm for style transfer
class StyleConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization
self.relu = nn.ReLU(inplace=True)
RNN / LSTM (Recurrent Neural Networks)
Default: LayerNorm
import torch.nn as nn
class NormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states
def forward(self, x):
# x: (batch, seq_len, input_size)
output, (h_n, c_n) = self.lstm(x)
# output: (batch, seq_len, hidden_size)
# Normalize each timestep's output
output_norm = self.ln(output) # Applies independently per timestep
return output_norm, (h_n, c_n)
Why NOT BatchNorm:
# BatchNorm in RNN mixes information across timesteps!
# Given (batch=32, seq_len=100, hidden=256)
# BatchNorm would compute:
# mean/std over (batch × seq_len) = 3200 values
# This mixes t=0 with t=99 (destroys temporal structure!)
# LayerNorm computes:
# mean/std over hidden_size = 256 values per (batch, timestep)
# Each timestep normalized independently (preserves temporal structure)
Layer-wise normalization in stacked RNN:
class StackedNormalizedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_size = input_size if i == 0 else hidden_size
self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True))
self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer
def forward(self, x):
for lstm, ln in zip(self.layers[::2], self.layers[1::2]):
x, _ = lstm(x)
x = ln(x) # Normalize between layers
return x
Transformer
Default: LayerNorm (or RMSNorm for modern/large models)
Two placement options: Pre-norm vs Post-norm
Post-norm (original Transformer, "Attention is All You Need"):
class TransformerLayerPostNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Post-norm: Apply normalization AFTER residual
x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding
x = self.ln2(x + self.ffn(x)) # Normalize after adding
return x
Pre-norm (modern, more stable):
class TransformerLayerPreNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm: Apply normalization BEFORE sublayer
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention
x = x + self.ffn(self.ln2(x)) # Normalize before FFN
return x
Pre-norm vs Post-norm comparison:
# Post-norm (original):
# - Less stable (requires careful initialization + warmup)
# - Slightly better performance IF training succeeds
# - Hard to train deep models (> 12 layers)
# Pre-norm (modern):
# - More stable (easier to train deep models)
# - Standard for large models (GPT-3: 96 layers!)
# - Recommended default
# Empirical: GPT-2, BERT (post-norm, ≤12 layers)
# GPT-3, T5, LLaMA (pre-norm, ≥24 layers)
Using RMSNorm instead:
class TransformerLayerRMSNorm(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm
self.rms2 = RMSNorm(d_model)
def forward(self, x):
# Pre-norm with RMSNorm (LLaMA style)
x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0]
x = x + self.ffn(self.rms2(x))
return x
GAN (Generative Adversarial Network)
Generator: InstanceNorm or no normalization
class Generator(nn.Module):
def __init__(self):
super().__init__()
# Use InstanceNorm for image-to-image translation
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
self.in1 = nn.InstanceNorm2d(64) # Per-image normalization
def forward(self, x):
x = self.conv1(x)
x = self.in1(x) # Preserves per-image characteristics
return x
Discriminator: No normalization or LayerNorm
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
# Often no normalization (BatchNorm can hurt GAN training)
self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
# No normalization here
def forward(self, x):
x = self.conv1(x)
# Directly to activation (no norm)
return x
Why avoid BatchNorm in GANs:
# BatchNorm in discriminator:
# - Mixes real and fake samples in batch
# - Leaks information (discriminator can detect batch composition)
# - Hurts training stability
# Recommendation:
# Generator: InstanceNorm (for image translation) or no norm
# Discriminator: No normalization or LayerNorm
Decision Framework
Step 1: Check batch size
if batch_size >= 8:
consider_batchnorm = True
else:
use_groupnorm_or_layernorm = True # BatchNorm will be unstable
Step 2: Check architecture
if architecture == "CNN":
if batch_size >= 8:
use_batchnorm()
else:
use_groupnorm(num_groups=32)
# Exception: Style transfer
if task == "style_transfer":
use_instancenorm()
elif architecture in ["RNN", "LSTM", "GRU"]:
use_layernorm() # NEVER BatchNorm!
elif architecture == "Transformer":
if model_size == "large": # > 1B parameters
use_rmsnorm() # 15-20% faster
else:
use_layernorm()
# Placement: Pre-norm (more stable)
use_prenorm_placement()
elif architecture == "GAN":
if component == "generator":
if task == "image_translation":
use_instancenorm()
else:
use_no_norm() # Or InstanceNorm
elif component == "discriminator":
use_no_norm() # Or LayerNorm
Step 3: Verify placement
# CNNs: After convolution, before activation
x = conv(x)
x = norm(x) # Here!
x = relu(x)
# RNNs: After LSTM, normalize hidden states
output, (h, c) = lstm(x)
output = norm(output) # Here!
# Transformers: Pre-norm (modern) or post-norm (original)
# Pre-norm (recommended):
x = x + sublayer(norm(x)) # Normalize before sublayer
# Post-norm (original):
x = norm(x + sublayer(x)) # Normalize after residual
Implementation Checklist
Before adding normalization:
- ☐ Check batch size: If < 8, avoid BatchNorm
- ☐ Check architecture: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm
- ☐ Check task: Style transfer→InstanceNorm
- ☐ Verify placement: After conv/linear, before activation (CNNs)
- ☐ Test training stability: Loss should decrease smoothly
During training:
- ☐ Monitor running statistics (BatchNorm): Check running_mean/running_var are updating
- ☐ Test inference mode: Verify model.eval() uses running stats correctly
- ☐ Check gradient flow: Normalization should help, not hurt gradients
If training is unstable:
- ☐ Try different normalization: BatchNorm→GroupNorm, LayerNorm→RMSNorm
- ☐ Try pre-norm (Transformers): More stable than post-norm
- ☐ Reduce learning rate: Normalization allows larger LR, but start conservatively
Common Mistakes
Mistake 1: BatchNorm with small batches
# WRONG: BatchNorm with batch_size=2
model = ResNet50(norm_layer=nn.BatchNorm2d)
dataloader = DataLoader(dataset, batch_size=2) # Too small!
# RIGHT: GroupNorm for small batches
model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels))
dataloader = DataLoader(dataset, batch_size=2) # Works!
Mistake 2: BatchNorm in RNN
# WRONG: BatchNorm in LSTM
class BadLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps
def forward(self, x):
output, _ = self.lstm(x)
output = output.permute(0, 2, 1) # (B, H, T)
output = self.bn(output) # Mixes timesteps!
return output
# RIGHT: LayerNorm in LSTM
class GoodLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(100, 256)
self.ln = nn.LayerNorm(256) # Per-timestep normalization
def forward(self, x):
output, _ = self.lstm(x)
output = self.ln(output) # Independent per timestep
return output
Mistake 3: Forgetting model.eval()
# WRONG: Using training mode during inference
model.train() # BatchNorm uses batch statistics
predictions = model(test_data) # Batch statistics from test data (leakage!)
# RIGHT: Use eval mode during inference
model.eval() # BatchNorm uses running statistics
with torch.no_grad():
predictions = model(test_data) # Uses accumulated running stats
Mistake 4: Post-norm for deep Transformers
# WRONG: Post-norm for 24-layer Transformer (unstable!)
class DeepTransformerPostNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPostNorm(512, 8) for _ in range(24)
]) # Hard to train!
# RIGHT: Pre-norm for deep Transformers
class DeepTransformerPreNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayerPreNorm(512, 8) for _ in range(24)
]) # Stable training!
Mistake 5: Wrong normalization for style transfer
# WRONG: BatchNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch
# RIGHT: InstanceNorm for style transfer
class StyleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 7, padding=3)
self.norm = nn.InstanceNorm2d(64) # Per-image normalization
Performance Impact
Training speed:
# Without normalization: 100 epochs to converge
# With normalization: 10 epochs to converge (10x faster!)
# Reason: Larger learning rates possible
lr_no_norm = 0.001 # Must be small (unstable otherwise)
lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes)
Inference speed:
# Normalization overhead (relative to no normalization):
BatchNorm: +2% (minimal, cached running stats)
LayerNorm: +3-5% (compute mean/std per forward pass)
RMSNorm: +2-3% (faster than LayerNorm)
GroupNorm: +5-8% (more computation than BatchNorm)
InstanceNorm: +3-5% (similar to LayerNorm)
# For most models: Overhead is negligible compared to conv/linear layers
Memory usage:
# Normalization memory (per layer):
BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β)
LayerNorm: 2 × normalized_shape (γ, β)
RMSNorm: 1 × normalized_shape (γ only, no β)
# Example: 512 channels
BatchNorm: 4 × 512 = 2048 parameters
LayerNorm: 2 × 512 = 1024 parameters
RMSNorm: 1 × 512 = 512 parameters # Most efficient!
When NOT to Normalize
Case 1: Final output layer
# Don't normalize final predictions
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = ResNet50() # Normalization inside
self.fc = nn.Linear(2048, 1000)
# NO normalization here! (final logits should be unnormalized)
def forward(self, x):
x = self.backbone(x)
x = self.fc(x) # Raw logits
return x # Don't normalize!
Case 2: Very small networks
# Single-layer network: Normalization overkill
class TinyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10) # MNIST classifier
# No normalization needed (network too simple)
def forward(self, x):
return self.fc(x)
Case 3: When debugging
# Remove normalization to isolate issues
# If training fails with normalization, try without to check if:
# - Initialization is correct
# - Loss function is correct
# - Data is correctly preprocessed
Modern Recommendations (2025)
CNNs:
- Default: BatchNorm (if batch_size ≥ 8)
- Small batches: GroupNorm (num_groups=32)
- Style transfer: InstanceNorm
RNNs/LSTMs:
- Default: LayerNorm
- Never: BatchNorm (breaks temporal structure)
Transformers:
- Small models (< 1B): LayerNorm + pre-norm
- Large models (≥ 1B): RMSNorm + pre-norm (15-20% faster)
- Avoid: Post-norm for deep models (> 12 layers)
GANs:
- Generator: InstanceNorm (image translation) or no norm
- Discriminator: No normalization or LayerNorm
- Avoid: BatchNorm (leaks information)
Emerging:
- RMSNorm adoption increasing: LLaMA, Mistral, Gemma all use RMSNorm
- Pre-norm becoming standard: More stable for deep networks
- GroupNorm gaining traction: Object detection, small-batch training
Summary
Normalization is mandatory for modern deep learning. The question is which normalization, not whether to normalize.
Quick decision tree:
- Batch size ≥ 8? → Consider BatchNorm (CNNs)
- Batch size < 8? → Use GroupNorm (CNNs) or LayerNorm (all)
- RNN/LSTM? → LayerNorm (never BatchNorm!)
- Transformer? → LayerNorm or RMSNorm with pre-norm
- Style transfer? → InstanceNorm
- GAN? → InstanceNorm (generator) or no norm (discriminator)
Modern defaults:
- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8)
- RNNs: LayerNorm
- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models)
- GANs: InstanceNorm (generator), no norm (discriminator)
Key insight: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer.