Files
gh-tachyon-beep-skillpacks-…/skills/using-training-optimization/batch-size-and-memory-tradeoffs.md
2025-11-30 09:00:11 +08:00

1652 lines
56 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Batch Size and Memory Tradeoffs
## Overview
Batch size is one of the most misunderstood hyperparameters. Most engineers think: "larger batch = faster training = better". Wrong. Batch size affects convergence speed, generalization, memory usage, and actual wall-clock training time in complex ways. **Larger batch size is NOT always better.**
**Core principle**: Batch size selection is a system optimization problem, not a memory constraint problem. Choose batch size based on computational speed, convergence requirements, and generalization targets - not just what fits in memory.
## When to Use This Skill
**Use this skill when:**
- Choosing batch size for new training
- Training is slow and considering larger batches
- Out-of-memory errors during training
- Learning rate needs adjustment after batch size change
- Distributed training needs batch size scaling
- Gradient accumulation considerations
- User asks "what batch size should I use?"
- Training accuracy varies widely between batch sizes
- Convergence takes too long or is unstable
- Memory per sample calculation needed
- Comparing training speed: iterations vs epochs vs wall-clock time
- Fine-tuning with different batch sizes than pre-training
**Symptoms you need this skill:**
- "I have memory, what's the maximum batch size?" (wrong question)
- "Larger batches train faster, so use 512?" (incomplete)
- "Batch size doesn't affect accuracy, only speed?" (false)
- "Gradient accumulation is a workaround for small memory?" (misconception)
- "Just scale learning rate by 2x when doubling batch size?" (incomplete)
- "We get OOM at batch 256, so use 128 forever" (not optimized)
**Don't use when:**
- User has pure memory/infrastructure questions (use pytorch-engineering)
- User asks about optimizer selection (use optimizer-selection-framework)
- User asks about learning rate scheduling (use learning-rate-scheduling)
- User has general training failure (not batch-size specific)
## Core Patterns
### Pattern 1: The Batch Size Tradeoff Space
**The critical insight**: Batch size affects FOUR independent dimensions simultaneously. Optimize one = impact others.
**The four dimensions:**
```
1. TRAINING SPEED (iterations to converge)
├─ Larger batch → fewer iterations to convergence ✓
├─ BUT: Gradient variance decreases (noisier gradients are better)
└─ Result: Mixed - can't just maximize batch
2. COMPUTATIONAL EFFICIENCY (wall-clock time)
├─ Larger batch → amortize overhead per sample ✓
├─ BUT: Larger batch → need larger LR (unstable)
├─ AND: Gradient accumulation = repeated backward (slow)
└─ Result: Optimal ≠ Maximum
3. GENERALIZATION (test accuracy)
├─ Smaller batch → noisier gradients → better regularization ✓
├─ Larger batch → cleaner gradient → overfit risk ✗
├─ BUT: Can compensate with stronger regularization
└─ Result: Batch size ↔ regularization coupling
4. MEMORY USAGE (GPU memory required)
├─ Larger batch → linear increase in activation memory
├─ Parameters constant regardless of batch
├─ Optimizer state constant regardless of batch
└─ Result: Memory ∝ batch size (linear only for activations)
```
**The mental model:**
```
LARGER BATCH:
✓ Fewer iterations to convergence
✓ Better computational efficiency (up to point)
✗ Worse generalization (harder to regularize)
✗ Requires larger learning rate (instability risk)
✗ Higher memory usage
SMALLER BATCH:
✗ More iterations to convergence
✗ Worse computational efficiency
✓ Better generalization (noise helps)
✓ Smaller learning rates are stable
✓ Lower memory usage
```
**Finding the sweet spot:**
- Start with batch size that uses ~80% GPU memory
- Adjust learning rate using linear scaling rule
- Monitor validation accuracy
- If validation accuracy drops → batch too large, reduce or regularize
- If training is slow → may need gradient accumulation, not larger batch
### Pattern 2: Linear Learning Rate Scaling Rule
**The rule that changes everything:**
If you increase batch size by factor K, increase learning rate by factor K.
```
New LR = Old LR × (New Batch Size / Old Batch Size)
```
**Why this works (the math):**
```
Gradient Descent Update: param = param - lr * gradient
With Batch Size B, gradient is average of B samples:
gradient_B = (1/B) * sum(gradients from B samples)
update_B = lr * gradient_B
With Batch Size 2B, gradient is average of 2B samples:
gradient_2B = (1/(2B)) * sum(gradients from 2B samples)
Variance drops by 2x when averaging 2x more samples.
If variance drops 2x, gradient magnitude is √2x smaller.
To keep update magnitude constant: lr should increase by 2x.
Empirically validated: Goyal et al. (2017) "Accurate, Large Batch Training"
```
**Implementation:**
```python
# Pattern 1: Direct scaling
original_lr = 0.001
original_batch_size = 32
new_batch_size = 128
scaling_factor = new_batch_size / original_batch_size # 4x
new_lr = original_lr * scaling_factor # 0.004
# Pattern 2: When changing both batch AND learning rate
def compute_scaled_lr(base_lr, base_batch_size, current_batch_size):
"""
Compute learning rate for new batch size using linear scaling rule.
Args:
base_lr: Learning rate at reference batch size
base_batch_size: Batch size where base_lr was tuned (usually 32 or 256)
current_batch_size: New batch size
Returns:
Scaled learning rate
WHY: Linear scaling rule keeps update magnitude constant
"""
scale_factor = current_batch_size / base_batch_size
return base_lr * scale_factor
# Example: ResNet-50 training (ImageNet baseline)
# Reference: batch=256, lr=0.1
# Now training at: batch=1024
scaled_lr = compute_scaled_lr(0.1, 256, 1024) # 0.4
print(f"Batch 256 with lr=0.1 → Batch 1024 with lr={scaled_lr}")
```
**When linear scaling works:**
```python
# CASE 1: Scaling works well
# Batch: 32 → 256 (8x increase)
# Learning rate: 0.001 → 0.008 (8x)
# Training: ✓ Converges normally, same final accuracy
# Wall-clock: ✓ Faster (fewer iterations, better hardware utilization)
# CASE 2: Scaling doesn't work
# Batch: 32 → 1024 (32x increase!)
# Learning rate: 0.001 → 0.032 (32x)
# Problem: Learning rate too large, training diverges
# Solution: Need warmup phase
```
**The Critical Caveat: WARMUP IS REQUIRED**
```python
# WRONG: Apply full scaled LR immediately
optimizer = torch.optim.SGD(model.parameters(), lr=0.032) # Too large!
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), targets)
loss.backward()
optimizer.step() # Loss diverges on first iteration!
# CORRECT: Warmup phase before scaled LR
def warmup_lr_schedule(base_lr, current_batch_size, reference_batch_size,
current_step, warmup_steps):
"""
Linear warmup from 0 to scaled LR.
WHY: Large LR jumps can cause divergence.
Gradual warmup lets model adapt to larger updates.
"""
scaled_lr = base_lr * (current_batch_size / reference_batch_size)
if current_step < warmup_steps:
# Linear warmup: ramp from 0 to scaled_lr
return scaled_lr * (current_step / warmup_steps)
else:
# Full scaled LR after warmup
return scaled_lr
# Implementation with PyTorch scheduler
from torch.optim.lr_scheduler import LambdaLR
def get_warmup_scheduler(optimizer, warmup_steps):
base_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda)
# Training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.032)
scheduler = get_warmup_scheduler(optimizer, warmup_steps=1000)
for epoch in range(100):
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), targets)
loss.backward()
optimizer.step()
scheduler.step() # Gradually increase LR
```
**Practical guidelines:**
```
BATCH SIZE INCREASE LEARNING RATE SCALE WARMUP NEEDED? WHY
2x (64→128) 2x (0.001→0.002) No Safe, gradual
4x (64→256) 4x (0.001→0.004) Maybe Starting to matter
8x (64→512) 8x (0.001→0.008) YES Risky without warmup
16x+ (64→1024) 16x+ (0.001→0.016) CRITICAL Risk of divergence
```
### Pattern 3: Gradient Accumulation - The Alternative to Large Batches
**What gradient accumulation does:**
Gradient accumulation simulates large batch size without large GPU memory. Instead of 1 forward+backward of batch 256, do 8 forward+backwardsof batch 32. Same effective batch, 1/8th memory.
**How it works:**
```python
# SIMPLE APPROACH (without accumulation)
batch_size = 256
effective_batch_size = 256 # Process full batch at once
memory_required = HIGH # Can't fit in GPU
for batch in train_loader: # batch.size() = 256
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# GRADIENT ACCUMULATION APPROACH
batch_size = 32
accumulation_steps = 8
effective_batch_size = 32 * 8 = 256 # Same as above!
memory_required = LOW # Only batch 32 in memory at once
optimizer.zero_grad()
for accumulation_step in range(accumulation_steps):
batch = next(iter(train_loader)) # batch.size() = 32
output = model(batch)
loss = criterion(output, target)
loss.backward() # Accumulate gradients (don't zero!)
# Don't call optimizer.step() yet!
optimizer.step() # Update weights after accumulation complete
# Effect: Updated weights as if we processed batch_size=256
```
**When to use gradient accumulation:**
```python
# CASE 1: Model too large to fit large batch
# Model: GPT-2 (124M parameters)
# Available GPU: 24GB
# Desired batch: 512 per GPU
# Fits in memory: No, only 32 fits
# Solution: Accumulate 16 steps of batch 32 = effective 512
model_params = 124_000_000 # 124M
param_memory = model_params * 4 # bytes (FP32)
optimizer_memory = model_params * 8 # Adam state (8x parameters)
batch_size = 32
sequence_length = 512
activation_memory_per_sample = param_memory / 10 # Rough estimate
total_memory = param_memory + optimizer_memory + (batch_size * activation_memory_per_sample)
# ~2GB memory per step
# 16 accumulation steps still << 24GB
# CASE 2: Distributed training across 8 GPUs
# Per-GPU batch: 32
# Number of GPUs: 8
# Local accumulation: 4 steps
# Total effective: 32 * 8 * 4 = 1024 (synchronized across 8 GPUs)
# Accumulation enables large total batch without massive per-GPU batch
```
**The memory math:**
```
Memory with Gradient Accumulation:
Without accumulation (batch_size = 256):
- Parameters: Fixed
- Optimizer state: Fixed (8x params for Adam)
- Activations: O(batch_size) = O(256)
- Gradients: O(batch_size) = O(256)
- Total ≈ 1.0x baseline memory
With accumulation (batch_size = 32, steps = 8):
- Parameters: Fixed (same)
- Optimizer state: Fixed (same)
- Activations: O(batch_size) = O(32) = 8x SMALLER
- Gradients: O(batch_size) = O(32) = 8x SMALLER
- Total ≈ 0.15x baseline memory (for activations+gradients)
Savings: ~85% memory reduction!
Cost: 8x longer (8 backward passes instead of 1)
Net wall-clock: ~1.5-2x slower (overhead, synchronization)
```
**Implementation patterns:**
```python
# Pattern 1: Manual gradient accumulation
num_accumulation_steps = 8
optimizer.zero_grad()
for step, (batch, target) in enumerate(train_loader):
output = model(batch)
loss = criterion(output, target)
# Scale loss by accumulation steps
# WHY: Otherwise gradient magnitudes stack up across steps
loss = loss / num_accumulation_steps
loss.backward() # Accumulate gradients
if (step + 1) % num_accumulation_steps == 0:
optimizer.step() # Update after accumulation complete
optimizer.zero_grad()
# Pattern 2: With learning rate adjustment
# IMPORTANT: Don't adjust learning rate just because of accumulation!
# Accumulation is transparent to optimizer.
# Scale is: effective_batch = batch_size * num_accumulation_steps
# So LR should match effective_batch, NOT per-GPU batch
original_lr = 0.1 # Tuned for batch_size = 32
num_accumulation_steps = 8
effective_batch = 32 * 8 # 256
# Linear scaling rule based on effective batch:
# Batch 32 → 256 is 8x increase
# So LR: 0.1 → 0.8 (8x)
new_lr = original_lr * 8 # 0.8
optimizer = torch.optim.SGD(model.parameters(), lr=new_lr)
# Pattern 3: Distributed training with gradient accumulation
# Per-GPU batch: 32
# Number of GPUs: 8
# Accumulation steps: 4
# Effective batch: 32 * 8 * 4 = 1024
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model)
num_accumulation_steps = 4
optimizer.zero_grad()
for step, (batch, target) in enumerate(train_loader):
output = model(batch)
loss = criterion(output, target)
loss = loss / num_accumulation_steps
loss.backward()
if (step + 1) % num_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Pattern 4: With synchronization (distributed)
class GradientAccumulator:
def __init__(self, model, num_accumulation_steps, sync_gradients_every=1):
self.model = model
self.num_steps = num_accumulation_steps
self.sync_every = sync_gradients_every
self.step_count = 0
def should_sync_gradients(self):
"""
In DDP, only sync gradients when we're about to do optimizer.step().
This reduces communication overhead.
"""
return (self.step_count + 1) % self.sync_every == 0
def backward(self, loss):
loss = loss / self.num_steps
# Only sync if we're about to step
if self.should_sync_gradients():
loss.backward()
else:
with self.model.no_sync(): # Skip gradient sync in DDP
loss.backward()
self.step_count += 1
```
**Gradient accumulation vs large batch - when to choose:**
```python
# When gradient accumulation is GOOD choice:
# 1. Memory-constrained (can't fit large batch)
# 2. Need large effective batch for convergence
# 3. Can tolerate ~1.5-2x slowdown
# 4. Training wall-clock time not critical
# When gradient accumulation is BAD choice:
# 1. Can fit desired batch size in memory
# 2. Training speed is critical (wall-clock matters)
# 3. Already have good convergence with smaller batches
# 4. Reduced gradient noise is important for task
# Comparison table:
# LARGE BATCH GRADIENT ACCUMULATION
# Memory High Low (1/accumulation)
# Wall-clock time Fast ~1.5-2x slower
# Convergence speed Good Same (effective batch same)
# Implementation Simple Requires manual loop
# Memory savings None ~85% (with 8x accumulation)
# When to use When memory OK When memory constrained
```
### Pattern 4: Memory Estimation and Optimization
**Understanding memory components:**
```
Total GPU Memory = Parameters + Optimizer State + Activations + Gradients
Example: Training BERT-base (110M params) with batch_size=32, seq_len=512
1. PARAMETERS (Fixed)
- BERT: 110M × 4 bytes (FP32) = 440 MB
- Or 110M × 2 bytes (FP16) = 220 MB
2. OPTIMIZER STATE (Fixed)
- SGD: No extra state = 0 MB
- Adam: m + v = 2 × params = 880 MB (FP32) or 440 MB (FP16)
- AdamW: Same as Adam
3. ACTIVATIONS (Linear in batch_size, seq_len)
- Stored during forward pass (for backward)
- BERT layer: ~batch × seq_len × hidden_dim × 4
- = 32 × 512 × 768 × 4 bytes
- = ~320 MB per layer
- × 12 layers = ~3.8 GB
4. GRADIENTS (Linear in batch_size)
- Stored after backward, until optimizer.step()
- Same size as parameters = 440 MB
TOTAL MEMORY = 440 + 880 + 3800 + 440 = ~5.6 GB
Typical budget: Use ~80% GPU = 19 GB with 24GB GPU
Room for more: Can increase batch from 32 → 128 safely
```
**Memory calculation framework:**
```python
def estimate_memory_usage(
num_params: int,
batch_size: int,
seq_length: int,
hidden_dim: int,
num_layers: int,
dtype_bytes: int = 4, # 4 for FP32, 2 for FP16
optimizer: str = "adam", # or "sgd"
use_gradient_checkpointing: bool = False,
):
"""
Estimate memory for training a transformer model.
Args:
num_params: Total parameters
batch_size: Batch size
seq_length: Sequence length
hidden_dim: Hidden dimension (for activation estimation)
num_layers: Number of transformer layers
dtype_bytes: 4 for FP32, 2 for FP16, 1 for INT8
optimizer: "sgd" (no state), "adam" (8x params)
use_gradient_checkpointing: If True, reduce activation memory
Returns:
Memory in GB
WHY: Helps choose batch size without trial-and-error OOM
"""
# 1. Parameter memory
param_memory = num_params * dtype_bytes
# 2. Optimizer state
if optimizer.lower() == "adam":
opt_memory = 2 * num_params * dtype_bytes # m + v
elif optimizer.lower() == "adamw":
opt_memory = 2 * num_params * dtype_bytes # m + v
else: # SGD
opt_memory = 0
# 3. Activation memory (transformer-specific)
# Activations = hidden states + attention weights stored during forward
# Per layer: batch × seq_len × hidden_dim × 4 bytes
# × num_layers
activation_memory_per_layer = batch_size * seq_length * hidden_dim * dtype_bytes
total_activation_memory = activation_memory_per_layer * num_layers
if use_gradient_checkpointing:
# With checkpointing: only save activations for last layer
# (recompute others during backward)
total_activation_memory = activation_memory_per_layer # Only 1 layer
# 4. Gradient memory (same as parameter memory)
gradient_memory = num_params * dtype_bytes
# Total
total_bytes = param_memory + opt_memory + total_activation_memory + gradient_memory
total_gb = total_bytes / (1024**3)
return total_gb
# Example: BERT training
memory_gb = estimate_memory_usage(
num_params=110_000_000, # BERT-base
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4, # FP32
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory: {memory_gb:.1f} GB") # ~5.6 GB
# Optimize by reducing batch
memory_gb_batch16 = estimate_memory_usage(
num_params=110_000_000,
batch_size=16, # 2x smaller
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4,
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory with batch 16: {memory_gb_batch16:.1f} GB") # ~3.8 GB
# Optimize by mixed precision
memory_gb_fp16 = estimate_memory_usage(
num_params=110_000_000,
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=2, # FP16 instead of FP32
optimizer="adam",
use_gradient_checkpointing=False,
)
print(f"Memory with FP16: {memory_gb_fp16:.1f} GB") # ~2.8 GB
# Optimize with checkpointing
memory_gb_ckpt = estimate_memory_usage(
num_params=110_000_000,
batch_size=32,
seq_length=512,
hidden_dim=768,
num_layers=12,
dtype_bytes=4,
optimizer="adam",
use_gradient_checkpointing=True, # Save only last layer activations
)
print(f"Memory with checkpointing: {memory_gb_ckpt:.1f} GB") # ~1.0 GB
```
**Memory optimization techniques:**
```python
# Technique 1: Gradient Checkpointing
# Recompute activations instead of storing them
# Memory: O(sqrt(num_layers)) instead of O(num_layers)
# Cost: ~30% slower training (recompute activations during backward)
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Forward: compute and store activations
# Backward: recompute activations during backward
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
x = self.attention(x)
x = self.feedforward(x)
return x
# Technique 2: Mixed Precision (FP16)
# Use FP16 for forward+backward (2x memory)
# Use FP32 for weights (don't accumulate errors)
# Memory: ~50% reduction
# Speed: 1.3-2x faster on modern GPUs
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch, target in train_loader:
optimizer.zero_grad()
with autocast(): # Automatic FP16 casting
output = model(batch)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Technique 3: Quantization-Aware Training
# Store weights in INT8 or FP8
# Requires special hardware support
# Memory: 75-90% reduction
# Speed: 2-4x faster
# Technique 4: Batch Size Scheduling
# Start with small batch, increase during training
# Reason: Large batch early = poor generalization
# Large batch late = good generalization
# Memory: Gradually increases as needed
def get_adaptive_batch_size(epoch, total_epochs):
"""Increase batch size as training progresses"""
base_batch = 32
max_batch = 256
# Linear increase: start small, end large
scale_factor = base_batch + (max_batch - base_batch) * (epoch / total_epochs)
return int(scale_factor)
```
### Pattern 5: Batch Size Effects on Convergence and Generalization
**The generalization gap - why bigger batch = worse accuracy:**
```
Generalization Gap = Test Accuracy (Large Batch) - Test Accuracy (Small Batch)
Why small batch generalizes better:
1. Gradient Noise: Small batch = noisy gradients
- Noise acts as regularization
- Forces model to find robust minima
- Larger noise → larger generalization margin
2. Loss Landscape: SGD with noise explores landscape differently
- Large batch: Gradient descent (exact gradient)
- Small batch: Stochastic gradient (noisy)
- Noise helps escape sharp minima (bad generalization)
- Leads to flat minima (good generalization)
3. Batch Normalization Interaction:
- BN computes statistics per batch
- Larger batch → more stable statistics
- More stable → less regularization effect
- Less regularization → worse generalization
Real numbers (ResNet-50 on ImageNet):
- Batch 256: 76.0% accuracy
- Batch 1024: 74.8% accuracy (1.2% gap!)
- Batch 4096: 72.0% accuracy (4% gap!)
```
**The sharp minima problem:**
```
SMALL BATCH (32):
Loss landscape: Finds FLAT minima
- Small change in weights → loss increases slowly
- Generalizes well (robust to input variations)
- Test accuracy ≈ Train accuracy
- Variance: Higher (gradient noise)
LARGE BATCH (1024):
Loss landscape: Finds SHARP minima
- Small change in weights → loss increases quickly
- Generalizes poorly (sensitive to input variations)
- Test accuracy << Train accuracy (overfitting)
- Variance: Lower (stable gradients)
SOLUTION: Add regularization to large batch training
- L2 regularization (weight decay)
- Dropout
- Data augmentation
- Label smoothing
```
**Batch size effects on different architectures:**
```python
# Architecture 1: ResNets (well-studied)
# Batch 256: 76.0% top-1 accuracy (ImageNet)
# Batch 1024: 74.8% (-1.2%)
# Batch 4096: 72.0% (-4%)
# Conclusion: Batch size matters, gap grows exponentially
# Architecture 2: Vision Transformers
# Batch 512: 82.0% accuracy
# Batch 1024: 81.8% (-0.2%)
# Batch 4096: 81.0% (-1%)
# Conclusion: Less sensitive to batch size (more robust)
# Architecture 3: BERT (Language)
# Batch 128: 89.0% GLUE score
# Batch 256: 88.8% (-0.2%)
# Batch 512: 88.2% (-0.8%)
# Conclusion: Moderate sensitivity
# WHY THE DIFFERENCES?
# - ResNets: Simple architecture, sharp minima
# - Vision Transformers: Attention provides regularization
# - BERT: Pre-training + fine-tuning, already regularized
```
**Empirical guidelines for batch size vs generalization:**
```python
# Rule 1: Start with batch 128-256
# Most tasks achieve good accuracy at this range
# Memory reasonable on modern GPUs
# Generalization gap minimal
# Rule 2: If increasing batch size - add regularization
def add_regularization_for_large_batch(batch_size, base_batch=256):
"""Adjust regularization strength for larger batch size"""
# Start from base: batch 256, weight_decay 0.0001
# Double batch → increase regularization
scale_factor = batch_size / base_batch
weight_decay = 0.0001 * (scale_factor ** 0.5) # sqrt scale
dropout = 0.1 # Add dropout
label_smoothing = 0.1 # Label smoothing helps
return {
'weight_decay': weight_decay,
'dropout': dropout,
'label_smoothing': label_smoothing,
}
# Rule 3: Validate on validation set
# Don't assume scaling rule works for accuracy
# Larger batch might need different epochs/learning rate schedule
# Rule 4: Gradient accumulation doesn't help generalization
# Accumulation ≠ large batch for gradient statistics
# Gradient accumulation has same gradient per parameter
# Just takes longer (multiple backward passes)
# Generalization benefit same as if you had memory for full batch
```
### Pattern 6: Finding Optimal Batch Size (Not Just Maximum)
**The batch size selection framework:**
```
Step 1: Calculate memory budget
→ Max memory available (e.g., 24GB GPU)
→ Estimate parameters + optimizer state
→ Available for batch = Total - (params + opt state)
Step 2: Estimate per-sample memory
→ Run small batch (8), measure memory
→ Divide by 8 to get per-sample
→ Max batch = Available Memory / per-sample
Step 3: Find memory-safe batch
→ Use 80% of max (leaves margin)
→ This is maximum batch that's safe
Step 4: Check convergence at maximum batch
→ Train model with maximum safe batch
→ Compare accuracy to smaller batches
→ If >2% accuracy drop: reduce batch or add regularization
Step 5: Optimize for wall-clock time
→ Profile training time at different batch sizes
→ Wall-clock = (iterations) × (time per iteration)
→ Iterations = (samples / batch) × epochs
→ Find batch that minimizes wall-clock time
→ Often NOT the maximum batch!
Step 6: Select based on task requirements
→ If convergence matters more: smaller batch
→ If speed matters more: larger batch
→ If memory constrained: gradient accumulation
→ If fine-tuning: smaller batch (preserve pre-training)
```
**Implementation:**
```python
def find_optimal_batch_size(
model,
train_loader,
criterion,
device,
target_accuracy=None,
time_budget_seconds=None,
):
"""
Find optimal batch size by profiling at different sizes.
Args:
model: PyTorch model to profile
train_loader: DataLoader
criterion: Loss function
device: torch.device
target_accuracy: If set, find batch that achieves this
time_budget_seconds: If set, find fastest batch within budget
Returns:
Optimal batch size, profiling results
WHY: Maximum batch ≠ optimal batch
"""
batch_sizes = [32, 64, 128, 256, 512]
results = {}
for batch_size in batch_sizes:
# Measure memory for this batch size
try:
batch, target = next(iter(train_loader))
batch = batch[:batch_size].to(device)
target = target[:batch_size].to(device)
torch.cuda.reset_peak_memory_stats(device)
with torch.cuda.device(device):
output = model(batch)
loss = criterion(output, target)
loss.backward()
memory_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
# Measure iteration time
import time
start = time.time()
for _ in range(10):
output = model(batch)
loss = criterion(output, target)
loss.backward()
iteration_time = (time.time() - start) / 10
# Calculate total training time
# Assume 100 epochs, 50k samples
iterations_per_epoch = 50000 // batch_size
total_iterations = iterations_per_epoch * 100
total_time = total_iterations * iteration_time
results[batch_size] = {
'memory_mb': memory_mb,
'iteration_time_ms': iteration_time * 1000,
'total_time_hours': total_time / 3600,
}
except RuntimeError as e:
results[batch_size] = {'error': str(e)}
# Find optimal based on criteria
if target_accuracy is not None:
# Choose smallest batch that achieves target accuracy
return min(results.keys())
elif time_budget_seconds is not None:
# Choose largest batch within time budget
valid = {bs: r for bs, r in results.items()
if 'error' not in r and r['total_time_hours'] * 3600 < time_budget_seconds}
return max(valid.keys()) if valid else None
else:
# Default: choose largest batch within 80% memory limit
memory_limit = 0.8 * torch.cuda.get_device_properties(device).total_memory / (1024**2)
valid = {bs: r for bs, r in results.items()
if 'error' not in r and r['memory_mb'] < memory_limit}
return max(valid.keys()) if valid else None
# Batch size discovery loop
def discover_optimal_batch_size(model, train_loader, criterion, device):
"""
Progressive batch size search starting from small.
Pattern: Double batch size until OOM, then back off.
"""
batch_size = 8
while True:
try:
# Try current batch size
batch, target = next(iter(train_loader))
batch = batch[:batch_size].to(device)
target = target[:batch_size].to(device)
output = model(batch)
loss = criterion(output, target)
loss.backward()
print(f"✓ Batch {batch_size} works")
# Try 2x
prev_batch = batch_size
batch_size *= 2
except RuntimeError as e:
if "out of memory" in str(e).lower():
# OOM: go back to last working batch
optimal_batch = prev_batch
print(f"✗ Batch {batch_size} OOM")
print(f"→ Use batch size {optimal_batch} (safe margin)")
# But check if we can use 1.5x
test_batch = int(optimal_batch * 1.5)
try:
batch = batch[:test_batch].to(device)
output = model(batch)
loss = criterion(output, target)
loss.backward()
print(f"✓ Batch {test_batch} also works, use this")
return test_batch
except:
return optimal_batch
else:
raise
```
**Batch size selection by use case:**
```python
# Use Case 1: Maximum accuracy matters (research, publication)
# → Choose smaller batch (128-256)
# → More gradient noise = better generalization
# → Willing to train longer if accuracy is better
optimal_batch_size = 128
# Use Case 2: Training speed matters (prototyping, iteration)
# → Choose larger batch (512-1024)
# → Trade some accuracy for wall-clock speed
# → Need to add regularization to reduce generalization gap
optimal_batch_size = 512
regularization_strength = 'strong' # weight_decay, dropout
# Use Case 3: Memory severely constrained (mobile, edge)
# → Choose small batch (16-32)
# → Use gradient accumulation to simulate larger batch
# → Accept lower accuracy if necessary
optimal_batch_size = 16
accumulation_steps = 8 # Simulate batch 128
# Use Case 4: Fine-tuning small dataset
# → Choose small batch (16-32)
# → Preserve pre-training (smaller updates)
# → Larger batch risks forgetting pre-trained knowledge
optimal_batch_size = 16
# Use Case 5: Large model, large dataset
# → Choose medium-large batch (256-512)
# → Gradient accumulation for effective larger batch
# → Mixed precision for memory savings
optimal_batch_size = 256
use_mixed_precision = True
use_gradient_accumulation = False # Fits with mixed precision
# Use Case 6: Distributed training (multiple GPUs/TPUs)
# → Per-GPU batch: 32-64
# → Accumulation: 4-8 steps
# → Total effective: per_gpu * num_gpus * accumulation
# → Large total effective batch, small per-GPU batch
per_gpu_batch = 64
num_gpus = 8
accumulation_steps = 4
effective_batch = 64 * 8 * 4 # 2048
```
## Common Pitfalls
**Pitfall 1: Confusing Maximum Batch with Optimal Batch**
**Symptom**: "I have 24GB memory, so I should use the largest batch that fits"
**Why it breaks**: Larger batch = worse generalization. Maximum batch might achieve 2-3% lower accuracy.
**Fix**: Use 80% of maximum batch size, validate accuracy, adjust if needed.
```python
# WRONG
max_batch = find_max_batch_that_fits(model, memory=24_000_000_000)
train(model, batch_size=max_batch) # Likely overfit
# CORRECT
safe_batch = int(max_batch * 0.8) # 80% of maximum
train(model, batch_size=safe_batch)
validate_accuracy(model) # Check if acceptable
if accuracy_drop > 2%:
reduce_batch_size(safe_batch * 0.8)
add_regularization()
```
**Pitfall 2: Ignoring Learning Rate Scaling**
**Symptom**: "I doubled my batch size, training diverges now"
**Why it breaks**: Gradient magnitudes decrease with larger batch, so learning rate must increase proportionally.
**Fix**: Use linear scaling rule: new_lr = old_lr × (new_batch / old_batch)
```python
# WRONG
batch_size = 64
learning_rate = 0.001
# Increase batch without adjusting LR
batch_size = 256
# Learning rate still 0.001 - too small!
# Gradient updates too conservative, very slow convergence
# CORRECT
batch_size = 64
learning_rate = 0.001
batch_size = 256
learning_rate = 0.001 * (256 / 64) # Scale by 4x
# = 0.004
```
**Pitfall 3: Using Huge Learning Rate Without Warmup**
**Symptom**: "I scaled my learning rate by 10x and now training diverges immediately"
**Why it breaks**: Very large learning rate jumps cause instability. Model can't adapt.
**Fix**: Add linear warmup phase: gradually increase LR from 0 to scaled value.
```python
# WRONG
scaled_lr = 0.001 * 10 # 0.01
optimizer = SGD(model, lr=0.01)
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), target)
loss.backward()
optimizer.step() # Diverges on first iteration!
# CORRECT
base_lr = 0.001
scaled_lr = 0.001 * 10 # 0.01
warmup_steps = 1000
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps)) * 10 # 0 to 10x over warmup
return 1.0 # 10x after warmup
optimizer = SGD(model, lr=base_lr)
scheduler = LambdaLR(optimizer, lr_lambda)
for epoch in range(100):
for batch in train_loader:
loss = criterion(model(batch), target)
loss.backward()
optimizer.step()
scheduler.step()
```
**Pitfall 4: Gradient Accumulation Without LR Adjustment**
**Symptom**: "I added gradient accumulation but training is much slower to converge"
**Why it breaks**: Accumulation itself doesn't require LR change, but if effective batch increased, LR should too.
**Fix**: Adjust LR based on effective batch size, not per-GPU batch size.
```python
# WRONG
batch_size = 32 # Per-GPU
num_accumulation = 8
# Learning rate still tuned for batch 32
# Effective batch = 32 × 8 = 256
# But LR not scaled for batch 256
# Convergence slower because LR too conservative
# CORRECT
batch_size = 32
num_accumulation = 8
effective_batch = batch_size * num_accumulation # 256
# Get LR for batch 32
base_lr_batch32 = 0.001
# Scale for batch 256
lr_batch256 = base_lr_batch32 * (256 / 32) # 0.008
optimizer = SGD(model, lr=lr_batch256)
```
**Pitfall 5: Assuming Batch Size Doesn't Affect Accuracy**
**Symptom**: "Batch size only affects speed, not accuracy"
**Why it breaks**: Batch size strongly affects generalization (1-4% gap is common).
**Fix**: Always validate final accuracy at different batch sizes. Larger batch might need different hyperparameters.
```python
# WRONG - assume accuracy independent of batch
batch_sizes = [64, 256, 1024]
for batch_size in batch_sizes:
model = train(learning_rate=0.001) # Same LR for all!
accuracy = evaluate(model)
# Accuracy will differ significantly!
# CORRECT - adjust hyperparameters per batch
for batch_size in batch_sizes:
lr = 0.001 * (batch_size / 64) # Scale LR
weight_decay = 0.0001 * (batch_size / 64) ** 0.5 # Increase regularization
model = train(learning_rate=lr, weight_decay=weight_decay)
accuracy = evaluate(model)
# More consistent accuracy across batch sizes
```
**Pitfall 6: Not Considering Synchronous vs Asynchronous Batch Norm**
**Symptom**: "My distributed training accuracy is much worse than single-GPU"
**Why it breaks**: Batch norm computes statistics per batch. Distributed training with small per-GPU batch = incorrect statistics.
**Fix**: Use SyncBatchNorm for correct statistics across all GPUs.
```python
# WRONG - Synchronous data parallel, asynchronous BN
from torch.nn.parallel import DataParallel
model = DataParallel(model, device_ids=[0, 1, 2, 3])
# Each GPU has batch_size=32
# BN computes stats from only its 32 samples
# Stats unstable, training broken
# CORRECT - Synchronous batch norm
from torch.nn.modules.batchnorm import SyncBatchNorm
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, find_unused_parameters=False)
# Each GPU: batch 32, but BN aggregates across all 4 GPUs = 128
# Stats computed from all 128 samples, stable
```
**Pitfall 7: Gradient Accumulation Too Large (>16x)**
**Symptom**: "I'm accumulating gradients over 32 steps but training diverges"
**Why it breaks**: Large accumulation means many iterations of gradient computation before update. Gradients become stale, divergence risk.
**Fix**: Keep accumulation ≤ 16x. Use distributed training for larger effective batches.
```python
# WRONG - excessive accumulation
batch_size = 4
accumulation_steps = 32 # 128x effective batch!
# Gradients from step 1 are way out of date by step 32
# Large variance in gradient estimates, divergence
# CORRECT - reasonable accumulation
batch_size = 32
accumulation_steps = 8 # 256x effective batch, acceptable
# Gradients only 8 iterations old by update time
# Variance manageable
# OR use distributed training instead
per_gpu_batch = 32
num_gpus = 8
effective_batch = 32 * 8 = 256 # Same as above, but no accumulation
# Better convergence properties
```
**Pitfall 8: Mixing Gradient Accumulation with Exponential Moving Average (EMA)**
**Symptom**: "I'm using gradient accumulation with learning rate scheduler and EMA, but training is unstable"
**Why it breaks**: EMA expects one update per step. With accumulation, multiple backward passes → stale momentum terms.
**Fix**: Update EMA only when you call optimizer.step(), not every backward pass.
```python
# WRONG - updating EMA every backward pass
ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999)
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward()
ema_model.update(model.parameters()) # Called every iteration!
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# CORRECT - update EMA only on optimizer.step()
for step, batch in enumerate(train_loader):
loss = criterion(model(batch), target)
loss.backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
ema_model.update(model.parameters()) # Only here!
```
**Pitfall 9: Batch Size Doubling Without Validation**
**Symptom**: "I increased batch from 64 to 128 based on linear scaling rule, but accuracy dropped 2%"
**Why it breaks**: Linear scaling rule gives convergence rate, not accuracy guarantee. Generalization gap widens.
**Fix**: Always validate on holdout set when changing batch size. Accept accuracy drop or add regularization.
```python
# WRONG - assume linear scaling guarantees accuracy
original_batch = 64
original_lr = 0.001
original_accuracy = 0.85
new_batch = 128
new_lr = 0.001 * (128 / 64) # 0.002
new_accuracy = 0.83 # Dropped 2%! Should have validated first
# CORRECT - validate and adjust regularization if needed
new_batch = 128
new_lr = 0.001 * (128 / 64)
model = train(lr=new_lr, batch=new_batch)
val_accuracy = validate(model)
if val_accuracy < 0.84: # Acceptable drop?
# Add regularization for larger batch
model = train(
lr=new_lr,
batch=new_batch,
weight_decay=0.0002, # Increase
dropout=0.2, # Add/increase
)
val_accuracy = validate(model)
```
**Pitfall 10: Using Maximum Batch in Fine-tuning**
**Symptom**: "I fine-tuned with large batch size and catastrophically forgot pre-training"
**Why it breaks**: Large batch = large updates. Pre-trained weights overwritten too quickly.
**Fix**: Fine-tuning requires SMALLER batch size (32-64) and smaller learning rate than pre-training.
```python
# WRONG - fine-tuning with large batch
pretrained_model = load_pretrained_bert()
batch_size = 512 # Large!
learning_rate = 0.001 # Too large!
model = fine_tune(pretrained_model, batch_size=512, lr=0.001)
# Overfit to task, forget pre-trained knowledge
# Pre-training lost!
# CORRECT - conservative fine-tuning
pretrained_model = load_pretrained_bert()
batch_size = 32 # Small, conservative
learning_rate = 0.00001 # Tiny, preserve pre-training
model = fine_tune(
pretrained_model,
batch_size=batch_size,
lr=learning_rate,
weight_decay=0.001, # Strong L2 regularization
)
# Preserves pre-training knowledge, adapts carefully
```
## Practical Decision Framework
### Quick Batch Size Decision Tree
```
1. How much GPU memory do you have?
├─ < 8 GB: Start with batch 16-32
├─ 8-16 GB: Start with batch 32-64
├─ 16-24 GB: Start with batch 64-128
└─ 24+ GB: Start with batch 128-256
2. Can you fit your target batch in memory?
├─ Yes: Use it (with LR scaling)
├─ No, by <2x: Use gradient accumulation
└─ No, by >2x: Use smaller batch + stronger regularization
3. Is accuracy your priority or speed?
├─ Accuracy: Use smaller batch (32-128)
├─ Speed: Use larger batch (256-1024)
└─ Both: Gradient accumulation + mixed precision
4. Are you fine-tuning or training from scratch?
├─ Fine-tuning: Use small batch (16-32), small LR
└─ From scratch: Use medium batch (64-256), scale LR
5. Are you using distributed training?
├─ Yes: Per-GPU batch 32-64, accumulate for effective 256-512
└─ No: Single GPU batch 64-256
```
## Red Flags - Stop and Clarify
| Excuse | Reality | What To Do |
|--------|---------|-----------|
| "Just use the maximum batch that fits" | Worse generalization likely. Need to validate accuracy. | Measure accuracy at 80% of max, validate trade-offs. |
| "Linear scaling rule means I don't need to validate" | Rule gives convergence rate, not accuracy guarantee. Generalization gap exists. | Always validate final accuracy with new batch size. |
| "Gradient accumulation is just for memory-constrained settings" | It's a legitimate technique with trade-offs (slowness) worth understanding. | Use when memory constrained; understand slowdown cost. |
| "Batch size only affects speed, not accuracy" | Incorrect. Batch size strongly affects final accuracy (1-4% typical gap). | Always measure accuracy, expect gap, add regularization. |
| "I'll use the batch size from a paper, it should work" | Different model, data, hardware - need to validate. | Use paper as starting point, but validate and adjust. |
| "Larger batch = faster training" | Depends on what you measure (iterations vs epochs vs wall-clock). | Measure actual wall-clock time at different batch sizes. |
| "Just double the learning rate when doubling batch" | Linear scaling rule requires warmup for large increases. | Add warmup phase, measure convergence. |
| "Fine-tuning works same as pre-training, just different data" | Fine-tuning needs much smaller batch and LR (preserve pre-training). | Use batch 16-32, LR 10-100x smaller than pre-training. |
## Advanced Patterns: Batch Size Optimization in Production
### Pattern 7: Batch Size Scheduling During Training
**Increasing batch size as training progresses - when and why:**
```python
# Intuition: Start with small batch (good generalization),
# increase later (finish training faster)
def get_scheduled_batch_size(epoch, total_epochs, base_batch=32, max_batch=256):
"""
Increase batch size linearly with epochs.
WHY: Start small for generalization, increase for speed later.
Research shows this works well for long training.
"""
# Linear increase: 0 → 100% over training
scale = epoch / total_epochs
return int(base_batch + (max_batch - base_batch) * scale)
# Usage in training loop
for epoch in range(total_epochs):
batch_size = get_scheduled_batch_size(epoch, total_epochs)
for batch, target in get_data_loader(batch_size=batch_size):
# Adjust learning rate dynamically
lr = 0.001 * (batch_size / 32) # Scale with batch
update_learning_rate(optimizer, lr)
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Alternative: exponential schedule
def get_exponential_batch_schedule(epoch, base_batch=32, max_batch=256):
"""Exponential increase instead of linear"""
scale = (epoch / total_epochs)
return int(base_batch * (max_batch / base_batch) ** scale)
```
**When batch size scheduling is valuable:**
```
GOOD FIT:
- Long training (100+ epochs)
- Starting generalization is important
- Speed only matters at end
- Example: ResNet on ImageNet
NOT NEEDED:
- Short training (10-20 epochs)
- Already regularized enough (BERT fine-tuning)
- Batch size well-chosen from start
```
### Pattern 8: Batch Size vs Other Hyperparameters
**Understanding interactions with other hyperparameters:**
```python
# Interaction 1: Batch size ↔ Learning rate
# Already covered: linear scaling rule
# Interaction 2: Batch size ↔ Weight decay
# Larger batch → worse generalization
# Solution: Increase weight decay when increasing batch
# Typical: weight_decay ~ sqrt(batch_size)
def adjust_weight_decay(base_wd=0.0001, base_batch=256, new_batch=512):
"""Scale weight decay with batch size"""
return base_wd * (new_batch / base_batch) ** 0.5
# Example
wd_batch_256 = 0.0001
wd_batch_512 = adjust_weight_decay(wd_batch_256, 256, 512) # 0.000141
# Interaction 3: Batch size ↔ Dropout
# Larger batch → add/increase dropout
# Dropout magnitude depends on layer, typically 0.1-0.5
def adjust_dropout(base_dropout=0.1, base_batch=256, new_batch=512):
"""Increase dropout for larger batches"""
# Dropout strength ~ sqrt(batch_size)
scale = (new_batch / base_batch) ** 0.5
return min(base_dropout * scale, 0.5) # Cap at 0.5
# Interaction 4: Batch size ↔ Number of epochs
# Larger batch → more epochs needed to converge
# Typical: iterations constant ≈ samples/batch × epochs
# If batch 4x → epochs 1.5-2x to match convergence
base_batch = 64
base_epochs = 100
base_iterations = (50000 / base_batch) * base_epochs # Total iterations
new_batch = 256
# To maintain same iterations:
new_epochs = base_iterations / (50000 / new_batch) # ~25 epochs
# Wall-clock faster (fewer iterations) but need fewer epochs
# Interaction 5: Batch size ↔ Optimizer choice
# SGD: works well at all batch sizes
# Momentum: accumulates larger steps, works best with smaller batch
# Adam: adaptive, less sensitive to batch size
# RMSprop: similar to Adam
# Recommendation:
# - Small batch (32-128): SGD with momentum or Adam
# - Large batch (512+): Adam (more stable) or SGD with warmup + large LR
# Interaction 6: Batch size ↔ Normalization technique
# Batch Norm: statistics from batch, larger batch = better stats
# Layer Norm: independent of batch size
# Group Norm: middle ground, works well with any batch size
# If using BatchNorm with small batch (< 16):
# → Use SyncBatchNorm across devices
# → Or use GroupNorm instead
# If using BatchNorm with large batch (> 1024):
# → Standard BatchNorm fine
# → May want to reduce BN momentum (accumulate stats slower)
```
## Rationalization Table: Common Excuses About Batch Size
| Rationalization | Why It's Wrong | Correct Approach |
|---|---|---|
| "Larger batch is always better for speed" | Wall-clock time depends on iterations AND time-per-iteration. Larger batch may have lower throughput. | Profile wall-clock time at different batch sizes, choose fastest. |
| "I'll tune batch size last, it's not important" | Batch size affects convergence rate, generalization, and stability early. Tuning last wastes time. | Choose good batch size early (based on memory), validate accuracy. |
| "Maximum batch that fits = optimal batch" | Generalization gap widens with batch size (1-4% typical). Maximum might hit accuracy target. | Use 80% of max, validate on validation set, adjust if needed. |
| "Linear scaling rule means I don't validate" | Scaling rule gives convergence rate. Accuracy still varies with batch size due to generalization gap. | Always validate test/validation accuracy with new batch. |
| "Gradient accumulation is slow, don't use it" | True, it's slower (1.5-2x). But if memory is bottleneck, only alternative. Choose based on constraints. | Use when memory constrained. Accept slowdown. Don't use if memory OK. |
| "I don't need warmup, I'll just use scaled LR" | Large LR jumps cause divergence. Warmup prevents this. | Add linear warmup phase for scaled LR. |
| "My paper used batch X, I'll use that" | Different model, data, hardware converge differently. Paper batch might not be optimal for you. | Use paper as starting point. Validate and adjust for your setup. |
| "Fine-tuning uses same batch as pre-training" | Fine-tuning needs much smaller batch (preserve knowledge). Using pre-training batch erases pre-training. | Use batch 10-20x smaller than pre-training. Use tiny LR. |
| "Batch size only affects speed, not accuracy" | Batch size strongly affects generalization (1-4% gap common). Different final accuracy with different batch. | Expect accuracy variation with batch. Validate at each batch size. |
| "I increased batch, why is training slower?" | Fewer iterations (good) but longer per-iteration (bad). Total wall-clock = iterations × time-per-iteration. | Profile actual wall-clock time. May need gradient accumulation. |
| "I'll start with large batch to save memory" | Large batch → bad generalization early → harder to recover later. Start small, increase if needed. | Start with batch 32-64, increase during training if memory allows. |
## Comprehensive Example: Training a Vision Transformer
Let's put it all together with a real example:
```python
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.optim import AdamW
from torchvision import models, datasets, transforms
def train_vision_transformer_optimized():
"""
Complete example: training Vision Transformer with batch size optimization.
"""
# Step 1: Model and data
device = torch.device("cuda:0")
model = models.vit_b_16(pretrained=False).to(device)
criterion = torch.nn.CrossEntropyLoss()
# Dataset (ImageNet-scale)
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
# Step 2: Determine batch size
# ViT-Base: 86M parameters
# GPU: 40GB A100
# Memory estimate: params (344MB) + optimizer (688MB) + activations
# Can fit batch 256-512
base_batch = 256
num_accumulation_steps = 1 # Can fit directly
effective_batch = base_batch
# Step 3: Initialize optimizer with scaled LR
# Base LR tuned for batch 256
base_lr = 1e-4
scaled_lr = base_lr * (effective_batch / 256) # 1e-4 (no scaling needed)
optimizer = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.05)
# Step 4: Warmup scheduler
warmup_steps = 1000
total_steps = 100 * len(dataset) // effective_batch
def warmup_cosine_schedule(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
return 0.5 * (1.0 + torch.cos(
torch.tensor(3.14159) *
(step - warmup_steps) / (total_steps - warmup_steps)
)).item()
scheduler = LambdaLR(optimizer, warmup_cosine_schedule)
# Step 5: Training loop with gradient accumulation (even though not needed)
# Good practice for larger models
model.train()
optimizer.zero_grad()
for epoch in range(100):
for step, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward + backward
logits = model(images)
loss = criterion(logits, labels)
loss = loss / num_accumulation_steps
loss.backward()
# Update on accumulation step
if (step + 1) % num_accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, "
f"lr {optimizer.param_groups[0]['lr']:.2e}")
# Validate every epoch
val_accuracy = validate(model, device)
print(f"Epoch {epoch} validation accuracy: {val_accuracy:.2%}")
return model
# Key patterns demonstrated:
# 1. Batch size chosen based on memory (80% of max)
# 2. Learning rate scaled for batch size
# 3. Warmup phase for gradual LR increase
# 4. Cosine annealing for LR decay
# 5. Gradient accumulation structure (even if not needed)
# 6. Gradient clipping for stability
# 7. Regular validation to monitor accuracy
```
## Summary: Batch Size and Memory Decision Making
**The core principle:** Batch size is a system design choice affecting convergence, generalization, speed, and memory simultaneously. There is no universal "right" batch size - it depends on your constraints and priorities.
**The decision process:**
1. **Memory constraint**: Start with 80% of maximum batch
2. **Convergence**: Scale learning rate 1:1 with batch increase (with warmup)
3. **Generalization**: Validate accuracy, reduce if gap >2% (or add regularization)
4. **Performance**: Profile wall-clock time at different batch sizes
5. **Architecture**: Different models have different optimal batches
**The key insights:**
- Larger batch = faster iterations but worse generalization
- Linear scaling rule requires warmup for large increases
- Gradient accumulation is a legitimate technique (understand slowdown cost)
- Fine-tuning requires smaller batch than pre-training
- Distributed training needs care with batch norm and gradient updates
- Always measure, validate, and adjust - don't assume rules apply to your case
**The testing approach:**
When pressure-tested, this skill should:
- Explain why maximum batch ≠ optimal batch (generalization gap)
- Provide concrete examples of linear scaling rule with warmup
- Address gradient accumulation systematically (when, why, cost)
- Discuss memory estimation and optimization techniques
- Help select batch size based on constraints AND priorities
- Resist rationalizations and always recommend validation
## References and Further Reading
**Key papers:**
- Goyal et al. (2017) "Accurate, Large Batch Training" - Linear scaling rule
- You et al. (2019) "Large Batch Optimization for Deep Learning" - Theory
- Smith et al. (2017) "Don't Decay the Learning Rate" - Learning rate schedules
**Techniques mentioned:**
- Batch Normalization: Ioffe & Szegedy (2015)
- Layer Normalization: Ba et al. (2016)
- Mixed Precision Training: Micikevicius et al. (2017)
- Gradient Checkpointing: Chen et al. (2016)
**Related Yzmir Skills:**
- `learning-rate-scheduling` - LR schedule choices beyond linear scaling
- `gradient-management` - Gradient clipping and accumulation for stability
- `optimization-algorithms` - Optimizer selection and hyperparameter tuning