# Batch Size and Memory Tradeoffs ## Overview Batch size is one of the most misunderstood hyperparameters. Most engineers think: "larger batch = faster training = better". Wrong. Batch size affects convergence speed, generalization, memory usage, and actual wall-clock training time in complex ways. **Larger batch size is NOT always better.** **Core principle**: Batch size selection is a system optimization problem, not a memory constraint problem. Choose batch size based on computational speed, convergence requirements, and generalization targets - not just what fits in memory. ## When to Use This Skill **Use this skill when:** - Choosing batch size for new training - Training is slow and considering larger batches - Out-of-memory errors during training - Learning rate needs adjustment after batch size change - Distributed training needs batch size scaling - Gradient accumulation considerations - User asks "what batch size should I use?" - Training accuracy varies widely between batch sizes - Convergence takes too long or is unstable - Memory per sample calculation needed - Comparing training speed: iterations vs epochs vs wall-clock time - Fine-tuning with different batch sizes than pre-training **Symptoms you need this skill:** - "I have memory, what's the maximum batch size?" (wrong question) - "Larger batches train faster, so use 512?" (incomplete) - "Batch size doesn't affect accuracy, only speed?" (false) - "Gradient accumulation is a workaround for small memory?" (misconception) - "Just scale learning rate by 2x when doubling batch size?" (incomplete) - "We get OOM at batch 256, so use 128 forever" (not optimized) **Don't use when:** - User has pure memory/infrastructure questions (use pytorch-engineering) - User asks about optimizer selection (use optimizer-selection-framework) - User asks about learning rate scheduling (use learning-rate-scheduling) - User has general training failure (not batch-size specific) ## Core Patterns ### Pattern 1: The Batch Size Tradeoff Space **The critical insight**: Batch size affects FOUR independent dimensions simultaneously. Optimize one = impact others. **The four dimensions:** ``` 1. TRAINING SPEED (iterations to converge) ├─ Larger batch → fewer iterations to convergence ✓ ├─ BUT: Gradient variance decreases (noisier gradients are better) └─ Result: Mixed - can't just maximize batch 2. COMPUTATIONAL EFFICIENCY (wall-clock time) ├─ Larger batch → amortize overhead per sample ✓ ├─ BUT: Larger batch → need larger LR (unstable) ├─ AND: Gradient accumulation = repeated backward (slow) └─ Result: Optimal ≠ Maximum 3. GENERALIZATION (test accuracy) ├─ Smaller batch → noisier gradients → better regularization ✓ ├─ Larger batch → cleaner gradient → overfit risk ✗ ├─ BUT: Can compensate with stronger regularization └─ Result: Batch size ↔ regularization coupling 4. MEMORY USAGE (GPU memory required) ├─ Larger batch → linear increase in activation memory ├─ Parameters constant regardless of batch ├─ Optimizer state constant regardless of batch └─ Result: Memory ∝ batch size (linear only for activations) ``` **The mental model:** ``` LARGER BATCH: ✓ Fewer iterations to convergence ✓ Better computational efficiency (up to point) ✗ Worse generalization (harder to regularize) ✗ Requires larger learning rate (instability risk) ✗ Higher memory usage SMALLER BATCH: ✗ More iterations to convergence ✗ Worse computational efficiency ✓ Better generalization (noise helps) ✓ Smaller learning rates are stable ✓ Lower memory usage ``` **Finding the sweet spot:** - Start with batch size that uses ~80% GPU memory - Adjust learning rate using linear scaling rule - Monitor validation accuracy - If validation accuracy drops → batch too large, reduce or regularize - If training is slow → may need gradient accumulation, not larger batch ### Pattern 2: Linear Learning Rate Scaling Rule **The rule that changes everything:** If you increase batch size by factor K, increase learning rate by factor K. ``` New LR = Old LR × (New Batch Size / Old Batch Size) ``` **Why this works (the math):** ``` Gradient Descent Update: param = param - lr * gradient With Batch Size B, gradient is average of B samples: gradient_B = (1/B) * sum(gradients from B samples) update_B = lr * gradient_B With Batch Size 2B, gradient is average of 2B samples: gradient_2B = (1/(2B)) * sum(gradients from 2B samples) Variance drops by 2x when averaging 2x more samples. If variance drops 2x, gradient magnitude is √2x smaller. To keep update magnitude constant: lr should increase by 2x. Empirically validated: Goyal et al. (2017) "Accurate, Large Batch Training" ``` **Implementation:** ```python # Pattern 1: Direct scaling original_lr = 0.001 original_batch_size = 32 new_batch_size = 128 scaling_factor = new_batch_size / original_batch_size # 4x new_lr = original_lr * scaling_factor # 0.004 # Pattern 2: When changing both batch AND learning rate def compute_scaled_lr(base_lr, base_batch_size, current_batch_size): """ Compute learning rate for new batch size using linear scaling rule. Args: base_lr: Learning rate at reference batch size base_batch_size: Batch size where base_lr was tuned (usually 32 or 256) current_batch_size: New batch size Returns: Scaled learning rate WHY: Linear scaling rule keeps update magnitude constant """ scale_factor = current_batch_size / base_batch_size return base_lr * scale_factor # Example: ResNet-50 training (ImageNet baseline) # Reference: batch=256, lr=0.1 # Now training at: batch=1024 scaled_lr = compute_scaled_lr(0.1, 256, 1024) # 0.4 print(f"Batch 256 with lr=0.1 → Batch 1024 with lr={scaled_lr}") ``` **When linear scaling works:** ```python # CASE 1: Scaling works well # Batch: 32 → 256 (8x increase) # Learning rate: 0.001 → 0.008 (8x) # Training: ✓ Converges normally, same final accuracy # Wall-clock: ✓ Faster (fewer iterations, better hardware utilization) # CASE 2: Scaling doesn't work # Batch: 32 → 1024 (32x increase!) # Learning rate: 0.001 → 0.032 (32x) # Problem: Learning rate too large, training diverges # Solution: Need warmup phase ``` **The Critical Caveat: WARMUP IS REQUIRED** ```python # WRONG: Apply full scaled LR immediately optimizer = torch.optim.SGD(model.parameters(), lr=0.032) # Too large! for epoch in range(100): for batch in train_loader: loss = criterion(model(batch), targets) loss.backward() optimizer.step() # Loss diverges on first iteration! # CORRECT: Warmup phase before scaled LR def warmup_lr_schedule(base_lr, current_batch_size, reference_batch_size, current_step, warmup_steps): """ Linear warmup from 0 to scaled LR. WHY: Large LR jumps can cause divergence. Gradual warmup lets model adapt to larger updates. """ scaled_lr = base_lr * (current_batch_size / reference_batch_size) if current_step < warmup_steps: # Linear warmup: ramp from 0 to scaled_lr return scaled_lr * (current_step / warmup_steps) else: # Full scaled LR after warmup return scaled_lr # Implementation with PyTorch scheduler from torch.optim.lr_scheduler import LambdaLR def get_warmup_scheduler(optimizer, warmup_steps): base_lrs = [param_group['lr'] for param_group in optimizer.param_groups] def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda) # Training loop optimizer = torch.optim.SGD(model.parameters(), lr=0.032) scheduler = get_warmup_scheduler(optimizer, warmup_steps=1000) for epoch in range(100): for step, batch in enumerate(train_loader): loss = criterion(model(batch), targets) loss.backward() optimizer.step() scheduler.step() # Gradually increase LR ``` **Practical guidelines:** ``` BATCH SIZE INCREASE LEARNING RATE SCALE WARMUP NEEDED? WHY 2x (64→128) 2x (0.001→0.002) No Safe, gradual 4x (64→256) 4x (0.001→0.004) Maybe Starting to matter 8x (64→512) 8x (0.001→0.008) YES Risky without warmup 16x+ (64→1024) 16x+ (0.001→0.016) CRITICAL Risk of divergence ``` ### Pattern 3: Gradient Accumulation - The Alternative to Large Batches **What gradient accumulation does:** Gradient accumulation simulates large batch size without large GPU memory. Instead of 1 forward+backward of batch 256, do 8 forward+backwardsof batch 32. Same effective batch, 1/8th memory. **How it works:** ```python # SIMPLE APPROACH (without accumulation) batch_size = 256 effective_batch_size = 256 # Process full batch at once memory_required = HIGH # Can't fit in GPU for batch in train_loader: # batch.size() = 256 output = model(batch) loss = criterion(output, target) loss.backward() optimizer.step() # GRADIENT ACCUMULATION APPROACH batch_size = 32 accumulation_steps = 8 effective_batch_size = 32 * 8 = 256 # Same as above! memory_required = LOW # Only batch 32 in memory at once optimizer.zero_grad() for accumulation_step in range(accumulation_steps): batch = next(iter(train_loader)) # batch.size() = 32 output = model(batch) loss = criterion(output, target) loss.backward() # Accumulate gradients (don't zero!) # Don't call optimizer.step() yet! optimizer.step() # Update weights after accumulation complete # Effect: Updated weights as if we processed batch_size=256 ``` **When to use gradient accumulation:** ```python # CASE 1: Model too large to fit large batch # Model: GPT-2 (124M parameters) # Available GPU: 24GB # Desired batch: 512 per GPU # Fits in memory: No, only 32 fits # Solution: Accumulate 16 steps of batch 32 = effective 512 model_params = 124_000_000 # 124M param_memory = model_params * 4 # bytes (FP32) optimizer_memory = model_params * 8 # Adam state (8x parameters) batch_size = 32 sequence_length = 512 activation_memory_per_sample = param_memory / 10 # Rough estimate total_memory = param_memory + optimizer_memory + (batch_size * activation_memory_per_sample) # ~2GB memory per step # 16 accumulation steps still << 24GB # CASE 2: Distributed training across 8 GPUs # Per-GPU batch: 32 # Number of GPUs: 8 # Local accumulation: 4 steps # Total effective: 32 * 8 * 4 = 1024 (synchronized across 8 GPUs) # Accumulation enables large total batch without massive per-GPU batch ``` **The memory math:** ``` Memory with Gradient Accumulation: Without accumulation (batch_size = 256): - Parameters: Fixed - Optimizer state: Fixed (8x params for Adam) - Activations: O(batch_size) = O(256) - Gradients: O(batch_size) = O(256) - Total ≈ 1.0x baseline memory With accumulation (batch_size = 32, steps = 8): - Parameters: Fixed (same) - Optimizer state: Fixed (same) - Activations: O(batch_size) = O(32) = 8x SMALLER - Gradients: O(batch_size) = O(32) = 8x SMALLER - Total ≈ 0.15x baseline memory (for activations+gradients) Savings: ~85% memory reduction! Cost: 8x longer (8 backward passes instead of 1) Net wall-clock: ~1.5-2x slower (overhead, synchronization) ``` **Implementation patterns:** ```python # Pattern 1: Manual gradient accumulation num_accumulation_steps = 8 optimizer.zero_grad() for step, (batch, target) in enumerate(train_loader): output = model(batch) loss = criterion(output, target) # Scale loss by accumulation steps # WHY: Otherwise gradient magnitudes stack up across steps loss = loss / num_accumulation_steps loss.backward() # Accumulate gradients if (step + 1) % num_accumulation_steps == 0: optimizer.step() # Update after accumulation complete optimizer.zero_grad() # Pattern 2: With learning rate adjustment # IMPORTANT: Don't adjust learning rate just because of accumulation! # Accumulation is transparent to optimizer. # Scale is: effective_batch = batch_size * num_accumulation_steps # So LR should match effective_batch, NOT per-GPU batch original_lr = 0.1 # Tuned for batch_size = 32 num_accumulation_steps = 8 effective_batch = 32 * 8 # 256 # Linear scaling rule based on effective batch: # Batch 32 → 256 is 8x increase # So LR: 0.1 → 0.8 (8x) new_lr = original_lr * 8 # 0.8 optimizer = torch.optim.SGD(model.parameters(), lr=new_lr) # Pattern 3: Distributed training with gradient accumulation # Per-GPU batch: 32 # Number of GPUs: 8 # Accumulation steps: 4 # Effective batch: 32 * 8 * 4 = 1024 from torch.nn.parallel import DistributedDataParallel as DDP model = DDP(model) num_accumulation_steps = 4 optimizer.zero_grad() for step, (batch, target) in enumerate(train_loader): output = model(batch) loss = criterion(output, target) loss = loss / num_accumulation_steps loss.backward() if (step + 1) % num_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad() # Pattern 4: With synchronization (distributed) class GradientAccumulator: def __init__(self, model, num_accumulation_steps, sync_gradients_every=1): self.model = model self.num_steps = num_accumulation_steps self.sync_every = sync_gradients_every self.step_count = 0 def should_sync_gradients(self): """ In DDP, only sync gradients when we're about to do optimizer.step(). This reduces communication overhead. """ return (self.step_count + 1) % self.sync_every == 0 def backward(self, loss): loss = loss / self.num_steps # Only sync if we're about to step if self.should_sync_gradients(): loss.backward() else: with self.model.no_sync(): # Skip gradient sync in DDP loss.backward() self.step_count += 1 ``` **Gradient accumulation vs large batch - when to choose:** ```python # When gradient accumulation is GOOD choice: # 1. Memory-constrained (can't fit large batch) # 2. Need large effective batch for convergence # 3. Can tolerate ~1.5-2x slowdown # 4. Training wall-clock time not critical # When gradient accumulation is BAD choice: # 1. Can fit desired batch size in memory # 2. Training speed is critical (wall-clock matters) # 3. Already have good convergence with smaller batches # 4. Reduced gradient noise is important for task # Comparison table: # LARGE BATCH GRADIENT ACCUMULATION # Memory High Low (1/accumulation) # Wall-clock time Fast ~1.5-2x slower # Convergence speed Good Same (effective batch same) # Implementation Simple Requires manual loop # Memory savings None ~85% (with 8x accumulation) # When to use When memory OK When memory constrained ``` ### Pattern 4: Memory Estimation and Optimization **Understanding memory components:** ``` Total GPU Memory = Parameters + Optimizer State + Activations + Gradients Example: Training BERT-base (110M params) with batch_size=32, seq_len=512 1. PARAMETERS (Fixed) - BERT: 110M × 4 bytes (FP32) = 440 MB - Or 110M × 2 bytes (FP16) = 220 MB 2. OPTIMIZER STATE (Fixed) - SGD: No extra state = 0 MB - Adam: m + v = 2 × params = 880 MB (FP32) or 440 MB (FP16) - AdamW: Same as Adam 3. ACTIVATIONS (Linear in batch_size, seq_len) - Stored during forward pass (for backward) - BERT layer: ~batch × seq_len × hidden_dim × 4 - = 32 × 512 × 768 × 4 bytes - = ~320 MB per layer - × 12 layers = ~3.8 GB 4. GRADIENTS (Linear in batch_size) - Stored after backward, until optimizer.step() - Same size as parameters = 440 MB TOTAL MEMORY = 440 + 880 + 3800 + 440 = ~5.6 GB Typical budget: Use ~80% GPU = 19 GB with 24GB GPU Room for more: Can increase batch from 32 → 128 safely ``` **Memory calculation framework:** ```python def estimate_memory_usage( num_params: int, batch_size: int, seq_length: int, hidden_dim: int, num_layers: int, dtype_bytes: int = 4, # 4 for FP32, 2 for FP16 optimizer: str = "adam", # or "sgd" use_gradient_checkpointing: bool = False, ): """ Estimate memory for training a transformer model. Args: num_params: Total parameters batch_size: Batch size seq_length: Sequence length hidden_dim: Hidden dimension (for activation estimation) num_layers: Number of transformer layers dtype_bytes: 4 for FP32, 2 for FP16, 1 for INT8 optimizer: "sgd" (no state), "adam" (8x params) use_gradient_checkpointing: If True, reduce activation memory Returns: Memory in GB WHY: Helps choose batch size without trial-and-error OOM """ # 1. Parameter memory param_memory = num_params * dtype_bytes # 2. Optimizer state if optimizer.lower() == "adam": opt_memory = 2 * num_params * dtype_bytes # m + v elif optimizer.lower() == "adamw": opt_memory = 2 * num_params * dtype_bytes # m + v else: # SGD opt_memory = 0 # 3. Activation memory (transformer-specific) # Activations = hidden states + attention weights stored during forward # Per layer: batch × seq_len × hidden_dim × 4 bytes # × num_layers activation_memory_per_layer = batch_size * seq_length * hidden_dim * dtype_bytes total_activation_memory = activation_memory_per_layer * num_layers if use_gradient_checkpointing: # With checkpointing: only save activations for last layer # (recompute others during backward) total_activation_memory = activation_memory_per_layer # Only 1 layer # 4. Gradient memory (same as parameter memory) gradient_memory = num_params * dtype_bytes # Total total_bytes = param_memory + opt_memory + total_activation_memory + gradient_memory total_gb = total_bytes / (1024**3) return total_gb # Example: BERT training memory_gb = estimate_memory_usage( num_params=110_000_000, # BERT-base batch_size=32, seq_length=512, hidden_dim=768, num_layers=12, dtype_bytes=4, # FP32 optimizer="adam", use_gradient_checkpointing=False, ) print(f"Memory: {memory_gb:.1f} GB") # ~5.6 GB # Optimize by reducing batch memory_gb_batch16 = estimate_memory_usage( num_params=110_000_000, batch_size=16, # 2x smaller seq_length=512, hidden_dim=768, num_layers=12, dtype_bytes=4, optimizer="adam", use_gradient_checkpointing=False, ) print(f"Memory with batch 16: {memory_gb_batch16:.1f} GB") # ~3.8 GB # Optimize by mixed precision memory_gb_fp16 = estimate_memory_usage( num_params=110_000_000, batch_size=32, seq_length=512, hidden_dim=768, num_layers=12, dtype_bytes=2, # FP16 instead of FP32 optimizer="adam", use_gradient_checkpointing=False, ) print(f"Memory with FP16: {memory_gb_fp16:.1f} GB") # ~2.8 GB # Optimize with checkpointing memory_gb_ckpt = estimate_memory_usage( num_params=110_000_000, batch_size=32, seq_length=512, hidden_dim=768, num_layers=12, dtype_bytes=4, optimizer="adam", use_gradient_checkpointing=True, # Save only last layer activations ) print(f"Memory with checkpointing: {memory_gb_ckpt:.1f} GB") # ~1.0 GB ``` **Memory optimization techniques:** ```python # Technique 1: Gradient Checkpointing # Recompute activations instead of storing them # Memory: O(sqrt(num_layers)) instead of O(num_layers) # Cost: ~30% slower training (recompute activations during backward) from torch.utils.checkpoint import checkpoint class TransformerBlock(nn.Module): def forward(self, x): # Forward: compute and store activations # Backward: recompute activations during backward return checkpoint(self._forward, x, use_reentrant=False) def _forward(self, x): x = self.attention(x) x = self.feedforward(x) return x # Technique 2: Mixed Precision (FP16) # Use FP16 for forward+backward (2x memory) # Use FP32 for weights (don't accumulate errors) # Memory: ~50% reduction # Speed: 1.3-2x faster on modern GPUs from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for batch, target in train_loader: optimizer.zero_grad() with autocast(): # Automatic FP16 casting output = model(batch) loss = criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # Technique 3: Quantization-Aware Training # Store weights in INT8 or FP8 # Requires special hardware support # Memory: 75-90% reduction # Speed: 2-4x faster # Technique 4: Batch Size Scheduling # Start with small batch, increase during training # Reason: Large batch early = poor generalization # Large batch late = good generalization # Memory: Gradually increases as needed def get_adaptive_batch_size(epoch, total_epochs): """Increase batch size as training progresses""" base_batch = 32 max_batch = 256 # Linear increase: start small, end large scale_factor = base_batch + (max_batch - base_batch) * (epoch / total_epochs) return int(scale_factor) ``` ### Pattern 5: Batch Size Effects on Convergence and Generalization **The generalization gap - why bigger batch = worse accuracy:** ``` Generalization Gap = Test Accuracy (Large Batch) - Test Accuracy (Small Batch) Why small batch generalizes better: 1. Gradient Noise: Small batch = noisy gradients - Noise acts as regularization - Forces model to find robust minima - Larger noise → larger generalization margin 2. Loss Landscape: SGD with noise explores landscape differently - Large batch: Gradient descent (exact gradient) - Small batch: Stochastic gradient (noisy) - Noise helps escape sharp minima (bad generalization) - Leads to flat minima (good generalization) 3. Batch Normalization Interaction: - BN computes statistics per batch - Larger batch → more stable statistics - More stable → less regularization effect - Less regularization → worse generalization Real numbers (ResNet-50 on ImageNet): - Batch 256: 76.0% accuracy - Batch 1024: 74.8% accuracy (1.2% gap!) - Batch 4096: 72.0% accuracy (4% gap!) ``` **The sharp minima problem:** ``` SMALL BATCH (32): Loss landscape: Finds FLAT minima - Small change in weights → loss increases slowly - Generalizes well (robust to input variations) - Test accuracy ≈ Train accuracy - Variance: Higher (gradient noise) LARGE BATCH (1024): Loss landscape: Finds SHARP minima - Small change in weights → loss increases quickly - Generalizes poorly (sensitive to input variations) - Test accuracy << Train accuracy (overfitting) - Variance: Lower (stable gradients) SOLUTION: Add regularization to large batch training - L2 regularization (weight decay) - Dropout - Data augmentation - Label smoothing ``` **Batch size effects on different architectures:** ```python # Architecture 1: ResNets (well-studied) # Batch 256: 76.0% top-1 accuracy (ImageNet) # Batch 1024: 74.8% (-1.2%) # Batch 4096: 72.0% (-4%) # Conclusion: Batch size matters, gap grows exponentially # Architecture 2: Vision Transformers # Batch 512: 82.0% accuracy # Batch 1024: 81.8% (-0.2%) # Batch 4096: 81.0% (-1%) # Conclusion: Less sensitive to batch size (more robust) # Architecture 3: BERT (Language) # Batch 128: 89.0% GLUE score # Batch 256: 88.8% (-0.2%) # Batch 512: 88.2% (-0.8%) # Conclusion: Moderate sensitivity # WHY THE DIFFERENCES? # - ResNets: Simple architecture, sharp minima # - Vision Transformers: Attention provides regularization # - BERT: Pre-training + fine-tuning, already regularized ``` **Empirical guidelines for batch size vs generalization:** ```python # Rule 1: Start with batch 128-256 # Most tasks achieve good accuracy at this range # Memory reasonable on modern GPUs # Generalization gap minimal # Rule 2: If increasing batch size - add regularization def add_regularization_for_large_batch(batch_size, base_batch=256): """Adjust regularization strength for larger batch size""" # Start from base: batch 256, weight_decay 0.0001 # Double batch → increase regularization scale_factor = batch_size / base_batch weight_decay = 0.0001 * (scale_factor ** 0.5) # sqrt scale dropout = 0.1 # Add dropout label_smoothing = 0.1 # Label smoothing helps return { 'weight_decay': weight_decay, 'dropout': dropout, 'label_smoothing': label_smoothing, } # Rule 3: Validate on validation set # Don't assume scaling rule works for accuracy # Larger batch might need different epochs/learning rate schedule # Rule 4: Gradient accumulation doesn't help generalization # Accumulation ≠ large batch for gradient statistics # Gradient accumulation has same gradient per parameter # Just takes longer (multiple backward passes) # Generalization benefit same as if you had memory for full batch ``` ### Pattern 6: Finding Optimal Batch Size (Not Just Maximum) **The batch size selection framework:** ``` Step 1: Calculate memory budget → Max memory available (e.g., 24GB GPU) → Estimate parameters + optimizer state → Available for batch = Total - (params + opt state) Step 2: Estimate per-sample memory → Run small batch (8), measure memory → Divide by 8 to get per-sample → Max batch = Available Memory / per-sample Step 3: Find memory-safe batch → Use 80% of max (leaves margin) → This is maximum batch that's safe Step 4: Check convergence at maximum batch → Train model with maximum safe batch → Compare accuracy to smaller batches → If >2% accuracy drop: reduce batch or add regularization Step 5: Optimize for wall-clock time → Profile training time at different batch sizes → Wall-clock = (iterations) × (time per iteration) → Iterations = (samples / batch) × epochs → Find batch that minimizes wall-clock time → Often NOT the maximum batch! Step 6: Select based on task requirements → If convergence matters more: smaller batch → If speed matters more: larger batch → If memory constrained: gradient accumulation → If fine-tuning: smaller batch (preserve pre-training) ``` **Implementation:** ```python def find_optimal_batch_size( model, train_loader, criterion, device, target_accuracy=None, time_budget_seconds=None, ): """ Find optimal batch size by profiling at different sizes. Args: model: PyTorch model to profile train_loader: DataLoader criterion: Loss function device: torch.device target_accuracy: If set, find batch that achieves this time_budget_seconds: If set, find fastest batch within budget Returns: Optimal batch size, profiling results WHY: Maximum batch ≠ optimal batch """ batch_sizes = [32, 64, 128, 256, 512] results = {} for batch_size in batch_sizes: # Measure memory for this batch size try: batch, target = next(iter(train_loader)) batch = batch[:batch_size].to(device) target = target[:batch_size].to(device) torch.cuda.reset_peak_memory_stats(device) with torch.cuda.device(device): output = model(batch) loss = criterion(output, target) loss.backward() memory_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2) # Measure iteration time import time start = time.time() for _ in range(10): output = model(batch) loss = criterion(output, target) loss.backward() iteration_time = (time.time() - start) / 10 # Calculate total training time # Assume 100 epochs, 50k samples iterations_per_epoch = 50000 // batch_size total_iterations = iterations_per_epoch * 100 total_time = total_iterations * iteration_time results[batch_size] = { 'memory_mb': memory_mb, 'iteration_time_ms': iteration_time * 1000, 'total_time_hours': total_time / 3600, } except RuntimeError as e: results[batch_size] = {'error': str(e)} # Find optimal based on criteria if target_accuracy is not None: # Choose smallest batch that achieves target accuracy return min(results.keys()) elif time_budget_seconds is not None: # Choose largest batch within time budget valid = {bs: r for bs, r in results.items() if 'error' not in r and r['total_time_hours'] * 3600 < time_budget_seconds} return max(valid.keys()) if valid else None else: # Default: choose largest batch within 80% memory limit memory_limit = 0.8 * torch.cuda.get_device_properties(device).total_memory / (1024**2) valid = {bs: r for bs, r in results.items() if 'error' not in r and r['memory_mb'] < memory_limit} return max(valid.keys()) if valid else None # Batch size discovery loop def discover_optimal_batch_size(model, train_loader, criterion, device): """ Progressive batch size search starting from small. Pattern: Double batch size until OOM, then back off. """ batch_size = 8 while True: try: # Try current batch size batch, target = next(iter(train_loader)) batch = batch[:batch_size].to(device) target = target[:batch_size].to(device) output = model(batch) loss = criterion(output, target) loss.backward() print(f"✓ Batch {batch_size} works") # Try 2x prev_batch = batch_size batch_size *= 2 except RuntimeError as e: if "out of memory" in str(e).lower(): # OOM: go back to last working batch optimal_batch = prev_batch print(f"✗ Batch {batch_size} OOM") print(f"→ Use batch size {optimal_batch} (safe margin)") # But check if we can use 1.5x test_batch = int(optimal_batch * 1.5) try: batch = batch[:test_batch].to(device) output = model(batch) loss = criterion(output, target) loss.backward() print(f"✓ Batch {test_batch} also works, use this") return test_batch except: return optimal_batch else: raise ``` **Batch size selection by use case:** ```python # Use Case 1: Maximum accuracy matters (research, publication) # → Choose smaller batch (128-256) # → More gradient noise = better generalization # → Willing to train longer if accuracy is better optimal_batch_size = 128 # Use Case 2: Training speed matters (prototyping, iteration) # → Choose larger batch (512-1024) # → Trade some accuracy for wall-clock speed # → Need to add regularization to reduce generalization gap optimal_batch_size = 512 regularization_strength = 'strong' # weight_decay, dropout # Use Case 3: Memory severely constrained (mobile, edge) # → Choose small batch (16-32) # → Use gradient accumulation to simulate larger batch # → Accept lower accuracy if necessary optimal_batch_size = 16 accumulation_steps = 8 # Simulate batch 128 # Use Case 4: Fine-tuning small dataset # → Choose small batch (16-32) # → Preserve pre-training (smaller updates) # → Larger batch risks forgetting pre-trained knowledge optimal_batch_size = 16 # Use Case 5: Large model, large dataset # → Choose medium-large batch (256-512) # → Gradient accumulation for effective larger batch # → Mixed precision for memory savings optimal_batch_size = 256 use_mixed_precision = True use_gradient_accumulation = False # Fits with mixed precision # Use Case 6: Distributed training (multiple GPUs/TPUs) # → Per-GPU batch: 32-64 # → Accumulation: 4-8 steps # → Total effective: per_gpu * num_gpus * accumulation # → Large total effective batch, small per-GPU batch per_gpu_batch = 64 num_gpus = 8 accumulation_steps = 4 effective_batch = 64 * 8 * 4 # 2048 ``` ## Common Pitfalls ❌ **Pitfall 1: Confusing Maximum Batch with Optimal Batch** → **Symptom**: "I have 24GB memory, so I should use the largest batch that fits" → **Why it breaks**: Larger batch = worse generalization. Maximum batch might achieve 2-3% lower accuracy. → **Fix**: Use 80% of maximum batch size, validate accuracy, adjust if needed. ```python # WRONG max_batch = find_max_batch_that_fits(model, memory=24_000_000_000) train(model, batch_size=max_batch) # Likely overfit # CORRECT safe_batch = int(max_batch * 0.8) # 80% of maximum train(model, batch_size=safe_batch) validate_accuracy(model) # Check if acceptable if accuracy_drop > 2%: reduce_batch_size(safe_batch * 0.8) add_regularization() ``` ❌ **Pitfall 2: Ignoring Learning Rate Scaling** → **Symptom**: "I doubled my batch size, training diverges now" → **Why it breaks**: Gradient magnitudes decrease with larger batch, so learning rate must increase proportionally. → **Fix**: Use linear scaling rule: new_lr = old_lr × (new_batch / old_batch) ```python # WRONG batch_size = 64 learning_rate = 0.001 # Increase batch without adjusting LR batch_size = 256 # Learning rate still 0.001 - too small! # Gradient updates too conservative, very slow convergence # CORRECT batch_size = 64 learning_rate = 0.001 batch_size = 256 learning_rate = 0.001 * (256 / 64) # Scale by 4x # = 0.004 ``` ❌ **Pitfall 3: Using Huge Learning Rate Without Warmup** → **Symptom**: "I scaled my learning rate by 10x and now training diverges immediately" → **Why it breaks**: Very large learning rate jumps cause instability. Model can't adapt. → **Fix**: Add linear warmup phase: gradually increase LR from 0 to scaled value. ```python # WRONG scaled_lr = 0.001 * 10 # 0.01 optimizer = SGD(model, lr=0.01) for epoch in range(100): for batch in train_loader: loss = criterion(model(batch), target) loss.backward() optimizer.step() # Diverges on first iteration! # CORRECT base_lr = 0.001 scaled_lr = 0.001 * 10 # 0.01 warmup_steps = 1000 def lr_lambda(step): if step < warmup_steps: return float(step) / float(max(1, warmup_steps)) * 10 # 0 to 10x over warmup return 1.0 # 10x after warmup optimizer = SGD(model, lr=base_lr) scheduler = LambdaLR(optimizer, lr_lambda) for epoch in range(100): for batch in train_loader: loss = criterion(model(batch), target) loss.backward() optimizer.step() scheduler.step() ``` ❌ **Pitfall 4: Gradient Accumulation Without LR Adjustment** → **Symptom**: "I added gradient accumulation but training is much slower to converge" → **Why it breaks**: Accumulation itself doesn't require LR change, but if effective batch increased, LR should too. → **Fix**: Adjust LR based on effective batch size, not per-GPU batch size. ```python # WRONG batch_size = 32 # Per-GPU num_accumulation = 8 # Learning rate still tuned for batch 32 # Effective batch = 32 × 8 = 256 # But LR not scaled for batch 256 # Convergence slower because LR too conservative # CORRECT batch_size = 32 num_accumulation = 8 effective_batch = batch_size * num_accumulation # 256 # Get LR for batch 32 base_lr_batch32 = 0.001 # Scale for batch 256 lr_batch256 = base_lr_batch32 * (256 / 32) # 0.008 optimizer = SGD(model, lr=lr_batch256) ``` ❌ **Pitfall 5: Assuming Batch Size Doesn't Affect Accuracy** → **Symptom**: "Batch size only affects speed, not accuracy" → **Why it breaks**: Batch size strongly affects generalization (1-4% gap is common). → **Fix**: Always validate final accuracy at different batch sizes. Larger batch might need different hyperparameters. ```python # WRONG - assume accuracy independent of batch batch_sizes = [64, 256, 1024] for batch_size in batch_sizes: model = train(learning_rate=0.001) # Same LR for all! accuracy = evaluate(model) # Accuracy will differ significantly! # CORRECT - adjust hyperparameters per batch for batch_size in batch_sizes: lr = 0.001 * (batch_size / 64) # Scale LR weight_decay = 0.0001 * (batch_size / 64) ** 0.5 # Increase regularization model = train(learning_rate=lr, weight_decay=weight_decay) accuracy = evaluate(model) # More consistent accuracy across batch sizes ``` ❌ **Pitfall 6: Not Considering Synchronous vs Asynchronous Batch Norm** → **Symptom**: "My distributed training accuracy is much worse than single-GPU" → **Why it breaks**: Batch norm computes statistics per batch. Distributed training with small per-GPU batch = incorrect statistics. → **Fix**: Use SyncBatchNorm for correct statistics across all GPUs. ```python # WRONG - Synchronous data parallel, asynchronous BN from torch.nn.parallel import DataParallel model = DataParallel(model, device_ids=[0, 1, 2, 3]) # Each GPU has batch_size=32 # BN computes stats from only its 32 samples # Stats unstable, training broken # CORRECT - Synchronous batch norm from torch.nn.modules.batchnorm import SyncBatchNorm model = SyncBatchNorm.convert_sync_batchnorm(model) model = DistributedDataParallel(model, find_unused_parameters=False) # Each GPU: batch 32, but BN aggregates across all 4 GPUs = 128 # Stats computed from all 128 samples, stable ``` ❌ **Pitfall 7: Gradient Accumulation Too Large (>16x)** → **Symptom**: "I'm accumulating gradients over 32 steps but training diverges" → **Why it breaks**: Large accumulation means many iterations of gradient computation before update. Gradients become stale, divergence risk. → **Fix**: Keep accumulation ≤ 16x. Use distributed training for larger effective batches. ```python # WRONG - excessive accumulation batch_size = 4 accumulation_steps = 32 # 128x effective batch! # Gradients from step 1 are way out of date by step 32 # Large variance in gradient estimates, divergence # CORRECT - reasonable accumulation batch_size = 32 accumulation_steps = 8 # 256x effective batch, acceptable # Gradients only 8 iterations old by update time # Variance manageable # OR use distributed training instead per_gpu_batch = 32 num_gpus = 8 effective_batch = 32 * 8 = 256 # Same as above, but no accumulation # Better convergence properties ``` ❌ **Pitfall 8: Mixing Gradient Accumulation with Exponential Moving Average (EMA)** → **Symptom**: "I'm using gradient accumulation with learning rate scheduler and EMA, but training is unstable" → **Why it breaks**: EMA expects one update per step. With accumulation, multiple backward passes → stale momentum terms. → **Fix**: Update EMA only when you call optimizer.step(), not every backward pass. ```python # WRONG - updating EMA every backward pass ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999) for step, batch in enumerate(train_loader): loss = criterion(model(batch), target) loss.backward() ema_model.update(model.parameters()) # Called every iteration! if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() # CORRECT - update EMA only on optimizer.step() for step, batch in enumerate(train_loader): loss = criterion(model(batch), target) loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() ema_model.update(model.parameters()) # Only here! ``` ❌ **Pitfall 9: Batch Size Doubling Without Validation** → **Symptom**: "I increased batch from 64 to 128 based on linear scaling rule, but accuracy dropped 2%" → **Why it breaks**: Linear scaling rule gives convergence rate, not accuracy guarantee. Generalization gap widens. → **Fix**: Always validate on holdout set when changing batch size. Accept accuracy drop or add regularization. ```python # WRONG - assume linear scaling guarantees accuracy original_batch = 64 original_lr = 0.001 original_accuracy = 0.85 new_batch = 128 new_lr = 0.001 * (128 / 64) # 0.002 new_accuracy = 0.83 # Dropped 2%! Should have validated first # CORRECT - validate and adjust regularization if needed new_batch = 128 new_lr = 0.001 * (128 / 64) model = train(lr=new_lr, batch=new_batch) val_accuracy = validate(model) if val_accuracy < 0.84: # Acceptable drop? # Add regularization for larger batch model = train( lr=new_lr, batch=new_batch, weight_decay=0.0002, # Increase dropout=0.2, # Add/increase ) val_accuracy = validate(model) ``` ❌ **Pitfall 10: Using Maximum Batch in Fine-tuning** → **Symptom**: "I fine-tuned with large batch size and catastrophically forgot pre-training" → **Why it breaks**: Large batch = large updates. Pre-trained weights overwritten too quickly. → **Fix**: Fine-tuning requires SMALLER batch size (32-64) and smaller learning rate than pre-training. ```python # WRONG - fine-tuning with large batch pretrained_model = load_pretrained_bert() batch_size = 512 # Large! learning_rate = 0.001 # Too large! model = fine_tune(pretrained_model, batch_size=512, lr=0.001) # Overfit to task, forget pre-trained knowledge # Pre-training lost! # CORRECT - conservative fine-tuning pretrained_model = load_pretrained_bert() batch_size = 32 # Small, conservative learning_rate = 0.00001 # Tiny, preserve pre-training model = fine_tune( pretrained_model, batch_size=batch_size, lr=learning_rate, weight_decay=0.001, # Strong L2 regularization ) # Preserves pre-training knowledge, adapts carefully ``` ## Practical Decision Framework ### Quick Batch Size Decision Tree ``` 1. How much GPU memory do you have? ├─ < 8 GB: Start with batch 16-32 ├─ 8-16 GB: Start with batch 32-64 ├─ 16-24 GB: Start with batch 64-128 └─ 24+ GB: Start with batch 128-256 2. Can you fit your target batch in memory? ├─ Yes: Use it (with LR scaling) ├─ No, by <2x: Use gradient accumulation └─ No, by >2x: Use smaller batch + stronger regularization 3. Is accuracy your priority or speed? ├─ Accuracy: Use smaller batch (32-128) ├─ Speed: Use larger batch (256-1024) └─ Both: Gradient accumulation + mixed precision 4. Are you fine-tuning or training from scratch? ├─ Fine-tuning: Use small batch (16-32), small LR └─ From scratch: Use medium batch (64-256), scale LR 5. Are you using distributed training? ├─ Yes: Per-GPU batch 32-64, accumulate for effective 256-512 └─ No: Single GPU batch 64-256 ``` ## Red Flags - Stop and Clarify | Excuse | Reality | What To Do | |--------|---------|-----------| | "Just use the maximum batch that fits" | Worse generalization likely. Need to validate accuracy. | Measure accuracy at 80% of max, validate trade-offs. | | "Linear scaling rule means I don't need to validate" | Rule gives convergence rate, not accuracy guarantee. Generalization gap exists. | Always validate final accuracy with new batch size. | | "Gradient accumulation is just for memory-constrained settings" | It's a legitimate technique with trade-offs (slowness) worth understanding. | Use when memory constrained; understand slowdown cost. | | "Batch size only affects speed, not accuracy" | Incorrect. Batch size strongly affects final accuracy (1-4% typical gap). | Always measure accuracy, expect gap, add regularization. | | "I'll use the batch size from a paper, it should work" | Different model, data, hardware - need to validate. | Use paper as starting point, but validate and adjust. | | "Larger batch = faster training" | Depends on what you measure (iterations vs epochs vs wall-clock). | Measure actual wall-clock time at different batch sizes. | | "Just double the learning rate when doubling batch" | Linear scaling rule requires warmup for large increases. | Add warmup phase, measure convergence. | | "Fine-tuning works same as pre-training, just different data" | Fine-tuning needs much smaller batch and LR (preserve pre-training). | Use batch 16-32, LR 10-100x smaller than pre-training. | ## Advanced Patterns: Batch Size Optimization in Production ### Pattern 7: Batch Size Scheduling During Training **Increasing batch size as training progresses - when and why:** ```python # Intuition: Start with small batch (good generalization), # increase later (finish training faster) def get_scheduled_batch_size(epoch, total_epochs, base_batch=32, max_batch=256): """ Increase batch size linearly with epochs. WHY: Start small for generalization, increase for speed later. Research shows this works well for long training. """ # Linear increase: 0 → 100% over training scale = epoch / total_epochs return int(base_batch + (max_batch - base_batch) * scale) # Usage in training loop for epoch in range(total_epochs): batch_size = get_scheduled_batch_size(epoch, total_epochs) for batch, target in get_data_loader(batch_size=batch_size): # Adjust learning rate dynamically lr = 0.001 * (batch_size / 32) # Scale with batch update_learning_rate(optimizer, lr) output = model(batch) loss = criterion(output, target) loss.backward() optimizer.step() # Alternative: exponential schedule def get_exponential_batch_schedule(epoch, base_batch=32, max_batch=256): """Exponential increase instead of linear""" scale = (epoch / total_epochs) return int(base_batch * (max_batch / base_batch) ** scale) ``` **When batch size scheduling is valuable:** ``` GOOD FIT: - Long training (100+ epochs) - Starting generalization is important - Speed only matters at end - Example: ResNet on ImageNet NOT NEEDED: - Short training (10-20 epochs) - Already regularized enough (BERT fine-tuning) - Batch size well-chosen from start ``` ### Pattern 8: Batch Size vs Other Hyperparameters **Understanding interactions with other hyperparameters:** ```python # Interaction 1: Batch size ↔ Learning rate # Already covered: linear scaling rule # Interaction 2: Batch size ↔ Weight decay # Larger batch → worse generalization # Solution: Increase weight decay when increasing batch # Typical: weight_decay ~ sqrt(batch_size) def adjust_weight_decay(base_wd=0.0001, base_batch=256, new_batch=512): """Scale weight decay with batch size""" return base_wd * (new_batch / base_batch) ** 0.5 # Example wd_batch_256 = 0.0001 wd_batch_512 = adjust_weight_decay(wd_batch_256, 256, 512) # 0.000141 # Interaction 3: Batch size ↔ Dropout # Larger batch → add/increase dropout # Dropout magnitude depends on layer, typically 0.1-0.5 def adjust_dropout(base_dropout=0.1, base_batch=256, new_batch=512): """Increase dropout for larger batches""" # Dropout strength ~ sqrt(batch_size) scale = (new_batch / base_batch) ** 0.5 return min(base_dropout * scale, 0.5) # Cap at 0.5 # Interaction 4: Batch size ↔ Number of epochs # Larger batch → more epochs needed to converge # Typical: iterations constant ≈ samples/batch × epochs # If batch 4x → epochs 1.5-2x to match convergence base_batch = 64 base_epochs = 100 base_iterations = (50000 / base_batch) * base_epochs # Total iterations new_batch = 256 # To maintain same iterations: new_epochs = base_iterations / (50000 / new_batch) # ~25 epochs # Wall-clock faster (fewer iterations) but need fewer epochs # Interaction 5: Batch size ↔ Optimizer choice # SGD: works well at all batch sizes # Momentum: accumulates larger steps, works best with smaller batch # Adam: adaptive, less sensitive to batch size # RMSprop: similar to Adam # Recommendation: # - Small batch (32-128): SGD with momentum or Adam # - Large batch (512+): Adam (more stable) or SGD with warmup + large LR # Interaction 6: Batch size ↔ Normalization technique # Batch Norm: statistics from batch, larger batch = better stats # Layer Norm: independent of batch size # Group Norm: middle ground, works well with any batch size # If using BatchNorm with small batch (< 16): # → Use SyncBatchNorm across devices # → Or use GroupNorm instead # If using BatchNorm with large batch (> 1024): # → Standard BatchNorm fine # → May want to reduce BN momentum (accumulate stats slower) ``` ## Rationalization Table: Common Excuses About Batch Size | Rationalization | Why It's Wrong | Correct Approach | |---|---|---| | "Larger batch is always better for speed" | Wall-clock time depends on iterations AND time-per-iteration. Larger batch may have lower throughput. | Profile wall-clock time at different batch sizes, choose fastest. | | "I'll tune batch size last, it's not important" | Batch size affects convergence rate, generalization, and stability early. Tuning last wastes time. | Choose good batch size early (based on memory), validate accuracy. | | "Maximum batch that fits = optimal batch" | Generalization gap widens with batch size (1-4% typical). Maximum might hit accuracy target. | Use 80% of max, validate on validation set, adjust if needed. | | "Linear scaling rule means I don't validate" | Scaling rule gives convergence rate. Accuracy still varies with batch size due to generalization gap. | Always validate test/validation accuracy with new batch. | | "Gradient accumulation is slow, don't use it" | True, it's slower (1.5-2x). But if memory is bottleneck, only alternative. Choose based on constraints. | Use when memory constrained. Accept slowdown. Don't use if memory OK. | | "I don't need warmup, I'll just use scaled LR" | Large LR jumps cause divergence. Warmup prevents this. | Add linear warmup phase for scaled LR. | | "My paper used batch X, I'll use that" | Different model, data, hardware converge differently. Paper batch might not be optimal for you. | Use paper as starting point. Validate and adjust for your setup. | | "Fine-tuning uses same batch as pre-training" | Fine-tuning needs much smaller batch (preserve knowledge). Using pre-training batch erases pre-training. | Use batch 10-20x smaller than pre-training. Use tiny LR. | | "Batch size only affects speed, not accuracy" | Batch size strongly affects generalization (1-4% gap common). Different final accuracy with different batch. | Expect accuracy variation with batch. Validate at each batch size. | | "I increased batch, why is training slower?" | Fewer iterations (good) but longer per-iteration (bad). Total wall-clock = iterations × time-per-iteration. | Profile actual wall-clock time. May need gradient accumulation. | | "I'll start with large batch to save memory" | Large batch → bad generalization early → harder to recover later. Start small, increase if needed. | Start with batch 32-64, increase during training if memory allows. | ## Comprehensive Example: Training a Vision Transformer Let's put it all together with a real example: ```python import torch from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torch.optim import AdamW from torchvision import models, datasets, transforms def train_vision_transformer_optimized(): """ Complete example: training Vision Transformer with batch size optimization. """ # Step 1: Model and data device = torch.device("cuda:0") model = models.vit_b_16(pretrained=False).to(device) criterion = torch.nn.CrossEntropyLoss() # Dataset (ImageNet-scale) transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Step 2: Determine batch size # ViT-Base: 86M parameters # GPU: 40GB A100 # Memory estimate: params (344MB) + optimizer (688MB) + activations # Can fit batch 256-512 base_batch = 256 num_accumulation_steps = 1 # Can fit directly effective_batch = base_batch # Step 3: Initialize optimizer with scaled LR # Base LR tuned for batch 256 base_lr = 1e-4 scaled_lr = base_lr * (effective_batch / 256) # 1e-4 (no scaling needed) optimizer = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.05) # Step 4: Warmup scheduler warmup_steps = 1000 total_steps = 100 * len(dataset) // effective_batch def warmup_cosine_schedule(step): if step < warmup_steps: return float(step) / float(max(1, warmup_steps)) return 0.5 * (1.0 + torch.cos( torch.tensor(3.14159) * (step - warmup_steps) / (total_steps - warmup_steps) )).item() scheduler = LambdaLR(optimizer, warmup_cosine_schedule) # Step 5: Training loop with gradient accumulation (even though not needed) # Good practice for larger models model.train() optimizer.zero_grad() for epoch in range(100): for step, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) # Forward + backward logits = model(images) loss = criterion(logits, labels) loss = loss / num_accumulation_steps loss.backward() # Update on accumulation step if (step + 1) % num_accumulation_steps == 0: # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() if step % 100 == 0: print(f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, " f"lr {optimizer.param_groups[0]['lr']:.2e}") # Validate every epoch val_accuracy = validate(model, device) print(f"Epoch {epoch} validation accuracy: {val_accuracy:.2%}") return model # Key patterns demonstrated: # 1. Batch size chosen based on memory (80% of max) # 2. Learning rate scaled for batch size # 3. Warmup phase for gradual LR increase # 4. Cosine annealing for LR decay # 5. Gradient accumulation structure (even if not needed) # 6. Gradient clipping for stability # 7. Regular validation to monitor accuracy ``` ## Summary: Batch Size and Memory Decision Making **The core principle:** Batch size is a system design choice affecting convergence, generalization, speed, and memory simultaneously. There is no universal "right" batch size - it depends on your constraints and priorities. **The decision process:** 1. **Memory constraint**: Start with 80% of maximum batch 2. **Convergence**: Scale learning rate 1:1 with batch increase (with warmup) 3. **Generalization**: Validate accuracy, reduce if gap >2% (or add regularization) 4. **Performance**: Profile wall-clock time at different batch sizes 5. **Architecture**: Different models have different optimal batches **The key insights:** - Larger batch = faster iterations but worse generalization - Linear scaling rule requires warmup for large increases - Gradient accumulation is a legitimate technique (understand slowdown cost) - Fine-tuning requires smaller batch than pre-training - Distributed training needs care with batch norm and gradient updates - Always measure, validate, and adjust - don't assume rules apply to your case **The testing approach:** When pressure-tested, this skill should: - Explain why maximum batch ≠ optimal batch (generalization gap) - Provide concrete examples of linear scaling rule with warmup - Address gradient accumulation systematically (when, why, cost) - Discuss memory estimation and optimization techniques - Help select batch size based on constraints AND priorities - Resist rationalizations and always recommend validation ## References and Further Reading **Key papers:** - Goyal et al. (2017) "Accurate, Large Batch Training" - Linear scaling rule - You et al. (2019) "Large Batch Optimization for Deep Learning" - Theory - Smith et al. (2017) "Don't Decay the Learning Rate" - Learning rate schedules **Techniques mentioned:** - Batch Normalization: Ioffe & Szegedy (2015) - Layer Normalization: Ba et al. (2016) - Mixed Precision Training: Micikevicius et al. (2017) - Gradient Checkpointing: Chen et al. (2016) **Related Yzmir Skills:** - `learning-rate-scheduling` - LR schedule choices beyond linear scaling - `gradient-management` - Gradient clipping and accumulation for stability - `optimization-algorithms` - Optimizer selection and hyperparameter tuning