Initial commit
This commit is contained in:
915
skills/using-neural-architectures/normalization-techniques.md
Normal file
915
skills/using-neural-architectures/normalization-techniques.md
Normal file
@@ -0,0 +1,915 @@
|
||||
|
||||
# Normalization Techniques
|
||||
|
||||
## Context
|
||||
|
||||
You're designing a neural network or debugging training instability. Someone suggests "add BatchNorm" without considering:
|
||||
- **Batch size dependency**: BatchNorm fails with small batches (< 8)
|
||||
- **Architecture mismatch**: BatchNorm breaks RNNs/Transformers (use LayerNorm)
|
||||
- **Task-specific needs**: Style transfer needs InstanceNorm, not BatchNorm
|
||||
- **Modern alternatives**: RMSNorm simpler and faster than LayerNorm for LLMs
|
||||
|
||||
**This skill prevents normalization cargo-culting and provides architecture-specific selection.**
|
||||
|
||||
## Why Normalization Matters
|
||||
|
||||
**Problem: Internal Covariate Shift**
|
||||
During training, layer input distributions shift as previous layers update. This causes:
|
||||
- Vanishing/exploding gradients (deep networks)
|
||||
- Slow convergence (small learning rates required)
|
||||
- Training instability (loss spikes)
|
||||
|
||||
**Solution: Normalization**
|
||||
Normalize activations to have stable statistics (mean=0, std=1). Benefits:
|
||||
- **10x faster convergence**: Can use larger learning rates
|
||||
- **Better generalization**: Regularization effect
|
||||
- **Enables deep networks**: 50+ layers without gradient issues
|
||||
- **Less sensitive to initialization**: Weights can start further from optimal
|
||||
|
||||
**Key insight**: Normalization is NOT optional for modern deep learning. The question is WHICH normalization, not WHETHER to normalize.
|
||||
|
||||
|
||||
## Normalization Families
|
||||
|
||||
### 1. Batch Normalization (BatchNorm)
|
||||
|
||||
**What it does:**
|
||||
Normalizes across the batch dimension for each channel/feature.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
Given input x with shape (B, C, H, W): # Batch, Channel, Height, Width
|
||||
|
||||
For each channel c:
|
||||
μ_c = mean(x[:, c, :, :]) # Mean over batch + spatial dims
|
||||
σ_c = std(x[:, c, :, :]) # Std over batch + spatial dims
|
||||
|
||||
x_norm[:, c, :, :] = (x[:, c, :, :] - μ_c) / √(σ_c² + ε)
|
||||
|
||||
# Learnable scale and shift
|
||||
y[:, c, :, :] = γ_c * x_norm[:, c, :, :] + β_c
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ CNNs for classification (ResNet, EfficientNet)
|
||||
- ✅ Large batch sizes (≥ 16)
|
||||
- ✅ IID data (image classification, object detection)
|
||||
|
||||
**When NOT to use:**
|
||||
- ❌ Small batch sizes (< 8): Noisy statistics cause training failure
|
||||
- ❌ RNNs/LSTMs: Breaks temporal dependencies
|
||||
- ❌ Transformers: Batch dependency problematic for variable-length sequences
|
||||
- ❌ Style transfer: Batch statistics erase style information
|
||||
|
||||
**Batch size dependency:**
|
||||
```python
|
||||
batch_size = 32: # ✓ Works well (stable statistics)
|
||||
batch_size = 16: # ✓ Acceptable
|
||||
batch_size = 8: # ✓ Marginal (consider GroupNorm)
|
||||
batch_size = 4: # ✗ Unstable (use GroupNorm)
|
||||
batch_size = 2: # ✗ FAILS! (noisy statistics)
|
||||
batch_size = 1: # ✗ Undefined (no batch to normalize over!)
|
||||
```
|
||||
|
||||
**PyTorch example:**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
# For Conv2d
|
||||
bn = nn.BatchNorm2d(num_features=64) # 64 channels
|
||||
x = torch.randn(32, 64, 28, 28) # Batch=32, Channels=64
|
||||
y = bn(x)
|
||||
|
||||
# For Linear
|
||||
bn = nn.BatchNorm1d(num_features=128) # 128 features
|
||||
x = torch.randn(32, 128) # Batch=32, Features=128
|
||||
y = bn(x)
|
||||
```
|
||||
|
||||
**Inference mode:**
|
||||
```python
|
||||
# Training: Uses batch statistics
|
||||
model.train()
|
||||
y = bn(x) # Normalizes using current batch mean/std
|
||||
|
||||
# Inference: Uses running statistics (accumulated during training)
|
||||
model.eval()
|
||||
y = bn(x) # Normalizes using running_mean/running_std
|
||||
```
|
||||
|
||||
|
||||
### 2. Layer Normalization (LayerNorm)
|
||||
|
||||
**What it does:**
|
||||
Normalizes across the feature dimension for each sample independently.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
Given input x with shape (B, C): # Batch, Features
|
||||
|
||||
For each sample b:
|
||||
μ_b = mean(x[b, :]) # Mean over features
|
||||
σ_b = std(x[b, :]) # Std over features
|
||||
|
||||
x_norm[b, :] = (x[b, :] - μ_b) / √(σ_b² + ε)
|
||||
|
||||
# Learnable scale and shift
|
||||
y[b, :] = γ * x_norm[b, :] + β
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Transformers (BERT, GPT, T5)
|
||||
- ✅ RNNs/LSTMs (maintains temporal independence)
|
||||
- ✅ Small batch sizes (batch-independent!)
|
||||
- ✅ Variable-length sequences
|
||||
- ✅ Reinforcement learning (batch_size=1 common)
|
||||
|
||||
**Advantages over BatchNorm:**
|
||||
- ✅ **Batch-independent**: Works with batch_size=1
|
||||
- ✅ **No running statistics**: Inference = training (no mode switching)
|
||||
- ✅ **Sequence-friendly**: Doesn't mix information across timesteps
|
||||
|
||||
**PyTorch example:**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
# For Transformer
|
||||
ln = nn.LayerNorm(normalized_shape=512) # d_model=512
|
||||
x = torch.randn(32, 128, 512) # Batch=32, SeqLen=128, d_model=512
|
||||
y = ln(x) # Normalizes last dimension independently per (batch, position)
|
||||
|
||||
# For RNN hidden states
|
||||
ln = nn.LayerNorm(normalized_shape=256) # hidden_size=256
|
||||
h = torch.randn(32, 256) # Batch=32, Hidden=256
|
||||
h_norm = ln(h)
|
||||
```
|
||||
|
||||
**Key difference from BatchNorm:**
|
||||
```python
|
||||
# BatchNorm: Normalizes across batch dimension
|
||||
# Given (B=32, C=64, H=28, W=28)
|
||||
# Computes 64 means/stds (one per channel, across batch + spatial)
|
||||
|
||||
# LayerNorm: Normalizes across feature dimension
|
||||
# Given (B=32, L=128, D=512)
|
||||
# Computes 32×128 means/stds (one per (batch, position), across features)
|
||||
```
|
||||
|
||||
|
||||
### 3. Group Normalization (GroupNorm)
|
||||
|
||||
**What it does:**
|
||||
Normalizes channels in groups, batch-independent.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
Given input x with shape (B, C, H, W):
|
||||
Divide C channels into G groups (C must be divisible by G)
|
||||
|
||||
For each sample b and group g:
|
||||
channels = x[b, g*(C/G):(g+1)*(C/G), :, :] # Channels in group g
|
||||
μ_{b,g} = mean(channels) # Mean over channels in group + spatial
|
||||
σ_{b,g} = std(channels) # Std over channels in group + spatial
|
||||
|
||||
x_norm[b, g*(C/G):(g+1)*(C/G), :, :] = (channels - μ_{b,g}) / √(σ_{b,g}² + ε)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Small batch sizes (< 8)
|
||||
- ✅ CNNs with batch_size=1 (style transfer, RL)
|
||||
- ✅ Object detection/segmentation (often use small batches)
|
||||
- ✅ When BatchNorm unstable but want spatial normalization
|
||||
|
||||
**Group size selection:**
|
||||
```python
|
||||
# num_groups trade-off:
|
||||
num_groups = 1: # = LayerNorm (all channels together)
|
||||
num_groups = C: # = InstanceNorm (each channel separate)
|
||||
num_groups = 32: # Standard choice (good balance)
|
||||
|
||||
# Rule: C must be divisible by num_groups
|
||||
channels = 64, num_groups = 32: # ✓ 64/32 = 2 channels per group
|
||||
channels = 64, num_groups = 16: # ✓ 64/16 = 4 channels per group
|
||||
channels = 64, num_groups = 30: # ✗ 64/30 not integer!
|
||||
```
|
||||
|
||||
**PyTorch example:**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
# For small batch CNN
|
||||
gn = nn.GroupNorm(num_groups=32, num_channels=64)
|
||||
x = torch.randn(2, 64, 28, 28) # Batch=2 (small!)
|
||||
y = gn(x) # Works well even with batch=2
|
||||
|
||||
# Compare performance:
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32]
|
||||
|
||||
bn = nn.BatchNorm2d(64)
|
||||
gn = nn.GroupNorm(32, 64)
|
||||
|
||||
for bs in batch_sizes:
|
||||
x = torch.randn(bs, 64, 28, 28)
|
||||
|
||||
# BatchNorm gets more stable with larger batch
|
||||
# GroupNorm consistent across all batch sizes
|
||||
```
|
||||
|
||||
**Empirical results (He et al. 2018):**
|
||||
```
|
||||
ImageNet classification with ResNet-50:
|
||||
batch_size = 32: BatchNorm = 76.5%, GroupNorm = 76.3% (tie)
|
||||
batch_size = 8: BatchNorm = 75.8%, GroupNorm = 76.1% (GroupNorm wins!)
|
||||
batch_size = 2: BatchNorm = 72.1%, GroupNorm = 75.3% (GroupNorm wins!)
|
||||
```
|
||||
|
||||
|
||||
### 4. Instance Normalization (InstanceNorm)
|
||||
|
||||
**What it does:**
|
||||
Normalizes each sample and channel independently (no batch mixing).
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
Given input x with shape (B, C, H, W):
|
||||
|
||||
For each sample b and channel c:
|
||||
μ_{b,c} = mean(x[b, c, :, :]) # Mean over spatial dimensions only
|
||||
σ_{b,c} = std(x[b, c, :, :]) # Std over spatial dimensions only
|
||||
|
||||
x_norm[b, c, :, :] = (x[b, c, :, :] - μ_{b,c}) / √(σ_{b,c}² + ε)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Style transfer (neural style, CycleGAN, pix2pix)
|
||||
- ✅ Image-to-image translation
|
||||
- ✅ When batch/channel mixing destroys information
|
||||
|
||||
**Why for style transfer:**
|
||||
```python
|
||||
# Style transfer goal: Transfer style while preserving content
|
||||
# BatchNorm: Mixes statistics across batch (erases individual style!)
|
||||
# InstanceNorm: Per-image statistics (preserves each image's style)
|
||||
|
||||
# Example: Neural style transfer
|
||||
content_image = load_image("photo.jpg")
|
||||
style_image = load_image("starry_night.jpg")
|
||||
|
||||
# With BatchNorm: Output loses content image's unique characteristics
|
||||
# With InstanceNorm: Content characteristics preserved, style applied
|
||||
```
|
||||
|
||||
**PyTorch example:**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
# For style transfer generator
|
||||
in_norm = nn.InstanceNorm2d(num_features=64)
|
||||
x = torch.randn(1, 64, 256, 256) # Single image
|
||||
y = in_norm(x) # Normalizes each channel independently
|
||||
|
||||
# CycleGAN generator architecture
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
|
||||
self.in1 = nn.InstanceNorm2d(64) # NOT BatchNorm!
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.in1(x) # Per-image normalization
|
||||
x = self.relu(x)
|
||||
return x
|
||||
```
|
||||
|
||||
**Relation to GroupNorm:**
|
||||
```python
|
||||
# InstanceNorm is GroupNorm with num_groups = num_channels
|
||||
InstanceNorm2d(C) == GroupNorm(num_groups=C, num_channels=C)
|
||||
```
|
||||
|
||||
|
||||
### 5. RMS Normalization (RMSNorm)
|
||||
|
||||
**What it does:**
|
||||
Simplified LayerNorm that only rescales (no recentering), faster and simpler.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
Given input x:
|
||||
|
||||
# LayerNorm (2 steps):
|
||||
x_centered = x - mean(x) # 1. Center
|
||||
x_norm = x_centered / std(x) # 2. Scale
|
||||
|
||||
# RMSNorm (1 step):
|
||||
rms = sqrt(mean(x²)) # Root Mean Square
|
||||
x_norm = x / rms # Only scale, no centering
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Modern LLMs (LLaMA, Mistral, Gemma)
|
||||
- ✅ When speed matters (15-20% faster than LayerNorm)
|
||||
- ✅ Large Transformer models (billions of parameters)
|
||||
|
||||
**Advantages:**
|
||||
- ✅ **Simpler**: One operation instead of two
|
||||
- ✅ **Faster**: ~15-20% speedup over LayerNorm
|
||||
- ✅ **Numerically stable**: No subtraction (avoids catastrophic cancellation)
|
||||
- ✅ **Same performance**: Empirically matches LayerNorm quality
|
||||
|
||||
**PyTorch implementation:**
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
# Compute RMS
|
||||
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
# Normalize
|
||||
x_norm = x / rms
|
||||
|
||||
# Scale (learnable)
|
||||
return self.weight * x_norm
|
||||
|
||||
# Usage in Transformer
|
||||
rms = RMSNorm(dim=512) # d_model=512
|
||||
x = torch.randn(32, 128, 512) # Batch, SeqLen, d_model
|
||||
y = rms(x)
|
||||
```
|
||||
|
||||
**Speed comparison (LLaMA-7B, A100 GPU):**
|
||||
```
|
||||
LayerNorm: 1000 tokens/sec
|
||||
RMSNorm: 1180 tokens/sec # 18% faster!
|
||||
|
||||
# For large models, this adds up:
|
||||
# 1 million tokens: 180 seconds saved
|
||||
```
|
||||
|
||||
**Modern LLM adoption:**
|
||||
```python
|
||||
# LLaMA (Meta, 2023): RMSNorm
|
||||
# Mistral (Mistral AI, 2023): RMSNorm
|
||||
# Gemma (Google, 2024): RMSNorm
|
||||
# PaLM (Google, 2022): RMSNorm
|
||||
|
||||
# Older models:
|
||||
# GPT-2/3 (OpenAI): LayerNorm
|
||||
# BERT (Google, 2018): LayerNorm
|
||||
```
|
||||
|
||||
|
||||
## Architecture-Specific Selection
|
||||
|
||||
### CNN (Convolutional Neural Networks)
|
||||
|
||||
**Default: BatchNorm**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.bn = nn.BatchNorm2d(out_channels) # After conv
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x) # Normalize
|
||||
x = self.relu(x)
|
||||
return x
|
||||
```
|
||||
|
||||
**Exception: Small batch sizes**
|
||||
```python
|
||||
# If batch_size < 8, use GroupNorm instead
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm = nn.GroupNorm(32, out_channels) # GroupNorm for small batches
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
```
|
||||
|
||||
**Exception: Style transfer**
|
||||
```python
|
||||
# Use InstanceNorm for style transfer
|
||||
class StyleConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm = nn.InstanceNorm2d(out_channels) # Per-image normalization
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
```
|
||||
|
||||
|
||||
### RNN / LSTM (Recurrent Neural Networks)
|
||||
|
||||
**Default: LayerNorm**
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
class NormalizedLSTM(nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
||||
self.ln = nn.LayerNorm(hidden_size) # Normalize hidden states
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, seq_len, input_size)
|
||||
output, (h_n, c_n) = self.lstm(x)
|
||||
# output: (batch, seq_len, hidden_size)
|
||||
|
||||
# Normalize each timestep's output
|
||||
output_norm = self.ln(output) # Applies independently per timestep
|
||||
return output_norm, (h_n, c_n)
|
||||
```
|
||||
|
||||
**Why NOT BatchNorm:**
|
||||
```python
|
||||
# BatchNorm in RNN mixes information across timesteps!
|
||||
# Given (batch=32, seq_len=100, hidden=256)
|
||||
|
||||
# BatchNorm would compute:
|
||||
# mean/std over (batch × seq_len) = 3200 values
|
||||
# This mixes t=0 with t=99 (destroys temporal structure!)
|
||||
|
||||
# LayerNorm computes:
|
||||
# mean/std over hidden_size = 256 values per (batch, timestep)
|
||||
# Each timestep normalized independently (preserves temporal structure)
|
||||
```
|
||||
|
||||
**Layer-wise normalization in stacked RNN:**
|
||||
```python
|
||||
class StackedNormalizedLSTM(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
in_size = input_size if i == 0 else hidden_size
|
||||
self.layers.append(nn.LSTM(in_size, hidden_size, batch_first=True))
|
||||
self.layers.append(nn.LayerNorm(hidden_size)) # After each LSTM layer
|
||||
|
||||
def forward(self, x):
|
||||
for lstm, ln in zip(self.layers[::2], self.layers[1::2]):
|
||||
x, _ = lstm(x)
|
||||
x = ln(x) # Normalize between layers
|
||||
return x
|
||||
```
|
||||
|
||||
|
||||
### Transformer
|
||||
|
||||
**Default: LayerNorm (or RMSNorm for modern/large models)**
|
||||
|
||||
**Two placement options: Pre-norm vs Post-norm**
|
||||
|
||||
**Post-norm (original Transformer, "Attention is All You Need"):**
|
||||
```python
|
||||
class TransformerLayerPostNorm(nn.Module):
|
||||
def __init__(self, d_model, num_heads):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_model, 4 * d_model),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * d_model, d_model)
|
||||
)
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# Post-norm: Apply normalization AFTER residual
|
||||
x = self.ln1(x + self.attn(x, x, x)[0]) # Normalize after adding
|
||||
x = self.ln2(x + self.ffn(x)) # Normalize after adding
|
||||
return x
|
||||
```
|
||||
|
||||
**Pre-norm (modern, more stable):**
|
||||
```python
|
||||
class TransformerLayerPreNorm(nn.Module):
|
||||
def __init__(self, d_model, num_heads):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_model, 4 * d_model),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * d_model, d_model)
|
||||
)
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# Pre-norm: Apply normalization BEFORE sublayer
|
||||
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # Normalize before attention
|
||||
x = x + self.ffn(self.ln2(x)) # Normalize before FFN
|
||||
return x
|
||||
```
|
||||
|
||||
**Pre-norm vs Post-norm comparison:**
|
||||
```python
|
||||
# Post-norm (original):
|
||||
# - Less stable (requires careful initialization + warmup)
|
||||
# - Slightly better performance IF training succeeds
|
||||
# - Hard to train deep models (> 12 layers)
|
||||
|
||||
# Pre-norm (modern):
|
||||
# - More stable (easier to train deep models)
|
||||
# - Standard for large models (GPT-3: 96 layers!)
|
||||
# - Recommended default
|
||||
|
||||
# Empirical: GPT-2, BERT (post-norm, ≤12 layers)
|
||||
# GPT-3, T5, LLaMA (pre-norm, ≥24 layers)
|
||||
```
|
||||
|
||||
**Using RMSNorm instead:**
|
||||
```python
|
||||
class TransformerLayerRMSNorm(nn.Module):
|
||||
def __init__(self, d_model, num_heads):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(d_model, num_heads)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_model, 4 * d_model),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * d_model, d_model)
|
||||
)
|
||||
self.rms1 = RMSNorm(d_model) # 15-20% faster than LayerNorm
|
||||
self.rms2 = RMSNorm(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# Pre-norm with RMSNorm (LLaMA style)
|
||||
x = x + self.attn(self.rms1(x), self.rms1(x), self.rms1(x))[0]
|
||||
x = x + self.ffn(self.rms2(x))
|
||||
return x
|
||||
```
|
||||
|
||||
|
||||
### GAN (Generative Adversarial Network)
|
||||
|
||||
**Generator: InstanceNorm or no normalization**
|
||||
```python
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Use InstanceNorm for image-to-image translation
|
||||
self.conv1 = nn.Conv2d(3, 64, 7, padding=3)
|
||||
self.in1 = nn.InstanceNorm2d(64) # Per-image normalization
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.in1(x) # Preserves per-image characteristics
|
||||
return x
|
||||
```
|
||||
|
||||
**Discriminator: No normalization or LayerNorm**
|
||||
```python
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Often no normalization (BatchNorm can hurt GAN training)
|
||||
self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
|
||||
# No normalization here
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# Directly to activation (no norm)
|
||||
return x
|
||||
```
|
||||
|
||||
**Why avoid BatchNorm in GANs:**
|
||||
```python
|
||||
# BatchNorm in discriminator:
|
||||
# - Mixes real and fake samples in batch
|
||||
# - Leaks information (discriminator can detect batch composition)
|
||||
# - Hurts training stability
|
||||
|
||||
# Recommendation:
|
||||
# Generator: InstanceNorm (for image translation) or no norm
|
||||
# Discriminator: No normalization or LayerNorm
|
||||
```
|
||||
|
||||
|
||||
## Decision Framework
|
||||
|
||||
### Step 1: Check batch size
|
||||
|
||||
```python
|
||||
if batch_size >= 8:
|
||||
consider_batchnorm = True
|
||||
else:
|
||||
use_groupnorm_or_layernorm = True # BatchNorm will be unstable
|
||||
```
|
||||
|
||||
### Step 2: Check architecture
|
||||
|
||||
```python
|
||||
if architecture == "CNN":
|
||||
if batch_size >= 8:
|
||||
use_batchnorm()
|
||||
else:
|
||||
use_groupnorm(num_groups=32)
|
||||
|
||||
# Exception: Style transfer
|
||||
if task == "style_transfer":
|
||||
use_instancenorm()
|
||||
|
||||
elif architecture in ["RNN", "LSTM", "GRU"]:
|
||||
use_layernorm() # NEVER BatchNorm!
|
||||
|
||||
elif architecture == "Transformer":
|
||||
if model_size == "large": # > 1B parameters
|
||||
use_rmsnorm() # 15-20% faster
|
||||
else:
|
||||
use_layernorm()
|
||||
|
||||
# Placement: Pre-norm (more stable)
|
||||
use_prenorm_placement()
|
||||
|
||||
elif architecture == "GAN":
|
||||
if component == "generator":
|
||||
if task == "image_translation":
|
||||
use_instancenorm()
|
||||
else:
|
||||
use_no_norm() # Or InstanceNorm
|
||||
elif component == "discriminator":
|
||||
use_no_norm() # Or LayerNorm
|
||||
```
|
||||
|
||||
### Step 3: Verify placement
|
||||
|
||||
```python
|
||||
# CNNs: After convolution, before activation
|
||||
x = conv(x)
|
||||
x = norm(x) # Here!
|
||||
x = relu(x)
|
||||
|
||||
# RNNs: After LSTM, normalize hidden states
|
||||
output, (h, c) = lstm(x)
|
||||
output = norm(output) # Here!
|
||||
|
||||
# Transformers: Pre-norm (modern) or post-norm (original)
|
||||
# Pre-norm (recommended):
|
||||
x = x + sublayer(norm(x)) # Normalize before sublayer
|
||||
|
||||
# Post-norm (original):
|
||||
x = norm(x + sublayer(x)) # Normalize after residual
|
||||
```
|
||||
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
### Before adding normalization:
|
||||
|
||||
1. ☐ **Check batch size**: If < 8, avoid BatchNorm
|
||||
2. ☐ **Check architecture**: CNN→BatchNorm, RNN→LayerNorm, Transformer→LayerNorm/RMSNorm
|
||||
3. ☐ **Check task**: Style transfer→InstanceNorm
|
||||
4. ☐ **Verify placement**: After conv/linear, before activation (CNNs)
|
||||
5. ☐ **Test training stability**: Loss should decrease smoothly
|
||||
|
||||
### During training:
|
||||
|
||||
6. ☐ **Monitor running statistics** (BatchNorm): Check running_mean/running_var are updating
|
||||
7. ☐ **Test inference mode**: Verify model.eval() uses running stats correctly
|
||||
8. ☐ **Check gradient flow**: Normalization should help, not hurt gradients
|
||||
|
||||
### If training is unstable:
|
||||
|
||||
9. ☐ **Try different normalization**: BatchNorm→GroupNorm, LayerNorm→RMSNorm
|
||||
10. ☐ **Try pre-norm** (Transformers): More stable than post-norm
|
||||
11. ☐ **Reduce learning rate**: Normalization allows larger LR, but start conservatively
|
||||
|
||||
|
||||
## Common Mistakes
|
||||
|
||||
### Mistake 1: BatchNorm with small batches
|
||||
|
||||
```python
|
||||
# WRONG: BatchNorm with batch_size=2
|
||||
model = ResNet50(norm_layer=nn.BatchNorm2d)
|
||||
dataloader = DataLoader(dataset, batch_size=2) # Too small!
|
||||
|
||||
# RIGHT: GroupNorm for small batches
|
||||
model = ResNet50(norm_layer=lambda channels: nn.GroupNorm(32, channels))
|
||||
dataloader = DataLoader(dataset, batch_size=2) # Works!
|
||||
```
|
||||
|
||||
### Mistake 2: BatchNorm in RNN
|
||||
|
||||
```python
|
||||
# WRONG: BatchNorm in LSTM
|
||||
class BadLSTM(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(100, 256)
|
||||
self.bn = nn.BatchNorm1d(256) # WRONG! Mixes timesteps
|
||||
|
||||
def forward(self, x):
|
||||
output, _ = self.lstm(x)
|
||||
output = output.permute(0, 2, 1) # (B, H, T)
|
||||
output = self.bn(output) # Mixes timesteps!
|
||||
return output
|
||||
|
||||
# RIGHT: LayerNorm in LSTM
|
||||
class GoodLSTM(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(100, 256)
|
||||
self.ln = nn.LayerNorm(256) # Per-timestep normalization
|
||||
|
||||
def forward(self, x):
|
||||
output, _ = self.lstm(x)
|
||||
output = self.ln(output) # Independent per timestep
|
||||
return output
|
||||
```
|
||||
|
||||
### Mistake 3: Forgetting model.eval()
|
||||
|
||||
```python
|
||||
# WRONG: Using training mode during inference
|
||||
model.train() # BatchNorm uses batch statistics
|
||||
predictions = model(test_data) # Batch statistics from test data (leakage!)
|
||||
|
||||
# RIGHT: Use eval mode during inference
|
||||
model.eval() # BatchNorm uses running statistics
|
||||
with torch.no_grad():
|
||||
predictions = model(test_data) # Uses accumulated running stats
|
||||
```
|
||||
|
||||
### Mistake 4: Post-norm for deep Transformers
|
||||
|
||||
```python
|
||||
# WRONG: Post-norm for 24-layer Transformer (unstable!)
|
||||
class DeepTransformerPostNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayerPostNorm(512, 8) for _ in range(24)
|
||||
]) # Hard to train!
|
||||
|
||||
# RIGHT: Pre-norm for deep Transformers
|
||||
class DeepTransformerPreNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayerPreNorm(512, 8) for _ in range(24)
|
||||
]) # Stable training!
|
||||
```
|
||||
|
||||
### Mistake 5: Wrong normalization for style transfer
|
||||
|
||||
```python
|
||||
# WRONG: BatchNorm for style transfer
|
||||
class StyleGenerator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 64, 7, padding=3)
|
||||
self.norm = nn.BatchNorm2d(64) # WRONG! Mixes styles across batch
|
||||
|
||||
# RIGHT: InstanceNorm for style transfer
|
||||
class StyleGenerator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 64, 7, padding=3)
|
||||
self.norm = nn.InstanceNorm2d(64) # Per-image normalization
|
||||
```
|
||||
|
||||
|
||||
## Performance Impact
|
||||
|
||||
### Training speed:
|
||||
|
||||
```python
|
||||
# Without normalization: 100 epochs to converge
|
||||
# With normalization: 10 epochs to converge (10x faster!)
|
||||
|
||||
# Reason: Larger learning rates possible
|
||||
lr_no_norm = 0.001 # Must be small (unstable otherwise)
|
||||
lr_with_norm = 0.01 # Can be 10x larger (normalization stabilizes)
|
||||
```
|
||||
|
||||
### Inference speed:
|
||||
|
||||
```python
|
||||
# Normalization overhead (relative to no normalization):
|
||||
BatchNorm: +2% (minimal, cached running stats)
|
||||
LayerNorm: +3-5% (compute mean/std per forward pass)
|
||||
RMSNorm: +2-3% (faster than LayerNorm)
|
||||
GroupNorm: +5-8% (more computation than BatchNorm)
|
||||
InstanceNorm: +3-5% (similar to LayerNorm)
|
||||
|
||||
# For most models: Overhead is negligible compared to conv/linear layers
|
||||
```
|
||||
|
||||
### Memory usage:
|
||||
|
||||
```python
|
||||
# Normalization memory (per layer):
|
||||
BatchNorm: 2 × num_channels (running_mean, running_std) + 2 × num_channels (γ, β)
|
||||
LayerNorm: 2 × normalized_shape (γ, β)
|
||||
RMSNorm: 1 × normalized_shape (γ only, no β)
|
||||
|
||||
# Example: 512 channels
|
||||
BatchNorm: 4 × 512 = 2048 parameters
|
||||
LayerNorm: 2 × 512 = 1024 parameters
|
||||
RMSNorm: 1 × 512 = 512 parameters # Most efficient!
|
||||
```
|
||||
|
||||
|
||||
## When NOT to Normalize
|
||||
|
||||
**Case 1: Final output layer**
|
||||
```python
|
||||
# Don't normalize final predictions
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.backbone = ResNet50() # Normalization inside
|
||||
self.fc = nn.Linear(2048, 1000)
|
||||
# NO normalization here! (final logits should be unnormalized)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.fc(x) # Raw logits
|
||||
return x # Don't normalize!
|
||||
```
|
||||
|
||||
**Case 2: Very small networks**
|
||||
```python
|
||||
# Single-layer network: Normalization overkill
|
||||
class TinyNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(784, 10) # MNIST classifier
|
||||
# No normalization needed (network too simple)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
```
|
||||
|
||||
**Case 3: When debugging**
|
||||
```python
|
||||
# Remove normalization to isolate issues
|
||||
# If training fails with normalization, try without to check if:
|
||||
# - Initialization is correct
|
||||
# - Loss function is correct
|
||||
# - Data is correctly preprocessed
|
||||
```
|
||||
|
||||
|
||||
## Modern Recommendations (2025)
|
||||
|
||||
### CNNs:
|
||||
- **Default**: BatchNorm (if batch_size ≥ 8)
|
||||
- **Small batches**: GroupNorm (num_groups=32)
|
||||
- **Style transfer**: InstanceNorm
|
||||
|
||||
### RNNs/LSTMs:
|
||||
- **Default**: LayerNorm
|
||||
- **Never**: BatchNorm (breaks temporal structure)
|
||||
|
||||
### Transformers:
|
||||
- **Small models** (< 1B): LayerNorm + pre-norm
|
||||
- **Large models** (≥ 1B): RMSNorm + pre-norm (15-20% faster)
|
||||
- **Avoid**: Post-norm for deep models (> 12 layers)
|
||||
|
||||
### GANs:
|
||||
- **Generator**: InstanceNorm (image translation) or no norm
|
||||
- **Discriminator**: No normalization or LayerNorm
|
||||
- **Avoid**: BatchNorm (leaks information)
|
||||
|
||||
### Emerging:
|
||||
- **RMSNorm adoption increasing**: LLaMA, Mistral, Gemma all use RMSNorm
|
||||
- **Pre-norm becoming standard**: More stable for deep networks
|
||||
- **GroupNorm gaining traction**: Object detection, small-batch training
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
**Normalization is mandatory for modern deep learning.** The question is which normalization, not whether to normalize.
|
||||
|
||||
**Quick decision tree:**
|
||||
1. **Batch size ≥ 8?** → Consider BatchNorm (CNNs)
|
||||
2. **Batch size < 8?** → Use GroupNorm (CNNs) or LayerNorm (all)
|
||||
3. **RNN/LSTM?** → LayerNorm (never BatchNorm!)
|
||||
4. **Transformer?** → LayerNorm or RMSNorm with pre-norm
|
||||
5. **Style transfer?** → InstanceNorm
|
||||
6. **GAN?** → InstanceNorm (generator) or no norm (discriminator)
|
||||
|
||||
**Modern defaults:**
|
||||
- CNNs: BatchNorm (batch ≥ 8) or GroupNorm (batch < 8)
|
||||
- RNNs: LayerNorm
|
||||
- Transformers: RMSNorm + pre-norm (large models) or LayerNorm + pre-norm (small models)
|
||||
- GANs: InstanceNorm (generator), no norm (discriminator)
|
||||
|
||||
**Key insight**: Match normalization to architecture and batch size. Don't cargo-cult "add BatchNorm everywhere"—it fails for small batches, RNNs, Transformers, and style transfer.
|
||||
Reference in New Issue
Block a user