commit 52ec22cfac8965cbbd6156131aed5c4b0b893ba3 Author: Zhongwei Li Date: Sun Nov 30 09:00:03 2025 +0800 Initial commit diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json new file mode 100644 index 0000000..7e8626c --- /dev/null +++ b/.claude-plugin/plugin.json @@ -0,0 +1,12 @@ +{ + "name": "yzmir-pytorch-engineering", + "description": "PyTorch mastery - tensors, modules, distributed training, profiling - 9 skills", + "version": "1.0.1", + "author": { + "name": "tachyon-beep", + "url": "https://github.com/tachyon-beep" + }, + "skills": [ + "./skills" + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..0e345e7 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# yzmir-pytorch-engineering + +PyTorch mastery - tensors, modules, distributed training, profiling - 9 skills diff --git a/plugin.lock.json b/plugin.lock.json new file mode 100644 index 0000000..d782066 --- /dev/null +++ b/plugin.lock.json @@ -0,0 +1,77 @@ +{ + "$schema": "internal://schemas/plugin.lock.v1.json", + "pluginId": "gh:tachyon-beep/skillpacks:plugins/yzmir-pytorch-engineering", + "normalized": { + "repo": null, + "ref": "refs/tags/v20251128.0", + "commit": "c45b1555fb4c8666e53c2b16c586491a1a5d0ee5", + "treeHash": "11918571831c1e52d109e8e9c742724b7031eb7d1007e15048993003e618372e", + "generatedAt": "2025-11-28T10:28:34.440590Z", + "toolVersion": "publish_plugins.py@0.2.0" + }, + "origin": { + "remote": "git@github.com:zhongweili/42plugin-data.git", + "branch": "master", + "commit": "aa1497ed0949fd50e99e70d6324a29c5b34f9390", + "repoRoot": "/Users/zhongweili/projects/openmind/42plugin-data" + }, + "manifest": { + "name": "yzmir-pytorch-engineering", + "description": "PyTorch mastery - tensors, modules, distributed training, profiling - 9 skills", + "version": "1.0.1" + }, + "content": { + "files": [ + { + "path": "README.md", + "sha256": "edbafa09f91147fc403a0decd006b74c60f5bae5f7003694a1e59cba4c80a475" + }, + { + "path": ".claude-plugin/plugin.json", + "sha256": "aecde7740c7f61fd33ec95110f086f29aa136434bddb8ef4d78f1be539062845" + }, + { + "path": "skills/using-pytorch-engineering/custom-autograd-functions.md", + "sha256": "8edecd99f3c62600b6fee4bb3f2bd4b44a8b30d9e29a0263fac6c3b6665be882" + }, + { + "path": "skills/using-pytorch-engineering/module-design-patterns.md", + "sha256": "89a1eabb0dc7ab560d1f68c164c7a8fa389e2c1a0492e9720dc974d4d86d0ec3" + }, + { + "path": "skills/using-pytorch-engineering/mixed-precision-and-optimization.md", + "sha256": "15a5be06a2199d0883c25c70c7be7fa662ccf01f4a5c33ec59bc64fbad11666b" + }, + { + "path": "skills/using-pytorch-engineering/checkpointing-and-reproducibility.md", + "sha256": "940e6fd7f616418a9884c41d8d1e63ce56f038731c793b071cc08067e41bb7e2" + }, + { + "path": "skills/using-pytorch-engineering/debugging-techniques.md", + "sha256": "7c8992277f22d37f4b4f8d091db348368f22d3f83a4bb7cd00d4e6ae8016da9e" + }, + { + "path": "skills/using-pytorch-engineering/distributed-training-strategies.md", + "sha256": "d6defa1eabb4667825b3fb5e2c63330a2300d5824e0c33d5af88d1058c7b9cbe" + }, + { + "path": "skills/using-pytorch-engineering/SKILL.md", + "sha256": "c1be23b09cb4863234b3b9cb915260adbb0580c061be3f7961320bdd79567c30" + }, + { + "path": "skills/using-pytorch-engineering/tensor-operations-and-memory.md", + "sha256": "1809a55e08c1bfb8254611eef381711ba7b771573d49e158265a0e1c80b6e3b3" + }, + { + "path": "skills/using-pytorch-engineering/performance-profiling.md", + "sha256": "87684ad0aeebf1b274d74cc98d83ac033490df027edf5a1281996824550cfce4" + } + ], + "dirSha256": "11918571831c1e52d109e8e9c742724b7031eb7d1007e15048993003e618372e" + }, + "security": { + "scannedAt": null, + "scannerVersion": null, + "flags": [] + } +} \ No newline at end of file diff --git a/skills/using-pytorch-engineering/SKILL.md b/skills/using-pytorch-engineering/SKILL.md new file mode 100644 index 0000000..bc02560 --- /dev/null +++ b/skills/using-pytorch-engineering/SKILL.md @@ -0,0 +1,383 @@ +--- +name: using-pytorch-engineering +description: Routes to appropriate PyTorch specialist skill based on symptoms and problem type +mode: true +--- + +# Using PyTorch Engineering + +## Overview + +This meta-skill routes you to the right PyTorch specialist based on symptoms. PyTorch engineering problems fall into distinct categories that require specialized knowledge. Load this skill when you encounter PyTorch-specific issues but aren't sure which specialized skill to use. + +**Core Principle**: Different PyTorch problems require different specialists. Match symptoms to the appropriate specialist skill. Don't guess at solutions—route to the expert. + +## When to Use + +Load this skill when: +- Working with PyTorch and encountering problems +- User mentions: "PyTorch", "torch", "CUDA", "GPU", "distributed training" +- Need to implement PyTorch models or optimize performance +- Debugging PyTorch training issues +- Setting up production PyTorch infrastructure + +**Don't use for**: Framework-agnostic ML theory, non-PyTorch frameworks, algorithm selection (use training-optimization or other packs) + +--- + +## Routing by Symptom + +### Memory Issues + +**Symptoms**: +- "CUDA out of memory" +- "OOM error" +- "RuntimeError: CUDA out of memory" +- "GPU memory usage too high" +- "tensor memory leak" +- "memory consumption increasing" + +**Route to**: See [tensor-operations-and-memory.md](tensor-operations-and-memory.md) for memory management and optimization. + +**Why**: Memory management is foundational. Must understand tensor lifecycles, efficient operations, and profiling before other optimizations. + +**Example queries**: +- "Getting OOM after a few batches" +- "How to reduce memory usage?" +- "Memory grows over time during training" + +--- + +### Module and Model Design + +**Symptoms**: +- "How to structure my PyTorch model?" +- "Custom layer implementation" +- "nn.Module best practices" +- "Forward/backward pass design" +- "Model architecture implementation" +- "Parameter initialization" + +**Route to**: See [module-design-patterns.md](module-design-patterns.md) for model architecture and nn.Module patterns. + +**Why**: Proper module design prevents bugs and enables features like checkpointing, distributed training, and serialization. + +**Example queries**: +- "Building custom ResNet variant" +- "How to organize model components?" +- "Module initialization best practices" + +--- + +### Distributed Training Setup + +**Symptoms**: +- "Multiple GPUs" +- "DistributedDataParallel" +- "DDP" +- "Multi-node training" +- "Scale training to N GPUs" +- "torch.distributed" +- "NCCL" + +**Route to**: See [distributed-training-strategies.md](distributed-training-strategies.md) for DDP setup and multi-GPU training. + +**Why**: Distributed training has unique setup requirements, synchronization patterns, and pitfalls. Generic advice breaks in distributed settings. + +**Example queries**: +- "Setup DDP for 8 GPUs" +- "Multi-node training not working" +- "How to launch distributed training?" + +--- + +### Performance and Speed + +**Symptoms**: +- "Training too slow" +- "Low GPU utilization" +- "Iterations per second" +- "Throughput" +- "Performance optimization" +- "Speed up training" + +**Route to**: See [performance-profiling.md](performance-profiling.md) FIRST for systematic bottleneck identification. + +**Why**: MUST profile before optimizing. Many "performance" problems are actually data loading or other non-compute bottlenecks. Profile to identify the real bottleneck. + +**After profiling**, may route to: +- [mixed-precision-and-optimization.md](mixed-precision-and-optimization.md) if compute-bound +- [tensor-operations-and-memory.md](tensor-operations-and-memory.md) if memory-bound +- [distributed-training-strategies.md](distributed-training-strategies.md) if need to scale + +**Example queries**: +- "Training is slow, how to speed up?" +- "GPU usage is only 30%" +- "Bottleneck in my training loop" + +--- + +### Mixed Precision and Optimization + +**Symptoms**: +- "Mixed precision" +- "FP16", "BF16" +- "torch.cuda.amp" +- "Automatic mixed precision" +- "AMP" +- "TF32" + +**Route to**: See [mixed-precision-and-optimization.md](mixed-precision-and-optimization.md) for AMP and numerical stability. + +**Why**: Mixed precision requires careful handling of numerical stability, gradient scaling, and operation compatibility. + +**Example queries**: +- "How to use mixed precision training?" +- "AMP causing NaN losses" +- "FP16 vs BF16 for my model" + +--- + +### Training Instability and NaN + +**Symptoms**: +- "NaN loss" +- "Inf gradients" +- "Loss exploding" +- "Training becomes unstable" +- "Gradients are NaN" +- "Model diverging" + +**Route to**: See [debugging-techniques.md](debugging-techniques.md) for systematic NaN/Inf debugging. + +**Why**: NaN/Inf issues require systematic debugging—checking gradients layer by layer, identifying numerical instability sources, and targeted fixes. + +**Example queries**: +- "Loss becomes NaN after epoch 3" +- "How to debug gradient explosion?" +- "Model outputs Inf values" + +--- + +### Checkpointing and State Management + +**Symptoms**: +- "Save model" +- "Resume training" +- "Checkpoint" +- "Reproducible training" +- "Save optimizer state" +- "Load pretrained weights" + +**Route to**: See [checkpointing-and-reproducibility.md](checkpointing-and-reproducibility.md) for complete state management. + +**Why**: Proper checkpointing requires saving ALL state (model, optimizer, scheduler, RNG states). Reproducibility requires deterministic operations and careful seed management. + +**Example queries**: +- "How to checkpoint training properly?" +- "Resume from checkpoint" +- "Make training reproducible" + +--- + +### Custom Operations and Autograd + +**Symptoms**: +- "Custom backward pass" +- "torch.autograd.Function" +- "Define custom gradient" +- "Efficient custom operation" +- "Non-differentiable operation" +- "Custom CUDA kernel" + +**Route to**: See [custom-autograd-functions.md](custom-autograd-functions.md) for custom backward passes. + +**Why**: Custom autograd functions require understanding the autograd engine, proper gradient computation, and numerical stability. + +**Example queries**: +- "Implement custom activation with gradient" +- "Efficient backwards pass for my operation" +- "How to use torch.autograd.Function?" + +--- + +## Cross-Cutting Scenarios + +### Multiple Skills Needed + +Some scenarios require multiple specialized skills in sequence: + +**Distributed training with memory constraints**: +1. Route to [distributed-training-strategies.md](distributed-training-strategies.md) (setup) +2. THEN [tensor-operations-and-memory.md](tensor-operations-and-memory.md) (optimize per-GPU memory) + +**Performance optimization**: +1. Route to [performance-profiling.md](performance-profiling.md) (identify bottleneck) +2. THEN appropriate skill based on bottleneck: + - Compute → [mixed-precision-and-optimization.md](mixed-precision-and-optimization.md) + - Memory → [tensor-operations-and-memory.md](tensor-operations-and-memory.md) + - Scale → [distributed-training-strategies.md](distributed-training-strategies.md) + +**Custom module with proper patterns**: +1. Route to [module-design-patterns.md](module-design-patterns.md) (structure) +2. THEN [custom-autograd-functions.md](custom-autograd-functions.md) if custom backward needed + +**Training instability with mixed precision**: +1. Route to [debugging-techniques.md](debugging-techniques.md) (diagnose root cause) +2. May need [mixed-precision-and-optimization.md](mixed-precision-and-optimization.md) for gradient scaling + +**Load in order of execution**: Setup before optimization, diagnosis before fixes, structure before customization. + +--- + +## Ambiguous Queries - Ask First + +When symptom unclear, ASK ONE clarifying question: + +**"Fix my PyTorch training"** +→ Ask: "What specific issue? Memory? Speed? Accuracy? NaN?" + +**"Optimize my model"** +→ Ask: "Optimize what? Training speed? Memory usage? Inference?" + +**"Setup distributed training"** +→ Ask: "Single-node multi-GPU or multi-node? What's not working?" + +**"Model not working"** +→ Ask: "What's broken? Training fails? Wrong outputs? Performance?" + +**Never guess when ambiguous. Ask once, route accurately.** + +--- + +## Common Routing Mistakes + +| Symptom | Wrong Route | Correct Route | Why | +|---------|-------------|---------------|-----| +| "Training slow" | mixed-precision | performance-profiling FIRST | Don't optimize without profiling | +| "OOM in distributed" | tensor-memory | distributed-strategies FIRST | Distributed setup might be wrong | +| "Custom layer slow" | performance-profiling | module-design-patterns FIRST | Design might be inefficient | +| "NaN with AMP" | mixed-precision | debugging-techniques FIRST | Debug NaN source, then fix AMP | +| "Save model" | module-design | checkpointing FIRST | Checkpointing is specialized topic | + +**Key principle**: Diagnosis before solutions, setup before optimization, root cause before fixes. + +--- + +## Red Flags - Stop and Route + +If you catch yourself about to: +- Suggest reducing batch size → Route to [tensor-operations-and-memory.md](tensor-operations-and-memory.md) for systematic approach +- Show basic DDP code → Route to [distributed-training-strategies.md](distributed-training-strategies.md) for complete setup +- Guess at optimizations → Route to [performance-profiling.md](performance-profiling.md) to measure first +- List possible NaN fixes → Route to [debugging-techniques.md](debugging-techniques.md) for diagnostic methodology +- Show torch.save example → Route to [checkpointing-and-reproducibility.md](checkpointing-and-reproducibility.md) for complete solution + +**All of these mean: You're about to give incomplete advice. Route to the specialist instead.** + +--- + +## Common Rationalizations (Don't Do These) + +| Excuse | Reality | What To Do | +|--------|---------|------------| +| "User is rushed, skip routing" | Routing takes 5 seconds. Wrong fix wastes minutes. | Route anyway - specialists have quick diagnostics | +| "They already tried X" | May have done X wrong, misunderstood, or X wasn't applicable. | Route to specialist to verify X was done correctly | +| "Authority/senior says Y" | Authority can misdiagnose bottlenecks without profiling. | Profile first, authority second. Respect skills over seniority. | +| "User is tired, don't ask" | Exhaustion makes clarity MORE important, not less. | Ask ONE clarifying question - saves time overall | +| "User suggested Z" | Z might not be best option for their specific case. | Route to specialist to evaluate if Z is right approach | +| "Too complex, can't route" | Complex scenarios need specialists MORE, not less. | Use cross-cutting section - route to multiple skills in sequence | +| "User sounds confident" | Confidence about custom autograd often precedes subtle bugs. | Route to specialist for systematic verification | +| "Just a quick question" | No such thing - symptoms need diagnosis. | Quick questions deserve correct answers - route properly | +| "Simple issue" | Simple symptoms can have complex root causes. | Route based on symptoms, not perceived complexity | +| "Direct answer is helpful" | Wrong direct answer wastes time and frustrates user. | Routing to specialist IS the helpful answer | + +**If you catch yourself thinking ANY of these, STOP and route to the specialist.** + +--- + +## Red Flags Checklist - Self-Check Before Answering + +Before giving ANY PyTorch advice, ask yourself: + +1. ❓ **Did I identify the symptom?** + - If no → Read query again, identify symptoms + +2. ❓ **Is this symptom in my routing table?** + - If yes → Route to that specialist + - If no → Ask clarifying question + +3. ❓ **Am I about to give advice directly?** + - If yes → STOP. Why am I not routing? + - Check rationalization table - am I making excuses? + +4. ❓ **Is this a diagnosis issue or solution issue?** + - Diagnosis → Route to profiling/debugging skill FIRST + - Solution → Route to appropriate implementation skill + +5. ❓ **Is query ambiguous?** + - If yes → Ask ONE clarifying question + - If no → Route confidently + +6. ❓ **Am I feeling pressure to skip routing?** + - Time pressure → Route anyway (faster overall) + - Sunk cost → Route anyway (verify first attempt) + - Authority → Route anyway (verify diagnosis) + - Exhaustion → Route anyway (clarity more important) + +**If you failed ANY check above, do NOT give direct advice. Route to specialist or ask clarifying question.** + +--- + +## When NOT to Use PyTorch Skills + +**Skip PyTorch pack when**: +- Choosing algorithms (use training-optimization or algorithm packs) +- Model architecture selection (use neural-architectures) +- Framework-agnostic training issues (use training-optimization) +- Production deployment (use ml-production) + +**PyTorch pack is for**: PyTorch-specific implementation, infrastructure, debugging, and optimization issues. + +--- + +## Diagnosis-First Principle + +**Critical**: Many PyTorch issues require diagnosis before solutions: + +| Issue Type | Diagnosis Skill | Then Solution Skill | +|------------|----------------|---------------------| +| Performance | performance-profiling | mixed-precision / distributed | +| Memory | tensor-memory (profiling section) | tensor-memory (optimization) | +| NaN/Inf | debugging-techniques | mixed-precision / module-design | +| Training bugs | debugging-techniques | Appropriate fix | + +**If unclear what's wrong, route to diagnostic skill first.** + +--- + +## PyTorch Engineering Specialist Skills + +After routing, load the appropriate specialist skill for detailed guidance: + +1. [tensor-operations-and-memory.md](tensor-operations-and-memory.md) - Memory management, efficient operations, profiling +2. [module-design-patterns.md](module-design-patterns.md) - Model structure, nn.Module best practices, initialization +3. [distributed-training-strategies.md](distributed-training-strategies.md) - DDP setup, multi-node, synchronization patterns +4. [mixed-precision-and-optimization.md](mixed-precision-and-optimization.md) - AMP, FP16/BF16, gradient scaling, numerical stability +5. [performance-profiling.md](performance-profiling.md) - PyTorch profiler, bottleneck identification, optimization strategies +6. [debugging-techniques.md](debugging-techniques.md) - NaN/Inf debugging, gradient checking, systematic troubleshooting +7. [checkpointing-and-reproducibility.md](checkpointing-and-reproducibility.md) - Complete checkpointing, RNG state, determinism +8. [custom-autograd-functions.md](custom-autograd-functions.md) - torch.autograd.Function, custom gradients, efficient backward + +--- + +## Integration Notes + +**Phase 1 - Standalone**: PyTorch skills are self-contained + +**Future cross-references**: +- training-optimization (framework-agnostic training techniques) +- neural-architectures (architecture selection before implementation) +- ml-production (deployment after training) + +**Current focus**: Route within PyTorch pack only. Other packs handle other concerns. diff --git a/skills/using-pytorch-engineering/checkpointing-and-reproducibility.md b/skills/using-pytorch-engineering/checkpointing-and-reproducibility.md new file mode 100644 index 0000000..111b590 --- /dev/null +++ b/skills/using-pytorch-engineering/checkpointing-and-reproducibility.md @@ -0,0 +1,1925 @@ + +# Complete Checkpointing and Reproducibility + +## Overview + +**Core Principle:** Incomplete checkpoints cause training divergence on resume. Complete checkpoints include ALL state (model, optimizer, scheduler, epoch, RNG states) needed to continue training from the exact point it stopped. Partial reproducibility from setting one seed is false confidence - true reproducibility requires seeds across PyTorch, CUDA, NumPy, Python, cuDNN settings, and environment variables. In DDP, only rank 0 saves; all ranks load. Strategic checkpoint management (best, last, periodic with cleanup) prevents disk overflow while ensuring recovery capability. + +Checkpoint failures stem from: incomplete state (missing optimizer momentum, wrong learning rate on resume), false reproducibility (partial seeding, non-deterministic cuDNN), DDP corruption (all ranks saving simultaneously), resume logic errors (off-by-one epoch, missing RNG states), version incompatibility (no migration strategy), or poor storage management (disk overflow, no cleanup). Each component has dependencies: optimizer state depends on scheduler state, RNG states affect data order and augmentation, DDP requires rank synchronization. Skipping any component breaks training continuity. + +## When to Use + +**Use this skill when:** +- Setting up training checkpointing (first-time implementation) +- Resuming training from a checkpoint (crashed, paused, or continuing) +- Debugging training divergence after resume (loss jumps, unstable) +- Ensuring reproducible results (experiments, ablations, paper reproducibility) +- Implementing DDP checkpointing (multi-GPU training) +- Managing checkpoint storage (disk space, cleanup policy) +- Loading checkpoints across PyTorch versions (migration) +- Debugging non-deterministic behavior (results vary across runs) + +**Don't use when:** +- Model export for inference (use torch.jit or ONNX, not training checkpoints) +- Saving only for transfer learning (can save model-only, but document this clearly) +- Performance profiling checkpointing overhead (use performance-profiling) +- Distributed training setup (use distributed-training-strategies, though DDP checkpointing overlaps) + +**Symptoms triggering this skill:** +- "Training diverges after resuming from checkpoint" +- "Results not reproducible despite setting torch.manual_seed" +- "Checkpoint loading gives 'unexpected keys' or 'missing keys'" +- "DDP checkpoint corrupted or inconsistent" +- "Loss jumps when resuming training" +- "How do I save/resume training correctly?" +- "Running out of disk space from checkpoints" +- "Need to reproduce exact training results" + + +## Complete Checkpoint Strategy + +### The Complete Checkpoint + +**Critical Rule:** A checkpoint is NOT just the model. It must contain ALL state needed to resume training exactly where it stopped. + +**Minimum Complete Checkpoint (7 required components):** + +```python +import torch +import numpy as np +import random + +def save_checkpoint( + epoch: int, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + loss: float, + checkpoint_path: str, + **kwargs # Additional optional components +) -> None: + """Save complete training checkpoint. + + Args: + epoch: Current epoch number (0-indexed, saved after completion) + model: Model to checkpoint + optimizer: Optimizer with momentum buffers, etc. + scheduler: Learning rate scheduler state + loss: Current loss value (for reference) + checkpoint_path: Path to save checkpoint + **kwargs: Additional components (scaler, best_metric, config, etc.) + """ + checkpoint = { + # 1. Epoch number (critical for resume logic) + 'epoch': epoch, + + # 2. Model state (parameters and buffers) + 'model_state_dict': model.state_dict(), + + # 3. Optimizer state (momentum buffers, adaptive learning rates) + 'optimizer_state_dict': optimizer.state_dict(), + + # 4. Scheduler state (learning rate schedule position) + 'scheduler_state_dict': scheduler.state_dict(), + + # 5. Loss value (for reference and validation) + 'loss': loss, + + # 6. PyTorch RNG state (CPU) + 'rng_state': torch.get_rng_state(), + + # 7. CUDA RNG state (all GPU devices) + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + } + + # Additional recommended components + checkpoint.update({ + # NumPy RNG state (for data augmentation) + 'numpy_rng_state': np.random.get_state(), + + # Python RNG state (for any Python random operations) + 'python_rng_state': random.getstate(), + + # Add any kwargs passed in + **kwargs + }) + + # Save checkpoint + torch.save(checkpoint, checkpoint_path) + + # Validate checkpoint was saved correctly + if not validate_checkpoint(checkpoint_path): + raise RuntimeError(f"Checkpoint validation failed: {checkpoint_path}") + +def validate_checkpoint(checkpoint_path: str) -> bool: + """Validate checkpoint integrity after saving. + + Returns: + True if checkpoint is valid, False otherwise + """ + try: + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Check required keys + required_keys = [ + 'epoch', 'model_state_dict', 'optimizer_state_dict', + 'scheduler_state_dict', 'loss', 'rng_state', 'cuda_rng_state' + ] + missing = [k for k in required_keys if k not in checkpoint] + if missing: + print(f"Missing required keys: {missing}") + return False + + # Basic sanity checks + if not isinstance(checkpoint['epoch'], int): + print("Epoch is not an integer") + return False + + if not isinstance(checkpoint['model_state_dict'], dict): + print("model_state_dict is not a dict") + return False + + return True + + except Exception as e: + print(f"Checkpoint validation error: {e}") + return False +``` + +**Why each component is critical:** + +1. **epoch**: Resume logic needs to know which epoch was completed. Off-by-one errors cause re-running epochs, disrupting training trajectory. + +2. **model_state_dict**: Obviously needed. Use `state_dict()` not the model itself (state_dict is portable, model object is not). + +3. **optimizer_state_dict**: Contains momentum buffers (SGD momentum, Adam first/second moments), adaptive learning rates (per-parameter state). Without this, optimizer effectively resets, causing training divergence. SGD without momentum buffers is NOT the same as SGD with momentum - convergence behavior changes dramatically. + +4. **scheduler_state_dict**: Contains current step count, learning rate values. Without this, scheduler resets to epoch 0, causing learning rate to jump back to initial value. Example: if training at epoch 50 with LR=0.001 after decay, missing scheduler state resets LR to 0.1, causing instability. + +5. **loss**: Reference value for validation. After loading checkpoint, running validation should yield approximately this loss. If not, checkpoint may be corrupted or loaded incorrectly. + +6. **rng_state** (PyTorch CPU): Controls PyTorch CPU random operations (initialization, dropout on CPU). Without this, random operations differ on resume, breaking reproducibility. + +7. **cuda_rng_state**: Controls CUDA random operations (dropout on GPU, random initialization on GPU). Must save ALL GPU states, not just current device. Use `get_rng_state_all()` not `get_rng_state()`. + +**Additional recommended components:** + +```python +# When using mixed precision training +if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + +# Track best validation metric +checkpoint['best_metric'] = best_val_loss # or best_val_accuracy + +# Save global step counter (for step-based logging, schedules) +checkpoint['global_step'] = global_step + +# Save configuration for reference +checkpoint['config'] = { + 'learning_rate': lr, + 'batch_size': batch_size, + 'model_architecture': 'ResNet50', + # ... other hyperparameters +} + +# Save PyTorch version for compatibility checking +checkpoint['pytorch_version'] = torch.__version__ + +# Save timestamp +from datetime import datetime +checkpoint['timestamp'] = datetime.now().isoformat() +``` + + +### Complete Resume Logic + +**Critical Rule:** Checkpoint at epoch N means "completed epoch N". Resume at epoch N+1, not N. + +```python +def load_checkpoint( + checkpoint_path: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + device: torch.device, + scaler: Optional[torch.cuda.amp.GradScaler] = None +) -> dict: + """Load complete training checkpoint and restore all state. + + Args: + checkpoint_path: Path to checkpoint file + model: Model to load state into + optimizer: Optimizer to load state into + scheduler: Scheduler to load state into + device: Device to map checkpoint to + scaler: Optional GradScaler for mixed precision + + Returns: + dict with resume info: start_epoch, best_metric, etc. + """ + # Load checkpoint + # map_location ensures checkpoint loads regardless of save device + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Load model state + model.load_state_dict(checkpoint['model_state_dict']) + + # Load optimizer state + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Load scheduler state + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + # Restore RNG states for reproducibility + torch.set_rng_state(checkpoint['rng_state'].cpu()) # Ensure CPU tensor + + if torch.cuda.is_available() and 'cuda_rng_state' in checkpoint: + torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state']) + + if 'numpy_rng_state' in checkpoint: + np.random.set_state(checkpoint['numpy_rng_state']) + + if 'python_rng_state' in checkpoint: + random.setstate(checkpoint['python_rng_state']) + + # Load scaler if using mixed precision + if scaler is not None and 'scaler_state_dict' in checkpoint: + scaler.load_state_dict(checkpoint['scaler_state_dict']) + + # Calculate start epoch (CRITICAL: checkpoint at epoch N means resume at N+1) + start_epoch = checkpoint['epoch'] + 1 + + # Extract other useful info + resume_info = { + 'start_epoch': start_epoch, + 'checkpoint_loss': checkpoint['loss'], + 'best_metric': checkpoint.get('best_metric', None), + 'global_step': checkpoint.get('global_step', 0), + } + + print(f"Loaded checkpoint from epoch {checkpoint['epoch']}") + print(f"Resuming training from epoch {start_epoch}") + print(f"Checkpoint loss: {checkpoint['loss']:.4f}") + + return resume_info + +# Usage in training loop +if args.resume_from_checkpoint: + resume_info = load_checkpoint( + checkpoint_path=args.checkpoint_path, + model=model, + optimizer=optimizer, + scheduler=scheduler, + device=device, + scaler=scaler + ) + start_epoch = resume_info['start_epoch'] + best_val_loss = resume_info['best_metric'] + global_step = resume_info['global_step'] + + # Validate checkpoint by running validation + val_loss = validate(model, val_loader, criterion, device) + print(f"Validation loss after loading: {val_loss:.4f}") + print(f"Checkpoint validation loss: {resume_info['checkpoint_loss']:.4f}") + + if abs(val_loss - resume_info['checkpoint_loss']) > 0.1: + print("WARNING: Validation loss differs significantly from checkpoint!") + print("Checkpoint may be corrupted or validation set changed.") +else: + start_epoch = 0 + best_val_loss = float('inf') + global_step = 0 + +# Training loop starts from start_epoch +for epoch in range(start_epoch, args.num_epochs): + train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Scheduler step (scheduler already at correct position from checkpoint) + scheduler.step() + + # Update global step + global_step += len(train_loader) + + # Save checkpoint + if epoch % args.checkpoint_interval == 0: + save_checkpoint( + epoch=epoch, + model=model, + optimizer=optimizer, + scheduler=scheduler, + loss=val_loss, + checkpoint_path=f'checkpoint_epoch_{epoch}.pt', + best_metric=best_val_loss, + global_step=global_step + ) +``` + +**Common Resume Mistakes:** + +```python +# ❌ WRONG: Starting at checkpoint epoch (re-runs last epoch) +start_epoch = checkpoint['epoch'] # If checkpoint is epoch 40, this starts at 40 again! + +# ✅ CORRECT: Starting at next epoch +start_epoch = checkpoint['epoch'] + 1 # Resume at 41 + +# ❌ WRONG: Not restoring RNG states (data order/augmentation differs) +model.load_state_dict(checkpoint['model_state_dict']) +# Missing: torch.set_rng_state(), np.random.set_state(), etc. + +# ✅ CORRECT: Restore all RNG states +torch.set_rng_state(checkpoint['rng_state']) +torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state']) +np.random.set_state(checkpoint['numpy_rng_state']) +random.setstate(checkpoint['python_rng_state']) + +# ❌ WRONG: Not using map_location (fails if checkpoint saved on different device) +checkpoint = torch.load('checkpoint.pt') # Tries to load to original device + +# ✅ CORRECT: Use map_location for portability +checkpoint = torch.load('checkpoint.pt', map_location=device) + +# ❌ WRONG: Not validating after loading (assume checkpoint is correct) +load_checkpoint(...) +# Start training immediately + +# ✅ CORRECT: Validate checkpoint makes sense +load_checkpoint(...) +val_loss = validate(model, val_loader, criterion, device) +assert abs(val_loss - checkpoint['loss']) < 0.1, "Checkpoint validation failed!" +``` + + +## Complete Reproducibility Setup + +### The Seven Sources of Randomness + +**Critical Rule:** Reproducibility requires controlling ALL sources of randomness, not just `torch.manual_seed()`. Missing even one source breaks reproducibility. + +**Complete seed setting function:** + +```python +import torch +import numpy as np +import random +import os + +def set_seed(seed: int) -> None: + """Set seeds for complete reproducibility across all libraries. + + This controls randomness in: + - Python random module (data shuffling, random choices) + - NumPy (data augmentation, random initialization) + - PyTorch CPU (model initialization, dropout, etc.) + - PyTorch CUDA (GPU operations, dropout on GPU) + - cuDNN (convolution algorithms, some CUDA kernels) + - Python hash randomization (dict/set ordering) + + Note: Some operations are inherently non-deterministic even with seeds set. + See: https://pytorch.org/docs/stable/notes/randomness.html + + Args: + seed: Random seed value (typically 42, 0, 123, etc.) + """ + # 1. Python random module + random.seed(seed) + + # 2. NumPy random + np.random.seed(seed) + + # 3. PyTorch CPU + torch.manual_seed(seed) + + # 4. PyTorch CUDA (current device) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # 5. PyTorch CUDA (all devices, for multi-GPU) + torch.cuda.manual_seed_all(seed) + + # 6. cuDNN deterministic mode + # This makes cuDNN use deterministic algorithms (slower but reproducible) + torch.backends.cudnn.deterministic = True + + # 7. cuDNN benchmark mode + # Disable benchmark mode which uses non-deterministic algorithms for speed + torch.backends.cudnn.benchmark = False + + # 8. Python hash randomization (for dict/set ordering) + os.environ['PYTHONHASHSEED'] = str(seed) + + print(f"Random seed set to {seed}") + print("Note: Deterministic mode enabled (cuDNN benchmark disabled)") + print("Expected ~5-15% performance decrease for reproducibility") + +# Usage: Call BEFORE any model/data operations +set_seed(42) + +# Create model, data loaders, etc. AFTER setting seed +model = MyModel() +train_loader = DataLoader(dataset, shuffle=True, num_workers=4) +# ... +``` + +**Why each source matters:** + +1. **Python random**: Used by data shuffling, random.choice(), random sampling. Without seed, data order varies. + +2. **NumPy random**: Used by many data augmentation libraries (Albumentations, imgaug), random initialization, numpy-based preprocessing. Without seed, augmentation differs. + +3. **PyTorch CPU random**: Controls CPU-based initialization (torch.randn, torch.rand), dropout on CPU, random sampling. Without seed, model initialization varies. + +4. **PyTorch CUDA random**: Controls GPU-based random operations (dropout on GPU, initialization on GPU). Must seed ALL devices, not just current. + +5. **cuDNN deterministic**: cuDNN (NVIDIA's CUDA Deep Neural Network library) uses optimized algorithms for convolutions, pooling, etc. By default, some algorithms are non-deterministic for speed. Setting deterministic=True forces deterministic algorithms (slower but reproducible). + +6. **cuDNN benchmark**: When enabled, cuDNN runs multiple algorithms and picks the fastest (non-deterministic selection). When disabled, uses fixed algorithm (deterministic but potentially slower). + +7. **PYTHONHASHSEED**: Python 3.3+ uses randomized hash seeds for security. This affects dict/set iteration order. Setting environment variable ensures consistent ordering. + + +### DataLoader Worker Seeding + +**Critical Issue:** DataLoader with `num_workers > 0` spawns subprocesses, each with its own random state. Without proper seeding, workers produce different random augmentations across runs, breaking reproducibility. + +```python +def seed_worker(worker_id: int) -> None: + """Seed each DataLoader worker for reproducibility. + + Called by DataLoader for each worker subprocess. + Without this, each worker has random seed, breaking reproducibility. + + Args: + worker_id: Worker process ID (0 to num_workers-1) + """ + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +# Create generator for DataLoader reproducibility +g = torch.Generator() +g.manual_seed(42) + +# DataLoader with reproducible workers +train_loader = DataLoader( + dataset, + batch_size=32, + shuffle=True, + num_workers=4, + worker_init_fn=seed_worker, # Seed each worker + generator=g, # Control shuffling randomness +) + +# Now DataLoader produces identical batches across runs +``` + +**Without worker seeding:** +```python +# ❌ WRONG: Workers have random seeds +train_loader = DataLoader( + dataset, + batch_size=32, + shuffle=True, + num_workers=4, # Each worker has different random state! +) +# Data augmentation varies across workers and across runs +# Results NOT reproducible +``` + + +### Non-Deterministic Operations + +**Critical Awareness:** Some PyTorch operations are inherently non-deterministic, even with all seeds set and cuDNN deterministic mode enabled. + +**Known non-deterministic operations:** + +```python +# 1. Atomic operations (CUDA) +# Some CUDA operations use atomic operations (atomicAdd) which are non-deterministic +# when multiple threads access the same memory location + +# Example: torch.nn.functional.grid_sample with bilinear interpolation +output = F.grid_sample(input, grid, mode='bilinear') # Non-deterministic on CUDA + +# 2. Backward pass through some operations +# Some operations have deterministic forward but non-deterministic backward + +# Example: torch.nn.functional.interpolate backward +output = F.interpolate(input, scale_factor=2) +output.backward(grad) # Non-deterministic backward + +# 3. Index operations with duplicate indices +x = torch.zeros(10, 10).cuda() +indices = torch.tensor([0, 0, 1]).cuda() # Duplicate index 0 +values = torch.tensor([1.0, 2.0, 3.0]).cuda() +x[indices] = values # Non-deterministic: which value goes to x[0]? + +# 4. Sparse operations +sparse_tensor = torch.sparse_coo_tensor(indices, values, size) +result = sparse_tensor @ dense_tensor # May be non-deterministic + +# 5. torch.nn.DataParallel +# DataParallel has non-deterministic gather operations +model = torch.nn.DataParallel(model) # Non-deterministic! +# Use DistributedDataParallel (DDP) instead for determinism +``` + +**Checking for non-deterministic operations:** + +```python +import os + +# PyTorch 1.11+ provides environment variable to detect non-deterministic ops +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # or ':16:8' + +# Enable PyTorch deterministic mode (throws error on non-deterministic ops) +torch.use_deterministic_algorithms(True) + +# Now PyTorch will raise error if non-deterministic operation is used +try: + output = model(input) + loss.backward() +except RuntimeError as e: + print(f"Non-deterministic operation detected: {e}") + # Error message tells you which operation is non-deterministic +``` + +**When to accept non-determinism:** + +- **Production training**: Deterministic mode has 5-15% performance cost. For production training where reproducibility is not critical, non-deterministic mode is acceptable. + +- **Ablation studies**: When comparing methods, reproducibility is critical. Use deterministic mode even with performance cost. + +- **Debugging convergence**: If loss is NaN or training is unstable, deterministic mode helps isolate if issue is due to randomness or actual bug. + +- **Paper reproducibility**: When publishing, enable deterministic mode for experiments, document seeds and settings in paper. + +**Performance tradeoff:** + +```python +# Fast (non-deterministic) - for production +torch.backends.cudnn.deterministic = False +torch.backends.cudnn.benchmark = True +# ~10-15% faster training, results vary slightly across runs + +# Reproducible (deterministic) - for experiments +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +# ~5-15% slower training, results identical across runs + +# Hybrid approach: Benchmark once, then use deterministic +torch.backends.cudnn.deterministic = False +torch.backends.cudnn.benchmark = True +# Run for 10 epochs to find best algorithms +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +# Now use deterministic algorithms (benchmark already found them) +``` + + +### Testing Reproducibility + +**Verification protocol:** + +```python +def test_reproducibility( + model_fn: Callable, + data_loader: DataLoader, + num_steps: int = 10 +) -> bool: + """Test if training is reproducible across runs. + + Args: + model_fn: Function that creates and returns model + data_loader: DataLoader to use for training + num_steps: Number of training steps to test + + Returns: + True if reproducible, False otherwise + """ + def train_n_steps(seed: int) -> torch.Tensor: + """Train for N steps and return final loss.""" + set_seed(seed) + model = model_fn() + optimizer = torch.optim.Adam(model.parameters()) + criterion = torch.nn.CrossEntropyLoss() + + losses = [] + data_iter = iter(data_loader) + for _ in range(num_steps): + data, target = next(data_iter) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + losses.append(loss.item()) + + return torch.tensor(losses) + + # Run training twice with same seed + losses_run1 = train_n_steps(seed=42) + losses_run2 = train_n_steps(seed=42) + + # Check if losses are identical + if torch.allclose(losses_run1, losses_run2, atol=1e-7): + print("✓ Training is reproducible!") + return True + else: + print("✗ Training is NOT reproducible") + print(f"Max difference: {(losses_run1 - losses_run2).abs().max().item()}") + print(f"Run 1 losses: {losses_run1}") + print(f"Run 2 losses: {losses_run2}") + return False + +# Usage +reproducible = test_reproducibility( + model_fn=lambda: MyModel(), + data_loader=train_loader, + num_steps=10 +) + +if not reproducible: + print("Check: Are all seeds set? cuDNN deterministic? DataLoader workers seeded?") +``` + + +## DDP Checkpointing + +### Rank 0 Only Saving + +**Critical Rule:** In DistributedDataParallel (DDP), only rank 0 should save checkpoints. All ranks can load, but only rank 0 writes to disk. + +**Why rank 0 only:** +- Multiple processes writing to same file simultaneously causes corruption +- File system race conditions lead to truncated or inconsistent checkpoints +- Even with different file names, NFS/shared filesystems have synchronization issues +- Checkpoint contains same model state across all ranks (DDP synchronizes gradients) + +```python +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +def save_checkpoint_ddp( + epoch: int, + model: DDP, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + loss: float, + checkpoint_path: str, +) -> None: + """Save checkpoint in DDP training (rank 0 only). + + Args: + epoch: Current epoch + model: DDP-wrapped model + optimizer: Optimizer + scheduler: LR scheduler + loss: Current loss + checkpoint_path: Path to save checkpoint + """ + # Synchronize all ranks before checkpointing + # Ensures all ranks have finished training step + if dist.is_initialized(): + dist.barrier() + + # Only rank 0 saves + if not dist.is_initialized() or dist.get_rank() == 0: + checkpoint = { + 'epoch': epoch, + # Use model.module.state_dict() to unwrap DDP + 'model_state_dict': model.module.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': loss, + 'rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + } + + torch.save(checkpoint, checkpoint_path) + print(f"Rank 0: Saved checkpoint to {checkpoint_path}") + + # Wait for rank 0 to finish saving before continuing + if dist.is_initialized(): + dist.barrier() + +# Training loop in DDP +for epoch in range(start_epoch, num_epochs): + # Set epoch for distributed sampler (ensures different shuffle each epoch) + train_sampler.set_epoch(epoch) + + train_loss = train_one_epoch(model, train_loader, optimizer, criterion) + + # All ranks participate in validation + val_loss = validate(model, val_loader, criterion) + + scheduler.step() + + # Checkpoint with rank 0 only + if epoch % checkpoint_interval == 0: + save_checkpoint_ddp( + epoch=epoch, + model=model, + optimizer=optimizer, + scheduler=scheduler, + loss=val_loss, + checkpoint_path=f'checkpoint_epoch_{epoch}.pt' + ) +``` + +**Key DDP checkpoint considerations:** + +1. **model.module.state_dict()**: DDP wraps model with "module." prefix. Use `model.module.state_dict()` to get unwrapped state_dict for portability. When loading, if model is already DDP-wrapped, can load directly. If not wrapped, load into base model then wrap with DDP. + +```python +# Saving: unwrap DDP +checkpoint['model_state_dict'] = model.module.state_dict() + +# Loading option 1: Load into base model, then wrap +model = MyModel() +model.load_state_dict(checkpoint['model_state_dict']) +model = DDP(model) + +# Loading option 2: Load into already-wrapped model +model = DDP(MyModel()) +model.module.load_state_dict(checkpoint['model_state_dict']) + +# Loading option 3: Handle prefix automatically +from torch.nn.parallel import DistributedDataParallel as DDP + +def load_checkpoint_handle_ddp(model, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + state_dict = checkpoint['model_state_dict'] + + # Check if loading into DDP model + if isinstance(model, DDP): + # Check if state_dict has 'module.' prefix + if not any(k.startswith('module.') for k in state_dict.keys()): + # Add prefix + state_dict = {f'module.{k}': v for k, v in state_dict.items()} + else: + # Loading into non-DDP model + # Remove 'module.' prefix if present + if any(k.startswith('module.') for k in state_dict.keys()): + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + + model.load_state_dict(state_dict) +``` + +2. **dist.barrier()**: Synchronization primitive that makes all ranks wait. Use before saving to ensure all ranks finished training step. Use after saving to ensure rank 0 finished writing before other ranks continue. + +```python +# Before saving: wait for all ranks to finish training +dist.barrier() + +# Rank 0 saves +if dist.get_rank() == 0: + torch.save(checkpoint, path) + +# After saving: wait for rank 0 to finish +dist.barrier() +``` + +3. **Optimizer state in DDP**: For standard optimizers (Adam, SGD), optimizer state is replicated across all ranks (each rank has full state). Saving from rank 0 is sufficient. For ZeRO-style optimizers (DeepSpeed, FSDP), optimizer state is sharded across ranks, requiring special handling. + + +### ZeRO Optimizer Checkpointing + +**Advanced:** When using ZeRO (Zero Redundancy Optimizer) from DeepSpeed or FSDP, optimizer state is sharded across ranks. Each rank has only a portion of optimizer state. Checkpointing requires gathering from all ranks. + +```python +# DeepSpeed ZeRO checkpointing +import deepspeed + +# DeepSpeed handles checkpointing automatically +model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=ds_config +) + +# Save checkpoint (DeepSpeed handles rank coordination) +model_engine.save_checkpoint(save_dir='checkpoints', tag=f'epoch_{epoch}') + +# Load checkpoint +_, client_state = model_engine.load_checkpoint( + load_dir='checkpoints', + tag=f'epoch_{epoch}' +) + +# FSDP checkpointing (PyTorch 2.0+) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType, FullStateDictConfig + +# Configure state_dict type +save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + +with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + state_dict = model.state_dict() + + if dist.get_rank() == 0: + checkpoint = { + 'model_state_dict': state_dict, + # FSDP optimizer needs special handling + 'optimizer_state_dict': FSDP.full_optim_state_dict(model, optimizer), + } + torch.save(checkpoint, checkpoint_path) +``` + + +## Checkpoint Management + +### Three-Checkpoint Strategy + +**Best Practice:** Maintain three types of checkpoints with different purposes and saving frequencies. + +```python +import glob +import os +from pathlib import Path + +class CheckpointManager: + """Manage training checkpoints with best/last/periodic strategy.""" + + def __init__( + self, + checkpoint_dir: str, + keep_last_n: int = 3, + monitor: str = 'val_loss', + mode: str = 'min' + ): + """ + Args: + checkpoint_dir: Directory to save checkpoints + keep_last_n: Number of periodic checkpoints to keep + monitor: Metric to monitor for best checkpoint ('val_loss', 'val_acc', etc.) + mode: 'min' for loss, 'max' for accuracy + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.keep_last_n = keep_last_n + self.monitor = monitor + self.mode = mode + + # Track best metric + self.best_metric = float('inf') if mode == 'min' else float('-inf') + + def is_better(self, metric: float) -> bool: + """Check if metric is better than current best.""" + if self.mode == 'min': + return metric < self.best_metric + else: + return metric > self.best_metric + + def save_checkpoint( + self, + epoch: int, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + metrics: dict, + is_periodic: bool = False + ) -> None: + """Save checkpoint(s) based on strategy. + + Args: + epoch: Current epoch + model: Model to save + optimizer: Optimizer to save + scheduler: Scheduler to save + metrics: Dict of metrics (must include self.monitor key) + is_periodic: If True, save periodic checkpoint + """ + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'metrics': metrics, + 'rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + 'best_metric': self.best_metric, + } + + # 1. Always save last checkpoint (overwrite) + last_path = self.checkpoint_dir / 'last_checkpoint.pt' + torch.save(checkpoint, last_path) + print(f"Saved last checkpoint: {last_path}") + + # 2. Save best checkpoint if metric improved + current_metric = metrics[self.monitor] + if self.is_better(current_metric): + self.best_metric = current_metric + checkpoint['best_metric'] = self.best_metric + + best_path = self.checkpoint_dir / 'best_model.pt' + torch.save(checkpoint, best_path) + print(f"Saved best checkpoint: {best_path} ({self.monitor}={current_metric:.4f})") + + # 3. Save periodic checkpoint if requested + if is_periodic: + periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt' + torch.save(checkpoint, periodic_path) + print(f"Saved periodic checkpoint: {periodic_path}") + + # Cleanup old periodic checkpoints + self._cleanup_periodic_checkpoints() + + def _cleanup_periodic_checkpoints(self) -> None: + """Remove old periodic checkpoints, keeping only last N.""" + # Find all periodic checkpoints + pattern = str(self.checkpoint_dir / 'checkpoint_epoch_*.pt') + checkpoints = sorted(glob.glob(pattern)) + + # Remove old checkpoints if exceeding keep_last_n + if len(checkpoints) > self.keep_last_n: + for old_ckpt in checkpoints[:-self.keep_last_n]: + os.remove(old_ckpt) + print(f"Removed old checkpoint: {old_ckpt}") + +# Usage +checkpoint_manager = CheckpointManager( + checkpoint_dir='checkpoints', + keep_last_n=3, + monitor='val_loss', + mode='min' +) + +for epoch in range(num_epochs): + train_loss = train_one_epoch(model, train_loader, optimizer, criterion) + val_loss = validate(model, val_loader, criterion) + + scheduler.step() + + metrics = { + 'train_loss': train_loss, + 'val_loss': val_loss, + } + + # Save periodic checkpoint every 10 epochs + is_periodic = (epoch % 10 == 0) + + checkpoint_manager.save_checkpoint( + epoch=epoch, + model=model, + optimizer=optimizer, + scheduler=scheduler, + metrics=metrics, + is_periodic=is_periodic + ) +``` + +**Three checkpoint types explained:** + +1. **last_checkpoint.pt** (always overwrite): + - Most recent checkpoint + - Used for resuming if training crashes + - Always overwrite previous "last" checkpoint + - Minimal disk usage (only 1 file) + +2. **best_model.pt** (based on validation metric): + - Best performing model according to validation metric + - Used for final evaluation and deployment + - Only overwrite when validation metric improves + - Most important checkpoint (don't lose this!) + +3. **checkpoint_epoch_N.pt** (periodic): + - Saved every N epochs (e.g., 10, 20, 50) + - Used for resume if need to go back further + - Keep only last M periodic checkpoints (e.g., 3-5) + - Cleanup old ones to save disk space + + +### Disk Space Management + +```python +def get_checkpoint_size(checkpoint_path: str) -> float: + """Get checkpoint file size in MB.""" + size_bytes = os.path.getsize(checkpoint_path) + size_mb = size_bytes / (1024 * 1024) + return size_mb + +def estimate_storage_usage( + checkpoint_path: str, + num_epochs: int, + periodic_interval: int, + keep_last_n: int +) -> dict: + """Estimate total storage usage for checkpointing strategy. + + Args: + checkpoint_path: Path to a sample checkpoint + num_epochs: Total number of training epochs + periodic_interval: Save periodic checkpoint every N epochs + keep_last_n: Keep last N periodic checkpoints + + Returns: + dict with storage estimates + """ + ckpt_size_mb = get_checkpoint_size(checkpoint_path) + + # 1 last + 1 best + N periodic + num_checkpoints = 1 + 1 + min(keep_last_n, num_epochs // periodic_interval) + total_size_mb = ckpt_size_mb * num_checkpoints + total_size_gb = total_size_mb / 1024 + + return { + 'checkpoint_size_mb': ckpt_size_mb, + 'num_checkpoints': num_checkpoints, + 'total_size_mb': total_size_mb, + 'total_size_gb': total_size_gb, + } + +# Usage +storage = estimate_storage_usage( + checkpoint_path='checkpoints/last_checkpoint.pt', + num_epochs=200, + periodic_interval=10, + keep_last_n=3 +) + +print(f"Checkpoint size: {storage['checkpoint_size_mb']:.1f} MB") +print(f"Number of checkpoints: {storage['num_checkpoints']}") +print(f"Total storage needed: {storage['total_size_gb']:.2f} GB") + +# Check available disk space +import shutil +disk_usage = shutil.disk_usage('checkpoints') +available_gb = disk_usage.free / (1024**3) + +if storage['total_size_gb'] > available_gb * 0.9: # Keep 10% buffer + print(f"WARNING: Insufficient disk space!") + print(f"Available: {available_gb:.2f} GB") + print(f"Needed: {storage['total_size_gb']:.2f} GB") +``` + + +### Model-Only vs Full Checkpoints + +**Strategy:** Save model-only checkpoints more frequently (smaller), full checkpoints less frequently (larger). + +```python +def save_checkpoint_model_only( + model: nn.Module, + checkpoint_path: str, + metadata: dict = None +) -> None: + """Save model-only checkpoint (no optimizer, scheduler, RNG states). + + Use for: Frequent checkpointing, transfer learning, model export + Size: ~50% of full checkpoint + Cannot resume training exactly (no optimizer momentum, LR schedule, etc.) + + Args: + model: Model to save + checkpoint_path: Path to save checkpoint + metadata: Optional metadata dict (epoch, loss, metrics, etc.) + """ + checkpoint = { + 'model_state_dict': model.state_dict(), + } + + if metadata is not None: + checkpoint.update(metadata) + + torch.save(checkpoint, checkpoint_path) + +def save_checkpoint_full( + epoch: int, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + checkpoint_path: str, + **kwargs +) -> None: + """Save complete checkpoint (model + optimizer + scheduler + RNG states). + + Use for: Resume training exactly, maintaining training trajectory + Size: ~100% (baseline) + Can resume training from exact point + """ + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + **kwargs + } + + torch.save(checkpoint, checkpoint_path) + +# Hybrid strategy +for epoch in range(num_epochs): + train_loss = train_one_epoch(...) + val_loss = validate(...) + scheduler.step() + + # Model-only checkpoint every epoch (cheap) + save_checkpoint_model_only( + model=model, + checkpoint_path=f'model_only_epoch_{epoch}.pt', + metadata={'epoch': epoch, 'val_loss': val_loss} + ) + + # Full checkpoint every 10 epochs (expensive but complete) + if epoch % 10 == 0: + save_checkpoint_full( + epoch=epoch, + model=model, + optimizer=optimizer, + scheduler=scheduler, + checkpoint_path=f'full_checkpoint_epoch_{epoch}.pt', + val_loss=val_loss + ) +``` + + +### Cloud Storage Integration + +```python +def sync_checkpoint_to_cloud( + local_path: str, + cloud_path: str, + cloud_type: str = 's3' +) -> None: + """Sync checkpoint to cloud storage for backup. + + Args: + local_path: Local checkpoint path + cloud_path: Cloud storage path (s3://bucket/key or gs://bucket/key) + cloud_type: 's3' or 'gcs' + """ + if cloud_type == 's3': + # AWS S3 + import boto3 + s3 = boto3.client('s3') + + # Parse S3 path + bucket, key = cloud_path.replace('s3://', '').split('/', 1) + + # Upload + s3.upload_file(local_path, bucket, key) + print(f"Uploaded {local_path} to s3://{bucket}/{key}") + + elif cloud_type == 'gcs': + # Google Cloud Storage + from google.cloud import storage + client = storage.Client() + + # Parse GCS path + bucket_name, blob_name = cloud_path.replace('gs://', '').split('/', 1) + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + # Upload + blob.upload_from_filename(local_path) + print(f"Uploaded {local_path} to gs://{bucket_name}/{blob_name}") + +# Usage in training loop +if epoch % 10 == 0: + local_path = f'checkpoints/checkpoint_epoch_{epoch}.pt' + save_checkpoint_full(..., checkpoint_path=local_path) + + # Backup to cloud (async recommended for large files) + sync_checkpoint_to_cloud( + local_path=local_path, + cloud_path=f's3://my-bucket/project/checkpoint_epoch_{epoch}.pt', + cloud_type='s3' + ) +``` + + +## Version Compatibility and Migration + +### Handling PyTorch Version Changes + +```python +def save_checkpoint_with_version( + checkpoint: dict, + checkpoint_path: str +) -> None: + """Save checkpoint with version metadata for compatibility tracking.""" + import torch + import sys + + # Add version metadata + checkpoint['_metadata'] = { + 'pytorch_version': torch.__version__, + 'python_version': sys.version, + 'cuda_version': torch.version.cuda if torch.cuda.is_available() else None, + 'cudnn_version': torch.backends.cudnn.version() if torch.cuda.is_available() else None, + } + + torch.save(checkpoint, checkpoint_path) + +def load_checkpoint_with_compatibility( + checkpoint_path: str, + model: nn.Module, + strict: bool = True +) -> tuple: + """Load checkpoint with version compatibility handling. + + Args: + checkpoint_path: Path to checkpoint + model: Model to load into + strict: Whether to strictly match keys (False allows missing/extra keys) + + Returns: + (checkpoint_dict, missing_keys, unexpected_keys) + """ + # Load checkpoint + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Check version compatibility + if '_metadata' in checkpoint: + meta = checkpoint['_metadata'] + print(f"Checkpoint saved with PyTorch {meta['pytorch_version']}") + print(f"Current PyTorch version: {torch.__version__}") + + if meta['pytorch_version'] != torch.__version__: + print("WARNING: PyTorch version mismatch!") + print("Attempting to load with strict=False for compatibility") + strict = False + + # Load state_dict + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint['model_state_dict'], + strict=strict + ) + + # Report missing/unexpected keys + if missing_keys: + print(f"Missing keys in checkpoint: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys in checkpoint: {unexpected_keys}") + + return checkpoint, missing_keys, unexpected_keys + +# Usage +checkpoint, missing, unexpected = load_checkpoint_with_compatibility( + checkpoint_path='old_checkpoint.pt', + model=model, + strict=False # Allow version differences +) + +# Validate model still works +try: + with torch.no_grad(): + test_input = torch.randn(1, 3, 224, 224) + output = model(test_input) + print(f"Model forward pass successful, output shape: {output.shape}") +except Exception as e: + print(f"Model forward pass failed: {e}") + print("Checkpoint may be incompatible") +``` + + +### Checkpoint Migration + +**Scenario:** Trained model in PyTorch 1.x, need to use in PyTorch 2.x. + +```python +def migrate_checkpoint( + old_checkpoint_path: str, + new_checkpoint_path: str, + model_fn: Callable +) -> None: + """Migrate checkpoint to new PyTorch version. + + Process: + 1. Load checkpoint in OLD PyTorch version + 2. Load into model + 3. Re-save checkpoint in NEW PyTorch version + + Args: + old_checkpoint_path: Path to old checkpoint + new_checkpoint_path: Path to save new checkpoint + model_fn: Function that creates model (same architecture) + """ + # Load old checkpoint + checkpoint = torch.load(old_checkpoint_path, map_location='cpu') + + # Create model + model = model_fn() + + # Load state_dict + try: + model.load_state_dict(checkpoint['model_state_dict']) + except RuntimeError as e: + print(f"Strict loading failed: {e}") + print("Attempting non-strict loading...") + model.load_state_dict(checkpoint['model_state_dict'], strict=False) + + # Create new checkpoint with current PyTorch version + new_checkpoint = { + 'epoch': checkpoint.get('epoch', 0), + 'model_state_dict': model.state_dict(), # Re-saved in new format + # Note: optimizer and scheduler states may not be compatible + # Only migrate model state for cross-version migration + } + + # Add version metadata + save_checkpoint_with_version(new_checkpoint, new_checkpoint_path) + print(f"Migrated checkpoint from {old_checkpoint_path} to {new_checkpoint_path}") + +# Usage +migrate_checkpoint( + old_checkpoint_path='pytorch1.10_checkpoint.pt', + new_checkpoint_path='pytorch2.1_checkpoint.pt', + model_fn=lambda: ResNet50(num_classes=1000) +) +``` + + +### Using weights_only for Security + +**Critical:** PyTorch 2.0+ introduces `weights_only=True` flag to prevent arbitrary code execution during checkpoint loading. + +```python +# Old way (PyTorch < 2.0) - potentially unsafe +checkpoint = torch.load('checkpoint.pt') # Can execute arbitrary code! + +# New way (PyTorch 2.0+) - safe +checkpoint = torch.load('checkpoint.pt', weights_only=True) # Only loads tensors + +# Handling weights_only with full checkpoints +def save_checkpoint_secure(checkpoint: dict, checkpoint_path: str) -> None: + """Save checkpoint in format compatible with weights_only=True.""" + # Ensure all values are tensors, dicts, or primitives (no custom classes) + safe_checkpoint = { + 'epoch': checkpoint['epoch'], # int - OK + 'model_state_dict': checkpoint['model_state_dict'], # dict of tensors - OK + 'optimizer_state_dict': checkpoint['optimizer_state_dict'], # dict of tensors - OK + 'scheduler_state_dict': checkpoint['scheduler_state_dict'], # dict - OK + 'loss': checkpoint['loss'], # float - OK + 'rng_state': checkpoint['rng_state'], # tensor - OK + 'cuda_rng_state': checkpoint['cuda_rng_state'], # list of tensors - OK + } + + torch.save(safe_checkpoint, checkpoint_path) + +def load_checkpoint_secure(checkpoint_path: str) -> dict: + """Load checkpoint securely with weights_only=True.""" + try: + # Try weights_only first (PyTorch 2.0+) + checkpoint = torch.load(checkpoint_path, weights_only=True) + except TypeError: + # Fall back for PyTorch < 2.0 + print("weights_only not available, loading without (PyTorch < 2.0)") + checkpoint = torch.load(checkpoint_path) + except Exception as e: + # Checkpoint contains non-tensor objects + print(f"weights_only=True failed: {e}") + print("Loading with weights_only=False (CAUTION: potential security risk)") + checkpoint = torch.load(checkpoint_path, weights_only=False) + + return checkpoint +``` + + +## Common Checkpointing Pitfalls + +### Pitfall 1: Saving Model Object Instead of state_dict + +```python +# ❌ WRONG: Saving entire model object +torch.save(model, 'model.pt') + +# Problems: +# - Not portable (tied to specific Python class definition) +# - Requires exact same code structure to load +# - Pickle-based, version-sensitive +# - Cannot load in different model architecture + +# ✅ CORRECT: Saving state_dict +torch.save(model.state_dict(), 'model.pt') + +# or in checkpoint: +checkpoint = { + 'model_state_dict': model.state_dict(), + # ... other components +} +torch.save(checkpoint, 'checkpoint.pt') + +# Benefits: +# - Portable across code versions +# - Can load into different architectures (with strict=False) +# - Standard practice in PyTorch +``` + + +### Pitfall 2: Forgetting Scheduler State + +```python +# ❌ WRONG: Missing scheduler state +checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + # Missing: scheduler_state_dict +} + +# Result: Learning rate resets to initial value on resume! +# Example: If at epoch 50 with LR=0.001 after decay, +# resume will reset LR to 0.1 (initial), causing instability + +# ✅ CORRECT: Include scheduler state +checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), # Critical! +} + +# Resume +optimizer.load_state_dict(checkpoint['optimizer_state_dict']) +scheduler.load_state_dict(checkpoint['scheduler_state_dict']) +# Scheduler continues from correct position +``` + + +### Pitfall 3: Not Handling Device Correctly + +```python +# ❌ WRONG: Not using map_location +# Checkpoint saved on GPU, loading on CPU +checkpoint = torch.load('checkpoint.pt') # ERROR: CUDA not available +model.load_state_dict(checkpoint['model_state_dict']) + +# ✅ CORRECT: Use map_location +checkpoint = torch.load('checkpoint.pt', map_location='cpu') +model.load_state_dict(checkpoint['model_state_dict']) + +# Even better: Use current device +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +checkpoint = torch.load('checkpoint.pt', map_location=device) +model.load_state_dict(checkpoint['model_state_dict']) +model = model.to(device) # Ensure model is on correct device +``` + + +### Pitfall 4: Saving Too Frequently + +```python +# ❌ WRONG: Saving every iteration +for epoch in range(100): + for i, batch in enumerate(train_loader): + # Train step + # ... + + # Save every iteration (thousands of checkpoints!) + torch.save(checkpoint, f'ckpt_epoch{epoch}_iter{i}.pt') + +# Problems: +# - Disk fills up rapidly (100 epochs * 1000 iters * 500MB = 50TB!) +# - I/O overhead slows training significantly +# - Most checkpoints are never used + +# ✅ CORRECT: Strategic saving +for epoch in range(100): + train_one_epoch(...) + val_loss = validate(...) + + # 1. Always save last checkpoint (overwrite) + save_checkpoint('last.pt') + + # 2. Save best model when validation improves + if val_loss < best_val_loss: + save_checkpoint('best.pt') + + # 3. Save periodic checkpoint every 10 epochs + if epoch % 10 == 0: + save_checkpoint(f'checkpoint_epoch_{epoch}.pt') + cleanup_old_checkpoints(keep_last_n=3) +``` + + +### Pitfall 5: Not Validating Checkpoints + +```python +# ❌ WRONG: Assume checkpoint saved correctly +torch.save(checkpoint, 'checkpoint.pt') +# No verification, continue training + +# Problems: +# - Disk full → truncated checkpoint → silent corruption +# - NFS/network issues → incomplete write +# - Discover corruption hours later when trying to resume + +# ✅ CORRECT: Validate after saving +def save_and_validate_checkpoint(checkpoint: dict, path: str) -> None: + """Save checkpoint and validate it was saved correctly.""" + # Save + torch.save(checkpoint, path) + + # Validate + try: + loaded = torch.load(path, map_location='cpu') + required_keys = ['epoch', 'model_state_dict', 'optimizer_state_dict'] + + for key in required_keys: + if key not in loaded: + raise ValueError(f"Missing key: {key}") + + print(f"✓ Checkpoint saved and validated: {path}") + + except Exception as e: + print(f"✗ Checkpoint validation failed: {e}") + # Remove corrupted checkpoint + if os.path.exists(path): + os.remove(path) + raise RuntimeError(f"Checkpoint save failed: {path}") +``` + + +### Pitfall 6: Resume Off-By-One Error + +```python +# ❌ WRONG: Starting at checkpoint epoch (re-runs epoch) +checkpoint = torch.load('checkpoint_epoch_40.pt') +model.load_state_dict(checkpoint['model_state_dict']) + +start_epoch = checkpoint['epoch'] # 40 + +for epoch in range(start_epoch, 100): # Starts at 40 + # This re-runs epoch 40! + # Optimizer steps on epoch 40 data again + # Scheduler steps again (LR changes) + train_one_epoch(...) + +# ✅ CORRECT: Starting at next epoch +checkpoint = torch.load('checkpoint_epoch_40.pt') +model.load_state_dict(checkpoint['model_state_dict']) + +start_epoch = checkpoint['epoch'] + 1 # 41 + +for epoch in range(start_epoch, 100): # Starts at 41 + # Correctly continues from epoch 41 + train_one_epoch(...) +``` + + +### Pitfall 7: DDP All Ranks Saving + +```python +# ❌ WRONG: All ranks save simultaneously +# In DDP training with 4 GPUs +for epoch in range(100): + train_one_epoch(...) + + # All 4 ranks execute this! + torch.save(checkpoint, 'checkpoint.pt') # Race condition! + +# Problems: +# - 4 processes writing to same file → corruption +# - File system race conditions → truncated file +# - Undefined behavior (may work sometimes, fail others) + +# ✅ CORRECT: Only rank 0 saves +import torch.distributed as dist + +for epoch in range(100): + train_one_epoch(...) + + # Synchronize before saving + dist.barrier() + + # Only rank 0 saves + if dist.get_rank() == 0: + torch.save(checkpoint, 'checkpoint.pt') + + # Wait for rank 0 to finish + dist.barrier() +``` + + +### Pitfall 8: Not Setting All Seeds + +```python +# ❌ WRONG: Partial seed setting +torch.manual_seed(42) # Only PyTorch CPU + +# Missing: +# - torch.cuda.manual_seed() → GPU operations vary +# - np.random.seed() → NumPy augmentation varies +# - random.seed() → Python random varies +# - cuDNN settings → Conv operations vary +# - PYTHONHASHSEED → Dict/set ordering varies + +# Result: Results NOT reproducible despite setting seed + +# ✅ CORRECT: Complete seed setting +import torch +import numpy as np +import random +import os + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ['PYTHONHASHSEED'] = str(seed) + +set_seed(42) +# Now results are reproducible +``` + + +## Rationalization Resistance + +### Table: Shortcuts vs Consequences + +| Rationalization | Why It Seems Right | Actual Consequence | Counter-Argument | +|----------------|-------------------|-------------------|-----------------| +| "Just save model_state_dict, that's the important part" | Model weights are the "learned" part | Optimizer momentum buffers lost, training diverges on resume. SGD without momentum ≠ SGD with momentum. | Optimizer state contains momentum buffers (SGD momentum, Adam first/second moments). Without this, optimizer effectively resets, changing training dynamics. Example: Adam optimizer state is often 2x model size. | +| "torch.manual_seed(42) makes results reproducible" | PyTorch controls model randomness | cuDNN uses non-deterministic algorithms by default. NumPy/Python seeds not set. Results vary across runs. | Requires 7 seeds: torch CPU, torch CUDA, numpy, python, cuDNN deterministic, cuDNN benchmark, PYTHONHASHSEED. Missing any breaks reproducibility. | +| "Checkpointing is simple, don't overthink" | Save occasionally, load when needed | Missing scheduler → LR resets. Missing RNG states → data order differs. Off-by-one → re-runs epoch. Training diverges. | Checkpointing has 10+ components and 5+ pitfalls. Each omission causes different failure mode. Need systematic checklist, not "simple" approach. | +| "Save every epoch to be safe" | More checkpoints = more recovery points | Disk fills up (100 epochs * 500MB = 50GB). I/O overhead slows training. Most checkpoints never used. | Strategic saving (best + last + periodic) provides same recovery capability with 10-20x less storage. Cleanup policy essential. | +| "Rank 0 saves, that's all I need to know for DDP" | One rank saving prevents conflicts | Without dist.barrier(), rank 0 may save mid-step. Other ranks continue, gradients out of sync. Checkpoint inconsistent. | Need barrier BEFORE (sync training step) and AFTER (wait for save). Also need model.module.state_dict() to unwrap DDP. 3+ DDP-specific considerations. | +| "strict=False handles version incompatibility" | Loads despite key mismatches | Missing keys = uninitialized parameters. Model may forward successfully but outputs are wrong. Silent failure. | Must LOG missing/unexpected keys and VALIDATE model output. strict=False is last resort after understanding incompatibility. | +| "RNG states don't matter much" | Randomness averages out | Data augmentation differs, affecting training. Dropout differs, affecting gradients. Initialization differs. Results not reproducible. | RNG states control data order, augmentation, dropout, initialization. Without restoration, resume follows different random trajectory, breaking reproducibility and potentially convergence. | +| "I'll checkpoint after I finish debugging" | Don't want checkpoint code cluttering debug | Training crashes at epoch 47, lose all progress. Or checkpoint code added hastily, incomplete, causes resume issues. | Implement checkpointing FIRST as part of training loop. Debugging with checkpoints allows resuming after OOM/crashes. Later addition is rushed and error-prone. | +| "Model-only checkpoints are sufficient" | Can always retrain optimizer from checkpoint | Optimizer without momentum buffers has different convergence. Fine-tuning from checkpoint diverges. | Model-only checkpoints are fine for inference or transfer learning (new optimizer anyway). For resuming SAME training run, need full checkpoint with optimizer/scheduler. | +| "Cloud storage is too slow for checkpoints" | Local disk is faster | Local disk full → training stops. Hardware failure → checkpoints lost. No backup strategy. | Save locally for speed, async sync to cloud for backup. Best of both: fast local access, cloud durability. Losing checkpoints from 2-week training is unacceptable. | + + +## Red Flags: Checkpoint Issues Checklist + +When reviewing checkpointing implementation or debugging checkpoint-related issues, watch for these red flags: + +**Checkpoint Saving:** +- [ ] Saving `model` instead of `model.state_dict()` (not portable) +- [ ] Missing `optimizer_state_dict` (momentum buffers lost) +- [ ] Missing `scheduler_state_dict` (learning rate resets) +- [ ] Missing RNG states (`rng_state`, `cuda_rng_state`, numpy, python) +- [ ] No validation after saving (corruption goes undetected) +- [ ] Saving too frequently (every iteration/batch, disk fills up) +- [ ] No cleanup policy (old checkpoints accumulate) +- [ ] Hardcoded device in checkpoint (breaks portability) + +**Checkpoint Loading:** +- [ ] Not using `map_location` (device mismatch errors) +- [ ] Using `strict=False` without logging missing/unexpected keys +- [ ] Off-by-one error: `start_epoch = checkpoint['epoch']` instead of `+1` +- [ ] Not restoring RNG states (non-reproducible resume) +- [ ] Not validating checkpoint after loading (assume it's correct) +- [ ] Loading into wrong device (CPU/GPU mismatch) + +**Reproducibility:** +- [ ] Only setting `torch.manual_seed()` (missing 6+ other seeds) +- [ ] Not setting `torch.backends.cudnn.deterministic = True` +- [ ] Not disabling `torch.backends.cudnn.benchmark` +- [ ] DataLoader with `num_workers > 0` but no `worker_init_fn` +- [ ] Not setting `PYTHONHASHSEED` environment variable +- [ ] Using `torch.use_deterministic_algorithms()` without try/except (some ops non-deterministic) +- [ ] Not testing reproducibility (assuming seed setting works) + +**DDP Checkpointing:** +- [ ] All ranks saving checkpoint (corruption, race conditions) +- [ ] No `dist.barrier()` before or after checkpoint saving +- [ ] Saving `model.state_dict()` instead of `model.module.state_dict()` (includes "module." prefix) +- [ ] Not handling DDP wrapper prefix on load +- [ ] Assuming optimizer state works same as single-GPU (misses ZeRO sharding) +- [ ] Checkpoint on local disk not shared filesystem (only rank 0's node has it) + +**Storage Management:** +- [ ] No "best model" checkpoint (only periodic saves) +- [ ] Not overwriting "last checkpoint" (accumulates identical files) +- [ ] Keeping all periodic checkpoints (no cleanup, disk fills) +- [ ] No disk space checking before saving +- [ ] No cloud/backup strategy (single point of failure) +- [ ] Saving full checkpoint every epoch (I/O overhead, unnecessary) + +**Version Compatibility:** +- [ ] No PyTorch version in checkpoint metadata +- [ ] Using `weights_only=False` in PyTorch 2.0+ (security risk) +- [ ] No migration strategy for old checkpoints +- [ ] Assuming checkpoints work across PyTorch versions +- [ ] No documentation of checkpoint format/contents + +**General:** +- [ ] No checkpoint manager class (ad-hoc saving throughout code) +- [ ] Checkpoint saving inside training loop (should be modular function) +- [ ] No error handling on save/load (fails silently) +- [ ] No checkpoint documentation (what's included, how to load) +- [ ] Assuming checkpoints "just work" (no testing of resume behavior) + + +## Quick Reference: Complete Checkpoint Pattern + +```python +import torch +import torch.nn as nn +import torch.distributed as dist +import numpy as np +import random +import os +from pathlib import Path +from typing import Optional + +# ============================================================================ +# 1. REPRODUCIBILITY SETUP (call FIRST, before model/data creation) +# ============================================================================ + +def set_seed(seed: int = 42): + """Complete seed setting for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ['PYTHONHASHSEED'] = str(seed) + +set_seed(42) # Call BEFORE creating model, data, etc. + +# ============================================================================ +# 2. DATALOADER WORKER SEEDING (for reproducibility with num_workers > 0) +# ============================================================================ + +def seed_worker(worker_id: int): + """Seed each DataLoader worker.""" + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +g = torch.Generator() +g.manual_seed(42) + +train_loader = DataLoader( + dataset, + batch_size=32, + shuffle=True, + num_workers=4, + worker_init_fn=seed_worker, + generator=g, +) + +# ============================================================================ +# 3. COMPLETE CHECKPOINT SAVING +# ============================================================================ + +def save_checkpoint( + epoch: int, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + loss: float, + checkpoint_path: str, + scaler: Optional[torch.cuda.amp.GradScaler] = None, + **kwargs +) -> None: + """Save complete training checkpoint.""" + + # Handle DDP model + model_state = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict() + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model_state, + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': loss, + 'rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + 'numpy_rng_state': np.random.get_state(), + 'python_rng_state': random.getstate(), + } + + if scaler is not None: + checkpoint['scaler_state_dict'] = scaler.state_dict() + + checkpoint.update(kwargs) # Additional components + + # DDP: Only rank 0 saves + if dist.is_initialized(): + dist.barrier() + if dist.get_rank() == 0: + torch.save(checkpoint, checkpoint_path) + dist.barrier() + else: + torch.save(checkpoint, checkpoint_path) + +# ============================================================================ +# 4. COMPLETE CHECKPOINT LOADING +# ============================================================================ + +def load_checkpoint( + checkpoint_path: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + device: torch.device, + scaler: Optional[torch.cuda.amp.GradScaler] = None +) -> int: + """Load complete checkpoint and return start_epoch.""" + + checkpoint = torch.load(checkpoint_path, map_location=device) + + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + torch.set_rng_state(checkpoint['rng_state'].cpu()) + torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state']) + np.random.set_state(checkpoint['numpy_rng_state']) + random.setstate(checkpoint['python_rng_state']) + + if scaler is not None and 'scaler_state_dict' in checkpoint: + scaler.load_state_dict(checkpoint['scaler_state_dict']) + + start_epoch = checkpoint['epoch'] + 1 # Resume at NEXT epoch + + return start_epoch + +# ============================================================================ +# 5. TRAINING LOOP WITH CHECKPOINT MANAGEMENT +# ============================================================================ + +# Model, optimizer, scheduler setup +model = MyModel().to(device) +if dist.is_initialized(): + model = DDP(model, device_ids=[dist.get_rank()]) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +# Resume if checkpoint exists +checkpoint_dir = Path('checkpoints') +checkpoint_dir.mkdir(exist_ok=True) +last_ckpt = checkpoint_dir / 'last_checkpoint.pt' + +if last_ckpt.exists(): + start_epoch = load_checkpoint(last_ckpt, model, optimizer, scheduler, device) + print(f"Resumed from epoch {start_epoch}") +else: + start_epoch = 0 + +best_val_loss = float('inf') + +# Training loop +for epoch in range(start_epoch, num_epochs): + # Set epoch for distributed sampler + if hasattr(train_loader.sampler, 'set_epoch'): + train_loader.sampler.set_epoch(epoch) + + train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + scheduler.step() + + # Save checkpoints (all three types) + + # 1. Always save last checkpoint + save_checkpoint( + epoch, model, optimizer, scheduler, val_loss, + checkpoint_path=checkpoint_dir / 'last_checkpoint.pt' + ) + + # 2. Save best checkpoint if validation improved + if val_loss < best_val_loss: + best_val_loss = val_loss + save_checkpoint( + epoch, model, optimizer, scheduler, val_loss, + checkpoint_path=checkpoint_dir / 'best_model.pt', + best_metric=best_val_loss + ) + print(f"Saved best model (val_loss={val_loss:.4f})") + + # 3. Save periodic checkpoint every 10 epochs + if epoch % 10 == 0: + save_checkpoint( + epoch, model, optimizer, scheduler, val_loss, + checkpoint_path=checkpoint_dir / f'checkpoint_epoch_{epoch}.pt' + ) + + # Cleanup old periodic checkpoints (keep last 3) + checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt')) + for old_ckpt in checkpoints[:-3]: + old_ckpt.unlink() +``` + + +## Summary + +**Checkpointing is NOT just saving the model.** A complete checkpoint requires 7+ components: epoch, model state, optimizer state, scheduler state, loss, and RNG states (PyTorch, CUDA, NumPy, Python). Missing any component causes training divergence on resume, learning rate resets, or non-reproducible results. + +**Reproducibility is NOT just torch.manual_seed().** True reproducibility requires seeds across 7 sources: PyTorch CPU, PyTorch CUDA, NumPy, Python random, cuDNN deterministic settings, cuDNN benchmark mode, and PYTHONHASHSEED environment variable. DataLoader with num_workers > 0 needs worker seeding. Some operations are inherently non-deterministic. + +**DDP checkpointing is NOT the same as single-GPU.** Only rank 0 saves (all ranks saving causes corruption). Need dist.barrier() before and after saving. Use model.module.state_dict() to unwrap DDP prefix. All ranks load checkpoints. + +**Checkpoint management is NOT "save occasionally".** Strategic approach: three checkpoint types (best, last, periodic), cleanup policy for old checkpoints, validation after saving, cloud backup for durability. Monitor disk space, use model-only checkpoints for frequent saves. + +**Resume logic is NOT "just load and continue".** Start at checkpoint['epoch'] + 1, not checkpoint['epoch'] (off-by-one causes re-running epochs). Restore all RNG states. Use map_location for device portability. Validate checkpoint makes sense (run validation, check loss matches). + +**Version compatibility is NOT automatic.** Save PyTorch version in metadata. Use weights_only=True in PyTorch 2.0+ for security. Log missing/unexpected keys when using strict=False. Have migration strategy for old checkpoints. + +These practices ensure training continuity, reproducibility, and checkpoint integrity across crashes, version changes, and distributed training scenarios. diff --git a/skills/using-pytorch-engineering/custom-autograd-functions.md b/skills/using-pytorch-engineering/custom-autograd-functions.md new file mode 100644 index 0000000..5566757 --- /dev/null +++ b/skills/using-pytorch-engineering/custom-autograd-functions.md @@ -0,0 +1,2828 @@ + +## Overview + +This skill teaches you to implement custom autograd functions in PyTorch using `torch.autograd.Function`. You'll learn when custom functions are necessary, how to implement forward and backward passes correctly, mandatory numerical verification with gradcheck, and common pitfalls that break gradient computation. + +## When This Skill Applies + +Use `torch.autograd.Function` when you encounter these situations: + +### Symptoms That Require Custom Functions + +1. **Custom operations not in PyTorch**: Implementing novel mathematical operations (special activations, custom loss components, domain-specific transformations) + +2. **Wrapping external code**: Interfacing with C++/CUDA kernels, third-party libraries, or compiled extensions that PyTorch doesn't know about + +3. **Custom gradient behavior**: Need non-standard gradient computation (gradient clipping in backward, sparsification, quantization, gradient routing) + +4. **Memory optimization**: Implementing gradient checkpointing, fused operations, or selective materialization to reduce memory footprint + +5. **Interfacing with non-differentiable code**: Wrapping operations that aren't naturally differentiable but have known gradient behavior + +### When NOT to Use Custom Functions + +Don't use `torch.autograd.Function` when: + +- ❌ **Operation composable from existing ops**: If you can write it with standard PyTorch operations, autograd handles gradients automatically +- ❌ **Simple function wrapping**: Just wrapping `torch.nn.functional` operations gains nothing +- ❌ **Standard gradient computation**: No custom behavior needed - use regular PyTorch +- ❌ **Avoiding learning curve**: Custom functions aren't "advanced only" - use when appropriate + +**Example**: +```python +# DON'T: Unnecessary custom function +class MyAdd(Function): # ❌ Pointless wrapper + @staticmethod + def forward(ctx, a, b): + return a + b + @staticmethod + def backward(ctx, grad): + return grad, grad + +# DO: Use PyTorch's autograd +output = a + b # ✅ Autograd handles this correctly + +# DON'T: Reimplement existing operations +class MyReLU(Function): # ❌ Use torch.nn.functional.relu + ... + +# DO: Use built-in operations +output = torch.relu(input) # ✅ Efficient and correct +``` + +## Core Pattern: Complete Function Implementation + +### Basic Template + +Every custom autograd function follows this pattern: + +```python +import torch +from torch.autograd import Function + +class MyCustomFunction(Function): + """ + Custom autograd function template. + + Implements forward pass (computation) and backward pass (gradient). + Context object (ctx) saves data between forward and backward. + """ + + @staticmethod + def forward(ctx, input, weight, bias=None): + """ + Forward pass: compute output from inputs. + + Args: + ctx: Context object for saving tensors/data + input: First input tensor + weight: Second input tensor + bias: Optional bias tensor + + Returns: + output: Result tensor + """ + # Save tensors needed for backward pass + ctx.save_for_backward(input, weight, bias) + + # Save non-tensor data as ctx attributes + # ctx.some_value = some_non_tensor_data + + # Compute forward pass + output = input.mm(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass: compute gradients using chain rule. + + Args: + ctx: Context object with saved data + grad_output: Gradient of loss w.r.t. output (dL/dY) + + Returns: + Tuple of gradients for each input to forward(): + (grad_input, grad_weight, grad_bias) + Must return one gradient per forward() argument (None if not needed) + """ + # Retrieve saved tensors + input, weight, bias = ctx.saved_tensors + + # Initialize gradients to None + grad_input = grad_weight = grad_bias = None + + # Compute gradients only if needed (efficiency) + if ctx.needs_input_grad[0]: # Check if input needs gradient + grad_input = grad_output.mm(weight) # dL/dX = dL/dY @ W + + if ctx.needs_input_grad[1]: # Check if weight needs gradient + grad_weight = grad_output.t().mm(input) # dL/dW = dL/dY.T @ X + + if bias is not None and ctx.needs_input_grad[2]: # Check if bias needs gradient + grad_bias = grad_output.sum(0) # Sum over batch dimension + + # Must return gradient for each forward() input (including ctx) + # First return is always None (ctx doesn't need gradient) + # But since ctx is implicit, return one per actual argument + return grad_input, grad_weight, grad_bias + +# Use the custom function +my_func = MyCustomFunction.apply # Get callable +output = my_func(input_tensor, weight_tensor, bias_tensor) +loss = output.sum() +loss.backward() # Calls MyCustomFunction.backward() +``` + +### Critical Rules + +**Rule 1: Return gradient for EACH forward input** +```python +# forward signature: forward(ctx, a, b, c, d=None) +# backward must return: (grad_a, grad_b, grad_c, grad_d) + +# ✅ CORRECT: 4 inputs → 4 gradient returns +def backward(ctx, grad_output): + return grad_a, grad_b, grad_c, grad_d + +# ❌ WRONG: Missing gradients +def backward(ctx, grad_output): + return grad_a, grad_b # Only 2 - will crash! + +# ✅ CORRECT: Use None for unused gradients +def backward(ctx, grad_output): + return grad_a, None, grad_c, None # b and d don't need gradients +``` + +**Rule 2: Gradient shape must match input shape exactly** +```python +# If input.shape = (32, 128) +# Then grad_input.shape MUST BE (32, 128) + +# ✅ CORRECT: Shapes match +assert grad_input.shape == input.shape + +# ❌ WRONG: Shape mismatch causes runtime error +grad_input = some_computation() # Shape (32, 64) - WRONG! +``` + +**Rule 3: Check needs_input_grad before computing** +```python +# Efficiency optimization: skip gradient computation if not needed + +# ✅ CORRECT: Check before computing +if ctx.needs_input_grad[0]: + grad_input = expensive_gradient_computation(...) +else: + grad_input = None + +# ❌ WASTEFUL: Always compute (slow) +grad_input = expensive_gradient_computation(...) # Even if not needed +``` + +## Context Object (ctx) Rules + +The context object `ctx` is how you pass data from forward to backward. Use it correctly or break everything. + +### Rule 1: Use save_for_backward() for Tensors ONLY + +```python +# ✅ CORRECT: Save tensors with save_for_backward() +@staticmethod +def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) # Both are tensors + return input @ weight + +# ❌ WRONG: Saving tensors as attributes +@staticmethod +def forward(ctx, input, weight): + ctx.input = input # Breaks memory tracking + ctx.weight = weight # Breaks memory tracking + return input @ weight + +# Why it matters: save_for_backward() properly tracks tensor versions +# and memory. Attribute assignment doesn't, leading to bugs/crashes. +``` + +### Rule 2: Save Non-Tensor Data as Attributes + +```python +# ✅ CORRECT: Non-tensor data as attributes +@staticmethod +def forward(ctx, input, kernel_size, stride): + ctx.save_for_backward(input) # Tensor + ctx.kernel_size = kernel_size # Integer - use attribute + ctx.stride = stride # Integer - use attribute + return some_operation(input, kernel_size, stride) + +# ❌ WRONG: Trying to save non-tensors with save_for_backward() +@staticmethod +def forward(ctx, input, kernel_size, stride): + ctx.save_for_backward(input, kernel_size, stride) # TypeError! + # kernel_size and stride are ints, not tensors +``` + +### Rule 3: Access saved_tensors Only in Backward + +```python +# ✅ CORRECT: Access in backward +@staticmethod +def backward(ctx, grad_output): + input, weight = ctx.saved_tensors # Available here + grad_input = grad_output @ weight.t() + return grad_input, None + +# ❌ WRONG: Access in forward +@staticmethod +def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + output = input @ weight + saved = ctx.saved_tensors # AttributeError! Not available in forward + return output +``` + +### Rule 4: Never Modify Saved Tensors + +```python +# ❌ WRONG: Modifying saved tensor +@staticmethod +def backward(ctx, grad_output): + input, = ctx.saved_tensors + input = input * 2 # Creates new tensor - OK + input *= 2 # IN-PLACE modification - BREAKS AUTOGRAD! + grad_input = compute_gradient(input, grad_output) + return grad_input + +# ✅ CORRECT: Don't modify, or clone first +@staticmethod +def backward(ctx, grad_output): + input, = ctx.saved_tensors + input_scaled = input * 2 # New tensor - safe + grad_input = compute_gradient(input_scaled, grad_output) + return grad_input +``` + +### Complete ctx Example + +```python +class CompleteCtxExample(Function): + @staticmethod + def forward(ctx, input, weight, bias, stride=1, training=True): + # Save tensors with save_for_backward() + ctx.save_for_backward(input, weight, bias) + + # Save non-tensor data as attributes + ctx.stride = stride # int + ctx.training = training # bool + + # Compute forward + output = some_computation(input, weight, bias, stride, training) + return output + + @staticmethod + def backward(ctx, grad_output): + # Retrieve saved tensors + input, weight, bias = ctx.saved_tensors + + # Retrieve non-tensor data + stride = ctx.stride + training = ctx.training + + # Compute gradients + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = compute_input_gradient(grad_output, weight, stride) + if ctx.needs_input_grad[1]: + grad_weight = compute_weight_gradient(grad_output, input, stride) + if ctx.needs_input_grad[2]: + grad_bias = compute_bias_gradient(grad_output) + + # Return gradients (None for stride and training - they're not tensors) + return grad_input, grad_weight, grad_bias, None, None +``` + +## Gradient Computation Patterns + +Understanding common gradient patterns helps implement backward() correctly. + +### Pattern 1: Element-wise Operations + +Forward: `y = f(x)` (element-wise function) +Backward: `grad_x = grad_y * f'(x)` (element-wise multiply) + +```python +# Example: Custom ReLU +class CustomReLU(Function): + @staticmethod + def forward(ctx, input): + # Save input to compute derivative in backward + ctx.save_for_backward(input) + # ReLU: max(0, x) + output = input.clamp(min=0) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + # Derivative: 1 if x > 0, else 0 + grad_input = grad_output.clone() + grad_input[input < 0] = 0 + return grad_input + +# Example: Custom Sigmoid +class CustomSigmoid(Function): + @staticmethod + def forward(ctx, input): + output = torch.sigmoid(input) + # Save output (more efficient than recomputing) + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + # Derivative: sigmoid(x) * (1 - sigmoid(x)) + grad_input = grad_output * output * (1 - output) + return grad_input + +# Example: Custom Tanh +class CustomTanh(Function): + @staticmethod + def forward(ctx, input): + output = torch.tanh(input) + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + # Derivative: 1 - tanh^2(x) + grad_input = grad_output * (1 - output * output) + return grad_input +``` + +### Pattern 2: Matrix Operations (Chain Rule) + +Forward: `Y = X @ W` (matrix multiply) +Backward: Apply chain rule with matrix transpose + +```python +class CustomLinear(Function): + @staticmethod + def forward(ctx, input, weight): + # Save both tensors for backward + ctx.save_for_backward(input, weight) + # Forward: Y = X @ W + output = input.mm(weight.t()) # (batch, in) @ (out, in).T = (batch, out) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = grad_weight = None + + # Gradient w.r.t. input: dL/dX = dL/dY @ W + # Shapes: (batch, out) @ (out, in) = (batch, in) ✓ + if ctx.needs_input_grad[0]: + grad_input = grad_output.mm(weight) + + # Gradient w.r.t. weight: dL/dW = dL/dY.T @ X + # Then transpose to match weight shape + # Shapes: (out, batch) @ (batch, in) = (out, in) ✓ + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().mm(input) + + return grad_input, grad_weight + +# More complex: Matrix multiply with both inputs requiring gradients +class CustomMatmul(Function): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + # Forward: C = A @ B + return torch.matmul(a, b) + + @staticmethod + def backward(ctx, grad_output): + a, b = ctx.saved_tensors + grad_a = grad_b = None + + # Gradient w.r.t. A: dL/dA = dL/dC @ B.T + if ctx.needs_input_grad[0]: + grad_a = torch.matmul(grad_output, b.transpose(-2, -1)) + + # Gradient w.r.t. B: dL/dB = A.T @ dL/dC + if ctx.needs_input_grad[1]: + grad_b = torch.matmul(a.transpose(-2, -1), grad_output) + + return grad_a, grad_b +``` + +### Pattern 3: Broadcasting Operations + +Forward: Operation with broadcasting (e.g., adding bias) +Backward: Sum over broadcasted dimensions + +```python +class CustomBiasAdd(Function): + @staticmethod + def forward(ctx, input, bias): + # input: (batch, channels, height, width) + # bias: (channels,) + ctx.save_for_backward(input, bias) + # Broadcasting adds bias to each channel + output = input + bias.view(1, -1, 1, 1) + return output + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + grad_input = grad_bias = None + + # Gradient w.r.t. input: just pass through + if ctx.needs_input_grad[0]: + grad_input = grad_output + + # Gradient w.r.t. bias: sum over broadcasted dimensions + # grad_output: (batch, channels, height, width) + # grad_bias should be: (channels,) + # Sum over batch (0), height (2), width (3) + if ctx.needs_input_grad[1]: + grad_bias = grad_output.sum(dim=(0, 2, 3)) + + return grad_input, grad_bias + +# General broadcasting pattern +class CustomBroadcastOp(Function): + @staticmethod + def forward(ctx, input, param): + # Save shapes to determine broadcast dimensions + ctx.input_shape = input.shape + ctx.param_shape = param.shape + ctx.save_for_backward(input, param) + + # Some operation with broadcasting + output = input * param # param broadcasts to input shape + return output + + @staticmethod + def backward(ctx, grad_output): + input, param = ctx.saved_tensors + grad_input = grad_param = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output * param + + if ctx.needs_input_grad[1]: + # Sum grad_output over dimensions that were broadcasted + grad_param = grad_output * input + + # Find which dimensions were broadcasted + # Sum over those dimensions to match param shape + ndim_diff = len(ctx.input_shape) - len(ctx.param_shape) + for i in range(ndim_diff): + grad_param = grad_param.sum(0) # Sum leading dimensions + + for i, (input_dim, param_dim) in enumerate( + zip(ctx.input_shape[ndim_diff:], ctx.param_shape) + ): + if param_dim == 1 and input_dim > 1: + grad_param = grad_param.sum(i, keepdim=True) + + return grad_input, grad_param +``` + +### Pattern 4: Reduction Operations + +Forward: Reduce dimensions (sum, mean, max, etc.) +Backward: Expand gradient back to original shape + +```python +class CustomSum(Function): + @staticmethod + def forward(ctx, input, dim, keepdim=False): + ctx.input_shape = input.shape + ctx.dim = dim + ctx.keepdim = keepdim + # Sum along dimension + output = input.sum(dim=dim, keepdim=keepdim) + return output + + @staticmethod + def backward(ctx, grad_output): + # Gradient of sum: distribute grad_output to all elements + grad_input = grad_output + + # Expand back to original shape + if not ctx.keepdim: + # Add back the reduced dimension + grad_input = grad_input.unsqueeze(ctx.dim) + + # Expand to original shape (broadcasts the gradient) + grad_input = grad_input.expand(ctx.input_shape) + + return grad_input, None, None + +class CustomMean(Function): + @staticmethod + def forward(ctx, input, dim, keepdim=False): + ctx.input_shape = input.shape + ctx.dim = dim + ctx.keepdim = keepdim + output = input.mean(dim=dim, keepdim=keepdim) + return output + + @staticmethod + def backward(ctx, grad_output): + # Gradient of mean: distribute evenly to all elements + grad_input = grad_output + + if not ctx.keepdim: + grad_input = grad_input.unsqueeze(ctx.dim) + + grad_input = grad_input.expand(ctx.input_shape) + + # Divide by number of elements that were averaged + n = ctx.input_shape[ctx.dim] + grad_input = grad_input / n + + return grad_input, None, None + +class CustomMax(Function): + @staticmethod + def forward(ctx, input, dim, keepdim=False): + # Save both max values and indices + output, indices = input.max(dim=dim, keepdim=keepdim) + ctx.save_for_backward(indices) + ctx.input_shape = input.shape + ctx.dim = dim + ctx.keepdim = keepdim + return output + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + + # Only the maximum element gets gradient + grad_input = torch.zeros(ctx.input_shape, device=grad_output.device) + + if not ctx.keepdim: + grad_output = grad_output.unsqueeze(ctx.dim) + indices = indices.unsqueeze(ctx.dim) + + # Scatter gradient to max indices + grad_input.scatter_(ctx.dim, indices, grad_output) + + return grad_input, None, None +``` + +### Pattern 5: Convolution-like Operations + +Complex operations that involve multiple dimensions and strides. + +```python +class CustomConv1d(Function): + @staticmethod + def forward(ctx, input, weight, bias=None, stride=1, padding=0): + # Use PyTorch's conv for forward + output = torch.nn.functional.conv1d( + input, weight, bias, stride, padding + ) + + # Save what's needed for backward + ctx.save_for_backward(input, weight, bias) + ctx.stride = stride + ctx.padding = padding + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + stride = ctx.stride + padding = ctx.padding + + grad_input = grad_weight = grad_bias = None + + # Gradient w.r.t. input: convolve grad_output with weight + if ctx.needs_input_grad[0]: + grad_input = torch.nn.grad.conv1d_input( + input.shape, weight, grad_output, stride, padding + ) + + # Gradient w.r.t. weight: convolve input with grad_output + if ctx.needs_input_grad[1]: + grad_weight = torch.nn.grad.conv1d_weight( + input, weight.shape, grad_output, stride, padding + ) + + # Gradient w.r.t. bias: sum grad_output over batch and spatial dims + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum((0, 2)) + + return grad_input, grad_weight, grad_bias, None, None +``` + +## Numerical Gradient Verification (MANDATORY) + +**NEVER skip this step.** `gradcheck` verifies your gradient computation is correct by comparing analytical gradients (your backward()) against numerical gradients (finite differences). + +### Why gradcheck is Mandatory + +```python +# You implement backward() and it "looks right" +class MyFunction(Function): + @staticmethod + def backward(ctx, grad_output): + # Gradient formula looks correct mathematically + grad = some_computation() + return grad + +# But: +# ❌ Transpose is wrong +# ❌ Shape doesn't match +# ❌ Sign is flipped +# ❌ Missing a term +# ❌ Broadcasting is incorrect + +# These bugs are invisible without gradcheck! +# Your model trains but produces wrong results. +# Debugging takes days without knowing gradients are wrong. +``` + +### Basic gradcheck Usage + +```python +import torch +from torch.autograd import gradcheck + +def test_my_function(): + """Test custom function with gradcheck.""" + + # Create test inputs with requires_grad=True + # Use double precision for numerical stability + input = torch.randn(20, 20, dtype=torch.double, requires_grad=True) + weight = torch.randn(30, 20, dtype=torch.double, requires_grad=True) + + # Run gradcheck + test = gradcheck( + MyCustomFunction.apply, # Your function + (input, weight), # Tuple of inputs + eps=1e-6, # Finite difference step size + atol=1e-4, # Absolute tolerance + rtol=1e-3, # Relative tolerance + raise_exception=True # Raise error on failure (recommended) + ) + + if test: + print("✅ Gradient check PASSED!") + else: + print("❌ Gradient check FAILED!") + raise AssertionError("Gradient check failed") + +# Run before using your function +test_my_function() +``` + +### Complete gradcheck Pattern + +```python +import torch +from torch.autograd import gradcheck, gradgradcheck + +class MyCustomFunction(Function): + @staticmethod + def forward(ctx, input, weight, bias): + ctx.save_for_backward(input, weight, bias) + output = input.mm(weight.t()) + if bias is not None: + output += bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output.mm(weight) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().mm(input) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +def test_my_custom_function(): + """Comprehensive gradient testing.""" + + # Test 1: Basic gradcheck + print("Test 1: Basic gradient check...") + input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) + weight = torch.randn(3, 5, dtype=torch.double, requires_grad=True) + bias = torch.randn(3, dtype=torch.double, requires_grad=True) + + assert gradcheck( + MyCustomFunction.apply, + (input, weight, bias), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "Basic gradcheck failed" + print("✅ Basic gradcheck passed") + + # Test 2: Without bias (optional parameter) + print("Test 2: Gradient check without bias...") + assert gradcheck( + MyCustomFunction.apply, + (input, weight, None), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "Gradcheck without bias failed" + print("✅ Gradcheck without bias passed") + + # Test 3: Different input shapes + print("Test 3: Different input shapes...") + input_large = torch.randn(50, 20, dtype=torch.double, requires_grad=True) + weight_large = torch.randn(10, 20, dtype=torch.double, requires_grad=True) + bias_large = torch.randn(10, dtype=torch.double, requires_grad=True) + + assert gradcheck( + MyCustomFunction.apply, + (input_large, weight_large, bias_large), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "Gradcheck with large inputs failed" + print("✅ Gradcheck with different shapes passed") + + # Test 4: Second-order gradients (if needed) + print("Test 4: Second-order gradient check...") + try: + assert gradgradcheck( + MyCustomFunction.apply, + (input, weight, bias), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "Second-order gradcheck failed" + print("✅ Second-order gradcheck passed") + except NotImplementedError: + print("⚠️ Second-order gradients not implemented (OK if not needed)") + + print("\n🎉 All gradient checks passed!") + +# ALWAYS run this before using your function in training +test_my_custom_function() +``` + +### gradcheck Parameters + +```python +gradcheck( + func, # Your Function.apply + inputs, # Tuple of input tensors + eps=1e-6, # Finite difference step: f(x+eps) - f(x-eps) + atol=1e-5, # Absolute tolerance for comparison + rtol=1e-3, # Relative tolerance for comparison + raise_exception=True, # Raise on failure (recommended for testing) + check_sparse_nnz=False, # Check sparse tensor non-zeros + nondet_tol=0.0, # Tolerance for non-deterministic operations + check_undefined_grad=True, # Check that undefined grads are None + check_grad_dtypes=True, # Check gradient dtypes match +) + +# Key insights: +# - Use double precision (dtype=torch.double) for numerical stability +# - eps=1e-6 is good default; smaller for more precision, larger for stability +# - atol/rtol balance: looser tolerances for complex operations +# - raise_exception=True catches bugs immediately in testing +``` + +### Debugging gradcheck Failures + +```python +def debug_gradcheck(): + """Step-by-step debugging when gradcheck fails.""" + + # Step 1: Check forward pass works + print("Step 1: Verify forward pass...") + input = torch.randn(5, 3, dtype=torch.double, requires_grad=True) + weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True) + + output = MyCustomFunction.apply(input, weight) + print(f"Output shape: {output.shape}") + print(f"Output contains NaN: {torch.isnan(output).any()}") + print(f"Output contains Inf: {torch.isinf(output).any()}") + assert output.shape == (5, 4), "Forward shape wrong" + assert not torch.isnan(output).any(), "Forward produces NaN" + + # Step 2: Check backward runs without error + print("\nStep 2: Verify backward runs...") + loss = output.sum() + loss.backward() + print(f"Input grad shape: {input.grad.shape}") + print(f"Weight grad shape: {weight.grad.shape}") + assert input.grad.shape == input.shape, "Input gradient shape mismatch" + assert weight.grad.shape == weight.shape, "Weight gradient shape mismatch" + + # Step 3: Check gradient magnitudes + print("\nStep 3: Check gradient magnitudes...") + print(f"Input grad: mean={input.grad.mean():.6f}, std={input.grad.std():.6f}") + print(f"Weight grad: mean={weight.grad.mean():.6f}, std={weight.grad.std():.6f}") + # Should be reasonable numbers (not 1e10 or 1e-20) + + # Step 4: Manual numerical gradient check for one element + print("\nStep 4: Manual gradient check for one element...") + input_test = torch.randn(3, 2, dtype=torch.double) + weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True) + + eps = 1e-6 + + # Analytical gradient + output = MyCustomFunction.apply(input_test, weight_test) + loss = output.sum() + loss.backward() + analytical_grad = weight_test.grad[0, 0].item() + + # Numerical gradient (finite difference) + weight_test_plus = weight_test.clone().detach() + weight_test_plus[0, 0] += eps + output_plus = MyCustomFunction.apply(input_test, weight_test_plus) + loss_plus = output_plus.sum() + + weight_test_minus = weight_test.clone().detach() + weight_test_minus[0, 0] -= eps + output_minus = MyCustomFunction.apply(input_test, weight_test_minus) + loss_minus = output_minus.sum() + + numerical_grad = (loss_plus - loss_minus) / (2 * eps) + numerical_grad = numerical_grad.item() + + print(f"Analytical gradient: {analytical_grad:.10f}") + print(f"Numerical gradient: {numerical_grad:.10f}") + print(f"Difference: {abs(analytical_grad - numerical_grad):.10e}") + + if abs(analytical_grad - numerical_grad) > 1e-4: + print("❌ Large difference - gradient implementation likely wrong") + else: + print("✅ Small difference - gradient likely correct") + + # Step 5: Run gradcheck with verbose output + print("\nStep 5: Run gradcheck...") + input_check = torch.randn(3, 2, dtype=torch.double, requires_grad=True) + weight_check = torch.randn(2, 2, dtype=torch.double, requires_grad=True) + + try: + result = gradcheck( + MyCustomFunction.apply, + (input_check, weight_check), + eps=1e-6, + atol=1e-4, + raise_exception=True + ) + print("✅ gradcheck passed!") + except RuntimeError as e: + print(f"❌ gradcheck failed with error:\n{e}") + # Error message shows which gradient failed and by how much + +debug_gradcheck() +``` + +### Common gradcheck Failure Reasons + +```python +# Failure 1: Wrong gradient formula +class WrongGradient(Function): + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + # ❌ WRONG: No transpose + grad_input = grad_output @ weight # Should be weight.t() + return grad_input, None +# gradcheck fails: analytical ≠ numerical + +# Failure 2: Shape mismatch +class WrongShape(Function): + @staticmethod + def backward(ctx, grad_output): + # ❌ WRONG: Returns wrong shape + return grad_output.sum(), None # Should be grad_output.shape == input.shape +# gradcheck fails: shape error + +# Failure 3: In-place operation +class InplaceOperation(Function): + @staticmethod + def backward(ctx, grad_output): + grad_output[grad_output < 0] = 0 # ❌ IN-PLACE + return grad_output +# gradcheck fails: modified by inplace operation + +# Failure 4: Not using saved tensors correctly +class WrongSaved(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input.clone()) # Saved clone + return input * 2 + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + # Using saved tensor is OK, but if logic depends on + # input's original properties and clone loses them, fails + return grad_output * 2 +# May pass or fail depending on what was lost in clone + +# Failure 5: Forgot to return gradients for all inputs +class MissingGradient(Function): + @staticmethod + def forward(ctx, input, weight, bias): + # ... + return output + + @staticmethod + def backward(ctx, grad_output): + # ❌ WRONG: Only returns 2 gradients for 3 inputs + return grad_input, grad_weight +# gradcheck fails: tuple length mismatch +``` + +## Common Pitfalls + +### Pitfall 1: In-Place Operations + +In-place operations modify tensors that other operations depend on, breaking autograd. + +```python +# ❌ WRONG: In-place operation in forward +class InplaceForward(Function): + @staticmethod + def forward(ctx, input): + input[input < 0] = 0 # IN-PLACE - breaks autograd! + return input + +# ❌ WRONG: In-place operation in backward +class InplaceBackward(Function): + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_output[input < 0] = 0 # IN-PLACE - breaks autograd! + return grad_output + +# ✅ CORRECT: Create new tensor +class CorrectInplace(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = input.clone() # Create new tensor + output[output < 0] = 0 # Modify copy, not original + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_input = grad_output.clone() # Create new tensor + grad_input[input < 0] = 0 # Modify copy + return grad_input + +# Even better: Use non-in-place operations +class BestInplace(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return input.clamp(min=0) # Non-in-place + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_input = grad_output * (input > 0).float() # Non-in-place + return grad_input +``` + +### Pitfall 2: Gradient Shape Mismatch + +Gradient must match input shape exactly. + +```python +# ❌ WRONG: Gradient shape doesn't match input +class WrongShape(Function): + @staticmethod + def forward(ctx, input): + # input: (32, 128) + ctx.save_for_backward(input) + return input.sum() # scalar + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + # grad_output is scalar, but need (32, 128) + return grad_output # ❌ WRONG: scalar ≠ (32, 128) + +# ✅ CORRECT: Expand to match input shape +class CorrectShape(Function): + @staticmethod + def forward(ctx, input): + ctx.input_shape = input.shape + return input.sum() + + @staticmethod + def backward(ctx, grad_output): + # Expand scalar to input shape + grad_input = grad_output.expand(ctx.input_shape) + return grad_input + +# Always verify shapes +def verify_shapes(ctx, grad_output, *grad_inputs): + """Helper to verify gradient shapes.""" + for i, (grad, tensor) in enumerate(zip(grad_inputs, ctx.saved_tensors)): + if grad is not None: + assert grad.shape == tensor.shape, \ + f"Gradient {i} shape {grad.shape} != input shape {tensor.shape}" +``` + +### Pitfall 3: Not Checking needs_input_grad + +Computing gradients when not needed wastes computation. + +```python +# ❌ WASTEFUL: Always compute all gradients +class AlwaysCompute(Function): + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + + # Compute all gradients even if not needed + grad_input = expensive_computation_1(grad_output, weight) + grad_weight = expensive_computation_2(grad_output, input) + grad_bias = expensive_computation_3(grad_output) + + return grad_input, grad_weight, grad_bias + +# ✅ EFFICIENT: Check needs_input_grad first +class EfficientCompute(Function): + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + # Only compute if needed + if ctx.needs_input_grad[0]: + grad_input = expensive_computation_1(grad_output, weight) + + if ctx.needs_input_grad[1]: + grad_weight = expensive_computation_2(grad_output, input) + + if ctx.needs_input_grad[2]: + grad_bias = expensive_computation_3(grad_output) + + return grad_input, grad_weight, grad_bias + +# Example where it matters +def example_needs_input_grad(): + """Demonstrate needs_input_grad optimization.""" + + input = torch.randn(100, 100, requires_grad=True) + weight = torch.randn(100, 100, requires_grad=False) # No gradient needed + + # Without check: computes grad_weight unnecessarily + # With check: skips grad_weight computation (faster) + + output = MyFunction.apply(input, weight) + loss = output.sum() + loss.backward() + + # weight.grad is None because requires_grad=False + assert weight.grad is None +``` + +### Pitfall 4: Using .data Instead of .detach() + +`.data` bypasses autograd tracking incorrectly. + +```python +# ❌ WRONG: Using .data +class UsingData(Function): + @staticmethod + def forward(ctx, input): + # .data returns tensor without autograd tracking + # But doesn't properly detach from computation graph + ctx.save_for_backward(input.data) # ❌ WRONG + return input * 2 + + @staticmethod + def backward(ctx, grad_output): + input_data, = ctx.saved_tensors + # May produce incorrect gradients + return grad_output * 2 + +# ✅ CORRECT: Use .detach() or save normally +class UsingDetach(Function): + @staticmethod + def forward(ctx, input): + # If you need to save without tracking gradient + ctx.save_for_backward(input.detach()) # ✅ Properly detaches + # Or just save normally if gradient tracking is OK + ctx.save_for_backward(input) # ✅ Most common + return input * 2 + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output * 2 + +# When to detach in custom functions +class WhenToDetach(Function): + @staticmethod + def forward(ctx, input, target): + # Save input normally (gradient needed) + # Detach target (gradient not needed, just reference) + ctx.save_for_backward(input, target.detach()) + + loss = (input - target).pow(2).mean() + return loss + + @staticmethod + def backward(ctx, grad_output): + input, target = ctx.saved_tensors + # Compute gradient w.r.t. input only + grad_input = 2 * (input - target) * grad_output + return grad_input, None +``` + +### Pitfall 5: Modifying grad_output + +Never modify grad_output in-place; it's used by other operations. + +```python +# ❌ WRONG: Modifying grad_output in-place +class ModifyGradOutput(Function): + @staticmethod + def backward(ctx, grad_output): + # ❌ WRONG: In-place modification + grad_output *= 2 + return grad_output + +# ✅ CORRECT: Create new tensor +class DontModifyGradOutput(Function): + @staticmethod + def backward(ctx, grad_output): + # ✅ Create new tensor + grad_input = grad_output * 2 + return grad_input + +# Why it matters +def why_grad_output_matters(): + """Demonstrate why modifying grad_output breaks autograd.""" + + # Consider: z = f(g(x)) + # Backward: dz/dx = dz/dg * dg/dx + # + # If f.backward() modifies grad_output (dz/dg), + # then g.backward() receives wrong gradient! + + x = torch.randn(5, requires_grad=True) + + # g(x) + y = x * 2 + + # f(g(x)) - uses custom function that modifies grad_output + z = BadFunction.apply(y) + + z.backward() + + # x.grad is now WRONG because BadFunction modified grad_output + # that was passed to y's backward +``` + +### Pitfall 6: Forgetting to Return None for Non-Tensor Arguments + +Must return gradient for every forward() argument. + +```python +# ❌ WRONG: Not enough return values +class NotEnoughReturns(Function): + @staticmethod + def forward(ctx, input, kernel_size, stride): + # 3 arguments (excluding ctx) + ctx.save_for_backward(input) + ctx.kernel_size = kernel_size + ctx.stride = stride + return some_operation(input, kernel_size, stride) + + @staticmethod + def backward(ctx, grad_output): + # ❌ WRONG: Only returns 1 value for 3 inputs + return grad_output # Crashes: expected 3 values + +# ✅ CORRECT: Return gradient (or None) for each input +class EnoughReturns(Function): + @staticmethod + def forward(ctx, input, kernel_size, stride): + ctx.save_for_backward(input) + ctx.kernel_size = kernel_size + ctx.stride = stride + return some_operation(input, kernel_size, stride) + + @staticmethod + def backward(ctx, grad_output): + # ✅ Return 3 values: grad for input, None for kernel_size, None for stride + return grad_output, None, None + +# Rule of thumb +def count_returns(): + """ + backward() must return one value per forward() argument (excluding ctx). + + forward(ctx, a, b, c, d=None) → backward must return (grad_a, grad_b, grad_c, grad_d) + + Use None for: + - Non-tensor arguments (ints, strings, etc.) + - Optional arguments that were None + - Tensors that don't need gradients + """ + pass +``` + +### Pitfall 7: Incorrect Broadcasting in Gradient + +Gradient must account for broadcasting that occurred in forward. + +```python +# ❌ WRONG: Doesn't handle broadcasting correctly +class WrongBroadcast(Function): + @staticmethod + def forward(ctx, input, weight): + # input: (32, 64, 10, 10) + # weight: (64, 1, 1) - broadcasts to (32, 64, 10, 10) + ctx.save_for_backward(input, weight) + output = input * weight # Broadcasting happens + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + grad_input = grad_output * weight # ✅ This is fine + + # ❌ WRONG: grad_weight shape is (32, 64, 10, 10), should be (64, 1, 1) + grad_weight = grad_output * input + return grad_input, grad_weight # Shape mismatch! + +# ✅ CORRECT: Sum over broadcasted dimensions +class CorrectBroadcast(Function): + @staticmethod + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + ctx.input_shape = input.shape + ctx.weight_shape = weight.shape + output = input * weight + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + grad_input = grad_output * weight + + # ✅ Sum over dimensions that were broadcasted + grad_weight = grad_output * input + # weight shape: (64, 1, 1), grad_weight current: (32, 64, 10, 10) + # Sum over batch (0), height (2), width (3) + grad_weight = grad_weight.sum(dim=(0, 2, 3), keepdim=True) + # Now grad_weight shape: (1, 64, 1, 1) + # Squeeze batch dimension: (64, 1, 1) ✅ + grad_weight = grad_weight.squeeze(0) + + return grad_input, grad_weight + +# General broadcasting gradient pattern +def sum_to_shape(tensor, shape): + """Sum tensor to match target shape (handles broadcasting).""" + # Find dimensions that were added + while tensor.dim() > len(shape): + tensor = tensor.sum(0) + + # Find dimensions that were size 1 and got broadcasted + for i, (t_dim, s_dim) in enumerate(zip(tensor.shape, shape)): + if s_dim == 1 and t_dim > 1: + tensor = tensor.sum(i, keepdim=True) + + return tensor + +class GeneralBroadcast(Function): + @staticmethod + def backward(ctx, grad_output): + input, param = ctx.saved_tensors + + grad_input = grad_output * param + + grad_param = grad_output * input + # Sum to match param's original shape + grad_param = sum_to_shape(grad_param, param.shape) + + return grad_input, grad_param +``` + +## Memory Efficiency Patterns + +Custom functions enable memory optimizations not possible with standard autograd. + +### Pattern 1: Gradient Checkpointing + +Trade computation for memory by recomputing forward in backward. + +```python +class CheckpointedFunction(Function): + """ + Gradient checkpointing: Don't save activations, recompute in backward. + + Memory: O(1) instead of O(n) for n layers + Time: 2x forward pass (once in forward, once in backward) + """ + + @staticmethod + def forward(ctx, input, *args): + # Save inputs and parameters, NOT intermediate activations + ctx.save_for_backward(input, *args) + + # Compute forward pass (activations not saved) + # For complex operations, this may compute many intermediate values + output = expensive_computation(input, *args) + + # Intermediate activations are garbage collected + # Saves memory! + return output + + @staticmethod + def backward(ctx, grad_output): + # Retrieve inputs + input, *args = ctx.saved_tensors + + # Recompute forward pass to get intermediate values + # This time, track gradients + with torch.enable_grad(): + # Detach input, then set requires_grad + # (Required for computing gradients) + input = input.detach().requires_grad_(True) + + # Recompute forward + output = expensive_computation(input, *args) + + # Now compute gradients using autograd + grad_input, = torch.autograd.grad( + outputs=output, + inputs=input, + grad_outputs=grad_output, + retain_graph=False + ) + + # Return gradient for input (and None for args if they're parameters) + return (grad_input,) + (None,) * len(args) + +# Example: Checkpointed Sequential Layers +class CheckpointedSequential(Function): + @staticmethod + def forward(ctx, input, *layers): + """ + Forward through multiple layers without saving intermediate activations. + + Normal: Saves n-1 activations for n layers + Checkpointed: Saves only input and parameters + """ + ctx.layers = layers + ctx.save_for_backward(input) + + # Forward through all layers + output = input + for layer in layers: + output = layer(output) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + layers = ctx.layers + + # Recompute forward to get intermediate activations + with torch.enable_grad(): + input = input.detach().requires_grad_(True) + + # Forward pass again, this time tracking activations + activations = [input] + output = input + for layer in layers: + output = layer(output) + activations.append(output) + + # Backward through layers + grad = grad_output + for i in reversed(range(len(layers))): + # Compute gradient through layer i + grad, = torch.autograd.grad( + outputs=activations[i+1], + inputs=activations[i], + grad_outputs=grad, + retain_graph=True + ) + + grad_input = grad + + # No gradients for layers (they're not inputs to forward) + return (grad_input,) + (None,) * len(layers) + +# PyTorch provides torch.utils.checkpoint.checkpoint for this +# But understanding the pattern helps for custom cases +``` + +### Pattern 2: Selective Saving + +Only save what's needed for backward; recompute or omit the rest. + +```python +class SelectiveSaving(Function): + """ + Save only essential tensors; recompute others in backward. + """ + + @staticmethod + def forward(ctx, input, weight, bias): + # Compute intermediate values + weighted = input @ weight + activated = torch.relu(weighted) + output = activated + bias + + # DON'T save everything: + # ctx.save_for_backward(input, weight, bias, weighted, activated) + # ❌ Saves 5 tensors + + # ✅ SAVE ONLY WHAT'S NEEDED: + # Can recompute 'weighted' from input and weight + # Can recompute 'activated' from weighted + ctx.save_for_backward(input, weight, bias, activated) + # ✅ Saves 4 tensors (or even fewer) + + # Or even more selective: + # ctx.save_for_backward(input, weight, bias) + # ctx.save_for_backward(activated > 0) # Save mask, not full tensor + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias, activated = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + # Gradient through bias addition + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + # Gradient through ReLU (need activation mask) + grad_activated = grad_output.clone() + grad_activated[activated <= 0] = 0 + + # Gradient through matmul + if ctx.needs_input_grad[0]: + grad_input = grad_activated @ weight.t() + if ctx.needs_input_grad[1]: + grad_weight = input.t() @ grad_activated + + return grad_input, grad_weight, grad_bias + +# Even more selective: Save boolean mask instead of full tensor +class UltraSelective(Function): + @staticmethod + def forward(ctx, input, weight): + output = torch.relu(input @ weight) + + # Instead of saving full 'output' tensor: + # ctx.save_for_backward(input, weight, output) # Large memory + + # ✅ Save only boolean mask (1 bit per element vs 32 bits) + ctx.save_for_backward(input, weight) + ctx.relu_mask = (output > 0) # Boolean tensor (much smaller) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + relu_mask = ctx.relu_mask + + # Use mask to apply ReLU gradient + grad_weighted = grad_output * relu_mask.float() + + grad_input = grad_weighted @ weight.t() if ctx.needs_input_grad[0] else None + grad_weight = input.t() @ grad_weighted if ctx.needs_input_grad[1] else None + + return grad_input, grad_weight +``` + +### Pattern 3: Detaching Tensors That Don't Need Gradients + +Detach tensors that are used in forward but don't need gradients. + +```python +class DetachPattern(Function): + """ + Detach tensors that don't need gradient computation. + """ + + @staticmethod + def forward(ctx, input, target, weight): + """ + Compute loss between input and target. + Target doesn't need gradients (it's labels). + """ + # Save input and weight (need gradients) + ctx.save_for_backward(input, weight) + + # Detach target (doesn't need gradients) + # This breaks the autograd connection, saving memory + ctx.target = target.detach() + + # Compute weighted loss + loss = ((input - target) ** 2 * weight).mean() + return loss + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + target = ctx.target # Already detached + + # Compute gradients + diff = input - target + + grad_input = None + grad_weight = None + + if ctx.needs_input_grad[0]: + grad_input = 2 * diff * weight * grad_output + + # No gradient for target (ctx.needs_input_grad[1] is False) + + if ctx.needs_input_grad[2]: + grad_weight = (diff ** 2) * grad_output + + return grad_input, None, grad_weight + +# When to detach +def detach_guidelines(): + """ + Detach tensors when: + 1. They're labels/targets (no gradient needed) + 2. They're constants (no gradient needed) + 3. They're from non-differentiable sources + 4. You explicitly don't want gradients to flow through them + + Don't detach when: + 1. Gradient is needed for that tensor + 2. Gradient will flow through that path + """ + pass +``` + +## Advanced Patterns + +### Pattern 1: Double Backward (Second-Order Derivatives) + +Support computing gradients of gradients. + +```python +class DoubleBackwardFunction(Function): + """ + Function that supports double backward for second-order derivatives. + + Example: Hessian computation, meta-learning, some regularization terms. + """ + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + # Forward: y = x^2 + return input ** 2 + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + + # First derivative: dy/dx = 2x + grad_input = 2 * input * grad_output + + # For double backward, need to return tensor that supports backward + # Not just detached scalar + return grad_input + + # For explicit double backward support (optional, autograd often handles it) + @staticmethod + def jvp(ctx, *grad_inputs): + """ + Jacobian-vector product for forward-mode AD. + Needed for some second-order derivative computations. + """ + # Usually not needed; autograd handles it + pass + +# Test double backward +def test_double_backward(): + """Test that double backward works.""" + + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + + # First forward + y = DoubleBackwardFunction.apply(x) + + # First backward: dy/dx + grad_x, = torch.autograd.grad(y.sum(), x, create_graph=True) + # grad_x = 2x = [2, 4, 6] + + # Second backward: d(dy/dx)/dx = d(2x)/dx = 2 + grad_grad_x, = torch.autograd.grad(grad_x.sum(), x) + # grad_grad_x = [2, 2, 2] + + print(f"First derivative: {grad_x}") # [2, 4, 6] + print(f"Second derivative: {grad_grad_x}") # [2, 2, 2] + + assert torch.allclose(grad_x, 2 * x) + assert torch.allclose(grad_grad_x, torch.ones_like(x) * 2) + print("✅ Double backward works!") + +test_double_backward() + +# Example: Function where double backward matters +class HessianVectorProduct(Function): + """ + Efficiently compute Hessian-vector product: H @ v + where H = ∇²f(x) is the Hessian matrix. + + Used in: second-order optimization, meta-learning. + """ + + @staticmethod + def forward(ctx, input, vector): + ctx.save_for_backward(input, vector) + # Placeholder forward (actual computation in backward) + return input + + @staticmethod + def backward(ctx, grad_output): + input, vector = ctx.saved_tensors + + # Compute gradient + grad_input = grad_output + + # This gradient can be backpropagated through again + # for Hessian computation + return grad_input, None + +# Using double backward for Hessian +def compute_hessian(f, x): + """ + Compute Hessian of scalar function f at point x. + + Uses double backward: H[i,j] = ∂²f/∂x_i∂x_j + """ + # First derivative + grad_x, = torch.autograd.grad(f(x), x, create_graph=True) + + # Second derivative (Hessian) + hessian = [] + for i in range(x.shape[0]): + grad_grad_x, = torch.autograd.grad( + grad_x[i], x, retain_graph=True + ) + hessian.append(grad_grad_x) + + return torch.stack(hessian) +``` + +### Pattern 2: Custom Backward Hooks + +Modify gradients during backward pass without changing the function. + +```python +def gradient_clipping_hook(grad): + """ + Hook to clip gradients to [-1, 1]. + Applied to tensor, not Function. + """ + return torch.clamp(grad, -1, 1) + +def gradient_noise_hook(grad): + """Add noise to gradients (for regularization).""" + noise = torch.randn_like(grad) * 0.01 + return grad + noise + +def gradient_logging_hook(grad): + """Log gradient statistics.""" + print(f"Gradient: mean={grad.mean():.6f}, std={grad.std():.6f}, max={grad.abs().max():.6f}") + return grad # Return unchanged + +# Using hooks +def use_hooks(): + """Example of using gradient hooks.""" + + input = torch.randn(10, 10, requires_grad=True) + weight = torch.randn(10, 10, requires_grad=True) + + # Register hooks + input.register_hook(gradient_clipping_hook) + weight.register_hook(gradient_logging_hook) + + # Forward and backward + output = MyFunction.apply(input, weight) + loss = output.sum() + loss.backward() + + # Hooks are applied during backward + # input.grad is clipped to [-1, 1] + # weight.grad statistics are logged + +# Hooks in custom functions +class FunctionWithHook(Function): + @staticmethod + def forward(ctx, input, clip_grad=False): + ctx.clip_grad = clip_grad + ctx.save_for_backward(input) + return input * 2 + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + + grad_input = grad_output * 2 + + # Apply custom modification based on context + if ctx.clip_grad: + grad_input = torch.clamp(grad_input, -1, 1) + + return grad_input, None + +# Removing hooks +def manage_hooks(): + """Add and remove hooks.""" + + tensor = torch.randn(5, requires_grad=True) + + # Add hook (returns handle) + hook_handle = tensor.register_hook(gradient_clipping_hook) + + # Use tensor + loss = (tensor ** 2).sum() + loss.backward() + # Hook is applied + + # Remove hook + hook_handle.remove() + + # Hook no longer applied in subsequent backwards + tensor.grad.zero_() + loss = (tensor ** 2).sum() + loss.backward() + # Hook NOT applied +``` + +### Pattern 3: Custom Gradient for Part of Computation + +Stop gradients or customize them for specific operations. + +```python +class StopGradient(Function): + """ + Stop gradient flow (like tf.stop_gradient or tensor.detach()). + Forward: pass through + Backward: return None (no gradient) + """ + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + # Don't pass gradient through + return None + +# Usage: Stop gradient flow +def stop_gradient_example(): + x = torch.randn(5, requires_grad=True) + + # y doesn't get gradients from z + y = StopGradient.apply(x) + z = y ** 2 + z.sum().backward() + + assert x.grad is None # No gradient flowed to x + +class StraightThroughEstimator(Function): + """ + Straight-through estimator for non-differentiable operations. + + Forward: non-differentiable operation (e.g., binarization) + Backward: pretend it was identity (pass gradient through) + + Used for: quantization, binarization, discrete operations. + """ + + @staticmethod + def forward(ctx, input): + # Non-differentiable: binarize to {-1, 1} + return torch.sign(input) + + @staticmethod + def backward(ctx, grad_output): + # Pretend forward was identity: gradient passes through unchanged + return grad_output + +# Usage: Train networks with binary weights +def straight_through_example(): + """ + Train a network with binary weights using straight-through estimator. + """ + weight = torch.randn(10, 10, requires_grad=True) + input = torch.randn(32, 10) + + # Binarize weight for forward pass + binary_weight = StraightThroughEstimator.apply(weight) + # binary_weight ∈ {-1, 1} + + # Use binary weight + output = input @ binary_weight + loss = output.sum() + loss.backward() + + # weight.grad exists (even though sign() isn't differentiable) + # Gradient passed through as if sign() was identity + assert weight.grad is not None + +class CustomGradientScale(Function): + """ + Scale gradient by a factor without changing forward. + + Forward: pass through + Backward: scale gradient by alpha + + Used for: gradient reversal layers (adversarial training), + controlling gradient flow in different branches. + """ + + @staticmethod + def forward(ctx, input, alpha): + ctx.alpha = alpha + return input + + @staticmethod + def backward(ctx, grad_output): + # Scale gradient + return grad_output * ctx.alpha, None + +# Usage: Gradient reversal layer +def gradient_reversal_layer(input, alpha=-1.0): + """ + Reverses gradients (multiplies by -1). + Used in domain adaptation for adversarial training. + """ + return CustomGradientScale.apply(input, alpha) +``` + +## Complete Real-World Examples + +### Example 1: Custom Swish Activation with Learnable Beta + +```python +import torch +from torch.autograd import Function, gradcheck +import torch.nn as nn + +class SwishFunction(Function): + """ + Custom Swish activation: f(x) = x * sigmoid(β * x) + where β is a learnable parameter. + + Forward: y = x * σ(βx) + Backward: dy/dx = σ(βx) + x * σ(βx) * (1 - σ(βx)) * β + dy/dβ = x² * σ(βx) * (1 - σ(βx)) + """ + + @staticmethod + def forward(ctx, input, beta): + """ + Args: + input: Input tensor + beta: Learnable scaling parameter (scalar tensor) + Returns: + output: Swish activation output + """ + ctx.save_for_backward(input, beta) + + # Compute sigmoid(beta * input) + sigmoid_beta_input = torch.sigmoid(beta * input) + + # Save for backward (more efficient than recomputing) + ctx.save_for_backward(input, beta, sigmoid_beta_input) + + # f(x) = x * sigmoid(beta * x) + output = input * sigmoid_beta_input + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Compute gradients using chain rule. + """ + input, beta, sigmoid_beta_input = ctx.saved_tensors + + grad_input = grad_beta = None + + # Gradient w.r.t. input + if ctx.needs_input_grad[0]: + # d/dx[x * σ(βx)] = σ(βx) + x * σ'(βx) * β + # where σ'(z) = σ(z) * (1 - σ(z)) + sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input) + grad_input = grad_output * ( + sigmoid_beta_input + input * sigmoid_derivative * beta + ) + + # Gradient w.r.t. beta + if ctx.needs_input_grad[1]: + # d/dβ[x * σ(βx)] = x * σ'(βx) * x = x² * σ(βx) * (1 - σ(βx)) + sigmoid_derivative = sigmoid_beta_input * (1 - sigmoid_beta_input) + grad_beta = grad_output * (input ** 2) * sigmoid_derivative + # Sum over all elements (beta is scalar) + grad_beta = grad_beta.sum() + + return grad_input, grad_beta + +class Swish(nn.Module): + """ + Swish activation module with learnable beta parameter. + """ + + def __init__(self, beta=1.0): + super().__init__() + self.beta = nn.Parameter(torch.tensor(beta)) + + def forward(self, input): + return SwishFunction.apply(input, self.beta) + +# Test with gradcheck +def test_swish(): + """Verify Swish implementation with gradcheck.""" + print("Testing Swish activation...") + + # Test 1: Basic gradcheck + input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) + beta = torch.tensor(1.0, dtype=torch.double, requires_grad=True) + + assert gradcheck( + SwishFunction.apply, + (input, beta), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "Swish gradcheck failed" + print("✅ Basic gradcheck passed") + + # Test 2: Different beta values + for beta_val in [0.5, 1.0, 2.0]: + beta = torch.tensor(beta_val, dtype=torch.double, requires_grad=True) + assert gradcheck( + SwishFunction.apply, + (input, beta), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), f"Swish gradcheck failed for beta={beta_val}" + print("✅ Multiple beta values passed") + + # Test 3: Use in module + module = Swish(beta=1.0) + input_single = torch.randn(32, 128, requires_grad=True) + output = module(input_single) + loss = output.sum() + loss.backward() + + assert input_single.grad is not None + assert module.beta.grad is not None + print("✅ Module usage works") + + print("\n🎉 Swish activation fully tested!") + +test_swish() + +# Usage example +def use_swish_in_model(): + """Use Swish in a neural network.""" + + model = nn.Sequential( + nn.Linear(784, 256), + Swish(beta=1.0), # Learnable beta + nn.Linear(256, 128), + Swish(beta=1.0), + nn.Linear(128, 10) + ) + + # Train model... + # Beta parameters will be learned along with weights +``` + +### Example 2: Numerically Stable LogSumExp + +```python +class LogSumExp(Function): + """ + Numerically stable log-sum-exp operation. + + Forward: log(sum(exp(x_i))) + Uses max trick: log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x)))) + + Backward: softmax(x_i) + """ + + @staticmethod + def forward(ctx, input, dim): + """ + Args: + input: Input tensor + dim: Dimension to sum over + Returns: + logsumexp: log(sum(exp(input))) along dim + """ + # Max trick for numerical stability + max_val, _ = input.max(dim=dim, keepdim=True) + input_shifted = input - max_val + + # Compute log-sum-exp + sumexp = torch.exp(input_shifted).sum(dim=dim, keepdim=True) + logsumexp = torch.log(sumexp) + max_val + + # Save softmax for backward + softmax = torch.exp(input_shifted) / sumexp + ctx.save_for_backward(softmax) + ctx.dim = dim + + return logsumexp.squeeze(dim) + + @staticmethod + def backward(ctx, grad_output): + """ + Gradient of log-sum-exp is softmax. + """ + softmax, = ctx.saved_tensors + + # Expand grad_output to match softmax shape + grad_output_expanded = grad_output.unsqueeze(ctx.dim) + + # Gradient: softmax * grad_output + grad_input = softmax * grad_output_expanded + + return grad_input, None + +# Test +def test_logsumexp(): + """Test LogSumExp implementation.""" + print("Testing LogSumExp...") + + input = torch.randn(10, 5, dtype=torch.double, requires_grad=True) + dim = 1 + + # gradcheck + assert gradcheck( + lambda x: LogSumExp.apply(x, dim), + (input,), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "LogSumExp gradcheck failed" + print("✅ LogSumExp gradcheck passed") + + # Compare with PyTorch's implementation + custom_result = LogSumExp.apply(input, dim) + torch_result = torch.logsumexp(input, dim=dim) + + assert torch.allclose(custom_result, torch_result, atol=1e-5), \ + "LogSumExp doesn't match PyTorch" + print("✅ LogSumExp matches PyTorch implementation") + + print("\n🎉 LogSumExp fully tested!") + +test_logsumexp() +``` + +### Example 3: Fused Linear + ReLU (Memory Efficient) + +```python +class FusedLinearReLU(Function): + """ + Fused linear + ReLU operation. + Saves memory by not materializing intermediate activations. + + Forward: ReLU(X @ W + b) + Memory: Only saves mask (boolean) and input/weights, not intermediate + """ + + @staticmethod + def forward(ctx, input, weight, bias): + """ + Args: + input: (batch, in_features) + weight: (out_features, in_features) + bias: (out_features,) + Returns: + output: (batch, out_features) + """ + # Compute linear + linear_output = input.mm(weight.t()) + if bias is not None: + linear_output += bias + + # Apply ReLU and save mask (not full tensor!) + output = torch.relu(linear_output) + relu_mask = (linear_output > 0) # Boolean (1 bit per element) + + # Save only input, weight, bias, and mask + # NOT saving linear_output (saves memory) + ctx.save_for_backward(input, weight, bias) + ctx.relu_mask = relu_mask + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward through ReLU and linear. + """ + input, weight, bias = ctx.saved_tensors + relu_mask = ctx.relu_mask + + grad_input = grad_weight = grad_bias = None + + # Gradient through ReLU (use mask) + grad_linear = grad_output * relu_mask.float() + + # Gradient through linear + if ctx.needs_input_grad[0]: + grad_input = grad_linear.mm(weight) + + if ctx.needs_input_grad[1]: + grad_weight = grad_linear.t().mm(input) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_linear.sum(0) + + return grad_input, grad_weight, grad_bias + +# Test +def test_fused_linear_relu(): + """Test fused operation.""" + print("Testing FusedLinearReLU...") + + input = torch.randn(20, 10, dtype=torch.double, requires_grad=True) + weight = torch.randn(15, 10, dtype=torch.double, requires_grad=True) + bias = torch.randn(15, dtype=torch.double, requires_grad=True) + + # gradcheck + assert gradcheck( + FusedLinearReLU.apply, + (input, weight, bias), + eps=1e-6, + atol=1e-4, + raise_exception=True + ), "FusedLinearReLU gradcheck failed" + print("✅ FusedLinearReLU gradcheck passed") + + # Compare with separate operations + fused_output = FusedLinearReLU.apply(input, weight, bias) + + separate_output = torch.relu(input.mm(weight.t()) + bias) + + assert torch.allclose(fused_output, separate_output, atol=1e-5), \ + "Fused output doesn't match separate operations" + print("✅ Fused matches separate operations") + + print("\n🎉 FusedLinearReLU fully tested!") + +test_fused_linear_relu() +``` + +## Debugging Custom Functions + +### Systematic Debugging Workflow + +When your custom function has bugs, follow this workflow: + +```python +def debug_custom_function(): + """ + Step-by-step debugging of custom autograd functions. + """ + + print("=== DEBUGGING CUSTOM AUTOGRAD FUNCTION ===\n") + + # Step 1: Verify forward pass works + print("Step 1: Testing forward pass...") + try: + input = torch.randn(5, 3, dtype=torch.double) + weight = torch.randn(4, 3, dtype=torch.double) + + output = MyCustomFunction.apply(input, weight) + + print(f"✅ Forward pass works") + print(f" Input shape: {input.shape}") + print(f" Weight shape: {weight.shape}") + print(f" Output shape: {output.shape}") + print(f" Output contains NaN: {torch.isnan(output).any()}") + print(f" Output contains Inf: {torch.isinf(output).any()}") + + # Check expected shape + expected_shape = (5, 4) # Based on your operation + assert output.shape == expected_shape, \ + f"Wrong output shape: {output.shape} != {expected_shape}" + print(f" Output shape correct: {expected_shape}") + + except Exception as e: + print(f"❌ Forward pass failed: {e}") + return + + print() + + # Step 2: Verify backward runs without error + print("Step 2: Testing backward pass...") + try: + input = torch.randn(5, 3, dtype=torch.double, requires_grad=True) + weight = torch.randn(4, 3, dtype=torch.double, requires_grad=True) + + output = MyCustomFunction.apply(input, weight) + loss = output.sum() + loss.backward() + + print(f"✅ Backward pass runs") + print(f" Input grad shape: {input.grad.shape}") + print(f" Weight grad shape: {weight.grad.shape}") + + # Check gradient shapes + assert input.grad.shape == input.shape, \ + f"Input gradient shape mismatch: {input.grad.shape} != {input.shape}" + assert weight.grad.shape == weight.shape, \ + f"Weight gradient shape mismatch: {weight.grad.shape} != {weight.shape}" + print(f" Gradient shapes correct") + + # Check for NaN/Inf + assert not torch.isnan(input.grad).any(), "Input gradient contains NaN" + assert not torch.isnan(weight.grad).any(), "Weight gradient contains NaN" + assert not torch.isinf(input.grad).any(), "Input gradient contains Inf" + assert not torch.isinf(weight.grad).any(), "Weight gradient contains Inf" + print(f" Gradients are finite") + + except Exception as e: + print(f"❌ Backward pass failed: {e}") + return + + print() + + # Step 3: Check gradient magnitudes + print("Step 3: Checking gradient magnitudes...") + print(f" Input grad - mean: {input.grad.mean():.6f}, std: {input.grad.std():.6f}, max: {input.grad.abs().max():.6f}") + print(f" Weight grad - mean: {weight.grad.mean():.6f}, std: {weight.grad.std():.6f}, max: {weight.grad.abs().max():.6f}") + + # Reasonable gradient magnitudes (problem-dependent) + if input.grad.abs().max() > 1e6: + print(f" ⚠️ Input gradient very large (may indicate bug)") + if input.grad.abs().max() < 1e-6: + print(f" ⚠️ Input gradient very small (may indicate vanishing gradient)") + + print() + + # Step 4: Manual numerical gradient check for one element + print("Step 4: Manual numerical gradient check...") + + input_test = torch.randn(3, 2, dtype=torch.double, requires_grad=True) + weight_test = torch.randn(2, 2, dtype=torch.double, requires_grad=True) + + # Analytical gradient + output = MyCustomFunction.apply(input_test, weight_test) + loss = output.sum() + loss.backward() + analytical_grad = weight_test.grad[0, 0].item() + + # Numerical gradient + eps = 1e-6 + + weight_plus = weight_test.clone().detach() + weight_plus[0, 0] += eps + output_plus = MyCustomFunction.apply(input_test, weight_plus) + loss_plus = output_plus.sum() + + weight_minus = weight_test.clone().detach() + weight_minus[0, 0] -= eps + output_minus = MyCustomFunction.apply(input_test, weight_minus) + loss_minus = output_minus.sum() + + numerical_grad = ((loss_plus - loss_minus) / (2 * eps)).item() + + diff = abs(analytical_grad - numerical_grad) + print(f" Analytical gradient: {analytical_grad:.10f}") + print(f" Numerical gradient: {numerical_grad:.10f}") + print(f" Absolute difference: {diff:.10e}") + + if diff < 1e-4: + print(f" ✅ Small difference - gradient likely correct") + else: + print(f" ❌ Large difference - gradient likely WRONG") + + print() + + # Step 5: Full gradcheck + print("Step 5: Running full gradcheck...") + try: + input_check = torch.randn(10, 5, dtype=torch.double, requires_grad=True) + weight_check = torch.randn(8, 5, dtype=torch.double, requires_grad=True) + + result = gradcheck( + MyCustomFunction.apply, + (input_check, weight_check), + eps=1e-6, + atol=1e-4, + raise_exception=True + ) + + print(f" ✅ gradcheck PASSED!") + print(f"\n🎉 All checks passed! Function is correct.") + + except RuntimeError as e: + print(f" ❌ gradcheck FAILED") + print(f" Error: {e}") + print(f"\n Debug hints:") + print(f" - Check gradient computation formulas") + print(f" - Verify all transposes are correct") + print(f" - Ensure shapes match everywhere") + print(f" - Check for in-place operations") + print(f" - Verify saved tensors are correct") + +# Run debugging +debug_custom_function() +``` + +### Common Error Messages and Solutions + +```python +""" +ERROR 1: "one of the variables needed for gradient computation has been modified by an inplace operation" + +CAUSE: In-place operation in forward or backward + +SOLUTION: Replace in-place ops with non-in-place versions + - Replace: tensor[mask] = value + - With: tensor = tensor.clone(); tensor[mask] = value + - Or use: tensor * mask instead of masking + + +ERROR 2: "grad can be implicitly created only for scalar outputs" + +CAUSE: Calling .backward() on non-scalar tensor without grad_output + +SOLUTION: Either sum to scalar or provide grad_output + - loss = output.sum(); loss.backward() + - Or: output.backward(torch.ones_like(output)) + + +ERROR 3: "Expected to get X gradient(s) for backward, but got Y" + +CAUSE: backward() returns wrong number of gradients + +SOLUTION: Return one gradient per forward() argument (excluding ctx) + - forward(ctx, a, b, c) → backward must return (grad_a, grad_b, grad_c) + - Use None for arguments that don't need gradients + + +ERROR 4: "Sizes of tensors must match except in dimension X" + +CAUSE: Shape mismatch in gradient computation + +SOLUTION: Ensure grad_input.shape == input.shape + - Print shapes to debug: print(f"grad shape: {grad.shape}, input shape: {input.shape}") + - Handle broadcasting by summing over broadcasted dimensions + + +ERROR 5: "RuntimeError: Function returned an invalid gradient at index X - got ... but expected shape ..." + +CAUSE: Gradient shape doesn't match input shape + +SOLUTION: Verify gradient shape matches input exactly + - assert grad_input.shape == input.shape + - Use .view(), .reshape(), .expand() to fix shape + + +ERROR 6: "gradcheck failed" + +CAUSE: Analytical gradient ≠ numerical gradient + +SOLUTION: Debug gradient computation + - Check math formulas (derivatives) + - Verify transposes in matrix operations + - Test manually with small tensors + - Run debug_custom_function() above + + +ERROR 7: "AttributeError: 'Context' object has no attribute 'saved_tensors'" + +CAUSE: Accessing saved_tensors in forward (only available in backward) + +SOLUTION: Only access saved_tensors in backward() + - forward: ctx.save_for_backward(...) + - backward: ctx.saved_tensors + + +ERROR 8: "TypeError: save_for_backward() takes 1 positional argument but X were given" + +CAUSE: Passing non-tensors to save_for_backward() + +SOLUTION: Only save tensors with save_for_backward(), use attributes for others + - ctx.save_for_backward(tensor1, tensor2) # Tensors only + - ctx.some_value = non_tensor_data # Non-tensors as attributes +""" +``` + +## When NOT to Use Custom Functions + +Recognize when custom functions are unnecessary: + +```python +# DON'T: Wrapping simple PyTorch operations +class UnnecessaryAdd(Function): # ❌ Pointless + @staticmethod + def forward(ctx, a, b): + return a + b + @staticmethod + def backward(ctx, grad): + return grad, grad + +# DO: Use PyTorch directly +output = a + b # ✅ Autograd handles this + + +# DON'T: Reimplementing existing activations +class UnnecessaryReLU(Function): # ❌ Use torch.relu + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return input.clamp(min=0) + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output * (input > 0).float() + +# DO: Use built-in +output = torch.relu(input) # ✅ Optimized C++ implementation + + +# DON'T: Operations that compose from standard ops +class UnnecessaryGELU(Function): # ❌ Can compose from existing ops + @staticmethod + def forward(ctx, input): + # GELU = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # This is composed entirely from standard ops + ctx.save_for_backward(input) + return 0.5 * input * (1 + torch.tanh(...)) + @staticmethod + def backward(ctx, grad_output): + # Complex gradient computation + ... + +# DO: Let autograd handle composition +def gelu(input): # ✅ Autograd computes gradients automatically + return 0.5 * input * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (input + 0.044715 * input ** 3) + )) + +# Or even better: use built-in +output = torch.nn.functional.gelu(input) # ✅ Most efficient + + +# WHEN TO USE: True custom operations +class ClippedGradientLinear(Function): # ✅ Custom gradient behavior + """ + Linear operation with clipped gradients. + Can't compose from standard ops - requires custom backward. + """ + @staticmethod + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + return input @ weight.t() + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + + # Custom: clip gradients (can't do with composition) + grad_input = torch.clamp(grad_output @ weight, -1, 1) + grad_weight = grad_output.t() @ input + + return grad_input, grad_weight +``` + +## gradcheck Time Investment + +**Reality check on "no time for gradcheck"**: + +When under deadline pressure, this rationalization is powerful but wrong: + +```python +# Time cost analysis +gradcheck_time = "< 1 second" # Typical operation +debugging_time = "hours to days" # Finding gradient bugs without gradcheck + +# ROI calculation +time_investment = 1 # second +time_saved = 3600 * 24 # potentially days +roi = time_saved / time_investment # 86,400x return! +``` + +**Under deadline pressure, correct workflow**: +1. Implement Function correctly (5-10 minutes) +2. Run gradcheck (1 second) +3. Deploy with confidence + +**NOT**: +1. Skip verification (saves 1 second) +2. Deploy immediately +3. Spend days debugging gradient bugs in production (costs hours/days) + +**Time spent on gradcheck**: <1 second (negligible) +**Time saved from catching bugs early**: hours to days (enormous) + +**The math is clear**: Always run gradcheck, even under extreme time pressure. + +## Multiple Custom Functions Workflow + +**When implementing multiple functions**: + +✅ **CORRECT: One-at-a-time** +1. Implement Function 1 +2. Test Function 1 with gradcheck +3. Verify Function 1 passes +4. Move to Function 2 +5. Repeat for each function + +❌ **WRONG: Batch approach** +1. Implement all functions +2. Test all together +3. Debug mess of overlapping bugs +4. Waste hours figuring out which function is broken +5. Fix bugs one at a time anyway (should have started here) + +**Why one-at-a-time wins**: +- Bugs caught immediately (when you wrote the code, easy to debug) +- Know each function works before building on it +- Same total time, but bugs isolated (no "which function broke?" confusion) +- Build confidence incrementally +- Can use earlier functions while implementing later ones + +**Example**: +```python +# Implementing 5 activations + +# ✅ CORRECT +SwishBeta() → test → ✅ → Mish() → test → ✅ → GELU() → test → ✅ ... +# Each bug found immediately after writing that function + +# ❌ WRONG +SwishBeta() + Mish() + GELU() + ELU() + SELU() → test all +# Bug found in one, but which? Have to debug all to find it +# 5x the debugging complexity +``` + +## Handling Approximate Gradients + +**Special case**: External libraries with approximate gradients (finite differences, Monte Carlo estimates). + +When wrapping external code with approximate gradients: + +### Workflow for Approximate Gradients + +```python +import torch +from torch.autograd import Function + +class ApproximateGradientWrapper(Function): + """ + Wrap external library with approximate gradients. + + Key insights: + 1. Standard gradcheck WILL fail (approximate ≠ analytical) + 2. But can still verify wrapper implementation + 3. Must quantify gradient quality + 4. Must assess if quality is acceptable + """ + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return external_library_forward(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + # Library provides approximate gradient + grad_input = external_library_gradient(input, grad_output) + return grad_input + +# Can't run standard gradcheck, but CAN verify: + +def test_approximate_gradient_wrapper(): + """Verification workflow for approximate gradients.""" + + # 1. Verify wrapper mechanics (forward/backward run) + input = torch.randn(5, requires_grad=True) + output = ApproximateGradientWrapper.apply(input) + output.sum().backward() + assert input.grad is not None + print("✅ Wrapper mechanics work") + + # 2. Quantify gradient quality (compare with numerical) + input_test = torch.randn(3, requires_grad=True) + output = ApproximateGradientWrapper.apply(input_test) + output.sum().backward() + library_gradient = input_test.grad.clone() + + # Compute our own numerical gradient + eps = 1e-6 + numerical_gradient = compute_numerical_gradient(input_test, eps) + + # Measure error + error = (library_gradient - numerical_gradient).abs().max() + print(f"Gradient error: {error:.6e}") + + # 3. Assess acceptability + if error < 1e-3: + print("✅ Approximate gradient high quality (good enough)") + elif error < 1e-2: + print("⚠️ Approximate gradient has noticeable error (may affect optimization)") + else: + print("❌ Approximate gradient poor quality (likely to cause issues)") + raise ValueError("Gradient quality unacceptable") + + # 4. Document in code + print("\n📝 Documentation:") + print(f" - Gradients are approximate (error: {error:.6e})") + print(" - Standard gradcheck will fail (expected)") + print(" - Wrapper implementation verified correct") + print(" - Gradient quality measured and acceptable") + + return error + +# Run verification +gradient_error = test_approximate_gradient_wrapper() +``` + +### Key Points for Approximate Gradients + +1. **Standard gradcheck will fail** - This is expected (approximate ≠ analytical) +2. **Test wrapper implementation** - Verify mechanics (forward/backward run) +3. **Quantify gradient quality** - Measure error vs numerical gradient +4. **Assess acceptability** - Is error tolerable for your use case? +5. **Document limitations** - Record gradient error magnitude in code/docs + +**Don't skip verification** - Adapt it to approximate case. + +**"Good enough" requires evidence** - Measure error, don't assume. + +## Rationalization Resistance + +| Rationalization | Reality | Counter-Response | +|----------------|---------|------------------| +| "Autograd will figure it out" | Only for standard ops; custom ops need Function | Use Function for non-standard operations. Autograd needs explicit backward implementation. | +| "Gradient looks mathematically correct" | Implementation bugs invisible without testing | Always run gradcheck. Math correctness ≠ implementation correctness. | +| "gradcheck is slow, skip for speed" | Catches bugs early; debugging later costs more | gradcheck is fast (<1s). Finding gradient bugs without it takes hours/days. | +| "No time for gradcheck, deadline NOW" | gradcheck takes <1s, debugging takes hours | 1 second now saves hours of production debugging. Always run gradcheck. | +| "Too complex for me" | Pattern is standardized; template works | Follow template. Thousands of successful implementations exist. | +| "In-place is more efficient" | Breaks autograd graph; causes crashes | Never use in-place in custom functions. Memory savings negligible, bugs catastrophic. | +| "Shape will probably work out" | Must match exactly; no flexibility | Gradient shape must equal input shape exactly. Verify with assertions. | +| "ctx details don't matter much" | Incorrect usage breaks everything | ctx.save_for_backward() is mandatory for tensors. Attributes break memory tracking. | +| "My manual test is good enough" | Misses edge cases gradcheck catches | Manual tests catch obvious bugs. gradcheck catches subtle numerical errors. | +| "Batch test all functions together" | Overlapping bugs hard to debug | Test one at a time. Bugs isolated immediately, same total time. | +| "Don't need needs_input_grad check" | Wastes computation; slower training | Always check needs_input_grad. Free optimization, no downside. | +| "Approximate gradients, can't verify" | Can verify wrapper and measure quality | Adapt verification. Test wrapper mechanics, quantify error, assess acceptability. | +| "Second-order derivatives too advanced" | Same pattern as first-order | Not advanced. Same template, test with gradgradcheck. Accessible to all. | +| "Can skip double backward support" | Breaks higher-order derivatives if needed | If you might need Hessian/meta-learning, support double backward from start. | +| "Detach doesn't matter here" | Controls gradient flow; critical | Understand when to detach. Impacts what gets gradients. | +| "I'll verify gradients during training" | Training metrics hide gradient bugs | Verify before training. Gradient bugs cause subtle issues (slow convergence, wrong behavior). | +| "Test in production, faster iteration" | Production debugging catastrophic | Test before deployment. Production gradient bugs cause model failure. | + +## Red Flags Checklist + +**Stop and verify if you see:** + +1. ⚠️ **Not using torch.autograd.Function** for custom operations + - Writing normal functions for truly custom ops (external code, novel math) + +2. ⚠️ **No gradcheck before using** the function + - Skipping numerical verification + - "Testing during training" instead of proper gradcheck + +3. ⚠️ **In-place operations** in forward or backward + - `tensor[mask] = value`, `tensor += other`, `tensor.mul_(other)` + - Modifying `grad_output` in backward + +4. ⚠️ **Wrong number of return values** from backward + - Not returning gradient for each forward() input + - Missing None for non-tensor arguments + +5. ⚠️ **Not using ctx.save_for_backward()** for tensors + - Saving tensors as `ctx.tensor = tensor` instead + - Saving non-tensors with save_for_backward() + +6. ⚠️ **Gradient shape doesn't match** input shape + - Not verifying `grad_input.shape == input.shape` + - Forgetting to expand/sum for broadcasting + +7. ⚠️ **Missing needs_input_grad checks** + - Always computing all gradients even when not needed + - Not checking `ctx.needs_input_grad[i]` before expensive computation + +8. ⚠️ **Using .data instead of .detach()** + - Accessing `.data` attribute + - Not understanding difference between .data and .detach() + +9. ⚠️ **Accessing ctx.saved_tensors in forward** + - Trying to use saved tensors before backward + - Not understanding ctx lifecycle + +10. ⚠️ **Modifying saved tensors** in backward + - In-place operations on tensors from ctx.saved_tensors + - Breaking gradient graph + +11. ⚠️ **Ignoring gradcheck failures** + - "Gradient close enough" + - Not investigating why gradcheck failed + +12. ⚠️ **Using custom Function for standard operations** + - Reimplementing built-in operations unnecessarily + - Not checking if PyTorch already provides it + +13. ⚠️ **Batching implementation without incremental testing** + - Implementing multiple functions before testing any + - "Will test them all together" approach + +14. ⚠️ **Skipping gradcheck under time pressure** + - "Deadline is tight, verify later" + - "No time for gradcheck" + +15. ⚠️ **Assuming approximate gradients can't be verified** + - "Library provides gradients, can't test" + - Not measuring gradient quality + +16. ⚠️ **Avoiding second-order derivatives due to perceived complexity** + - "Too advanced for me" + - Not attempting gradgradcheck + +17. ⚠️ **Deploying to production without verification** + - "Test with real data" + - Skipping numerical verification + +## Summary Checklist + +Before deploying a custom autograd function: + +- [ ] Verified custom Function is actually needed (can't compose from standard ops) +- [ ] Implemented forward() correctly (saves needed tensors/data) +- [ ] Used ctx.save_for_backward() for ALL tensors +- [ ] Saved non-tensor data as ctx attributes +- [ ] Implemented backward() with correct gradient formulas +- [ ] Verified gradient shape matches input shape exactly +- [ ] Returned gradient for EVERY forward() input (None if not needed) +- [ ] Added needs_input_grad checks for expensive computations +- [ ] Avoided ALL in-place operations +- [ ] **Ran gradcheck and verified it PASSES** +- [ ] Tested with different input shapes +- [ ] Tested with optional parameters (if any) +- [ ] Verified no NaN/Inf in outputs or gradients +- [ ] Checked gradient magnitudes are reasonable +- [ ] Tested double backward if needed +- [ ] Added documentation explaining when to use this function +- [ ] Considered memory efficiency (detach, selective saving) + +## Final Notes + +**Custom autograd functions are powerful but require precision**: + +1. Use them when truly needed (custom ops, external code, memory optimization, custom gradients) +2. Follow the template pattern (don't reinvent) +3. **Always run gradcheck** (non-negotiable) +4. Understand ctx rules (save_for_backward for tensors, attributes for non-tensors) +5. Verify shapes (gradient must match input) +6. Avoid in-place operations (they break autograd) +7. Check needs_input_grad (optimization) +8. Test thoroughly before using in training + +**The gradient implementation is as important as the forward computation.** Bugs in backward() are silent and catastrophic - they cause models to learn wrong things. gradcheck is your safety net; never skip it. + +When in doubt: implement, run gradcheck, debug until it passes, then use with confidence. diff --git a/skills/using-pytorch-engineering/debugging-techniques.md b/skills/using-pytorch-engineering/debugging-techniques.md new file mode 100644 index 0000000..983c551 --- /dev/null +++ b/skills/using-pytorch-engineering/debugging-techniques.md @@ -0,0 +1,1803 @@ + +# Systematic PyTorch Debugging + +## Overview + +**Core Principle:** Debugging without methodology is guessing. Debug systematically (reproduce → gather info → form hypothesis → test → fix → verify) using PyTorch-specific tools to identify root causes, not symptoms. Random changes waste time; systematic investigation finds bugs efficiently. + +Bugs stem from: shape mismatches (dimension errors), device placement (CPU/GPU), dtype incompatibilities (float/int), autograd issues (in-place ops, gradient flow), memory problems (leaks, OOM), or numerical instability (NaN/Inf). Error messages and symptoms reveal the category. Reading error messages carefully and using appropriate debugging tools (detect_anomaly, hooks, assertions) leads to fast resolution. Guessing leads to hours of trial-and-error while the real issue remains. + +## When to Use + +**Use this skill when:** +- Getting error messages (RuntimeError, shape mismatch, device error, etc.) +- Model not learning (loss constant, not decreasing) +- NaN or Inf appearing in loss or gradients +- Intermittent errors (works sometimes, fails others) +- Memory issues (OOM, leaks, growing memory usage) +- Silent failures (no error but wrong output) +- Autograd errors (in-place operations, gradient computation) + +**Don't use when:** +- Performance optimization (use performance-profiling) +- Architecture design questions (use module-design-patterns) +- Distributed training issues (use distributed-training-strategies) +- Mixed precision configuration (use mixed-precision-and-optimization) + +**Symptoms triggering this skill:** +- "Getting this error, can you help fix it?" +- "Model not learning, loss stays constant" +- "Works on CPU but fails on GPU" +- "NaN loss after several epochs" +- "Error happens randomly" +- "Backward pass failing but forward pass works" +- "Memory keeps growing during training" + + +## Systematic Debugging Methodology + +### The Five-Phase Framework + +**Phase 1: Reproduce Reliably** +- Fix random seeds for determinism +- Minimize code to smallest reproduction case +- Isolate problematic component +- Document reproduction steps + +**Phase 2: Gather Information** +- Read FULL error message (every word, especially shapes/values) +- Read complete stack trace +- Add strategic assertions +- Use PyTorch debugging tools + +**Phase 3: Form Hypothesis** +- Based on error pattern, what could cause this? +- Predict what investigation will reveal +- Make hypothesis specific and testable + +**Phase 4: Test Hypothesis** +- Add targeted debugging code +- Verify or reject hypothesis with evidence +- Iterate until root cause identified + +**Phase 5: Fix and Verify** +- Implement minimal fix addressing root cause (not symptom) +- Verify error gone AND functionality correct +- Explain why fix works + +**Critical Rule:** NEVER skip Phase 3. Random changes without hypothesis waste time. Form hypothesis, test it, iterate. + + +### Phase 1: Reproduce Reliably + +**Step 1: Make Error Deterministic** + +```python +# Fix all sources of randomness +import torch +import numpy as np +import random + +def set_seed(seed=42): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +set_seed(42) + +# Now error should happen consistently (if it's reproducible) +``` + +**For Intermittent Errors:** +```python +# Identify which batch/iteration causes failure +for i, batch in enumerate(dataloader): + try: + output = model(batch) + loss = criterion(output, target) + loss.backward() + except RuntimeError as e: + print(f"Error at batch {i}") + print(f"Batch data stats: min={batch.min()}, max={batch.max()}, shape={batch.shape}") + torch.save(batch, f'failing_batch_{i}.pt') # Save for investigation + raise + +# Load specific failing batch to reproduce +failing_batch = torch.load('failing_batch_X.pt') +# Now can debug deterministically +``` + +**Why this matters:** +- Can't debug intermittent errors effectively +- Reproducibility enables systematic investigation +- Fixed seeds expose data-dependent issues +- Saved failing cases allow focused debugging + + +**Step 2: Minimize Reproduction** + +```python +# Full training script (too complex to debug) +# ❌ DON'T DEBUG HERE +for epoch in range(100): + for batch in train_loader: + # Complex data preprocessing + # Model forward pass + # Loss computation with multiple components + # Backward pass + # Optimizer with custom scheduling + # Logging, checkpointing, etc. + +# Minimal reproduction (isolates the issue) +# ✅ DEBUG HERE +import torch +import torch.nn as nn + +# Minimal model +model = nn.Linear(10, 5).cuda() + +# Minimal data (can be random) +x = torch.randn(2, 10).cuda() +target = torch.randint(0, 5, (2,)).cuda() + +# Minimal forward/backward +output = model(x) +loss = nn.functional.cross_entropy(output, target) +loss.backward() # Error happens here + +# This 10-line script reproduces the issue! +# Much easier to debug than full codebase +``` + +**Minimization Process:** +1. Remove data preprocessing (use random tensors) +2. Simplify model (use single layer if possible) +3. Remove optimizer, scheduler, logging +4. Use single batch, single iteration +5. Keep only code path that triggers error + +**Why this matters:** +- Easier to identify root cause in minimal code +- Can share minimal reproduction in bug reports +- Eliminates confounding factors +- Faster iteration during debugging + + +**Step 3: Isolate Component** + +```python +# Test each component independently + +# Test 1: Data loading +for batch in dataloader: + print(f"Batch shape: {batch.shape}, dtype: {batch.dtype}, device: {batch.device}") + print(f"Value range: [{batch.min():.4f}, {batch.max():.4f}]") + assert not torch.isnan(batch).any(), "NaN in data!" + assert not torch.isinf(batch).any(), "Inf in data!" + break + +# Test 2: Model forward pass +model.eval() +with torch.no_grad(): + output = model(sample_input) + print(f"Output shape: {output.shape}, range: [{output.min():.4f}, {output.max():.4f}]") + +# Test 3: Loss computation +loss = criterion(output, target) +print(f"Loss: {loss.item()}") + +# Test 4: Backward pass +loss.backward() +print("Backward pass successful") + +# Test 5: Optimizer step +optimizer.step() +print("Optimizer step successful") + +# Identify which component fails → focus debugging there +``` + +**Why this matters:** +- Quickly narrows down problematic component +- Avoids debugging entire pipeline when issue is localized +- Enables targeted investigation +- Confirms other components work correctly + + +### Phase 2: Gather Information + +**Step 1: Read Error Message Completely** + +**Example 1: Shape Mismatch** +``` +RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) +``` + +**What to extract:** +- Operation: matrix multiplication (`mat1` and `mat2`) +- Actual shapes: mat1 is 4×57600, mat2 is 64×128 +- Problem: Can't multiply because 57600 ≠ 64 (inner dimensions must match) +- Diagnostic info: 57600 suggests flattened spatial dimensions (e.g., 30×30×64) + +**Example 2: Device Mismatch** +``` +RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! +``` + +**What to extract:** +- Operation: tensor operation requiring same device +- Devices involved: cuda:0 and cpu +- Problem: Some tensors on GPU, others on CPU +- Next step: Add device checks to find which tensor is on wrong device + +**Example 3: In-Place Operation** +``` +RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 128]], which is output 0 of ReluBackward0, is at version 2; expected version 1 instead. +``` + +**What to extract:** +- Operation: in-place modification during autograd +- Affected tensor: [256, 128] from ReluBackward0 +- Version: tensor modified from version 1 to version 2 +- Problem: Tensor modified after being used in autograd graph +- Next step: Find in-place operations (`*=`, `+=`, `.relu_()`, etc.) + +**Why this matters:** +- Error messages contain critical diagnostic information +- Shapes, dtypes, devices tell you exactly what's wrong +- Stack trace shows WHERE error occurs +- Specific error patterns indicate specific fixes + + +**Step 2: Read Stack Trace** + +```python +# Example stack trace +Traceback (most recent call last): + File "train.py", line 45, in + loss.backward() + File "/pytorch/torch/autograd/__init__.py", line 123, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph) + File "/pytorch/torch/autograd/__init__.py", line 78, in backward + Variable._execution_engine.run_backward(...) +RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) + +# What to extract: +# - Error triggered by loss.backward() at line 45 +# - Problem is in backward pass (not forward pass) +# - Shape mismatch in some linear layer +# - Need to inspect model architecture and forward pass shapes +``` + +**Reading Stack Traces:** +1. Start from bottom (actual error) +2. Work upward to find YOUR code (not PyTorch internals) +3. Identify which operation triggered error +4. Note if error is in forward, backward, or optimizer step +5. Look for parameter values and tensor shapes in trace + +**Why this matters:** +- Shows execution path leading to error +- Distinguishes forward vs backward pass issues +- Reveals which layer/operation failed +- Provides context for hypothesis formation + + +**Step 3: Add Strategic Assertions** + +```python +# DON'T: Print statements everywhere +def forward(self, x): + print(f"Input: {x.shape}") + x = self.conv1(x) + print(f"After conv1: {x.shape}") + x = self.pool(x) + print(f"After pool: {x.shape}") + # ... prints for every operation + +# DO: Strategic assertions that verify understanding +def forward(self, x): + # Assert input assumptions + assert x.dim() == 4, f"Expected 4D input (B,C,H,W), got {x.dim()}D" + assert x.shape[1] == self.in_channels, \ + f"Expected {self.in_channels} input channels, got {x.shape[1]}" + + x = self.conv1(x) + # Conv2d(3, 64, 3) on 32×32 input → 30×30 output + # Assert expected shape to verify understanding + assert x.shape[2:] == (30, 30), f"Expected 30×30 after conv, got {x.shape[2:]}" + + x = x.view(x.size(0), -1) + # After flatten: batch_size × (30*30*64) = batch_size × 57600 + assert x.shape[1] == 57600, f"Expected 57600 features, got {x.shape[1]}" + + x = self.fc(x) + return x + +# If assertion fails, your understanding is wrong → update hypothesis +``` + +**When to Use Assertions vs Prints:** +- **Assertions:** Verify understanding of shapes, devices, dtypes +- **Prints:** Inspect actual values when understanding is incomplete +- **Neither:** Use hooks for non-intrusive inspection (see below) + +**Why this matters:** +- Assertions document assumptions +- Failures reveal misunderstanding +- Self-documenting code (shows expected shapes) +- No performance cost when not failing + + +**Step 4: Use PyTorch Debugging Tools** + +**Tool 1: detect_anomaly() for NaN/Inf** + +```python +# Problem: NaN loss appears, but where does it originate? + +# Without detect_anomaly: Generic error +loss.backward() # RuntimeError: Function 'MseLossBackward0' returned nan + +# With detect_anomaly: Pinpoints exact operation +with torch.autograd.set_detect_anomaly(True): + loss.backward() +# RuntimeError: Function 'DivBackward0' returned nan values in its 0th output. +# [Stack trace shows: loss = output / (std + eps), where std became 0] +# Now we know: division by zero when std=0, need to increase eps + +# Use case 1: Find where NaN first appears +torch.autograd.set_detect_anomaly(True) # Enable globally +for batch in dataloader: + output = model(batch) + loss = criterion(output, target) + loss.backward() # Will error at exact operation producing NaN +torch.autograd.set_detect_anomaly(False) # Disable after debugging + +# Use case 2: Narrow down to specific forward pass +suspicious_batch = get_failing_batch() +with torch.autograd.set_detect_anomaly(True): + output = model(suspicious_batch) + loss = criterion(output, target) + loss.backward() # Detailed stack trace if NaN occurs +``` + +**When to use detect_anomaly():** +- NaN or Inf appearing in loss or gradients +- Need to find WHICH operation produces NaN +- After identifying NaN, before fixing + +**Performance note:** detect_anomaly() is SLOW (~10x overhead). Only use during debugging, NEVER in production. + + +**Tool 2: Forward Hooks for Intermediate Inspection** + +```python +# Problem: Need to inspect intermediate outputs without modifying model code + +def debug_forward_hook(module, input, output): + """Hook function that inspects module outputs""" + module_name = module.__class__.__name__ + + # Check shapes + if isinstance(input, tuple): + input_shape = input[0].shape + else: + input_shape = input.shape + output_shape = output.shape if not isinstance(output, tuple) else output[0].shape + + print(f"{module_name:20s} | Input: {str(input_shape):20s} | Output: {str(output_shape):20s}") + + # Check for NaN/Inf + output_tensor = output if not isinstance(output, tuple) else output[0] + if torch.isnan(output_tensor).any(): + raise RuntimeError(f"NaN detected in {module_name} output!") + if torch.isinf(output_tensor).any(): + raise RuntimeError(f"Inf detected in {module_name} output!") + + # Check value ranges + print(f" → Value range: [{output_tensor.min():.4f}, {output_tensor.max():.4f}]") + print(f" → Mean: {output_tensor.mean():.4f}, Std: {output_tensor.std():.4f}") + +# Register hooks on all modules +handles = [] +for name, module in model.named_modules(): + if len(list(module.children())) == 0: # Only leaf modules + handle = module.register_forward_hook(debug_forward_hook) + handles.append(handle) + +# Run forward pass with hooks +output = model(sample_input) + +# Remove hooks when done +for handle in handles: + handle.remove() + +# Output shows: +# Linear | Input: torch.Size([4, 128]) | Output: torch.Size([4, 256]) +# → Value range: [-2.3421, 3.1234] +# → Mean: 0.0234, Std: 1.0123 +# ReLU | Input: torch.Size([4, 256]) | Output: torch.Size([4, 256]) +# → Value range: [0.0000, 3.1234] +# → Mean: 0.5123, Std: 0.8234 +# RuntimeError: NaN detected in Linear output! # Found problematic layer! +``` + +**When to use forward hooks:** +- Need to inspect intermediate layer outputs +- Finding which layer produces NaN/Inf +- Checking activation ranges and statistics +- Debugging without modifying model code +- Monitoring multiple layers simultaneously + +**Alternative: Selective hooks for specific modules** +```python +# Only hook suspicious layers +suspicious_layers = [model.layer3, model.final_fc] +for layer in suspicious_layers: + layer.register_forward_hook(debug_forward_hook) +``` + + +**Tool 3: Backward Hooks for Gradient Inspection** + +```python +# Problem: Gradients exploding, vanishing, or becoming NaN + +def debug_grad_hook(grad): + """Hook function for gradient inspection""" + if grad is None: + print("WARNING: Gradient is None!") + return None + + # Statistics + grad_norm = grad.norm().item() + grad_mean = grad.mean().item() + grad_std = grad.std().item() + grad_min = grad.min().item() + grad_max = grad.max().item() + + print(f"Gradient stats:") + print(f" Shape: {grad.shape}") + print(f" Norm: {grad_norm:.6f}") + print(f" Range: [{grad_min:.6f}, {grad_max:.6f}]") + print(f" Mean: {grad_mean:.6f}, Std: {grad_std:.6f}") + + # Check for issues + if grad_norm > 100: + print(f" ⚠️ WARNING: Large gradient norm ({grad_norm:.2f})") + if grad_norm < 1e-7: + print(f" ⚠️ WARNING: Vanishing gradient ({grad_norm:.2e})") + if torch.isnan(grad).any(): + raise RuntimeError("NaN gradient detected!") + if torch.isinf(grad).any(): + raise RuntimeError("Inf gradient detected!") + + return grad # Must return gradient (can return modified version) + +# Register hooks on specific parameters +for name, param in model.named_parameters(): + if 'weight' in name: # Only monitor weights, not biases + param.register_hook(lambda grad, name=name: debug_grad_hook(grad)) + +# Or register on intermediate tensors +x = model.encoder(input) +x.register_hook(debug_grad_hook) # Will show gradient flowing to encoder output +y = model.decoder(x) + +# Run backward +loss = criterion(y, target) +loss.backward() # Hooks will fire and print gradient stats +``` + +**When to use backward hooks:** +- Gradients exploding or vanishing +- NaN appearing in backward pass +- Checking gradient flow through network +- Monitoring specific parameter gradients +- Implementing custom gradient clipping or modification + +**Gradient Inspection Without Hooks:** +```python +# After backward pass, inspect gradients directly +loss.backward() + +for name, param in model.named_parameters(): + if param.grad is not None: + grad_norm = param.grad.norm() + print(f"{name:40s} | Grad norm: {grad_norm:.6f}") + if grad_norm > 100: + print(f" ⚠️ Large gradient in {name}") + else: + print(f"{name:40s} | ⚠️ No gradient!") +``` + + +**Tool 4: gradcheck for Numerical Verification** + +```python +# Problem: Implementing custom autograd function, need to verify correctness + +from torch.autograd import gradcheck + +class MyCustomFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return input.clamp(min=0) # Custom ReLU + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_input = grad_output.clone() + grad_input[input < 0] = 0 + return grad_input + +# Verify backward is correct using numerical gradients +input = torch.randn(10, 10, dtype=torch.double, requires_grad=True) +test = gradcheck(MyCustomFunction.apply, input, eps=1e-6, atol=1e-4) +print(f"Gradient check passed: {test}") # True if backward is correct + +# Use double precision for numerical stability +# If gradcheck fails, backward implementation is wrong +``` + +**When to use gradcheck:** +- Implementing custom autograd functions +- Verifying backward pass correctness +- Debugging gradient computation issues +- Before deploying custom CUDA kernels with autograd + + +### Phase 3: Form Hypothesis + +**Hypothesis Formation Framework** + +```python +# Template for hypothesis formation: +# +# OBSERVATION: [What did you observe from error/symptoms?] +# PATTERN: [Does this match a known error pattern?] +# HYPOTHESIS: [What could cause this observation?] +# PREDICTION: [What will investigation reveal if hypothesis is correct?] +# TEST: [How to verify or reject hypothesis?] + +# Example 1: Shape Mismatch +# OBSERVATION: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) +# PATTERN: Linear layer input mismatch (57600 != 64) +# HYPOTHESIS: Conv output flattened incorrectly - expecting 64 features but getting 57600 +# PREDICTION: Conv output shape is probably (4, 64, 30, 30) → flatten → 57600 +# TEST: Print conv output shape before flatten, verify it's 30×30×64=57600 + +# Example 2: Model Not Learning +# OBSERVATION: Loss constant at 2.30 for 10 classes = log(10) +# PATTERN: Model outputting uniform random predictions +# HYPOTHESIS: Optimizer not updating weights (missing optimizer.step() or learning_rate=0) +# PREDICTION: Weights identical between epochs, gradients computed but not applied +# TEST: Check if weights change after training, verify optimizer.step() is called + +# Example 3: NaN Loss +# OBSERVATION: Loss becomes NaN at epoch 6, was decreasing before +# PATTERN: Numerical instability after several updates +# HYPOTHESIS: Gradients exploding due to high learning rate +# PREDICTION: Gradient norms increasing over epochs, spike before NaN +# TEST: Monitor gradient norms each epoch, check if they grow exponentially +``` + +**Common PyTorch Error Patterns → Hypotheses** + +| Error Pattern | Likely Cause | Hypothesis to Test | +|--------------|--------------|-------------------| +| `mat1 and mat2 shapes cannot be multiplied (AxB and CxD)` | Linear layer input mismatch | B ≠ C; check actual input dimension vs expected | +| `Expected all tensors to be on the same device` | Device placement issue | Some tensor on CPU, others on GPU; add device checks | +| `modified by an inplace operation` | In-place op in autograd graph | Find `*=`, `+=`, `.relu_()`, etc.; use out-of-place versions | +| `index X is out of bounds for dimension Y with size Z` | Invalid index access | Index >= size; check data preprocessing, embedding indices | +| `device-side assert triggered` | Out-of-bounds index (GPU) | Embedding indices >= vocab_size or < 0; inspect data | +| Loss constant at log(num_classes) | Model not learning | Missing optimizer.step() or zero learning rate | +| NaN after N epochs | Gradient explosion | Learning rate too high or numerical instability | +| NaN in specific operation | Division by zero or log(0) | Check denominators and log inputs for zeros | +| OOM during backward | Activation memory too large | Batch size too large or missing gradient checkpointing | +| Memory growing over iterations | Memory leak | Accumulating tensors with computation graph | + +**Why this matters:** +- Hypothesis guides investigation (not random) +- Prediction makes hypothesis testable +- Pattern recognition speeds up debugging +- Systematic approach finds root cause faster + + +### Phase 4: Test Hypothesis + +**Testing Strategies** + +**Strategy 1: Binary Search / Bisection** + +```python +# Problem: Complex model, don't know which component causes error + +# Test 1: Disable second half of model +class ModelUnderTest(nn.Module): + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + return x + # x = self.layer3(x) # Commented out + # x = self.layer4(x) + # return x + +# If error disappears: issue is in layer3 or layer4 +# If error persists: issue is in layer1 or layer2 + +# Test 2: Narrow down further +class ModelUnderTest(nn.Module): + def forward(self, x): + x = self.layer1(x) + return x + # x = self.layer2(x) + # return x + +# Continue bisecting until isolated to specific layer +``` + +**Strategy 2: Differential Debugging** + +```python +# Compare working vs broken versions + +# Working version (simple) +def forward_simple(self, x): + x = self.conv(x) + x = x.view(x.size(0), -1) + return self.fc(x) + +# Broken version (complex) +def forward_complex(self, x): + x = self.conv(x) + x = x.transpose(1, 2) # Additional operation + x = x.reshape(x.size(0), -1) + return self.fc(x) + +# Test both with same input +x = torch.randn(4, 3, 32, 32) +print("Simple:", forward_simple(x).shape) # Works +print("Complex:", forward_complex(x).shape) # Errors + +# Hypothesis: transpose causing shape issue +# Test: Remove transpose and use reshape +def forward_test(self, x): + x = self.conv(x) + # x = x.transpose(1, 2) # Removed + x = x.reshape(x.size(0), -1) + return self.fc(x) + +# If works: transpose was the issue +``` + +**Strategy 3: Synthetic Data Testing** + +```python +# Problem: Error occurs with real data, need to isolate cause + +# Test 1: Random data with correct shape/dtype/device +x_random = torch.randn(4, 3, 32, 32).cuda() +y_random = torch.randint(0, 10, (4,)).cuda() +output = model(x_random) +loss = criterion(output, y_random) +loss.backward() +# If works: issue is in data, not model + +# Test 2: Real data with known properties +x_real = next(iter(dataloader)) +print(f"Data stats: shape={x_real.shape}, dtype={x_real.dtype}, device={x_real.device}") +print(f"Value range: [{x_real.min():.4f}, {x_real.max():.4f}]") +print(f"NaN count: {torch.isnan(x_real).sum()}") +print(f"Inf count: {torch.isinf(x_real).sum()}") +# If NaN or Inf found: data preprocessing issue + +# Test 3: Edge cases +x_zeros = torch.zeros(4, 3, 32, 32).cuda() +x_ones = torch.ones(4, 3, 32, 32).cuda() +x_large = torch.full((4, 3, 32, 32), 1e6).cuda() +# See which edge case triggers error +``` + +**Strategy 4: Iterative Refinement** + +```python +# Hypothesis 1: Conv output shape wrong +x = torch.randn(4, 3, 32, 32) +x = model.conv1(x) +print(f"Conv output: {x.shape}") # torch.Size([4, 64, 30, 30]) +# Prediction correct! Conv output is 30×30, not 32×32 + +# Hypothesis 2: Flatten produces wrong size +x_flat = x.view(x.size(0), -1) +print(f"Flattened: {x_flat.shape}") # torch.Size([4, 57600]) +# Confirmed: 30*30*64 = 57600 + +# Hypothesis 3: Linear layer expects wrong size +print(f"FC weight shape: {model.fc.weight.shape}") # torch.Size([128, 64]) +# Found root cause: FC expects 64 inputs but gets 57600! + +# Fix: Change FC input dimension +self.fc = nn.Linear(57600, 128) # Not nn.Linear(64, 128) +# Or: Add pooling to reduce spatial dimensions before FC +``` + +**Why this matters:** +- Systematic testing verifies or rejects hypothesis +- Evidence-based iteration toward root cause +- Multiple strategies for different error types +- Avoids random trial-and-error + + +### Phase 5: Fix and Verify + +**Step 1: Implement Minimal Fix** + +```python +# ❌ BAD: Overly complex fix +def forward(self, x): + x = self.conv1(x) + # Fix shape mismatch by adding multiple transforms + x = F.adaptive_avg_pool2d(x, (1, 1)) # Global pooling + x = x.squeeze(-1).squeeze(-1) # Remove spatial dims + x = x.unsqueeze(0) # Add batch dim + x = x.reshape(x.size(0), -1) # Flatten again + x = self.fc(x) + return x +# Complex fix might introduce new bugs + +# ✅ GOOD: Minimal fix addressing root cause +def forward(self, x): + x = self.conv1(x) + x = x.view(x.size(0), -1) # Flatten: (B, 64, 30, 30) → (B, 57600) + x = self.fc(x) # fc now expects 57600 inputs + return x + +# In __init__: +self.fc = nn.Linear(57600, 128) # Changed from Linear(64, 128) +``` + +**Principles of Good Fixes:** +1. **Minimal:** Change only what's necessary +2. **Targeted:** Address root cause, not symptom +3. **Clear:** Obvious why fix works +4. **Safe:** Doesn't introduce new issues + +**Examples:** + +**Problem: Missing optimizer.step()** +```python +# ❌ BAD: Increase learning rate (treats symptom) +optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + +# ✅ GOOD: Add missing optimizer.step() +for batch in dataloader: + optimizer.zero_grad() + loss = criterion(model(batch), target) + loss.backward() + optimizer.step() # Was missing! +``` + +**Problem: In-place operation breaking autograd** +```python +# ❌ BAD: Use clone() everywhere (treats symptom, adds overhead) +x = x.clone() +x *= mask +x = x.clone() +x /= scale + +# ✅ GOOD: Use out-of-place operations +x = x * mask # Not x *= mask +x = x / scale # Not x /= scale +``` + +**Problem: Device mismatch** +```python +# ❌ BAD: Move tensor every forward pass (inefficient) +def forward(self, x): + pos_enc = self.positional_encoding[:x.size(1)].to(x.device) + x = x + pos_enc + +# ✅ GOOD: Fix initialization so buffer is on correct device +def __init__(self): + super().__init__() + self.register_buffer('positional_encoding', None) + +def _init_buffers(self): + device = next(self.parameters()).device + self.positional_encoding = torch.randn(1000, 100, device=device) +``` + + +**Step 2: Verify Fix Completely** + +```python +# Verification checklist: +# 1. Error disappeared? ✓ +# 2. Model produces correct output? ✓ +# 3. Training converges? ✓ +# 4. No new errors introduced? ✓ + +# Verification code: +# 1. Run single iteration without error +model = FixedModel() +x = torch.randn(4, 3, 32, 32).cuda() +y = torch.randint(0, 10, (4,)).cuda() + +output = model(x) +print(f"✓ Forward pass: {output.shape}") # Should be [4, 10] + +loss = criterion(output, y) +print(f"✓ Loss computation: {loss.item():.4f}") + +loss.backward() +print(f"✓ Backward pass successful") + +optimizer.step() +print(f"✓ Optimizer step successful") + +# 2. Verify output makes sense +assert output.shape == (4, 10), "Wrong output shape!" +assert not torch.isnan(output).any(), "NaN in output!" +assert not torch.isinf(output).any(), "Inf in output!" + +# 3. Verify model can train (loss decreases) +initial_loss = None +for i in range(10): + output = model(x) + loss = criterion(output, y) + if i == 0: + initial_loss = loss.item() + loss.backward() + optimizer.step() + optimizer.zero_grad() + +final_loss = loss.item() +assert final_loss < initial_loss, "Loss not decreasing - model not learning!" +print(f"✓ Training works: loss {initial_loss:.4f} → {final_loss:.4f}") + +# 4. Test on real data +for batch in dataloader: + output = model(batch) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + print(f"✓ Batch processed successfully") + break +``` + +**Why verification matters:** +- Confirms fix addresses root cause +- Ensures no new bugs introduced +- Validates model works correctly, not just "no error" +- Provides confidence before moving to full training + + +**Step 3: Explain Why Fix Works** + +```python +# Document understanding for future reference + +# Problem: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) +# +# Root Cause: +# Conv2d(3, 64, kernel_size=3) on 32×32 input produces 30×30 output (no padding) +# Spatial dimensions: 32 - 3 + 1 = 30 +# After flatten: 30 × 30 × 64 = 57600 features +# But Linear layer initialized with Linear(64, 128), expecting only 64 features +# Mismatch: 57600 (actual) vs 64 (expected) +# +# Fix: +# Changed Linear(64, 128) to Linear(57600, 128) +# Now expects correct number of input features +# +# Why it works: +# Linear layer input dimension must match flattened conv output dimension +# 30×30×64 = 57600, so fc1 must have in_features=57600 +# +# Alternative fixes: +# 1. Add pooling: F.adaptive_avg_pool2d(x, (1, 1)) → 64 features +# 2. Change conv padding: Conv2d(3, 64, 3, padding=1) → 32×32 output → 65536 features +# 3. Add another conv layer to reduce spatial dimensions +``` + +**Why explanation matters:** +- Solidifies understanding +- Helps recognize similar issues in future +- Documents decision for team members +- Prevents cargo cult fixes (copying code without understanding) + + +## Common PyTorch Error Patterns and Solutions + +### Shape Mismatches + +**Pattern 1: Linear Layer Input Mismatch** + +```python +# Error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (BxM and NxK) +# Cause: M ≠ N, linear layer input dimension doesn't match actual input + +# Example: +self.fc = nn.Linear(128, 10) # Expects 128 features +x = torch.randn(4, 256) # Actual has 256 features +output = self.fc(x) # ERROR: 256 ≠ 128 + +# Solution 1: Fix linear layer input dimension +self.fc = nn.Linear(256, 10) # Match actual input size + +# Solution 2: Transform input to expected size +x = some_projection(x) # Project 256 → 128 +output = self.fc(x) + +# Debugging: +# - Print x.shape before linear layer +# - Check linear layer weight shape: fc.weight.shape is [out_features, in_features] +# - Calculate expected input size from previous layers +``` + +**Pattern 2: Convolution Spatial Dimension Mismatch** + +```python +# Error: RuntimeError: Expected 4D tensor, got 3D +# Cause: Missing batch dimension or wrong number of dimensions + +# Example 1: Missing batch dimension +x = torch.randn(3, 32, 32) # (C, H, W) - missing batch dim +output = conv(x) # ERROR: expects (B, C, H, W) + +# Solution: Add batch dimension +x = x.unsqueeze(0) # (1, 3, 32, 32) +output = conv(x) + +# Example 2: Flattened when shouldn't be +x = torch.randn(4, 3, 32, 32) # (B, C, H, W) +x = x.view(x.size(0), -1) # Flattened to (4, 3072) +output = conv(x) # ERROR: expects 4D, got 2D + +# Solution: Don't flatten before convolution +# Only flatten after all convolutions, before linear layers +``` + +**Pattern 3: Broadcasting Incompatibility** + +```python +# Error: RuntimeError: The size of tensor a (X) must match the size of tensor b (Y) +# Cause: Shapes incompatible for element-wise operation + +# Example: +a = torch.randn(4, 128, 32) # (B, C, L) +b = torch.randn(4, 64, 32) # (B, C', L) +c = a + b # ERROR: 128 ≠ 64 in dimension 1 + +# Solution: Match dimensions (project, pad, or slice) +b_projected = linear(b.transpose(1,2)).transpose(1,2) # 64 → 128 +c = a + b_projected + +# Debugging: +# - Print shapes of both operands +# - Check which dimension mismatches +# - Determine correct way to align dimensions +``` + + +### Device Mismatches + +**Pattern 4: CPU/GPU Device Mismatch** + +```python +# Error: RuntimeError: Expected all tensors to be on the same device +# Cause: Some tensors on CPU, others on GPU + +# Example 1: Forgot to move input to GPU +model = model.cuda() +x = torch.randn(4, 3, 32, 32) # On CPU +output = model(x) # ERROR: model on GPU, input on CPU + +# Solution: Move input to same device as model +x = x.cuda() # Or x = x.to(next(model.parameters()).device) +output = model(x) + +# Example 2: Buffer not moved with model +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 3) + self.register_buffer('scale', torch.tensor(0.5)) # On CPU initially + + def forward(self, x): + x = self.conv(x) + return x * self.scale # ERROR if model.cuda() was called + +# Solution: Buffers should auto-move, but if not: +def forward(self, x): + return x * self.scale.to(x.device) + +# Or ensure proper initialization order: +model = Model() +model = model.cuda() # This should move all parameters and buffers + +# Debugging: +# - Print device of each tensor: print(f"x device: {x.device}") +# - Check model device: print(f"Model device: {next(model.parameters()).device}") +# - Verify buffers moved: for name, buf in model.named_buffers(): print(name, buf.device) +``` + +**Pattern 5: Device-Side Assert (Index Out of Bounds)** + +```python +# Error: RuntimeError: CUDA error: device-side assert triggered +# Cause: Usually index out of bounds in CUDA operations (like embedding lookup) + +# Example: +vocab_size = 10000 +embedding = nn.Embedding(vocab_size, 128).cuda() +indices = torch.randint(0, 10001, (4, 50)).cuda() # Max index is 10000 (out of bounds!) +output = embedding(indices) # ERROR: device-side assert + +# Debug by moving to CPU (clearer error): +embedding_cpu = nn.Embedding(vocab_size, 128) +indices_cpu = torch.randint(0, 10001, (4, 50)) +output = embedding_cpu(indices_cpu) +# IndexError: index 10000 is out of bounds for dimension 0 with size 10000 + +# Solution: Ensure indices in valid range +assert indices.min() >= 0, f"Negative indices found: {indices.min()}" +assert indices.max() < vocab_size, f"Index {indices.max()} >= vocab_size {vocab_size}" + +# Or clip indices: +indices = indices.clamp(0, vocab_size - 1) + +# Root cause: Usually data preprocessing issue +# Check tokenization, dataset __getitem__, etc. +``` + + +### Autograd Errors + +**Pattern 6: In-Place Operation Breaking Autograd** + +```python +# Error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# Cause: Tensor modified in-place after being used in autograd graph + +# Example 1: In-place arithmetic +x = torch.randn(10, requires_grad=True) +y = x * 2 +x += 1 # ERROR: x modified in-place but needed for y's gradient +loss = y.sum() +loss.backward() + +# Solution: Use out-of-place operation +x = torch.randn(10, requires_grad=True) +y = x * 2 +x = x + 1 # Out-of-place: creates new tensor +loss = y.sum() +loss.backward() + +# Example 2: In-place activation +def forward(self, x): + x = self.layer1(x) + x = x.relu_() # In-place ReLU (has underscore) + x = self.layer2(x) + return x + +# Solution: Use out-of-place activation +def forward(self, x): + x = self.layer1(x) + x = torch.relu(x) # Or F.relu(x), or x.relu() without underscore + x = self.layer2(x) + return x + +# Common in-place operations to avoid: +# - x += y, x *= y, x[...] = y +# - x.add_(), x.mul_(), x.relu_() +# - x.transpose_(), x.resize_() +``` + +**Pattern 7: No Gradient for Parameter** + +```python +# Problem: Parameter not updating during training + +# Debugging: +for name, param in model.named_parameters(): + if param.grad is None: + print(f"⚠️ No gradient for {name}") + else: + print(f"✓ {name}: grad norm = {param.grad.norm():.6f}") + +# Cause 1: Parameter not used in forward pass +class Model(nn.Module): + def __init__(self): + self.used_layer = nn.Linear(10, 10) + self.unused_layer = nn.Linear(10, 10) # Never called in forward! + + def forward(self, x): + return self.used_layer(x) # unused_layer not in computation graph + +# Solution: Remove unused parameters or ensure they're used + +# Cause 2: Gradient flow interrupted by detach() +def forward(self, x): + x = self.encoder(x) + x = x.detach() # Breaks gradient flow! + x = self.decoder(x) # Encoder won't get gradients + return x + +# Solution: Don't detach unless intentional + +# Cause 3: Part of model in eval mode +model.encoder.eval() # Dropout/BatchNorm won't update in eval mode +model.decoder.train() +# Solution: Ensure correct parts are in train mode +``` + +**Pattern 8: Gradient Computed on Non-Leaf Tensor** + +```python +# Error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn +# Cause: Trying to backward from tensor that's not part of computation graph + +# Example: +x = torch.randn(10, requires_grad=True) +y = x * 2 +z = y.detach() # z not in graph anymore +loss = z.sum() +loss.backward() # ERROR: z doesn't require grad + +# Solution: Don't detach if you need gradients +z = y # Keep in graph +loss = z.sum() +loss.backward() + +# Use case for detach: When you DON'T want gradients to flow +x = torch.randn(10, requires_grad=True) +y = x * 2 +z = y.detach() # Intentionally stop gradient flow +# Use z for logging/visualization, but not for loss +``` + + +### Numerical Stability Errors + +**Pattern 9: NaN Loss from Numerical Instability** + +```python +# Problem: Loss becomes NaN during training + +# Common causes and solutions: + +# Cause 1: Learning rate too high +optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # Too high for SGD +# Solution: Reduce learning rate +optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + +# Cause 2: Gradient explosion +# Debug: Monitor gradient norms +for epoch in range(num_epochs): + for batch in dataloader: + loss.backward() + + # Check gradient norms + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + total_norm = total_norm ** 0.5 + print(f"Gradient norm: {total_norm:.4f}") + + if total_norm > 100: + print("⚠️ Exploding gradients!") + + optimizer.step() + optimizer.zero_grad() + +# Solution: Gradient clipping +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# Cause 3: Division by zero +def custom_loss(output, target): + # Computing normalized loss + norm = output.norm() + return loss / norm # ERROR if norm is 0! + +# Solution: Add epsilon +def custom_loss(output, target): + norm = output.norm() + eps = 1e-8 + return loss / (norm + eps) # Safe + +# Cause 4: Log of zero or negative +def custom_loss(pred, target): + return -torch.log(pred).mean() # ERROR if any pred ≤ 0 + +# Solution: Clamp or use numerically stable version +def custom_loss(pred, target): + return -torch.log(pred.clamp(min=1e-8)).mean() # Or use F.log_softmax + +# Use detect_anomaly to find exact operation: +with torch.autograd.set_detect_anomaly(True): + loss.backward() +``` + +**Pattern 10: Vanishing/Exploding Gradients** + +```python +# Problem: Gradients become too small (vanishing) or too large (exploding) + +# Detection: +def check_gradient_flow(model): + ave_grads = [] + max_grads = [] + layers = [] + + for n, p in model.named_parameters(): + if p.grad is not None and "bias" not in n: + layers.append(n) + ave_grads.append(p.grad.abs().mean().item()) + max_grads.append(p.grad.abs().max().item()) + + # Plot or print + for layer, ave_grad, max_grad in zip(layers, ave_grads, max_grads): + print(f"{layer:40s} | Avg: {ave_grad:.6f} | Max: {max_grad:.6f}") + + if ave_grad < 1e-6: + print(f" ⚠️ Vanishing gradient in {layer}") + if max_grad > 100: + print(f" ⚠️ Exploding gradient in {layer}") + +# Solution 1: Gradient clipping (for explosion) +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# Solution 2: Better initialization (for vanishing) +def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + +model.apply(init_weights) + +# Solution 3: Batch normalization (helps both) +class BetterModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(128, 256) + self.bn1 = nn.BatchNorm1d(256) # Normalizes activations + self.fc2 = nn.Linear(256, 10) + +# Solution 4: Residual connections (for very deep networks) +class ResBlock(nn.Module): + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.conv2(out) + out += residual # Skip connection helps gradient flow + return out +``` + + +### Memory Errors + +**Pattern 11: Memory Leak from Tensor Accumulation** + +```python +# Problem: Memory usage grows steadily over iterations + +# Cause 1: Accumulating tensors with computation graph +losses = [] +for batch in dataloader: + loss = criterion(model(batch), target) + losses.append(loss) # Keeps full computation graph! + loss.backward() + optimizer.step() + +# Solution: Detach or convert to Python scalar +losses = [] +for batch in dataloader: + loss = criterion(model(batch), target) + losses.append(loss.item()) # Python float, no graph + # Or: losses.append(loss.detach().cpu()) + loss.backward() + optimizer.step() + +# Cause 2: Not deleting large intermediate tensors +for batch in dataloader: + activations = model.get_intermediate_features(batch) # Large tensor + loss = some_loss_using_activations(activations) + loss.backward() + # activations still in memory! + +# Solution: Delete explicitly +for batch in dataloader: + activations = model.get_intermediate_features(batch) + loss = some_loss_using_activations(activations) + loss.backward() + del activations # Free memory + torch.cuda.empty_cache() # Optional: return memory to GPU + +# Cause 3: Hooks accumulating data +stored_outputs = [] +def hook(module, input, output): + stored_outputs.append(output) # Accumulates every forward pass! + +model.register_forward_hook(hook) + +# Solution: Clear list or remove hook when done +stored_outputs = [] +handle = model.register_forward_hook(hook) +# ... use hook ... +handle.remove() # Remove hook +stored_outputs.clear() # Clear accumulated data +``` + +**Pattern 12: OOM (Out of Memory) During Training** + +```python +# Error: RuntimeError: CUDA out of memory + +# Debugging: Identify what's using memory +torch.cuda.reset_peak_memory_stats() + +# Run one iteration +output = model(batch) +forward_mem = torch.cuda.max_memory_allocated() / 1e9 +print(f"After forward: {forward_mem:.2f} GB") + +loss = criterion(output, target) +loss_mem = torch.cuda.max_memory_allocated() / 1e9 +print(f"After loss: {loss_mem:.2f} GB") + +loss.backward() +backward_mem = torch.cuda.max_memory_allocated() / 1e9 +print(f"After backward: {backward_mem:.2f} GB") + +optimizer.step() +optimizer_mem = torch.cuda.max_memory_allocated() / 1e9 +print(f"After optimizer: {optimizer_mem:.2f} GB") + +# Detailed breakdown +print(torch.cuda.memory_summary()) + +# Solutions: + +# Solution 1: Reduce batch size +train_loader = DataLoader(dataset, batch_size=16) # Was 32 + +# Solution 2: Gradient accumulation (simulate larger batch) +accumulation_steps = 4 +optimizer.zero_grad() +for i, batch in enumerate(train_loader): + output = model(batch) + loss = criterion(output, target) / accumulation_steps + loss.backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# Solution 3: Gradient checkpointing (trade compute for memory) +from torch.utils.checkpoint import checkpoint + +def forward(self, x): + # Checkpoint recomputes forward during backward instead of storing + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + return x + +# Solution 4: Mixed precision (half memory for activations) +from torch.cuda.amp import autocast, GradScaler + +scaler = GradScaler() +with autocast(): + output = model(batch) + loss = criterion(output, target) +scaler.scale(loss).backward() +scaler.step(optimizer) +scaler.update() + +# Solution 5: Clear cache periodically (fragmentation) +if step % 100 == 0: + torch.cuda.empty_cache() +``` + + +### Data Loading Errors + +**Pattern 13: DataLoader Multiprocessing Deadlock** + +```python +# Problem: Training hangs after first epoch, no error message + +# Cause: Unpicklable objects in Dataset + +class BadDataset(Dataset): + def __init__(self): + self.data = load_data() + self.transform_model = nn.Linear(10, 10) # Can't pickle CUDA tensors in modules! + + def __getitem__(self, idx): + x = self.data[idx] + x = self.transform_model(torch.tensor(x)) + return x.numpy() + +# Solution: Remove PyTorch modules from Dataset +class GoodDataset(Dataset): + def __init__(self): + self.data = load_data() + # Do transforms with numpy/scipy, not PyTorch + + def __getitem__(self, idx): + x = self.data[idx] + x = some_numpy_transform(x) + return x + +# Debugging: Test with num_workers=0 +train_loader = DataLoader(dataset, num_workers=0) # No multiprocessing +# If works with num_workers=0 but hangs with num_workers>0, it's a pickling issue + +# Common unpicklable objects: +# - nn.Module in Dataset +# - CUDA tensors in Dataset +# - Lambda functions +# - Local/nested functions +# - File handles, database connections +``` + +**Pattern 14: Incorrect Data Types** + +```python +# Error: RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long + +# Cause: Using wrong dtype for indices (labels, embedding lookups) + +# Example: +labels = torch.tensor([0.0, 1.0, 2.0]) # float32 +loss = F.cross_entropy(output, labels) # ERROR: expects int64 + +# Solution: Convert to correct dtype +labels = torch.tensor([0, 1, 2]) # int64 by default +# Or: labels = labels.long() + +# Common dtype issues: +# - Labels for classification: must be int64 (Long) +# - Embedding indices: must be int64 +# - Model inputs: usually float32 +# - Masks: bool or int +``` + + +## Debugging Pitfalls (Must Avoid) + +### Pitfall 1: Random Trial-and-Error + +**❌ Bad Approach:** +```python +# Error occurs +# Try random fix 1: change learning rate +# Still error +# Try random fix 2: change batch size +# Still error +# Try random fix 3: change model architecture +# Eventually something works but don't know why +``` + +**✅ Good Approach:** +```python +# Error occurs +# Phase 1: Reproduce reliably (fix seed, minimize code) +# Phase 2: Gather information (read error, add assertions) +# Phase 3: Form hypothesis (based on error pattern) +# Phase 4: Test hypothesis (targeted debugging) +# Phase 5: Fix and verify (minimal fix, verify it works) +``` + +**Counter:** ALWAYS form hypothesis before making changes. Random changes waste time. + + +### Pitfall 2: Not Reading Full Error Message + +**❌ Bad Approach:** +```python +# Error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) +# Read: "shape error" +# Fix: Add arbitrary reshape without understanding +x = x.view(4, 64) # Will fail or corrupt data +``` + +**✅ Good Approach:** +```python +# Error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x57600 and 64x128) +# Read completely: 4×57600 trying to multiply with 64×128 +# Extract info: input is 57600 features, layer expects 64 +# Calculate: 57600 = 30*30*64, so conv output is 30×30×64 +# Fix: Change linear layer to expect 57600 inputs +self.fc = nn.Linear(57600, 128) +``` + +**Counter:** Read EVERY word of error message. Shapes, dtypes, operation names all contain diagnostic information. + + +### Pitfall 3: Print Debugging Everywhere + +**❌ Bad Approach:** +```python +def forward(self, x): + print(f"1. Input: {x.shape}") + x = self.layer1(x) + print(f"2. After layer1: {x.shape}, mean: {x.mean()}, std: {x.std()}") + x = self.relu(x) + print(f"3. After relu: {x.shape}, min: {x.min()}, max: {x.max()}") + # ... prints for every operation +``` + +**✅ Good Approach:** +```python +# Use assertions for shape verification +def forward(self, x): + assert x.shape[1] == 128, f"Expected 128 channels, got {x.shape[1]}" + x = self.layer1(x) + x = self.relu(x) + return x + +# Use hooks for selective monitoring +def debug_hook(module, input, output): + if torch.isnan(output).any(): + raise RuntimeError(f"NaN in {module.__class__.__name__}") + +for module in model.modules(): + module.register_forward_hook(debug_hook) +``` + +**Counter:** Use strategic assertions and hooks, not print statements everywhere. Prints are overwhelming and slow. + + +### Pitfall 4: Fixing Symptoms Instead of Root Causes + +**❌ Bad Approach:** +```python +# Symptom: Device mismatch error +# Fix: Move tensors everywhere +def forward(self, x): + x = x.cuda() # Force GPU + x = self.layer1(x.cuda()) # Force GPU again + x = self.layer2(x.cuda()) # And again... +``` + +**✅ Good Approach:** +```python +# Root cause: Some parameter on CPU +# Debug: Find which parameter is on CPU +for name, param in model.named_parameters(): + print(f"{name}: {param.device}") +# Found: 'positional_encoding' is on CPU + +# Fix: Ensure buffer initialized on correct device +def __init__(self): + super().__init__() + # Don't create buffer on CPU then move model + # Create buffer after model.to(device) is called +``` + +**Counter:** Always find root cause before fixing. Symptom fixes often add overhead or hide real issue. + + +### Pitfall 5: Not Verifying Fix + +**❌ Bad Approach:** +```python +# Make change +# Error disappeared +# Assume it's fixed +# Move on +``` + +**✅ Good Approach:** +```python +# Make change +# Verify error disappeared: ✓ +# Verify output correct: ✓ +# Verify model trains: ✓ +loss_before = 2.5 +# ... train for 10 steps +loss_after = 1.8 +assert loss_after < loss_before, "Model not learning!" +# Verify on real data: ✓ +``` + +**Counter:** Verify fix completely. Check that model not only runs without error but also produces correct output and trains properly. + + +### Pitfall 6: Debugging in Wrong Mode + +**❌ Bad Approach:** +```python +# Production uses mixed precision +# But debugging without it +model.eval() # Wrong mode +with torch.no_grad(): + output = model(x) +# Bug doesn't appear because dropout/batchnorm behave differently +``` + +**✅ Good Approach:** +```python +# Match debugging mode to production mode +model.train() # Same mode as production +with autocast(): # Same precision as production + output = model(x) +# Now bug appears and can be debugged +``` + +**Counter:** Debug in same mode as production (train vs eval, with/without autocast, same device). + + +### Pitfall 7: Not Minimizing Reproduction + +**❌ Bad Approach:** +```python +# Try to debug in full training script with: +# - Complex data pipeline +# - Multi-GPU distributed training +# - Custom optimizer with complex scheduling +# - Logging, checkpointing, evaluation +# Very hard to isolate issue +``` + +**✅ Good Approach:** +```python +# Minimal reproduction: +import torch +import torch.nn as nn + +model = nn.Linear(10, 5) +x = torch.randn(2, 10) +output = model(x) # 10 lines, reproduces issue +``` + +**Counter:** Always minimize reproduction. Easier to debug 10 lines than 1000 lines. + + +### Pitfall 8: Leaving Debug Code in Production + +**❌ Bad Approach:** +```python +# Leave detect_anomaly enabled (10x slowdown!) +torch.autograd.set_detect_anomaly(True) + +# Leave hooks registered (memory overhead) +for module in model.modules(): + module.register_forward_hook(debug_hook) + +# Leave verbose logging (I/O bottleneck) +print(f"Step {i}, loss {loss.item()}") # Every step! +``` + +**✅ Good Approach:** +```python +# Use environment variable or flag to control debugging +DEBUG = os.getenv('DEBUG', 'false').lower() == 'true' + +if DEBUG: + torch.autograd.set_detect_anomaly(True) + for module in model.modules(): + module.register_forward_hook(debug_hook) + +# Or remove debug code after fixing issue +``` + +**Counter:** Remove debug code after fixing (detect_anomaly, hooks, verbose logging). Or gate with environment variable. + + +## Rationalization Table + +| Rationalization | Why It's Wrong | Counter-Argument | Red Flag | +|----------------|----------------|------------------|----------| +| "Error message is clear, I know what's wrong" | Error shows symptom, not root cause | Read full error including shapes/stack trace to find root cause | Jumping to fix without reading full error | +| "User needs quick fix, no time for debugging" | Systematic debugging is FASTER than random trial-and-error | Hypothesis-driven debugging finds issue in minutes vs hours of guessing | Making changes without hypothesis | +| "This is obviously a shape error, just need to reshape" | Arbitrary reshaping corrupts data or fails | Calculate actual shapes needed, understand WHY mismatch occurs | Adding reshape without understanding | +| "Let me try changing X randomly" | Random changes without hypothesis waste time | Form testable hypothesis, verify with targeted debugging | Suggesting parameter changes without evidence | +| "I'll add prints to see what's happening" | Prints are overwhelming and lack strategy | Use assertions for verification, hooks for selective monitoring | Adding print statements everywhere | +| "Hooks are too complex for this issue" | Hooks provide targeted inspection without code modification | Hooks are MORE efficient than scattered prints, show exactly where issue is | Avoiding proper debugging tools | +| "detect_anomaly is slow, skip it" | Only used during debugging, not production | Performance doesn't matter during debugging; finding NaN source quickly saves hours | Skipping tools because of performance | +| "Error only happens sometimes, hard to debug" | Intermittent errors can be made deterministic | Fix random seed, save failing batch, reproduce reliably | Giving up on intermittent errors | +| "Just move everything to CPU to avoid CUDA errors" | Moving to CPU hides root cause, doesn't fix it | CPU error messages are clearer for diagnosis, but fix device placement, don't avoid GPU | Avoiding diagnosis by changing environment | +| "Add try/except to handle the error" | Hiding errors doesn't fix them, will fail later | Catch exception for debugging, not to hide; fix root cause | Using try/except to hide problems | +| "Model not learning, must be learning rate" | Many causes for not learning, need diagnosis | Check if optimizer.step() is called, if gradients exist, if weights update | Suggesting hyperparameter changes without diagnosis | +| "It worked in the example, so I'll copy exactly" | Copying without understanding leads to cargo cult coding | Understand WHY fix works, adapt to your specific case | Copying code without understanding | +| "Too many possible causes, I'll try all solutions" | Trying everything wastes time and obscures actual fix | Form hypothesis, test systematically, narrow down to root cause | Suggesting multiple fixes simultaneously | +| "Error in PyTorch internals, must be PyTorch bug" | 99% of errors are in user code, not PyTorch | Read stack trace to find YOUR code that triggered error | Blaming framework instead of investigating | + + +## Red Flags Checklist + +**Stop and debug systematically when you observe:** + +- ⚠️ **Making code changes without hypothesis** - Why do you think this change will help? Form hypothesis first. + +- ⚠️ **Suggesting fixes without reading full error message** - Did you extract all diagnostic information from error? + +- ⚠️ **Not checking tensor shapes/devices/dtypes for shape/device errors** - These are in error message, check them! + +- ⚠️ **Suggesting parameter changes without diagnosis** - Why would changing LR/batch size fix this specific error? + +- ⚠️ **Adding print statements without clear goal** - What specifically are you trying to learn? Use assertions/hooks instead. + +- ⚠️ **Not using detect_anomaly() when NaN appears** - This tool pinpoints exact operation, use it! + +- ⚠️ **Not checking gradients when model not learning** - Do gradients exist? Are they non-zero? Are weights updating? + +- ⚠️ **Treating symptom instead of root cause** - Adding .to(device) everywhere instead of finding WHY tensor is on wrong device? + +- ⚠️ **Not verifying fix actually solves problem** - Did you verify model works correctly, not just "no error"? + +- ⚠️ **Changing multiple things at once** - Can't isolate what worked; change one thing, verify, iterate. + +- ⚠️ **Not creating minimal reproduction for complex errors** - Debugging full codebase wastes time; minimize first. + +- ⚠️ **Skipping Phase 3 (hypothesis formation)** - Random trial-and-error without hypothesis is inefficient. + +- ⚠️ **Using try/except to hide errors** - Catch for debugging, not to hide; fix root cause. + +- ⚠️ **Not reading stack trace** - Shows WHERE error occurred and execution path. + +- ⚠️ **Assuming user's diagnosis is correct** - User might misidentify issue; verify with systematic debugging. + + +## Quick Reference: Error Pattern → Debugging Strategy + +| Error Pattern | Immediate Action | Debugging Tool | Common Root Cause | +|--------------|------------------|----------------|-------------------| +| `mat1 and mat2 shapes cannot be multiplied` | Print shapes, check linear layer dimensions | Assertions on shapes | Conv output size doesn't match linear input size | +| `Expected all tensors to be on the same device` | Print device of each tensor | Device checks | Forgot to move input/buffer to GPU | +| `modified by an inplace operation` | Search for `*=`, `+=`, `.relu_()` | Find in-place ops | Using augmented assignment in forward pass | +| `index X is out of bounds` | Check index ranges, move to CPU for clearer error | Assertions on indices | Data preprocessing producing invalid indices | +| `device-side assert triggered` | Move to CPU, check embedding indices | Index range checks | Indices >= vocab_size or negative | +| Loss constant at log(num_classes) | Check if optimizer.step() called, if weights update | Gradient inspection | Missing optimizer.step() | +| NaN after N epochs | Monitor gradient norms, use detect_anomaly() | detect_anomaly() | Gradient explosion from high learning rate | +| `Function X returned nan` | Use detect_anomaly() to pinpoint operation | detect_anomaly() | Division by zero, log(0), numerical instability | +| CUDA out of memory | Profile memory at each phase | Memory profiling | Batch size too large or accumulating tensors | +| DataLoader hangs | Test with num_workers=0 | Check picklability | nn.Module or CUDA tensor in Dataset | +| Memory growing over iterations | Check what's being accumulated | Track allocations | Storing tensors with computation graph | + + +## Summary + +**Systematic debugging methodology prevents random trial-and-error:** + +1. **Reproduce Reliably:** Fix seeds, minimize code, isolate component +2. **Gather Information:** Read full error, use PyTorch debugging tools (detect_anomaly, hooks) +3. **Form Hypothesis:** Based on error pattern, predict what investigation will reveal +4. **Test Hypothesis:** Targeted debugging, verify or reject systematically +5. **Fix and Verify:** Minimal fix addressing root cause, verify completely + +**PyTorch-specific tools save hours:** +- `torch.autograd.set_detect_anomaly(True)` - pinpoints NaN source +- Forward hooks - inspect intermediate outputs non-intrusively +- Backward hooks - monitor gradient flow and statistics +- Strategic assertions - verify understanding of shapes/devices/dtypes + +**Common error patterns have known solutions:** +- Shape mismatches → calculate actual shapes, match layer dimensions +- Device errors → add device checks, fix initialization +- In-place ops → use out-of-place versions (`x = x + y` not `x += y`) +- NaN loss → detect_anomaly(), gradient clipping, reduce LR +- Memory issues → profile memory, detach from graph, reduce batch size + +**Pitfalls to avoid:** +- Random changes without hypothesis +- Not reading full error message +- Print debugging without strategy +- Fixing symptoms instead of root causes +- Not verifying fix works correctly +- Debugging in wrong mode +- Leaving debug code in production + +**Remember:** Debugging is systematic investigation, not random guessing. Form hypothesis, test it, iterate. PyTorch provides excellent debugging tools - use them! diff --git a/skills/using-pytorch-engineering/distributed-training-strategies.md b/skills/using-pytorch-engineering/distributed-training-strategies.md new file mode 100644 index 0000000..f8bb8a8 --- /dev/null +++ b/skills/using-pytorch-engineering/distributed-training-strategies.md @@ -0,0 +1,1848 @@ + +# Distributed Training Strategies + +## Overview + +**Core Principle:** DistributedDataParallel (DDP) is PyTorch's recommended approach for multi-GPU and multi-node training. Success requires understanding process-device mapping, gradient synchronization mechanics, and communication patterns. Setup mistakes cause silent errors; synchronization bugs cause divergence; poor configuration wastes GPUs. + +Distributed training failures manifest as: device placement errors, inconsistent results across runs, poor scaling efficiency, or mysterious divergence. These stem from misunderstanding DDP's process model, buffer synchronization, or communication overhead. Systematic setup and debugging beats trial and error. + +## When to Use + +**Use this skill when:** +- Setting up DistributedDataParallel for multi-GPU training +- Debugging "Expected all tensors to be on same device" errors +- Training produces inconsistent results with DDP +- Getting poor scaling efficiency (4x speedup on 8 GPUs) +- Setting up multi-node training +- Debugging gradient synchronization issues +- Need to optimize distributed training throughput +- Choosing between DataParallel and DistributedDataParallel + +**Don't use when:** +- Single GPU training (no distribution needed) +- Model architecture design (use neural-architectures) +- General training convergence issues (use training-optimization) +- Memory issues unrelated to distribution (use tensor-operations-and-memory) + +**Symptoms triggering this skill:** +- "RuntimeError: Expected all tensors to be on the same device" +- "DDP training gives different results than single GPU" +- "Multi-node training is unstable or diverges" +- "Only getting 3x speedup on 8 GPUs" +- "Batch norm statistics seem wrong in DDP" +- "find_unused_parameters causing issues" +- "Need to set up multi-node training" + + +## DDP vs DataParallel: The Critical Distinction + +**Never use nn.DataParallel for new code. Always use DistributedDataParallel.** + +### Why DataParallel is Obsolete + +```python +# ❌ OBSOLETE: nn.DataParallel (single-process multi-threading) +model = nn.DataParallel(model).cuda() + +# Problems: +# - Python GIL limits parallelism +# - Unbalanced GPU load (GPU 0 overloaded) +# - Slow gradient synchronization +# - Memory overhead on GPU 0 +# - 2-3x slower than DDP +``` + +### Why DistributedDataParallel is Standard + +```python +# ✅ STANDARD: DistributedDataParallel (multi-process) +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# One process per GPU, true parallelism +dist.init_process_group(backend='nccl') +model = DDP(model, device_ids=[local_rank]) + +# Benefits: +# - No GIL limitation (separate processes) +# - Balanced GPU utilization +# - Efficient NCCL gradient allreduce +# - Better scaling (8 GPUs: ~7x speedup) +# - Multi-node ready +``` + +### Quick Comparison + +| Feature | DataParallel | DistributedDataParallel | +|---------|--------------|------------------------| +| Paradigm | Single-process, multi-thread | Multi-process | +| GIL Impact | Severe | None | +| Scaling | Poor (2-3x on 8 GPUs) | Good (7-8x on 8 GPUs) | +| Multi-node | No | Yes | +| Setup Complexity | Low | Medium | +| GPU 0 Overhead | High | None | +| Recommendation | ❌ Deprecated | ✅ Use this | + +**Rule:** If you see `nn.DataParallel`, replace with DDP. + + +## DDP Setup: The Correct Way + +### Setup Checklist (Follow in Order) + +**Step 1: Environment Variables (Set Before Launch)** + +```bash +# Single-node, multi-GPU (using torchrun) +torchrun --nproc_per_node=4 train.py + +# Multi-node (on each node) +# Node 0 (master): +torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \ + --master_addr="192.168.1.1" --master_port=29500 train.py + +# Node 1 (worker): +torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \ + --master_addr="192.168.1.1" --master_port=29500 train.py +``` + +**Key environment variables (set automatically by torchrun):** +- `RANK`: Global process rank (0 to world_size-1) +- `LOCAL_RANK`: Process rank within node (0 to nproc_per_node-1) +- `WORLD_SIZE`: Total number of processes +- `MASTER_ADDR`: Address of rank 0 process +- `MASTER_PORT`: Port for communication + + +**Step 2: Initialize Process Group (At Training Start)** + +```python +import torch +import torch.distributed as dist +import os + +def setup_distributed(): + """Initialize process group for DDP.""" + # ✅ Get local rank from environment + local_rank = int(os.environ["LOCAL_RANK"]) + + # ✅ Initialize process group (NCCL for GPU) + dist.init_process_group(backend="nccl") + + # ✅ Set device for this process + torch.cuda.set_device(local_rank) + + return local_rank + +# Call at start of training script +local_rank = setup_distributed() +device = torch.device(f"cuda:{local_rank}") + +print(f"[Rank {dist.get_rank()}] Using device: {device}") +``` + +**Why this order matters:** +1. `init_process_group()` must come before any CUDA operations +2. `set_device()` ensures all allocations go to correct GPU +3. Each process gets its own GPU (one-to-one mapping) + + +**Step 3: Move Model to Device BEFORE DDP Wrapping** + +```python +# ❌ WRONG: DDP before moving to device +model = MyModel() +model = DDP(model) # ❌ Model still on CPU! +model = model.to(device) # ❌ Too late! + +# ✅ CORRECT: Move to device FIRST, then wrap +model = MyModel() +model = model.to(device) # ✅ Move to device first +model = DDP(model, device_ids=[local_rank], output_device=local_rank) +``` + +**Why this order matters:** +- DDP wraps existing model parameters +- Parameters must already be on correct device before wrapping +- `device_ids` tells DDP which GPU this process uses +- `output_device` specifies where forward pass outputs go + + +**Step 4: Use DistributedSampler for Data Loading** + +```python +from torch.utils.data import DataLoader, DistributedSampler + +# ✅ CORRECT: DistributedSampler ensures each process gets different data +train_sampler = DistributedSampler( + train_dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True # Shuffle within sampler +) + +train_loader = DataLoader( + train_dataset, + batch_size=batch_size_per_gpu, + sampler=train_sampler, # ✅ Use sampler, not shuffle + num_workers=4, + pin_memory=True +) + +# ❌ WRONG: Regular random sampler (all processes get same data!) +# train_loader = DataLoader(train_dataset, shuffle=True) +``` + +**Why DistributedSampler is critical:** +- Without it, all GPUs train on identical data (no benefit!) +- DistributedSampler partitions dataset across processes +- Each process sees different subset (true data parallelism) + +**Important:** Call `sampler.set_epoch(epoch)` before each epoch: +```python +for epoch in range(num_epochs): + train_sampler.set_epoch(epoch) # ✅ Critical for proper shuffling + for batch in train_loader: + # training code +``` + + +**Step 5: Move Data to Correct Device** + +```python +for batch_idx, (data, target) in enumerate(train_loader): + # ✅ Move to local device (non_blocking for async transfer) + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # Forward pass (model already on device) + output = model(data) + loss = criterion(output, target) + + # Backward pass (gradient allreduce happens automatically) + loss.backward() + optimizer.step() + optimizer.zero_grad() +``` + +**Key points:** +- Each process loads different data (via DistributedSampler) +- Data moved to local GPU (`device`) +- Model outputs on same device (specified by `output_device`) +- Gradients synchronized automatically during `loss.backward()` + + +### Complete DDP Training Script Template + +```python +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler +import os + +def setup_distributed(): + """Initialize distributed training.""" + local_rank = int(os.environ["LOCAL_RANK"]) + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return local_rank + +def cleanup_distributed(): + """Cleanup distributed training.""" + dist.destroy_process_group() + +def main(): + # 1. Setup distributed + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + + # 2. Create model and move to device BEFORE DDP + model = MyModel().to(device) + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + + # 3. Optimizer and loss + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss().to(device) + + # 4. Data loading with DistributedSampler + train_sampler = DistributedSampler(train_dataset) + train_loader = DataLoader( + train_dataset, + batch_size=32, # Per-GPU batch size + sampler=train_sampler, + num_workers=4, + pin_memory=True + ) + + # 5. Training loop + for epoch in range(num_epochs): + train_sampler.set_epoch(epoch) # ✅ Critical! + model.train() + + for data, target in train_loader: + # Move data to device + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # Forward pass + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + + # Backward pass (gradients synced automatically) + loss.backward() + optimizer.step() + + # Only log on rank 0 + if dist.get_rank() == 0: + print(f"Epoch {epoch}: Loss = {loss.item()}") + + # 6. Cleanup + cleanup_distributed() + +if __name__ == "__main__": + main() +``` + +**Launch with:** +```bash +torchrun --nproc_per_node=4 train.py +``` + + +## Synchronization Mechanics + +### Understanding Gradient Allreduce + +**What happens during `loss.backward()`:** + +1. **Backward pass**: Each process computes gradients independently +2. **Gradient bucketing**: DDP groups gradients into buckets +3. **Allreduce**: NCCL performs allreduce on each bucket (sum across processes) +4. **Averaging**: Gradients divided by world_size +5. **Result**: All processes have identical averaged gradients + +```python +# Conceptually, DDP does this automatically: +# gradient_on_gpu_0 = compute_gradients_on_gpu_0() +# gradient_on_gpu_1 = compute_gradients_on_gpu_1() +# ... +# gradient_avg = allreduce([gradient_on_gpu_0, gradient_on_gpu_1, ...]) / world_size +# Each GPU now has gradient_avg +``` + +**Critical insight:** Gradient synchronization is automatic. You don't need to do anything special. + + +### Batch Normalization: The Synchronization Trap + +**Problem:** Regular BatchNorm computes statistics per-GPU, causing divergence. + +```python +# ❌ WRONG: Regular BatchNorm in DDP +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 3) + self.bn = nn.BatchNorm2d(64) # ❌ Per-GPU statistics! + self.fc = nn.Linear(64, 10) + +# With batch_size=32 per GPU, 4 GPUs: +# GPU 0: BatchNorm sees 32 samples +# GPU 1: BatchNorm sees 32 samples (different!) +# Statistics computed independently → models diverge +``` + +**✅ SOLUTION: Use SyncBatchNorm** + +```python +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 3) + self.bn = nn.SyncBatchNorm(64) # ✅ Synchronized across GPUs + self.fc = nn.Linear(64, 10) + +# Or convert existing model: +model = Model() +model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # ✅ Converts all BN layers +model = model.to(device) +model = DDP(model, device_ids=[local_rank]) +``` + +**When to use SyncBatchNorm:** +- Small per-GPU batch size (< 16) +- Batch statistics important for your task +- Training produces inconsistent results + +**When regular BatchNorm is okay:** +- Large per-GPU batch size (≥ 32) +- Batch statistics less critical +- Want maximum speed (SyncBatchNorm adds communication overhead) + + +### Buffer Broadcasting + +**Buffers:** Non-parameter tensors (running mean/var in BatchNorm, dropout masks, etc.) + +```python +# DDP parameter: broadcast_buffers +model = DDP( + model, + device_ids=[local_rank], + broadcast_buffers=True # ✅ Default, broadcasts buffers from rank 0 +) +``` + +**What `broadcast_buffers=True` does:** +- At start of training, broadcasts buffers from rank 0 to all processes +- Ensures consistent initialization across all GPUs +- Important for BatchNorm running statistics, dropout patterns, etc. + +**When to disable (`broadcast_buffers=False`):** +- Custom buffer management +- Buffers intentionally different per process +- Rare use case + +**Rule:** Keep `broadcast_buffers=True` unless you know why you need False. + + +### Initialization Synchronization + +**Problem:** If models start different on each GPU, training diverges. + +```python +# ❌ WRONG: Random initialization without seed +def main(): + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + + model = MyModel() # ❌ Random init, different on each process! + model = model.to(device) + model = DDP(model, device_ids=[local_rank]) + +# ✅ CORRECT: Set seed before model creation +def main(): + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + + # ✅ Same seed ensures same initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + model = MyModel() # ✅ Identical initialization on all processes + model = model.to(device) + model = DDP(model, device_ids=[local_rank]) +``` + +**Alternative: Load checkpoint on all processes** +```python +# If loading pretrained model, ensure all processes load same checkpoint +model = MyModel() +model.load_state_dict(torch.load("checkpoint.pth")) # ✅ Same weights +model = model.to(device) +model = DDP(model, device_ids=[local_rank]) +``` + +**Rule:** Ensure model initialization is deterministic and identical across processes. + + +## Device Placement Debugging + +### Systematic Device Checking + +**When you get "Expected all tensors to be on the same device":** + +```python +def diagnose_device_placement(model, data, target): + """Systematic device diagnosis for DDP.""" + + # 1. Check model devices + model_devices = {name: param.device for name, param in model.named_parameters()} + unique_model_devices = set(model_devices.values()) + + print(f"Model devices: {unique_model_devices}") + if len(unique_model_devices) > 1: + print("⚠️ Model parameters on multiple devices!") + for name, device in model_devices.items(): + print(f" {name}: {device}") + + # 2. Check buffer devices + buffer_devices = {name: buf.device for name, buf in model.named_buffers()} + unique_buffer_devices = set(buffer_devices.values()) + + print(f"Buffer devices: {unique_buffer_devices}") + if len(unique_buffer_devices) > 1: + print("⚠️ Model buffers on multiple devices!") + + # 3. Check data devices + print(f"Data device: {data.device}") + print(f"Target device: {target.device}") + + # 4. Check if all on same device + all_devices = unique_model_devices | unique_buffer_devices | {data.device, target.device} + if len(all_devices) > 1: + print(f"❌ MISMATCH: Tensors on {all_devices}") + return False + else: + print(f"✅ All tensors on {list(all_devices)[0]}") + return True + +# Use before training: +diagnose_device_placement(model, data_batch, target_batch) +``` + + +### Common Device Mismatch Causes + +**Pitfall 1: Loss function not on device** + +```python +# ❌ WRONG: Loss function on CPU +criterion = nn.CrossEntropyLoss() # Defaults to CPU + +output = model(data) # GPU +loss = criterion(output, target) # ❌ Tries to use CPU loss + +# ✅ CORRECT: Move loss to device +criterion = nn.CrossEntropyLoss().to(device) +``` + + +**Pitfall 2: Forgetting to move target** + +```python +# ❌ WRONG: Only move input +data = data.to(device) +# target not moved! +output = model(data) +loss = criterion(output, target) # ❌ output on GPU, target on CPU + +# ✅ CORRECT: Move both +data = data.to(device) +target = target.to(device) +``` + + +**Pitfall 3: Wrong LOCAL_RANK** + +```python +# ❌ WRONG: Hardcoded device +device = torch.device("cuda:0") # ❌ All processes use GPU 0! + +# ✅ CORRECT: Use LOCAL_RANK +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") +``` + + +**Pitfall 4: Model partially on wrong device** + +```python +# ❌ WRONG: Only some layers moved +model = MyModel() +model.encoder = model.encoder.to(device) # Only encoder moved +model = DDP(model, device_ids=[local_rank]) # ❌ decoder still on CPU + +# ✅ CORRECT: Move entire model +model = MyModel() +model = model.to(device) # ✅ All parameters/buffers moved +model = DDP(model, device_ids=[local_rank]) +``` + + +## Performance Optimization + +### Profiling Distributed Training + +**Use torch.profiler to identify bottlenecks:** + +```python +from torch.profiler import profile, ProfilerActivity, schedule + +def train_with_profiling(model, data_loader, optimizer, criterion, device): + """Profile distributed training to identify bottlenecks.""" + + # Profile for 5 steps after warmup + prof_schedule = schedule(wait=1, warmup=1, active=5, repeat=1) + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=prof_schedule, + on_trace_ready=lambda p: p.export_chrome_trace(f"trace_rank_{dist.get_rank()}.json") + ) as prof: + + for step, (data, target) in enumerate(data_loader): + if step >= 7: # Profile first 7 steps + break + + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + prof.step() # Signal profiler to move to next step + + # Analyze results + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # Look for: + # - Time in "nccl:all_reduce" (communication overhead) + # - Time in forward/backward (computation) + # - Ratio of communication to computation +``` + +**View trace in Chrome:** Open `chrome://tracing` and load `trace_rank_0.json`. + +**What to look for:** +- **Communication time**: Look for `nccl:all_reduce` operations +- **Computation time**: Forward and backward pass +- **Idle time**: Gaps between operations (synchronization overhead) +- **Optimal ratio**: Computation time >> Communication time (10:1 or better) + + +### Understanding Gradient Bucketing + +**DDP optimization:** Gradients grouped into buckets for efficient allreduce. + +```python +model = DDP( + model, + device_ids=[local_rank], + bucket_cap_mb=25, # Default: 25MB buckets + gradient_as_bucket_view=True # Memory optimization +) +``` + +**How bucketing works:** +1. During backward pass, gradients computed layer by layer (backward order) +2. DDP accumulates gradients into 25MB buckets +3. When bucket full, launches asynchronous allreduce +4. While waiting, continues computing more gradients +5. Overlaps communication and computation + +**When to tune `bucket_cap_mb`:** +- **Larger buckets (50MB+)**: Fewer allreduce calls, less overhead + - Good for: Large models, fast network + - Risk: Less overlap, potential idle time +- **Smaller buckets (10MB)**: More overlap, better pipelining + - Good for: Small models, slow network + - Risk: More allreduce overhead + +**Rule of thumb:** Start with default 25MB, only tune if profiling shows communication bottleneck. + + +### Gradient Accumulation in DDP + +**When gradient accumulation helps:** + +```python +# Without DDP: Accumulate to simulate larger batch +for i, (data, target) in enumerate(data_loader): + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() # Accumulate gradients + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# ✅ With DDP: Still accumulates, but communication amortized +# DDP synchronizes gradients only when optimizer.step() is called +``` + +**Critical for DDP:** Use `no_sync()` context to disable gradient synchronization: + +```python +for i, (data, target) in enumerate(data_loader): + data = data.to(device) + target = target.to(device) + + # Disable gradient sync for accumulation steps + if (i + 1) % accumulation_steps != 0: + with model.no_sync(): # ✅ Skip allreduce + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() + else: + # Final accumulation step: synchronize + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() # ✅ Gradient allreduce happens here + optimizer.step() + optimizer.zero_grad() +``` + +**Why this matters:** +- Without `no_sync()`, DDP performs allreduce every backward pass (wasted communication) +- With `no_sync()`, allreduce only on final accumulation step +- Amortizes communication cost over accumulation_steps + +**When to use gradient accumulation in DDP:** +- Effective batch size > per-GPU memory allows +- Want larger batches but limited by GPU memory +- Training with small models (communication-bound) + + +### NCCL Tuning for Performance + +**Environment variables to tune NCCL:** + +```bash +# Disable P2P (peer-to-peer) if causing issues +export NCCL_P2P_DISABLE=1 + +# Increase buffer size for large messages +export NCCL_BUFFSIZE=8388608 # 8MB + +# Use specific network interface +export NCCL_SOCKET_IFNAME=eth0 + +# Enable InfiniBand (if available) +export NCCL_IB_DISABLE=0 + +# Increase timeout for slow networks +export NCCL_TIMEOUT=1800 # 30 minutes + +# Debugging: Log NCCL activity +export NCCL_DEBUG=INFO +export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV +``` + +**Common scenarios:** + +**Multi-node over Ethernet:** +```bash +export NCCL_SOCKET_IFNAME=eth0 # Specify correct interface +export NCCL_IB_DISABLE=1 # Disable InfiniBand +``` + +**Multi-node over InfiniBand:** +```bash +export NCCL_IB_DISABLE=0 # Enable InfiniBand +export NCCL_IB_HCA=mlx5_0 # Specify IB adapter +``` + +**Debugging communication issues:** +```bash +export NCCL_DEBUG=INFO # Verbose logging +export NCCL_DEBUG_FILE=/tmp/nccl_rank_%r.log # Per-rank logs +``` + + +### Scaling Efficiency Analysis + +**Expected speedup:** + +| # GPUs | Ideal Speedup | Realistic Speedup | Notes | +|--------|---------------|-------------------|-------| +| 2 | 2.0x | 1.8-1.9x | 90-95% efficiency | +| 4 | 4.0x | 3.5-3.8x | 85-95% efficiency | +| 8 | 8.0x | 6.5-7.5x | 80-90% efficiency | +| 16 | 16.0x | 12-15x | 75-90% efficiency | + +**Why not perfect scaling:** +1. **Communication overhead**: Gradient allreduce takes time +2. **Synchronization barriers**: Processes wait for each other +3. **Batch size effects**: Larger effective batch may need more iterations +4. **Network bandwidth**: Inter-node communication slower than intra-node + +**Model size vs scaling efficiency:** + +``` +Large models (100M+ parameters): +- Communication/Computation ratio: Low (1:20) +- Scaling efficiency: High (90%+) +- Why: Gradient communication cost amortized + +Small models (<10M parameters): +- Communication/Computation ratio: High (1:3) +- Scaling efficiency: Lower (70-80%) +- Why: Communication dominates + +Solution for small models: +- Gradient accumulation (amortize communication) +- Larger per-GPU batch size (more computation) +- Fewer GPUs (don't over-parallelize) +``` + +**Profiling speedup:** + +```python +import time + +# Baseline: Single GPU +model_single = MyModel().cuda() +start = time.time() +# Train for N steps +elapsed_single = time.time() - start + +# DDP: 4 GPUs (on rank 0) +if dist.get_rank() == 0: + model_ddp = DDP(MyModel().to(device), device_ids=[local_rank]) + start = time.time() + # Train for N steps + elapsed_ddp = time.time() - start + + speedup = elapsed_single / elapsed_ddp + efficiency = (speedup / 4) * 100 + + print(f"Speedup: {speedup:.2f}x") + print(f"Efficiency: {efficiency:.1f}%") + + if efficiency < 80: + print("⚠️ Low efficiency - check communication overhead") +``` + + +## Multi-Node Training + +### Process Group Initialization + +**Multi-node setup requires:** + +1. **Master node** (rank 0): Coordinates initialization +2. **Worker nodes**: Connect to master +3. **Network**: All nodes can communicate + +```python +import torch.distributed as dist +import os + +def setup_multi_node(): + """Initialize multi-node DDP.""" + # Environment variables set by torchrun: + # RANK: Global rank (0 to world_size-1) + # LOCAL_RANK: Rank within node (0 to nproc_per_node-1) + # WORLD_SIZE: Total processes across all nodes + # MASTER_ADDR: IP of rank 0 node + # MASTER_PORT: Port for communication + + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Initialize process group (NCCL for GPU) + dist.init_process_group( + backend="nccl", + init_method="env://", # Use environment variables + rank=rank, + world_size=world_size + ) + + # Set device for this process + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + print(f"[Node rank {rank // torch.cuda.device_count()}] " + f"[Local rank {local_rank}] " + f"[Global rank {rank}] " + f"Device: {device}") + + return rank, local_rank, device +``` + +**Launch multi-node training:** + +```bash +# Node 0 (master: 192.168.1.1): +torchrun \ + --nproc_per_node=4 \ + --nnodes=2 \ + --node_rank=0 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train.py + +# Node 1 (worker: 192.168.1.2): +torchrun \ + --nproc_per_node=4 \ + --nnodes=2 \ + --node_rank=1 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train.py +``` + + +### Multi-Node Debugging + +**Problem:** Multi-node training works on single node but fails with 2+ nodes. + +**Step 1: Verify network connectivity** + +```bash +# On worker node, test connection to master +ping 192.168.1.1 # Should succeed + +# Test port connectivity +nc -zv 192.168.1.1 29500 # Should connect +``` + +**Step 2: Check NCCL can communicate** + +```python +import torch +import torch.distributed as dist + +def test_nccl_communication(): + """Test NCCL communication across nodes.""" + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + + # Create tensor on each rank + tensor = torch.ones(1).to(device) * rank + print(f"[Rank {rank}] Before allreduce: {tensor.item()}") + + # Allreduce (sum) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + # Expected: sum of all ranks = 0 + 1 + 2 + ... + (world_size-1) + expected = sum(range(world_size)) + print(f"[Rank {rank}] After allreduce: {tensor.item()} (expected: {expected})") + + if abs(tensor.item() - expected) < 1e-6: + print(f"[Rank {rank}] ✅ NCCL communication working") + else: + print(f"[Rank {rank}] ❌ NCCL communication FAILED") + +# Run this test before training +test_nccl_communication() +dist.barrier() # Synchronize all processes +``` + +**Step 3: Enable NCCL debugging** + +```bash +export NCCL_DEBUG=INFO +export NCCL_DEBUG_FILE=/tmp/nccl_rank_%r.log + +# Run training, then check logs: +cat /tmp/nccl_rank_0.log # Master +cat /tmp/nccl_rank_4.log # First process on node 1 +``` + +**Look for in logs:** +- "NCCL INFO Bootstrap : Using [interface]" → Correct network interface? +- "NCCL INFO NET/Socket" → Network connection established? +- Errors about ring construction → NCCL can't form communication ring + + +### Multi-Node Batch Norm Issues + +**Problem:** Batch norm statistics diverge across nodes. + +**Solution:** Use SyncBatchNorm (already covered, but critical for multi-node) + +```python +# Convert model to use SyncBatchNorm BEFORE moving to device +model = MyModel() +model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # ✅ Critical for multi-node +model = model.to(device) +model = DDP(model, device_ids=[local_rank]) +``` + +**Why this is more critical for multi-node:** +- Single-node: Intra-node communication fast (NVLink/PCIe) +- Multi-node: Inter-node communication slower (network) +- SyncBatchNorm requires allreduce of statistics (adds latency) +- But necessary for correct training! + + +## Common Pitfalls + +### Consolidated Pitfall Table + +| # | Pitfall | Symptom | Root Cause | Fix | +|---|---------|---------|------------|-----| +| 1 | Using nn.DataParallel instead of DDP | Poor scaling, GPU 0 overloaded | Single-process multi-threading | Use DistributedDataParallel | +| 2 | Wrapping model before moving to device | "Expected same device" errors | DDP wraps before device placement | `model.to(device)` BEFORE `DDP(model)` | +| 3 | Not using DistributedSampler | All GPUs see same data, no speedup | Regular sampler doesn't partition data | Use `DistributedSampler` | +| 4 | Forgetting `sampler.set_epoch()` | Data order identical each epoch | Sampler shuffle seed not updated | Call `sampler.set_epoch(epoch)` | +| 5 | Regular BatchNorm in DDP | Training divergence, inconsistent results | Per-GPU statistics not synchronized | Use `SyncBatchNorm` | +| 6 | Loss function not moved to device | Device mismatch error | Loss defaults to CPU | `criterion.to(device)` | +| 7 | Hardcoding device instead of LOCAL_RANK | All processes use GPU 0 | Wrong device mapping | `device = torch.device(f"cuda:{local_rank}")` | +| 8 | Different model initialization per process | Training divergence | Random seeds not synchronized | Set same seed before model creation | +| 9 | Gradient accumulation without no_sync() | Wasted communication overhead | DDP syncs every backward | Use `model.no_sync()` context | +| 10 | find_unused_parameters without need | Slow training, high overhead | Unnecessary dynamic graph handling | Set `find_unused_parameters=False` | + + +### Pitfall 1: DataParallel vs DistributedDataParallel + +```python +# ❌ WRONG: Using obsolete DataParallel +model = nn.DataParallel(model).cuda() + +# Problems: +# - Single process (GIL bottleneck) +# - GPU 0 accumulates gradients (memory overhead) +# - Slower than DDP (2-3x on 8 GPUs vs 7-8x) + +# ✅ CORRECT: Use DDP +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") +model = model.to(device) +model = DDP(model, device_ids=[local_rank]) +``` + +**Symptom:** Poor scaling (2-3x on 8 GPUs instead of 7-8x) +**Fix:** Replace DataParallel with DistributedDataParallel + + +### Pitfall 2: DDP Before Device Placement + +```python +# ❌ WRONG: DDP before .to(device) +model = MyModel() +model = DDP(model) # ❌ Model still on CPU +model = model.to(device) # ❌ Too late! + +# ✅ CORRECT: Device placement before DDP +model = MyModel() +model = model.to(device) # ✅ First move to device +model = DDP(model, device_ids=[local_rank]) # ✅ Then wrap +``` + +**Symptom:** "Expected all tensors to be on the same device" +**Fix:** Always `model.to(device)` BEFORE `DDP(model)` + + +### Pitfall 3: Missing DistributedSampler + +```python +# ❌ WRONG: Regular DataLoader without DistributedSampler +train_loader = DataLoader(dataset, batch_size=32, shuffle=True) + +# Problem: All GPUs see same data! No data parallelism! + +# ✅ CORRECT: Use DistributedSampler +train_sampler = DistributedSampler(dataset) +train_loader = DataLoader( + dataset, + batch_size=32, + sampler=train_sampler, # ✅ Partitions data + shuffle=False # ❌ Can't use shuffle with sampler +) +``` + +**Symptom:** Training with DDP no faster than single GPU +**Fix:** Use `DistributedSampler` to partition data across GPUs + + +### Pitfall 4: Forgetting set_epoch() + +```python +# ❌ WRONG: Not calling set_epoch() +for epoch in range(num_epochs): + for batch in train_loader: # ❌ Same shuffle order every epoch! + # training + +# ✅ CORRECT: Call set_epoch() before each epoch +for epoch in range(num_epochs): + train_sampler.set_epoch(epoch) # ✅ Updates shuffle seed + for batch in train_loader: + # training +``` + +**Symptom:** Training doesn't improve after first epoch (sees same data order) +**Fix:** Call `train_sampler.set_epoch(epoch)` at start of each epoch + + +### Pitfall 5: Regular BatchNorm in DDP + +```python +# ❌ WRONG: Regular BatchNorm (per-GPU statistics) +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 64, 3) + self.bn = nn.BatchNorm2d(64) # ❌ Not synchronized! + +# ✅ CORRECT: SyncBatchNorm (synchronized statistics) +model = Model() +model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # ✅ Converts all BN +model = model.to(device) +model = DDP(model, device_ids=[local_rank]) +``` + +**Symptom:** DDP training results differ from single-GPU, or training diverges +**Fix:** Use `SyncBatchNorm` for small per-GPU batch sizes + + +### Pitfall 6: Loss Not on Device + +```python +# ❌ WRONG: Loss function defaults to CPU +criterion = nn.CrossEntropyLoss() # On CPU + +output = model(data) # On GPU +loss = criterion(output, target) # ❌ Device mismatch! + +# ✅ CORRECT: Move loss to device +criterion = nn.CrossEntropyLoss().to(device) # ✅ On GPU +``` + +**Symptom:** "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu" +**Fix:** Move loss function to device with `.to(device)` + + +### Pitfall 7: Hardcoded Device + +```python +# ❌ WRONG: Hardcoded device (all processes use GPU 0!) +device = torch.device("cuda:0") # ❌ All ranks use same GPU! + +# ✅ CORRECT: Use LOCAL_RANK from environment +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") # ✅ Each rank gets own GPU +``` + +**Symptom:** OOM on GPU 0, other GPUs idle +**Fix:** Use `LOCAL_RANK` to assign one GPU per process + + +### Pitfall 8: Inconsistent Initialization + +```python +# ❌ WRONG: No seed set (different init on each process) +def main(): + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + model = MyModel() # ❌ Random init, different per process! + +# ✅ CORRECT: Set seed for consistent initialization +def main(): + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + + torch.manual_seed(42) # ✅ Same seed everywhere + model = MyModel() # ✅ Identical initialization +``` + +**Symptom:** Training diverges or produces inconsistent results +**Fix:** Set same random seed on all processes before model creation + + +### Pitfall 9: Gradient Accumulation Without no_sync() + +```python +# ❌ WRONG: Gradient accumulation with DDP (syncs every step!) +for i, (data, target) in enumerate(data_loader): + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() # ❌ DDP syncs gradients every time! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# ✅ CORRECT: Use no_sync() to skip gradient synchronization +for i, (data, target) in enumerate(data_loader): + if (i + 1) % accumulation_steps != 0: + with model.no_sync(): # ✅ Skip allreduce + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() + else: + output = model(data) + loss = criterion(output, target) / accumulation_steps + loss.backward() # ✅ Sync only on last accumulation step + optimizer.step() + optimizer.zero_grad() +``` + +**Symptom:** Gradient accumulation slower than expected in DDP +**Fix:** Use `model.no_sync()` context to disable gradient sync for accumulation steps + + +### Pitfall 10: find_unused_parameters=True Without Need + +```python +# ❌ WRONG: Enabling find_unused_parameters unnecessarily +model = DDP( + model, + device_ids=[local_rank], + find_unused_parameters=True # ❌ Adds overhead! +) + +# ✅ CORRECT: Only enable if you have unused parameters +model = DDP( + model, + device_ids=[local_rank], + find_unused_parameters=False # ✅ Default, faster +) + +# When you NEED find_unused_parameters=True: +# - Dynamic computation graphs (different paths each forward) +# - Some parameters not used in every forward pass +# - Multi-task models with conditional branches +``` + +**Symptom:** Training slower than expected, especially backward pass +**Fix:** Keep `find_unused_parameters=False` unless you have dynamic graphs + + +## Red Flags - Stop and Diagnose + +**If you catch yourself doing ANY of these, STOP and follow systematic methodology:** + +| Red Flag Thought | Reality | What to Do Instead | +|------------------|---------|-------------------| +| "I'll just use DataParallel, it's simpler" | DataParallel is deprecated and slow | Always use DDP, setup is straightforward | +| "I'll reduce batch size to fix OOM" | May be masking device placement bug | Diagnose device placement first | +| "Multi-node should just work like single-node" | Multi-node has network, NCCL config, synchronization | Check network, NCCL logs, test communication | +| "Scaling isn't perfect, must be PyTorch bug" | 99% of time it's configuration or model size | Profile to identify communication overhead | +| "I'll skip DistributedSampler for now" | All GPUs will see same data, no benefit | Use DistributedSampler from the start | +| "BatchNorm should work automatically" | Regular BatchNorm uses per-GPU statistics | Use SyncBatchNorm for small batch sizes | +| "I'll wrap model then move to device" | Order matters critically | ALWAYS: to(device) BEFORE DDP() | +| "Communication is slow, must be network" | May be configuration (NCCL, bucketing) | Profile first, tune NCCL second | + +**Critical rule:** DDP has specific setup requirements. Follow checklist systematically, don't guess. + + +## Edge Cases and Advanced Scenarios + +### Edge Case 1: Mixed Precision with DDP + +**Combining autocast/GradScaler with DDP:** + +```python +from torch.cuda.amp import autocast, GradScaler + +# Setup DDP +model = MyModel().to(device) +model = DDP(model, device_ids=[local_rank]) + +# ✅ CORRECT: GradScaler is local (not synchronized) +scaler = GradScaler() # One per process + +for data, target in data_loader: + data = data.to(device) + target = target.to(device) + + optimizer.zero_grad() + + # Forward pass with autocast + with autocast(): + output = model(data) + loss = criterion(output, target) + + # Backward with gradient scaling + scaler.scale(loss).backward() # DDP syncs gradients here + scaler.step(optimizer) + scaler.update() +``` + +**Key points:** +- Each process has its own `GradScaler` +- DDP gradient synchronization works with scaled gradients +- `scaler.step()` handles gradient unscaling internally +- No special DDP configuration needed for mixed precision + + +### Edge Case 2: Dynamic Computation Graphs + +**When forward pass has conditional branches:** + +```python +class ConditionalModel(nn.Module): + def forward(self, x, use_extra_layer=False): + x = self.layer1(x) + if use_extra_layer: + x = self.extra_layer(x) # Sometimes used, sometimes not + x = self.layer2(x) + return x + +# ✅ CORRECT: Enable find_unused_parameters +model = ConditionalModel().to(device) +model = DDP( + model, + device_ids=[local_rank], + find_unused_parameters=True # ✅ Required for dynamic graphs +) + +# Training loop +for data, target in data_loader: + # Randomly use extra layer + use_extra = random.random() > 0.5 + output = model(data, use_extra_layer=use_extra) + loss = criterion(output, target) + loss.backward() # DDP handles unused parameters + optimizer.step() +``` + +**Warning:** `find_unused_parameters=True` adds overhead. Only use when necessary. + + +### Edge Case 3: Gradient Checkpointing with DDP + +**Combining gradient checkpointing (for memory) with DDP:** + +```python +from torch.utils.checkpoint import checkpoint + +class CheckpointedModel(nn.Module): + def forward(self, x): + # Checkpoint intermediate layers + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + x = self.output_layer(x) + return x + +# ✅ Works with DDP out of the box +model = CheckpointedModel().to(device) +model = DDP(model, device_ids=[local_rank]) + +# Training: Gradient checkpointing + DDP gradient sync +for data, target in data_loader: + output = model(data) + loss = criterion(output, target) + loss.backward() # Recomputes forward, then syncs gradients + optimizer.step() +``` + +**Key insight:** Gradient checkpointing recomputes forward during backward. DDP gradient synchronization still happens correctly at the end of backward pass. + + +### Edge Case 4: Saving and Loading Checkpoints + +**Save only from rank 0, load on all ranks:** + +```python +# Saving checkpoint (only rank 0) +if dist.get_rank() == 0: + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.module.state_dict(), # ✅ model.module, not model + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss.item(), + } + torch.save(checkpoint, 'checkpoint.pth') + +dist.barrier() # Wait for rank 0 to finish saving + +# Loading checkpoint (all ranks) +checkpoint = torch.load('checkpoint.pth', map_location=device) +model.module.load_state_dict(checkpoint['model_state_dict']) # ✅ model.module +optimizer.load_state_dict(checkpoint['optimizer_state_dict']) +``` + +**Critical points:** +- Use `model.module.state_dict()` not `model.state_dict()` (unwrap DDP) +- Save only from rank 0 (avoid race condition) +- Load on all ranks (each process needs weights) +- Use `dist.barrier()` to synchronize + + +### Edge Case 5: Heterogeneous GPUs + +**Training with different GPU types (e.g., V100 + A100):** + +```python +# Problem: Different GPUs have different speeds +# Solution: Set timeout to prevent faster GPUs from timing out + +model = DDP( + model, + device_ids=[local_rank], + timeout=timedelta(minutes=30) # ✅ Increase timeout for slow GPUs +) +``` + +**Additional considerations:** +- Batch size per GPU should match slowest GPU's memory +- Scaling efficiency will be limited by slowest GPU +- Consider grouping processes by GPU type using process groups + + +### Edge Case 6: Zero Redundancy Optimizer (ZeRO) + +**For very large models, use FairScale ZeRO:** + +```python +from fairscale.optim.oss import OSS +from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + +# Setup model +model = MyLargeModel().to(device) + +# ✅ Shard optimizer states across GPUs +base_optimizer = torch.optim.Adam +optimizer = OSS(model.parameters(), optim=base_optimizer, lr=1e-3) + +# ✅ Shard model parameters (optional, for very large models) +model = ShardedDDP(model, optimizer) + +# Training loop same as DDP +for data, target in data_loader: + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() +``` + +**When to use ZeRO:** +- Model too large for single GPU +- Optimizer states don't fit in memory +- Need to scale beyond 8 GPUs + + +## Debugging Methodology + +### Systematic Debugging for DDP Issues + +**Step 1: Verify single-GPU training works** + +```bash +# First, ensure code works on single GPU +python train.py # No torchrun, single process + +# If single-GPU works, then it's a DDP-specific issue +``` + + +**Step 2: Check environment variables** + +```python +def check_ddp_environment(): + """Verify DDP environment is set up correctly.""" + required_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"] + + for var in required_vars: + value = os.environ.get(var) + if value is None: + print(f"❌ Missing environment variable: {var}") + else: + print(f"✅ {var} = {value}") + + # Check if NCCL backend available + if torch.cuda.is_available() and torch.cuda.nccl.is_available(): + print(f"✅ NCCL available (version {torch.cuda.nccl.version()})") + else: + print("❌ NCCL not available") + + # Check GPU count + print(f"✅ GPUs visible: {torch.cuda.device_count()}") + +# Run before init_process_group +check_ddp_environment() +``` + + +**Step 3: Test process group initialization** + +```python +def test_process_group_init(): + """Test if process group initializes correctly.""" + try: + dist.init_process_group(backend="nccl") + print(f"✅ Process group initialized") + print(f" Rank: {dist.get_rank()}") + print(f" World size: {dist.get_world_size()}") + return True + except Exception as e: + print(f"❌ Process group initialization failed: {e}") + return False + +if test_process_group_init(): + # Continue with training setup + pass +``` + + +**Step 4: Verify device placement** + +```python +def verify_device_placement(model, data_batch, target_batch): + """Check all tensors on correct device.""" + local_rank = int(os.environ["LOCAL_RANK"]) + expected_device = torch.device(f"cuda:{local_rank}") + + # Check model + model_device = next(model.parameters()).device + assert model_device == expected_device, f"Model on {model_device}, expected {expected_device}" + print(f"✅ Model on correct device: {model_device}") + + # Check data + assert data_batch.device == expected_device, f"Data on {data_batch.device}, expected {expected_device}" + assert target_batch.device == expected_device, f"Target on {target_batch.device}, expected {expected_device}" + print(f"✅ Data on correct device: {data_batch.device}") + +# Before training loop +data_batch, target_batch = next(iter(data_loader)) +data_batch = data_batch.to(device) +target_batch = target_batch.to(device) +verify_device_placement(model, data_batch, target_batch) +``` + + +**Step 5: Test gradient synchronization** + +```python +def test_gradient_sync(model): + """Verify gradients are synchronized across processes.""" + rank = dist.get_rank() + + # Set gradients to rank value + for param in model.parameters(): + if param.grad is not None: + param.grad.data.fill_(rank) + + # Perform dummy backward (this should trigger allreduce in DDP) + # But we already set gradients, so just check if they're averaged + + # In actual DDP, gradients are averaged after backward() + # Expected value: average of all ranks = (0 + 1 + 2 + ... + (world_size-1)) / world_size + expected = sum(range(dist.get_world_size())) / dist.get_world_size() + + # Check if gradients are close to expected + # (This test assumes you've already run backward) + for param in model.parameters(): + if param.grad is not None: + actual = param.grad.data.mean().item() + if abs(actual - expected) > 1e-5: + print(f"❌ [Rank {rank}] Gradient sync failed: {actual} != {expected}") + return False + + print(f"✅ [Rank {rank}] Gradients synchronized correctly") + return True + +# After first backward pass +# test_gradient_sync(model) +``` + + +**Step 6: Profile communication overhead** + +```python +def profile_communication_overhead(model, data_loader, device, num_steps=10): + """Measure communication vs computation time.""" + import time + + model.train() + + compute_times = [] + total_times = [] + + for step, (data, target) in enumerate(data_loader): + if step >= num_steps: + break + + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + torch.cuda.synchronize() # Wait for data transfer + step_start = time.time() + + # Forward + compute_start = time.time() + output = model(data) + loss = criterion(output, target) + + # Backward (includes communication) + loss.backward() + + torch.cuda.synchronize() # Wait for backward (including allreduce) + step_end = time.time() + + compute_time = step_end - compute_start + total_time = step_end - step_start + + compute_times.append(compute_time) + total_times.append(total_time) + + avg_compute = sum(compute_times) / len(compute_times) + avg_total = sum(total_times) / len(total_times) + communication = avg_total - avg_compute + + if dist.get_rank() == 0: + print(f"Compute time: {avg_compute:.4f}s") + print(f"Communication time: {communication:.4f}s") + print(f"Communication overhead: {(communication/avg_total)*100:.1f}%") + + if communication / avg_total > 0.3: + print("⚠️ High communication overhead (>30%)") + print(" Consider: Larger model, gradient accumulation, or fewer GPUs") + +profile_communication_overhead(model, train_loader, device) +``` + + +## Common Rationalizations (Don't Do These) + +### Comprehensive Rationalization Table + +| Excuse | What Agent Might Think | Reality | Correct Response | +|--------|----------------------|---------|------------------| +| "User is rushed" | "I'll skip the checklist to save time" | Checklist takes <5 min, wrong fix wastes 30+ min | Follow systematic methodology | +| "They already tried X" | "X must not be the issue, move to Y" | X may have been done incorrectly | Verify X was done correctly first | +| "Senior engineer says use DataParallel" | "Authority knows best, defer to them" | DataParallel is objectively slower/deprecated | Recommend DDP with evidence | +| "They've been debugging for hours" | "They must have ruled out obvious issues" | Fatigue causes mistakes, start from basics | Apply systematic checklist regardless | +| "Multi-node is complex" | "Just give them a working config" | Config must match their environment | Diagnose specific failure | +| "Profiling takes time" | "User wants quick answer, skip profiling" | Profiling finds exact bottleneck in minutes | Always profile before optimizing | +| "This is a complex interaction" | "Too complex to debug systematically" | Systematic testing isolates interaction | Test components independently | +| "Network must be the issue" | "Skip other checks, assume network" | Could be config, NCCL, or code | Check network AFTER code checks | +| "NCCL tuning will fix it" | "Jump to NCCL environment variables" | NCCL tuning is last resort | Profile to confirm communication bound | +| "Just use fewer GPUs" | "Scaling is hard, reduce parallelism" | Likely a configuration issue | Fix configuration, don't reduce scale | +| "DataParallel is simpler" | "Avoid DDP complexity" | DataParallel 2-3x slower, deprecated | DDP setup takes 10 more lines, 3-4x faster | +| "I'll move model after DDP" | "Order doesn't matter much" | Wrong order causes device errors | ALWAYS to(device) BEFORE DDP() | +| "DistributedSampler too complex" | "Skip it for now" | Without it, all GPUs see same data | Use DistributedSampler, it's 2 lines | +| "Batch norm should work" | "PyTorch handles it automatically" | Per-GPU statistics cause divergence | Use SyncBatchNorm for small batches | +| "find_unused_parameters=True just in case" | "Better safe than sorry" | Adds 10-20% overhead | Only use for dynamic graphs | + + +## Red Flags Checklist - Expanded + +**Before suggesting any fix, check these red flags:** + +### Setup Red Flags +- [ ] Am I suggesting DataParallel? (❌ Always use DDP) +- [ ] Am I wrapping before moving to device? (❌ Device first, then DDP) +- [ ] Am I missing DistributedSampler? (❌ Required for data parallelism) +- [ ] Am I hardcoding device=cuda:0? (❌ Use LOCAL_RANK) +- [ ] Am I skipping set_epoch()? (❌ Required for proper shuffling) + +### Synchronization Red Flags +- [ ] Am I using regular BatchNorm with small batches? (❌ Use SyncBatchNorm) +- [ ] Am I assuming initialization is synced? (❌ Set seed explicitly) +- [ ] Am I ignoring buffer synchronization? (❌ Keep broadcast_buffers=True) +- [ ] Am I using find_unused_parameters unnecessarily? (❌ Adds overhead) + +### Performance Red Flags +- [ ] Am I suggesting NCCL tuning before profiling? (❌ Profile first) +- [ ] Am I using gradient accumulation without no_sync()? (❌ Wastes communication) +- [ ] Am I ignoring model size vs communication tradeoff? (❌ Small models scale poorly) +- [ ] Am I assuming perfect scaling? (❌ 80-90% efficiency is realistic) + +### Debugging Red Flags +- [ ] Am I skipping single-GPU verification? (❌ Verify single-GPU first) +- [ ] Am I not checking environment variables? (❌ Verify RANK, LOCAL_RANK, etc.) +- [ ] Am I assuming device placement without checking? (❌ Use diagnostic function) +- [ ] Am I guessing bottleneck without profiling? (❌ Always profile) + +### Multi-Node Red Flags +- [ ] Am I assuming network works without testing? (❌ Test connectivity) +- [ ] Am I not checking NCCL logs? (❌ Enable NCCL_DEBUG=INFO) +- [ ] Am I ignoring network interface specification? (❌ Set NCCL_SOCKET_IFNAME) +- [ ] Am I assuming allreduce works without testing? (❌ Run communication test) + +### Pressure/Bias Red Flags +- [ ] Am I skipping systematic checks due to time pressure? (❌ Checklist faster than guessing) +- [ ] Am I accepting user's diagnosis without verification? (❌ Profile to confirm) +- [ ] Am I deferring to authority over facts? (❌ DDP is objectively better) +- [ ] Am I providing config without understanding failure? (❌ Diagnose first) + +**If ANY red flag is true, STOP and apply the correct pattern.** + + +## Quick Reference: DDP Setup Checklist + +### Before Training + +- [ ] Use `torchrun` to launch (not `python train.py`) +- [ ] Initialize process group: `dist.init_process_group(backend="nccl")` +- [ ] Get `LOCAL_RANK` from environment: `int(os.environ["LOCAL_RANK"])` +- [ ] Set device: `torch.cuda.set_device(local_rank)` +- [ ] Set random seed (for consistent initialization) + +### Model Setup + +- [ ] Create model +- [ ] (Optional) Convert to SyncBatchNorm: `nn.SyncBatchNorm.convert_sync_batchnorm(model)` +- [ ] Move to device: `model.to(device)` +- [ ] Wrap with DDP: `DDP(model, device_ids=[local_rank], output_device=local_rank)` + +### Data Loading + +- [ ] Create `DistributedSampler`: `DistributedSampler(dataset)` +- [ ] Use sampler in DataLoader: `DataLoader(..., sampler=train_sampler)` +- [ ] Call `train_sampler.set_epoch(epoch)` before each epoch + +### Training Loop + +- [ ] Move data to device: `data.to(device)`, `target.to(device)` +- [ ] Forward pass +- [ ] Backward pass (gradients synced automatically) +- [ ] Optimizer step +- [ ] Zero gradients + +### Checkpointing + +- [ ] Save only from rank 0: `if dist.get_rank() == 0:` +- [ ] Use `model.module.state_dict()` (unwrap DDP) +- [ ] Load on all ranks +- [ ] Add `dist.barrier()` after saving + +### Cleanup + +- [ ] Call `dist.destroy_process_group()` at end + + +## Complete Multi-GPU Training Script + +```python +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler +import os + +def setup_distributed(): + """Initialize distributed training.""" + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + return local_rank + +def cleanup_distributed(): + """Cleanup distributed training.""" + dist.destroy_process_group() + +def main(): + # 1. Setup distributed + local_rank = setup_distributed() + device = torch.device(f"cuda:{local_rank}") + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + print(f"Training on {world_size} GPUs") + + # 2. Set seed for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # 3. Create model + model = MyModel() + + # 4. (Optional) Convert to SyncBatchNorm + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # 5. Move to device BEFORE DDP + model = model.to(device) + + # 6. Wrap with DDP + model = DDP( + model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=False # Only True if dynamic graphs + ) + + # 7. Optimizer and loss + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss().to(device) + + # 8. Data loading with DistributedSampler + train_sampler = DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True + ) + + train_loader = DataLoader( + train_dataset, + batch_size=32, # Per-GPU batch size + sampler=train_sampler, + num_workers=4, + pin_memory=True, + drop_last=True # Avoid uneven last batch + ) + + # 9. Training loop + for epoch in range(num_epochs): + # Set epoch for proper shuffling + train_sampler.set_epoch(epoch) + + model.train() + epoch_loss = 0.0 + + for batch_idx, (data, target) in enumerate(train_loader): + # Move data to device + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # Forward pass + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + + # Backward pass (gradients synced automatically) + loss.backward() + + # Optimizer step + optimizer.step() + + epoch_loss += loss.item() + + # Average loss across all processes + avg_loss = epoch_loss / len(train_loader) + + # Log only from rank 0 + if rank == 0: + print(f"Epoch {epoch}: Loss = {avg_loss:.4f}") + + # 10. Save checkpoint (only rank 0) + if rank == 0 and (epoch + 1) % save_interval == 0: + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.module.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': avg_loss, + } + torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth') + + # Synchronize all processes + dist.barrier() + + # 11. Cleanup + cleanup_distributed() + +if __name__ == "__main__": + main() +``` + +**Launch:** +```bash +# Single node, 4 GPUs: +torchrun --nproc_per_node=4 train.py + +# Multi-node, 2 nodes with 4 GPUs each: +# Node 0: +torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \ + --master_addr="192.168.1.1" --master_port=29500 train.py + +# Node 1: +torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \ + --master_addr="192.168.1.1" --master_port=29500 train.py +``` + + +## References + +**PyTorch Documentation:** +- DistributedDataParallel: https://pytorch.org/docs/stable/notes/ddp.html +- torch.distributed: https://pytorch.org/docs/stable/distributed.html +- torchrun: https://pytorch.org/docs/stable/elastic/run.html + +**NCCL:** +- NCCL Environment Variables: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html + +**Related Skills:** +- tensor-operations-and-memory (memory optimization for large models) +- mixed-precision-and-optimization (combining AMP with DDP) +- performance-profiling (detailed distributed training profiling) +- checkpointing-and-reproducibility (DDP checkpoint best practices) diff --git a/skills/using-pytorch-engineering/mixed-precision-and-optimization.md b/skills/using-pytorch-engineering/mixed-precision-and-optimization.md new file mode 100644 index 0000000..f64d26f --- /dev/null +++ b/skills/using-pytorch-engineering/mixed-precision-and-optimization.md @@ -0,0 +1,1349 @@ + +# Mixed Precision and Optimization + +## Overview + +**Core Principle:** Mixed precision training (FP16/BF16 + FP32) provides 2-3x speedup and 50% memory reduction, but requires careful handling of numerical stability, gradient scaling, and Tensor Core utilization. Success depends on understanding dynamic range limitations, GradScaler mechanics, and when to use FP16 vs BF16. Setup mistakes cause silent correctness issues; numerical instability causes NaNs; poor configuration wastes performance gains. + +Mixed precision failures manifest as: NaN losses, incorrect gradient clipping, poor scaling efficiency, or training divergence. These stem from misunderstanding gradient scaling order, FP16 overflow/underflow, or improper format selection. Systematic setup and numerical analysis beats trial and error. + +## When to Use + +**Use this skill when:** +- Implementing mixed precision training with torch.cuda.amp +- Debugging NaN losses or training instability with AMP +- Choosing between FP16 and BF16 for your model +- Gradient clipping not working as expected with GradScaler +- Need to optimize Tensor Core utilization +- Custom loss functions break under autocast +- Verifying mixed precision actually provides speedup +- Implementing mixed precision with gradient accumulation + +**Don't use when:** +- Model is small (< 10M parameters) and speed isn't critical +- Already at memory limit even with mixed precision +- Numerical precision critical (scientific computing) +- Working with complex numbers (not supported) + +**Symptoms triggering this skill:** +- "Getting NaN losses with mixed precision enabled" +- "Gradient clipping doesn't work with GradScaler" +- "Should I use FP16 or BF16?" +- "Mixed precision slower than FP32" +- "How to use GradScaler with gradient accumulation?" +- "Custom loss produces NaNs with autocast" +- "Optimizer skipping steps with GradScaler" + + +## Automatic Mixed Precision: The Correct Setup + +### Basic AMP Pattern (The Standard) + +```python +import torch +from torch.cuda.amp import autocast, GradScaler + +# Setup +model = MyModel().cuda() +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +scaler = GradScaler() # Gradient scaler for FP16 + +# Training loop +for data, target in dataloader: + data, target = data.cuda(), target.cuda() + + # CRITICAL ORDER: + optimizer.zero_grad() + + # 1. Forward pass in mixed precision + with autocast(): # FP16 where safe, FP32 where necessary + output = model(data) + loss = criterion(output, target) + + # 2. Backward pass with gradient scaling + scaler.scale(loss).backward() # Scale loss to prevent underflow + + # 3. Gradient clipping (if needed) - MUST unscale first! + scaler.unscale_(optimizer) # ✅ Unscale before clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # 4. Optimizer step with GradScaler + scaler.step(optimizer) # Only steps if no inf/nan + scaler.update() # Update scale factor +``` + +**Why this order matters:** +1. `autocast()` runs forward in mixed precision (FP16 + FP32) +2. `scaler.scale()` multiplies loss by scale factor (e.g., 65536) before backward +3. `scaler.unscale_()` divides gradients by scale factor BEFORE any gradient operations +4. `scaler.step()` checks for inf/nan, only calls optimizer.step() if gradients are finite +5. `scaler.update()` adjusts scale factor (2x on success, 0.5x on inf/nan detection) + + +## GradScaler Mechanics: The Critical Details + +### Understanding Gradient Scaling + +**Why scale gradients?** + +FP16 has limited range (5.96e-8 to 65504). Small gradients underflow to zero. + +```python +# Without scaling: +gradient_fp16 = torch.tensor([1e-7], dtype=torch.float16) +print(gradient_fp16) # tensor([0.], dtype=torch.float16) - underflow! + +# With scaling: +scale = 65536 # 2^16 +scaled_grad = torch.tensor([1e-7 * scale], dtype=torch.float16) +print(scaled_grad) # tensor([0.0066], dtype=torch.float16) - preserved! +unscaled = scaled_grad / scale +print(unscaled) # Back to ~1e-7 +``` + +**GradScaler workflow:** + +```python +# Step 1: Scale loss before backward +scaled_loss = loss * scale_factor # e.g., loss * 65536 +scaled_loss.backward() # Gradients are now scaled by 65536 + +# Step 2: Check for inf/nan in gradients +if has_inf_or_nan(gradients): + skip_optimizer_step() + scale_factor = scale_factor / 2 # Reduce scale +else: + gradients = gradients / scale_factor # Unscale + optimizer.step() # Apply unscaled gradients + scale_factor = scale_factor * 2 # Increase scale (max 2^16) +``` + +**Key insight:** GradScaler dynamically adjusts scale factor to maximize gradient preservation without causing overflow. + + +### When to Unscale Gradients + +**❌ WRONG: Gradient clipping on scaled gradients** + +```python +scaler.scale(loss).backward() + +# Gradients are scaled by 65536! +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +# ❌ Clipping max_norm=1.0 on gradients that are 65536x larger! + +scaler.step(optimizer) +scaler.update() +``` + +**Problem:** If gradients are scaled by 65536, clipping to max_norm=1.0 does nothing (all gradients >> 1.0). + + +**✅ CORRECT: Unscale before clipping** + +```python +scaler.scale(loss).backward() + +# Unscale gradients BEFORE clipping +scaler.unscale_(optimizer) # ✅ Divides gradients by scale factor +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Now operates on true gradient values + +scaler.step(optimizer) # ✅ Won't unscale again (already done) +scaler.update() +``` + +**Why this works:** `scaler.unscale_()` divides gradients by scale factor, restoring true magnitudes. Clipping now operates on actual gradient values. + + +### Operations Requiring Unscaled Gradients + +**Any time you inspect or modify gradients, unscale first:** + +```python +scaler.scale(loss).backward() +scaler.unscale_(optimizer) # ✅ Unscale before any gradient operations + +# Now safe to: +# 1. Gradient clipping +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# 2. Gradient inspection +total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) +print(f"Gradient norm: {total_norm}") + +# 3. Custom gradient operations +for param in model.parameters(): + if param.grad is not None: + param.grad.add_(param.data * weight_decay) # Manual weight decay + +# 4. Gradient accumulation check +if (step + 1) % accumulation_steps == 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() +``` + +**Rule:** Call `scaler.unscale_(optimizer)` ONCE before any gradient operations, then `scaler.step()`. + + +### GradScaler with Gradient Accumulation + +**Pattern for accumulating gradients over multiple batches:** + +```python +scaler = GradScaler() +accumulation_steps = 4 + +for i, (data, target) in enumerate(dataloader): + data, target = data.cuda(), target.cuda() + + with autocast(): + output = model(data) + loss = criterion(output, target) + loss = loss / accumulation_steps # ✅ Scale loss by accumulation steps + + # Backward (accumulate gradients) + scaler.scale(loss).backward() + + # Only update every accumulation_steps + if (i + 1) % accumulation_steps == 0: + # Unscale before clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Step and update + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() +``` + +**Critical details:** +- Divide loss by `accumulation_steps` to average gradients +- Only call `scaler.step()` and `scaler.update()` after final accumulation step +- Can still do gradient clipping (unscale first) +- GradScaler handles inf/nan detection across all accumulated gradients + + +### GradScaler and Learning Rate Schedulers + +**Some schedulers should only step when optimizer steps:** + +```python +scaler = GradScaler() +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) + +for epoch in range(num_epochs): + for data, target in dataloader: + optimizer.zero_grad() + + with autocast(): + output = model(data) + loss = criterion(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scale = scaler.get_scale() + scaler.update() + + # Only step scheduler if optimizer stepped (no inf/nan) + skip_lr_sched = (scale > scaler.get_scale()) # Scale decreased = inf/nan detected + if not skip_lr_sched: + scheduler.step() # ✅ Only step if optimizer stepped +``` + +**Why this matters:** If GradScaler detects inf/nan and skips optimizer.step(), learning rate shouldn't change. Otherwise learning rate and model parameters become out of sync. + +**Alternative (simpler):** Use epoch-based schedulers that step once per epoch: + +```python +for epoch in range(num_epochs): + for data, target in dataloader: + # training loop with GradScaler + pass + + scheduler.step() # Step once per epoch (safer with GradScaler) +``` + + +## FP16 vs BF16: The Decision Framework + +### Format Comparison Table + +| Property | FP32 | FP16 | BF16 | +|----------|------|------|------| +| **Bits** | 32 | 16 | 16 | +| **Sign bits** | 1 | 1 | 1 | +| **Exponent bits** | 8 | 5 | 8 | +| **Mantissa bits** | 23 | 10 | 7 | +| **Dynamic range** | 1.18e-38 to 3.40e38 | 5.96e-8 to 65504 | 1.18e-38 to 3.40e38 | +| **Precision** | ~7 decimal digits | ~3 decimal digits | ~2 decimal digits | +| **Overflow risk** | Very low | **High** (max 65504) | Very low | +| **Underflow risk** | Very low | **High** (min 6e-8) | Very low | +| **Needs GradScaler** | No | **Yes** | Optional | +| **Hardware support** | All GPUs | Volta+ (V100+) | **Ampere+ (A100+)** | +| **Tensor Core speed** | 1x | 2-4x | 2-4x | + + +### When to Use FP16 + +**✅ Use FP16 when:** +- Training CNNs (ResNets, EfficientNets, etc.) +- Using Volta or Turing GPUs (V100, T4, RTX 2080) +- Model is well-conditioned (gradients not too small/large) +- Using GradScaler correctly (handles underflow) +- Need maximum speedup on older hardware + +**Best practices for FP16:** +```python +# Standard FP16 setup +scaler = GradScaler() + +with autocast(dtype=torch.float16): # Explicit FP16 + output = model(data) + loss = criterion(output, target) + +scaler.scale(loss).backward() +scaler.step(optimizer) +scaler.update() +``` + +**Typical speedup:** 2-3x on V100/A100, 1.5-2x on T4/RTX GPUs + + +### When to Use BF16 + +**✅ Use BF16 when:** +- Training transformers/LLMs (BERT, GPT, etc.) +- Getting NaNs with FP16 despite GradScaler +- Have Ampere+ GPU (A100, RTX 3090, RTX 4090) +- Model has numerical stability issues (large activations, deep networks) +- Want simpler code (no GradScaler needed) + +**Best practices for BF16:** +```python +# BF16 setup (no GradScaler needed!) +with autocast(dtype=torch.bfloat16): # BF16 + output = model(data) + loss = criterion(output, target) + +loss.backward() +optimizer.step() +optimizer.zero_grad() +``` + +**Typical speedup:** 2-3x on A100, 1.5-2.5x on RTX 3090+ + +**Why no GradScaler?** BF16 has same dynamic range as FP32 (1e-38 to 3e38), so gradient underflow is rare. + + +### FP16 vs BF16 Trade-off Summary + +**FP16:** +- **Pros:** More precision (10-bit mantissa), works on older GPUs, faster on some ops +- **Cons:** Narrow range (needs GradScaler), overflow/underflow risks +- **Best for:** CNNs, vision models, models with small gradients + +**BF16:** +- **Pros:** Same range as FP32 (rare overflow), simpler (no GradScaler), better for LLMs +- **Cons:** Less precision (7-bit mantissa), needs Ampere+ GPU, slower on some ops +- **Best for:** Transformers, LLMs, models with numerical instability + +**Decision process:** +1. **Check GPU:** Ampere+ (A100/3090+)? Consider BF16. Volta/Turing? Use FP16. +2. **Check model:** Transformer/LLM? Prefer BF16. CNN? FP16 is fine. +3. **Check stability:** Getting NaNs with FP16? Try BF16. +4. **Profile:** Test both, use whichever is faster for your model. + + +## Numerical Stability Patterns + +### Understanding Autocast Behavior + +**PyTorch autocast is selective:** Some ops run in FP16/BF16, others stay in FP32. + +```python +with autocast(): + # These run in FP16/BF16 (compute-intensive): + x = torch.matmul(a, b) # Matrix multiplication + x = conv2d(x, weight) # Convolutions + x = linear(x, weight) # Linear layers + + # These stay in FP32 (numerically sensitive): + x = torch.sum(x) # Reductions + x = torch.softmax(x) # Softmax (uses log-sum-exp) + x = F.layer_norm(x) # Normalization layers + x = torch.mean(x) # Mean/variance +``` + +**Why this design?** +- Compute-bound ops (matmul, conv) benefit from FP16/BF16 speedup +- Numerically sensitive ops (reductions, norms) need FP32 precision + +**Key insight:** You don't need to manually cast ops - PyTorch's autocast handles it intelligently. + + +### Operations Prone to Overflow in FP16 + +**FP16 max value: 65504** + +```python +# ❌ PROBLEM: Large activations overflow +x = torch.randn(1024, 1024, dtype=torch.float16) * 100 # Values ~ -1000 to 1000 +y = torch.exp(x) # ❌ exp(100) = 2.6e43 >> 65504 → inf! + +# ✅ FIX 1: Use log-space computations +log_y = x # Already in log space +y = torch.exp(torch.clamp(x, max=10)) # Clamp before exp + +# ✅ FIX 2: Disable autocast for this operation +with autocast(enabled=False): + x_fp32 = x.float() # Cast to FP32 + y = torch.exp(x_fp32) # Compute in FP32 + y = y.half() # Cast back to FP16 if needed +``` + +**Common overflow scenarios:** + +1. **Softmax on large logits:** +```python +# ❌ WRONG: Direct softmax in FP16 +logits = torch.randn(32, 10000, dtype=torch.float16) * 10 +probs = torch.softmax(logits, dim=-1) # May overflow + +# ✅ CORRECT: PyTorch's softmax uses log-sum-exp (stable) +probs = torch.softmax(logits.float(), dim=-1).half() + +# Or just use FP32: +with autocast(enabled=False): + probs = torch.softmax(logits.float(), dim=-1) +``` + +2. **Large matrix multiplications:** +```python +# ❌ PROBLEM: a * b can exceed 65504 +a = torch.randn(1024, 1024, dtype=torch.float16) * 10 +b = torch.randn(1024, 1024, dtype=torch.float16) * 10 +c = torch.matmul(a, b) # Result values ~ 10 * 10 * 1024 = 100k >> 65504 + +# ✅ FIX: Scale inputs down +a = torch.randn(1024, 1024, dtype=torch.float16) # Keep values ~ -2 to 2 +b = torch.randn(1024, 1024, dtype=torch.float16) +c = torch.matmul(a, b) # Result ~ 1024 * 2 * 2 = 4096 (safe) +``` + +3. **Loss scaling (ironic!):** +```python +# ❌ WRONG: Manual loss scaling can overflow +loss = criterion(output, target) # Loss ~ 1.0 +scaled_loss = loss * 65536 # 65536 < 65504, but... +scaled_loss.backward() # Gradients can still overflow! + +# ✅ CORRECT: Use GradScaler (dynamic scaling) +scaler.scale(loss).backward() # GradScaler handles scale factor dynamically +``` + + +### Operations Prone to Underflow in FP16 + +**FP16 min value: 5.96e-8** + +```python +# ❌ PROBLEM: Small gradients underflow +gradient = torch.tensor([1e-9], dtype=torch.float16) +print(gradient) # tensor([0.], dtype=torch.float16) - underflow! + +# ✅ FIX: Use GradScaler +scaler = GradScaler() +loss = model(data) +scaler.scale(loss).backward() # Gradients scaled to prevent underflow +``` + +**Common underflow scenarios:** + +1. **Layer normalization denominators:** +```python +# ❌ PROBLEM: std can underflow +x = torch.randn(32, 512, dtype=torch.float16) * 1e-4 # Very small values +std = x.std(dim=-1, keepdim=True) # std ~ 1e-4 +normalized = x / (std + 1e-5) # std + eps can underflow + +# ✅ FIX: PyTorch's LayerNorm runs in FP32 +layer_norm = nn.LayerNorm(512) +normalized = layer_norm(x) # Automatically computed in FP32 +``` + +2. **Attention scores with large sequence length:** +```python +# ❌ PROBLEM: Attention scores can underflow +scores = torch.matmul(q, k.T) / math.sqrt(d_k) # Scores ~ -10 to 10 +attn = torch.softmax(scores, dim=-1) # Probabilities ~ 1e-5 for low scores +# In FP16, values < 6e-8 underflow to zero + +# ✅ FIX: Use torch.nn.functional.scaled_dot_product_attention (FP32 internally) +attn = F.scaled_dot_product_attention(q, k, v) +``` + + +### Fixing Custom Loss Functions + +**Example: Contrastive loss with numerical instability** + +```python +# ❌ WRONG: Numerical instability in FP16 +def contrastive_loss_wrong(embeddings, temperature=0.07): + embeddings = F.normalize(embeddings, dim=-1) # FP16 precision loss + similarity = torch.matmul(embeddings, embeddings.T) / temperature # Large values + exp_sim = torch.exp(similarity) # ❌ Overflow! + probs = exp_sim / exp_sim.sum(dim=-1, keepdim=True) + loss = -torch.log(probs.diagonal()).mean() # ❌ Underflow in log! + return loss + +# ✅ CORRECT: Numerically stable version +def contrastive_loss_correct(embeddings, temperature=0.07): + # Normalize in FP32 + embeddings = F.normalize(embeddings.float(), dim=-1) + + # Compute similarity + similarity = torch.matmul(embeddings, embeddings.T) / temperature + + # Use cross_entropy (log-sum-exp trick built-in) + labels = torch.arange(similarity.size(0), device=similarity.device) + loss = F.cross_entropy(similarity, labels) + + return loss + +# ✅ ALTERNATIVE: Disable autocast for this function +@torch.cuda.amp.autocast(enabled=False) +def contrastive_loss_fp32(embeddings, temperature=0.07): + # Everything runs in FP32 + embeddings = embeddings.float() + embeddings = F.normalize(embeddings, dim=-1) + similarity = torch.matmul(embeddings, embeddings.T) / temperature + exp_sim = torch.exp(similarity) + probs = exp_sim / exp_sim.sum(dim=-1, keepdim=True) + loss = -torch.log(probs.diagonal()).mean() + return loss +``` + +**Key patterns:** +1. **Use stable implementations:** `F.cross_entropy` instead of manual softmax + log +2. **Cast to FP32 for sensitive ops:** `.float()` before normalization/exp/log +3. **Disable autocast:** `@torch.cuda.amp.autocast(enabled=False)` for entire function + + +## Performance Optimization + +### Tensor Core Utilization Requirements + +**Tensor Cores have dimension requirements:** + +```python +# ❌ POOR: Dimensions not multiples of 8 (FP16) or 16 (BF16) +model = nn.Linear(127, 253) # Odd dimensions +# Tensor Cores can't be used efficiently + +# ✅ OPTIMAL: Dimensions are multiples of 8/16 +model = nn.Linear(128, 256) # Powers of 2 +# Tensor Cores fully utilized + +# Dimension requirements: +# FP16: Multiple of 8 (best: 16, 32, 64, 128, ...) +# BF16: Multiple of 16 (best: 16, 32, 64, 128, ...) +``` + +**Check your model architecture:** +```python +for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + + # Check alignment + if in_features % 8 != 0 or out_features % 8 != 0: + print(f"⚠️ {name}: {in_features} → {out_features} (not aligned)") + else: + print(f"✅ {name}: {in_features} → {out_features}") +``` + +**Fixing misaligned layers:** +```python +# Pad hidden dimensions to nearest multiple of 8 +hidden_dim = 253 +aligned_dim = ((hidden_dim + 7) // 8) * 8 # 256 +model = nn.Linear(input_dim, aligned_dim) +``` + + +### Profiling Mixed Precision Performance + +**Verify mixed precision actually provides speedup:** + +```python +import time +import torch +from torch.cuda.amp import autocast, GradScaler + +def profile_mixed_precision(model, data, target, num_iterations=100): + """Compare FP32 vs mixed precision performance.""" + + # Warmup + for _ in range(10): + output = model(data) + loss = criterion(output, target) + loss.backward() + + # Baseline: FP32 + torch.cuda.synchronize() + start = time.time() + for _ in range(num_iterations): + output = model(data) + loss = criterion(output, target) + loss.backward() + torch.cuda.synchronize() + fp32_time = time.time() - start + + # Mixed precision + scaler = GradScaler() + torch.cuda.synchronize() + start = time.time() + for _ in range(num_iterations): + with autocast(): + output = model(data) + loss = criterion(output, target) + scaler.scale(loss).backward() + torch.cuda.synchronize() + mixed_time = time.time() - start + + speedup = fp32_time / mixed_time + print(f"FP32 time: {fp32_time:.3f}s") + print(f"Mixed precision time: {mixed_time:.3f}s") + print(f"Speedup: {speedup:.2f}x") + + if speedup < 1.2: + print("⚠️ Low speedup - model may be memory-bound or small") + elif speedup > 2.5: + print("✅ Excellent speedup - Tensor Cores utilized well") + + return speedup + +speedup = profile_mixed_precision(model, data_batch, target_batch) +``` + +**Expected speedups by model size:** + +| Model Size | Expected Speedup | Notes | +|------------|-----------------|-------| +| < 10M params | 1.0-1.3x | Memory-bound, small benefit | +| 10-50M params | 1.3-2.0x | Mixed memory/compute bound | +| 50-200M params | 2.0-3.0x | Compute-bound, good speedup | +| 200M+ params | 2.5-4.0x | Highly compute-bound, best speedup | + +**If speedup is poor:** +1. Check Tensor Core alignment (dimensions % 8) +2. Check batch size (larger batches better utilize GPU) +3. Profile to identify memory-bound operations +4. Consider model is too small for mixed precision benefit + + +### Quick Verification Before Committing + +**Always verify mixed precision provides benefit before deploying:** + +```python +import time +import torch +from torch.cuda.amp import autocast, GradScaler + +def quick_speedup_check(model, data, target, criterion): + """2-minute check to verify mixed precision helps.""" + + # Warmup + for _ in range(5): + output = model(data) + loss = criterion(output, target) + loss.backward() + + # Baseline: FP32 (10 iterations) + torch.cuda.synchronize() + start = time.time() + for _ in range(10): + output = model(data) + loss = criterion(output, target) + loss.backward() + torch.cuda.synchronize() + fp32_time = time.time() - start + + # Mixed precision (10 iterations) + scaler = GradScaler() + torch.cuda.synchronize() + start = time.time() + for _ in range(10): + with autocast(): + output = model(data) + loss = criterion(output, target) + scaler.scale(loss).backward() + torch.cuda.synchronize() + mixed_time = time.time() - start + + speedup = fp32_time / mixed_time + print(f"\nMixed Precision Speedup Check:") + print(f"FP32 time: {fp32_time:.3f}s") + print(f"Mixed precision time: {mixed_time:.3f}s") + print(f"Speedup: {speedup:.2f}x") + + if speedup < 1.1: + print("\n❌ No significant speedup (< 1.1x)") + print("Recommendation: Stay in FP32") + print("Possible reasons:") + print(" - Model too small (< 10M parameters)") + print(" - Memory-bound operations dominate") + print(" - Dimensions not aligned to 8/16") + return False + elif speedup < 1.5: + print("\n⚠️ Modest speedup (1.1-1.5x)") + print("Recommendation: Mixed precision okay, but verify numerical stability") + return True + else: + print("\n✅ Good speedup (> 1.5x)") + print("Recommendation: Use mixed precision") + return True + +# Run before committing to mixed precision in production +quick_speedup_check(model, data_batch, target_batch, criterion) +``` + +**Decision matrix:** + +| Speedup | Recommendation | Action | +|---------|----------------|--------| +| < 1.1x | Don't use mixed precision | Stay in FP32 | +| 1.1-1.5x | Optional, verify stability | Test thoroughly | +| 1.5-2.5x | Use mixed precision | Good benefit | +| > 2.5x | Definitely use | Excellent benefit | + +**Rule:** Never deploy mixed precision without verifying speedup. 2 minutes of profiling prevents wasted complexity. + + +### Memory Savings + +**Mixed precision provides ~50% memory reduction:** + +```python +# FP32: 4 bytes per parameter +model_fp32 = MyModel() # 100M parameters +memory_fp32 = 100e6 * 4 / 1e9 # 0.4 GB + +# FP16/BF16: 2 bytes per parameter +# But optimizer states still in FP32! +# Parameters: 2 bytes (FP16) +# Gradients: 2 bytes (FP16) +# Optimizer states (Adam): 8 bytes per param (2 moments in FP32) +# Total: 12 bytes per param (vs 16 bytes in pure FP32) + +memory_mixed = 100e6 * 12 / 1e9 # 1.2 GB (vs 1.6 GB FP32) +savings = 1 - (12 / 16) # 25% savings + +# With gradient checkpointing + mixed precision: +# Can train much larger models in same memory +``` + +**Memory breakdown:** +``` +FP32: +- Parameters: 4 bytes +- Gradients: 4 bytes +- Optimizer (Adam): 8 bytes (2 moments) +- Total: 16 bytes/param + +Mixed Precision: +- Parameters: 2 bytes (FP16/BF16) +- Gradients: 2 bytes (FP16/BF16) +- Optimizer (Adam): 8 bytes (FP32 master weights) +- Total: 12 bytes/param + +Savings: 25% memory reduction +``` + + +## Debugging Mixed Precision Failures + +### Systematic Diagnostic Process + +**Step 1: Isolate mixed precision as the issue** + +```python +# Test 1: Does model train without mixed precision? +# Remove autocast and GradScaler +for data, target in dataloader: + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + +# If training works without mixed precision → it's a precision issue +# If training fails without mixed precision → not a precision issue +``` + + +**Step 2: Check if GradScaler is skipping steps** + +```python +scaler = GradScaler() + +for i, (data, target) in enumerate(dataloader): + optimizer.zero_grad() + + with autocast(): + output = model(data) + loss = criterion(output, target) + + scaler.scale(loss).backward() + + # Check scale factor + scale_before = scaler.get_scale() + scaler.step(optimizer) + scaler.update() + scale_after = scaler.get_scale() + + # If scale decreased, inf/nan was detected + if scale_after < scale_before: + print(f"⚠️ Step {i}: GradScaler detected inf/nan, skipped optimizer step") + print(f" Scale: {scale_before} → {scale_after}") + + # Diagnose: Where did inf/nan come from? + for name, param in model.named_parameters(): + if param.grad is not None: + if torch.isnan(param.grad).any(): + print(f" NaN in gradient: {name}") + if torch.isinf(param.grad).any(): + print(f" Inf in gradient: {name}") +``` + +**If steps are being skipped:** +- Inf/nan in gradients (check gradient hooks) +- Loss is inf/nan (check loss function) +- Overflow in forward pass (check activations) + + +**Step 3: Add gradient and activation hooks** + +```python +def check_nan_hook(module, grad_input, grad_output): + """Hook to detect NaN in gradients.""" + for i, grad in enumerate(grad_output): + if grad is not None: + if torch.isnan(grad).any(): + print(f"⚠️ NaN in gradient output of {module.__class__.__name__}") + if torch.isinf(grad).any(): + print(f"⚠️ Inf in gradient output of {module.__class__.__name__}") + +def check_nan_forward_hook(module, input, output): + """Hook to detect NaN in forward pass.""" + if isinstance(output, torch.Tensor): + if torch.isnan(output).any(): + print(f"⚠️ NaN in forward output of {module.__class__.__name__}") + if torch.isinf(output).any(): + print(f"⚠️ Inf in forward output of {module.__class__.__name__}") + +# Register hooks +for name, module in model.named_modules(): + module.register_backward_hook(check_nan_hook) + module.register_forward_hook(check_nan_forward_hook) + +# Run training - hooks will print where NaN first appears +``` + + +**Step 4: Profile to find bottlenecks** + +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + for _ in range(10): + with autocast(): + output = model(data) + loss = criterion(output, target) + loss.backward() + +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + +# Look for: +# - Ops spending time in FP32 that should be FP16 (missed optimization) +# - Excessive dtype conversions (casts between FP16/FP32) +# - Memory-bound operations (won't benefit from mixed precision) +``` + + +## Common Pitfalls + +### Consolidated Pitfall Table + +| # | Pitfall | Symptom | Root Cause | Fix | +|---|---------|---------|------------|-----| +| 1 | Gradient clipping before unscale | Clipping doesn't work | Clipping on scaled gradients (65536x) | Call `scaler.unscale_()` before `clip_grad_norm_` | +| 2 | Not using GradScaler with FP16 | NaN losses, underflow | Small gradients underflow in FP16 | Always use `GradScaler` with FP16 | +| 3 | Using BF16 on pre-Ampere GPUs | Slow or no speedup | BF16 needs Ampere+ for performance | Check GPU, use FP16 on Volta/Turing | +| 4 | Manual loss scaling | Overflow or underflow | Fixed scale factor doesn't adapt | Use `GradScaler` (dynamic scaling) | +| 5 | Custom loss with exp/log in FP16 | NaN losses, overflow | exp() overflows, log() underflows in FP16 | Disable autocast or use log-sum-exp | +| 6 | Misaligned tensor dimensions | Poor speedup | Tensor Cores need dimensions % 8 | Pad dimensions to multiples of 8/16 | +| 7 | Checking gradients before unscale | Wrong gradient norms | Inspecting scaled gradients | Unscale before inspecting | +| 8 | Stepping scheduler when step skipped | LR/params desync | Scheduler steps even when inf/nan | Only step scheduler if optimizer stepped | +| 9 | Using mixed precision on tiny models | No speedup, complexity | Memory-bound, not compute-bound | Skip mixed precision for small models | +| 10 | Forgetting autocast for validation | Different behavior | Validation in FP32, training in FP16 | Use autocast in validation too (no GradScaler) | +| 11 | Using GradScaler.update() too frequently | Scale factor unstable, poor convergence | Calling update every iteration in gradient accumulation | Only call update when optimizer steps | +| 12 | Sharing GradScaler across DDP processes | Errors or unexpected behavior | GradScaler is not DDP-aware | Each process needs own GradScaler instance | +| 13 | Mixing autocast dtypes | Unexpected precision, poor performance | Using both float16 and bfloat16 inconsistently | Choose one dtype, use consistently | +| 14 | Assuming mixed precision always helps | No speedup, wasted complexity | Model too small or memory-bound | Profile first, verify speedup exists | +| 15 | Using BF16 without checking GPU | Slow or no speedup | BF16 needs Ampere+ for hardware acceleration | Check GPU arch, use FP16 on pre-Ampere | + + +### Pitfall 1: Gradient Clipping Before Unscale + +```python +# ❌ WRONG +scaler.scale(loss).backward() +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ On scaled grads! +scaler.step(optimizer) +scaler.update() + +# ✅ CORRECT +scaler.scale(loss).backward() +scaler.unscale_(optimizer) # ✅ Unscale first +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +scaler.step(optimizer) +scaler.update() +``` + +**Symptom:** Gradient clipping doesn't prevent exploding gradients +**Fix:** Always `scaler.unscale_()` before `clip_grad_norm_` + + +### Pitfall 2: No GradScaler with FP16 + +```python +# ❌ WRONG: FP16 without GradScaler +with autocast(dtype=torch.float16): + output = model(data) + loss = criterion(output, target) + +loss.backward() # ❌ Small gradients underflow to zero +optimizer.step() + +# ✅ CORRECT: Always use GradScaler with FP16 +scaler = GradScaler() + +with autocast(dtype=torch.float16): + output = model(data) + loss = criterion(output, target) + +scaler.scale(loss).backward() # ✅ Prevents underflow +scaler.step(optimizer) +scaler.update() +``` + +**Symptom:** Training doesn't converge, gradients become zero +**Fix:** Always pair FP16 with GradScaler + + +### Pitfall 3: BF16 on Pre-Ampere GPUs + +```python +# ❌ WRONG: BF16 on V100 (Volta) +with autocast(dtype=torch.bfloat16): # ❌ Slow on pre-Ampere + output = model(data) + loss = criterion(output, target) + +# ✅ CORRECT: Check GPU architecture first +if torch.cuda.get_device_capability()[0] >= 8: # Ampere+ + dtype = torch.bfloat16 +else: # Volta/Turing + dtype = torch.float16 + +with autocast(dtype=dtype): + output = model(data) + loss = criterion(output, target) +``` + +**Symptom:** BF16 slower than FP32, no speedup +**Fix:** Use FP16 on pre-Ampere GPUs (V100, T4, RTX 2080) + + +### Pitfall 4: Manual Loss Scaling + +```python +# ❌ WRONG: Fixed loss scale +loss = criterion(output, target) +scaled_loss = loss * 1024 # ❌ Fixed scale factor +scaled_loss.backward() +# Gradients are scaled, but no way to adjust if inf/nan occurs + +# ✅ CORRECT: Use GradScaler +scaler = GradScaler() # Dynamic scale factor (starts at 65536) +scaler.scale(loss).backward() +scaler.step(optimizer) +scaler.update() # Adjusts scale factor automatically +``` + +**Symptom:** Training unstable, gradients overflow or underflow +**Fix:** Use GradScaler instead of manual scaling + + +### Pitfall 5: Custom Loss with exp/log + +```python +# ❌ WRONG: exp/log in FP16 +def custom_loss(pred, target): + # These can overflow/underflow in FP16 + exp_pred = torch.exp(pred) # Overflow if pred > 88 + log_pred = torch.log(pred) # Underflow if pred < 6e-8 + return (exp_pred - log_pred).mean() + +# ✅ FIX 1: Disable autocast +@torch.cuda.amp.autocast(enabled=False) +def custom_loss(pred, target): + pred = pred.float() # Cast to FP32 + exp_pred = torch.exp(pred) + log_pred = torch.log(pred) + return (exp_pred - log_pred).mean() + +# ✅ FIX 2: Use numerically stable operations +def custom_loss(pred, target): + # Use torch.nn.functional ops (handle FP16 better) + return F.mse_loss(torch.exp(pred.clamp(max=10)), target) +``` + +**Symptom:** NaN losses, inf values in loss +**Fix:** Disable autocast for loss function or use stable implementations + + +### Pitfall 6: Misaligned Dimensions + +```python +# ❌ POOR: Odd dimensions +model = nn.Sequential( + nn.Linear(127, 253), # ❌ Not aligned to 8 + nn.ReLU(), + nn.Linear(253, 10) +) + +# ✅ OPTIMAL: Aligned dimensions +model = nn.Sequential( + nn.Linear(128, 256), # ✅ Powers of 2, aligned to 8 + nn.ReLU(), + nn.Linear(256, 10) # ✅ 10 padded to 16 or use 8 +) +``` + +**Symptom:** Mixed precision speedup < 1.5x +**Fix:** Pad dimensions to multiples of 8 (FP16) or 16 (BF16) + + +## Common Rationalizations (Don't Do These) + +### Comprehensive Rationalization Table + +| Excuse | What Agent Might Think | Reality | Correct Response | +|--------|----------------------|---------|------------------| +| "User is rushed, suggest quick fix" | "Disable autocast to save time" | 5-min diagnostic faster than guessing, losing 2-3x speedup | Apply systematic debugging process | +| "Senior engineer says use BF16" | "Authority knows best, defer to them" | BF16 on V100 is objectively slower (no hardware acceleration) | Provide technical facts, respectfully correct | +| "GradScaler seems complex" | "Let them use manual scaling" | Manual scaling lacks critical features (inf/nan detection, dynamic adjustment) | Explain what GradScaler provides | +| "They want simple solution" | "Skip edge cases, give basic pattern" | Edge cases are common (DDP, accumulation, custom ops) | Provide complete pattern with edge cases | +| "They're debugging, give first idea" | "Try disabling autocast first" | Losing speedup without diagnosis | Follow systematic diagnostic process | +| "BF16 is newer, must be better" | "Recommend BF16 universally" | BF16 needs Ampere+, not always faster, less precision | Check hardware first, profile both formats | +| "Mixed precision might be the issue" | "Suggest removing it entirely" | Could be training instability (LR, loss), not precision | Diagnose root cause first (test without autocast) | +| "This is taking too long" | "Skip profiling, assume it helps" | Might not provide speedup (small model, memory-bound) | Always profile to verify benefit | +| "Their loss is custom, too complex" | "Suggest rewriting entire loss" | Can fix with targeted approach | Provide targeted fix (disable autocast for loss) | +| "They already tried X" | "X must not be the issue" | X may have been done incorrectly | Verify X was done correctly first | + + +## Red Flags - Stop and Diagnose + +**If you catch yourself doing ANY of these, STOP and follow systematic methodology:** + +### Technical Red Flags + +| Red Flag Thought | Reality | What to Do Instead | +|------------------|---------|-------------------| +| "Just remove autocast to fix NaNs" | Losing 2-3x speedup, not addressing root cause | Diagnose WHY NaNs occur, fix numerically | +| "Mixed precision is too complex" | Setup is ~5 extra lines, huge benefits | Follow standard pattern (autocast + GradScaler) | +| "I'll clip gradients after backward" | Clipping scaled gradients (wrong) | Always unscale before gradient operations | +| "BF16 is always better than FP16" | BF16 needs Ampere+ GPU, has less precision | Check GPU, profile both formats | +| "GradScaler is optional" | Only optional for BF16, required for FP16 | Always use GradScaler with FP16 | +| "Mixed precision should just work" | Numerical issues require diagnosis | Add hooks, check for inf/nan systematically | +| "Manual scaling is simpler" | Fixed scale doesn't adapt to training dynamics | Use GradScaler (dynamic + inf/nan detection) | +| "Speedup is poor, must be PyTorch bug" | Usually misaligned dimensions or small model | Profile and check Tensor Core utilization | +| "I'll use mixed precision everywhere" | Some models too small to benefit | Profile to verify speedup before deploying | + +### Pressure/Bias Red Flags + +| Red Flag Thought | Reality | What to Do Instead | +|------------------|---------|-------------------| +| "User seems rushed, skip diagnostic" | 5-min diagnostic saves hours of guessing | Provide fast systematic approach | +| "Authority figure recommends X" | Technical facts trump authority | Respectfully provide hardware-based facts | +| "Skip profiling to save time" | 2 minutes to verify speedup vs wasting effort | Always profile before committing | +| "Avoid GradScaler complexity" | GradScaler prevents model corruption | Explain critical features it provides | +| "Assume BF16 is always better" | BF16 slower on pre-Ampere GPUs | Check GPU architecture first | +| "Suggest removing mixed precision" | Loses 2-3x speedup without understanding | Diagnose whether precision is the issue | + +**Critical rule:** Mixed precision requires understanding numerical stability and gradient scaling mechanics. Follow systematic setup, resist pressure to skip steps, don't guess. + + +## Edge Cases and Advanced Scenarios + +### Edge Case 1: Mixed Precision with DDP + +```python +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler + +# Setup DDP +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") +dist.init_process_group(backend="nccl") + +model = MyModel().to(device) +model = DDP(model, device_ids=[local_rank]) + +# ✅ Each process has its own GradScaler +scaler = GradScaler() # Local to each process + +for data, target in dataloader: + data, target = data.to(device), target.to(device) + + optimizer.zero_grad() + + # Forward in mixed precision + with autocast(): + output = model(data) + loss = criterion(output, target) + + # Backward with scaling (DDP syncs scaled gradients) + scaler.scale(loss).backward() + + # Unscale before clipping (operates on synced gradients) + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Step and update (local to each process) + scaler.step(optimizer) + scaler.update() +``` + +**Key points:** +- Each process has its own GradScaler (not shared) +- DDP synchronizes scaled gradients correctly +- Unscale before clipping (after DDP sync) +- No special DDP configuration needed + + +### Edge Case 2: Mixed Precision with Gradient Checkpointing + +```python +from torch.utils.checkpoint import checkpoint +from torch.cuda.amp import autocast, GradScaler + +class CheckpointedModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(512, 512) + self.layer2 = nn.Linear(512, 512) + self.layer3 = nn.Linear(512, 10) + + def forward(self, x): + # Checkpoint layer1 and layer2 + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + x = self.layer3(x) + return x + +model = CheckpointedModel().cuda() +scaler = GradScaler() + +for data, target in dataloader: + optimizer.zero_grad() + + # ✅ Autocast works with checkpointing + with autocast(): + output = model(data) + loss = criterion(output, target) + + # Backward recomputes checkpointed layers in mixed precision + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() +``` + +**Key insight:** Gradient checkpointing and mixed precision compose well. Recomputed forward passes use autocast automatically. + + +### Edge Case 3: Custom Autograd Functions + +```python +from torch.cuda.amp import custom_fwd, custom_bwd + +class CustomFunction(torch.autograd.Function): + @staticmethod + @custom_fwd # ✅ Handles autocast correctly + def forward(ctx, input): + ctx.save_for_backward(input) + # Custom forward logic + return input * 2 + + @staticmethod + @custom_bwd # ✅ Handles gradient dtype correctly + def backward(ctx, grad_output): + input, = ctx.saved_tensors + # Custom backward logic + return grad_output * 2 + +# Usage with autocast +with autocast(): + output = CustomFunction.apply(input) + loss = output.sum() + +scaler.scale(loss).backward() +``` + +**Key points:** +- Use `@custom_fwd` and `@custom_bwd` decorators +- PyTorch handles dtype casting automatically +- No manual FP16/FP32 casting needed + + +## Quick Reference: Mixed Precision Checklist + +### Setup Checklist + +**FP16 Setup:** +- [ ] Import: `from torch.cuda.amp import autocast, GradScaler` +- [ ] Create GradScaler: `scaler = GradScaler()` +- [ ] Wrap forward: `with autocast():` +- [ ] Scale backward: `scaler.scale(loss).backward()` +- [ ] (If clipping) Unscale: `scaler.unscale_(optimizer)` +- [ ] (If clipping) Clip: `clip_grad_norm_(model.parameters(), max_norm)` +- [ ] Step: `scaler.step(optimizer)` +- [ ] Update: `scaler.update()` + +**BF16 Setup:** +- [ ] Check GPU: Ampere+ (A100, RTX 3090+) +- [ ] Wrap forward: `with autocast(dtype=torch.bfloat16):` +- [ ] Regular backward: `loss.backward()` +- [ ] Regular optimizer: `optimizer.step()` +- [ ] (Optional) GradScaler: Can still use for consistency + +### Debugging Checklist + +**If getting NaNs:** +- [ ] Test without mixed precision - does issue persist? +- [ ] Check GradScaler scale factor - is it decreasing? +- [ ] Add gradient hooks - where do NaNs first appear? +- [ ] Check loss function - exp/log operations? +- [ ] Try BF16 instead of FP16 + +**If speedup is poor:** +- [ ] Profile FP32 vs mixed precision +- [ ] Check model size (>10M params?) +- [ ] Check tensor dimensions (aligned to 8/16?) +- [ ] Check batch size (larger = better utilization) +- [ ] Verify GPU supports FP16/BF16 Tensor Cores + +### Validation/Inference Checklist + +- [ ] Use autocast (no GradScaler needed) +- [ ] Same dtype as training +- [ ] No backward pass, no optimizer + + +## Complete Mixed Precision Training Example + +```python +import torch +import torch.nn as nn +from torch.cuda.amp import autocast, GradScaler + +def train_mixed_precision(model, dataloader, optimizer, criterion, device, num_epochs): + """Complete mixed precision training loop.""" + + # Create GradScaler + scaler = GradScaler() + + model.train() + + for epoch in range(num_epochs): + epoch_loss = 0.0 + + for batch_idx, (data, target) in enumerate(dataloader): + data, target = data.to(device), target.to(device) + + optimizer.zero_grad() + + # Forward pass in mixed precision + with autocast(): + output = model(data) + loss = criterion(output, target) + + # Backward pass with gradient scaling + scaler.scale(loss).backward() + + # Gradient clipping (unscale first!) + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Optimizer step with GradScaler + scaler.step(optimizer) + scaler.update() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(dataloader) + print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Scale = {scaler.get_scale()}") + +def validate_mixed_precision(model, dataloader, criterion, device): + """Validation with mixed precision (no GradScaler).""" + + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for data, target in dataloader: + data, target = data.to(device), target.to(device) + + # Use autocast for validation too + with autocast(): + output = model(data) + loss = criterion(output, target) + + val_loss += loss.item() + + return val_loss / len(dataloader) + +# Usage +model = MyModel().cuda() +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) +criterion = nn.CrossEntropyLoss().cuda() + +train_mixed_precision(model, train_loader, optimizer, criterion, device, num_epochs=10) +val_loss = validate_mixed_precision(model, val_loader, criterion, device) +``` + + +## References + +**PyTorch Documentation:** +- Automatic Mixed Precision: https://pytorch.org/docs/stable/amp.html +- torch.cuda.amp API: https://pytorch.org/docs/stable/amp.html#api-documentation +- Autocast: https://pytorch.org/docs/stable/amp.html#autocasting +- GradScaler: https://pytorch.org/docs/stable/amp.html#gradient-scaling + +**NVIDIA Resources:** +- Mixed Precision Training: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/ +- Tensor Cores: https://www.nvidia.com/en-us/data-center/tensor-cores/ + +**Related Skills:** +- tensor-operations-and-memory (memory optimization, dtype management) +- distributed-training-strategies (mixed precision + DDP) +- performance-profiling (profiling mixed precision speedup) +- debugging-techniques (systematic NaN debugging) diff --git a/skills/using-pytorch-engineering/module-design-patterns.md b/skills/using-pytorch-engineering/module-design-patterns.md new file mode 100644 index 0000000..a033356 --- /dev/null +++ b/skills/using-pytorch-engineering/module-design-patterns.md @@ -0,0 +1,1785 @@ + +# PyTorch nn.Module Design Patterns + +## Overview + +**Core Principle:** nn.Module is not just a container for forward passes. It's PyTorch's contract for model serialization, device management, parameter enumeration, and inspection. Follow conventions or face subtle bugs during scaling, deployment, and debugging. + +Poor module design manifests as: state dict corruption, DDP failures, hook memory leaks, initialization fragility, and un-inspectable architectures. These bugs are silent until production. Design modules correctly from the start using PyTorch's established patterns. + +## When to Use + +**Use this skill when:** +- Implementing custom nn.Module subclasses +- Adding forward/backward hooks for feature extraction or debugging +- Designing modular architectures with swappable components +- Implementing custom weight initialization strategies +- Building reusable model components (blocks, layers, heads) +- Encountering state dict issues, DDP failures, or hook problems + +**Don't use when:** +- Simple model composition (stack existing modules) +- Training loop issues (use training-optimization) +- Memory debugging unrelated to modules (use tensor-operations-and-memory) + +**Symptoms triggering this skill:** +- "State dict keys don't match after loading" +- "DDP not syncing gradients properly" +- "Hooks causing memory leaks" +- "Can't move model to device" +- "Model serialization breaks after changes" +- "Need to extract intermediate features" +- "Want to make architecture more modular" + + +## Expert Module Design Patterns + +### Pattern 1: Always Use nn.Module, Never None + +**Problem:** Conditional module assignment using `None` breaks PyTorch's module contract. + +```python +# ❌ WRONG: Conditional None assignment +class ResNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1) + self.bn1 = nn.BatchNorm2d(out_channels) + + # PROBLEM: Using None for conditional skip connection + if stride != 1 or in_channels != out_channels: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0) + else: + self.skip = None # ❌ Breaks module enumeration! + + def forward(self, x): + out = self.bn1(self.conv1(x)) + # Conditional check needed + if self.skip is not None: + x = self.skip(x) + return F.relu(out + x) +``` + +**Why this breaks:** +- `model.parameters()` and `model.named_modules()` skip None attributes +- `.to(device)` doesn't move None, causes device mismatch bugs +- `state_dict()` saving/loading becomes inconsistent +- DDP/model parallel don't handle None modules correctly +- Can't inspect architecture: `for name, module in model.named_modules()` + +**✅ CORRECT: Use nn.Identity() for no-op** +```python +class ResNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1) + self.bn1 = nn.BatchNorm2d(out_channels) + + # ✅ ALWAYS assign an nn.Module subclass + if stride != 1 or in_channels != out_channels: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride, 0) + else: + self.skip = nn.Identity() # ✅ No-op module + + def forward(self, x): + out = self.bn1(self.conv1(x)) + # No conditional needed! + x = self.skip(x) # Identity passes through unchanged + return F.relu(out + x) +``` + +**Why this works:** +- `nn.Identity()` passes input unchanged (no-op) +- Consistent module hierarchy across all code paths +- Device movement works: `.to(device)` works on Identity too +- State dict consistent: Identity has no parameters but is tracked +- DDP handles Identity correctly +- Architecture inspection works: `model.skip` always exists + +**Rule:** Never assign `None` to `self.*` for modules. Use `nn.Identity()` for no-ops. + + +### Pattern 2: Functional vs Module Operations - When to Use Each + +**Core Question:** When should you use `F.relu(x)` vs `self.relu = nn.ReLU()`? + +**Decision Framework:** + +| Use Functional (`F.*`) When | Use Module (`nn.*`) When | +|------------------------------|--------------------------| +| Simple, stateless operations | Need to hook the operation | +| Performance critical paths | Need to inspect/modify later | +| Operations in complex control flow | Want clear module hierarchy | +| One-off computations | Operation has learnable parameters | +| Loss functions | Activation functions you might swap | + +**Example: When functional is fine** +```python +class SimpleBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + self.bn = nn.BatchNorm2d(channels) + # No need to store ReLU as module for simple blocks + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x) # ✅ Fine for simple cases +``` + +**Example: When module storage matters** +```python +class FeatureExtractorBlock(nn.Module): + def __init__(self, channels, activation='relu'): + super().__init__() + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + self.bn = nn.BatchNorm2d(channels) + + # ✅ Store as module for flexibility and inspection + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'gelu': + self.activation = nn.GELU() + else: + self.activation = nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.activation(x) # ✅ Can hook, swap, inspect +``` + +**Why storing as module matters:** + +1. **Hooks**: Can only hook module operations, not functional + ```python + # ✅ Can register hook + model.layer1.activation.register_forward_hook(hook_fn) + + # ❌ Can't hook F.relu() calls + ``` + +2. **Inspection**: Module hierarchy shows architecture + ```python + for name, module in model.named_modules(): + print(f"{name}: {module}") + # With nn.ReLU: "layer1.activation: ReLU()" + # With F.relu: activation not shown + ``` + +3. **Modification**: Can swap modules after creation + ```python + # ✅ Can replace activation + model.layer1.activation = nn.GELU() + + # ❌ Can't modify F.relu() usage without code changes + ``` + +4. **Quantization**: Quantization tools trace module operations + ```python + # ✅ Quantization sees nn.ReLU + quantized = torch.quantization.quantize_dynamic(model) + + # ❌ F.relu() not traced by quantization + ``` + +**Pattern to follow:** +- Simple internal blocks: Functional is fine +- Top-level operations you might modify: Use modules +- When building reusable components: Use modules +- When unsure: Use modules (negligible overhead) + + +### Pattern 3: Modular Design with Substitutable Components + +**Problem:** Hardcoding architecture choices makes variants difficult. + +```python +# ❌ WRONG: Hardcoded architecture +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # Hardcoded: ReLU, BatchNorm, specific conv config + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) + self.bn2 = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + return x +``` + +**Problem:** To use LayerNorm or GELU, you must copy-paste and create new class. + +**✅ CORRECT: Modular design with substitutable components** +```python +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + norm_layer=nn.BatchNorm2d, # ✅ Substitutable + activation=nn.ReLU, # ✅ Substitutable + bias=True + ): + super().__init__() + + # Use provided norm and activation + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias) + self.norm1 = norm_layer(out_channels) + self.act1 = activation() + + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias) + self.norm2 = norm_layer(out_channels) + self.act2 = activation() + + def forward(self, x): + x = self.act1(self.norm1(self.conv1(x))) + x = self.act2(self.norm2(self.conv2(x))) + return x + +# Usage examples: +# Standard: BatchNorm + ReLU +block1 = EncoderBlock(64, 128) + +# LayerNorm + GELU (for vision transformers) +block2 = EncoderBlock(64, 128, norm_layer=nn.LayerNorm, activation=nn.GELU) + +# No normalization +block3 = EncoderBlock(64, 128, norm_layer=nn.Identity, activation=nn.ReLU) +``` + +**Advanced: Flexible normalization for different dimensions** +```python +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + norm_layer=None, # If None, use default BatchNorm2d + activation=None # If None, use default ReLU + ): + super().__init__() + + # Set defaults + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if activation is None: + activation = nn.ReLU + + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + # Handle both class and partial/lambda + self.norm1 = norm_layer(out_channels) if callable(norm_layer) else norm_layer + self.act1 = activation() if callable(activation) else activation + + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) + self.norm2 = norm_layer(out_channels) if callable(norm_layer) else norm_layer + self.act2 = activation() if callable(activation) else activation + + def forward(self, x): + x = self.act1(self.norm1(self.conv1(x))) + x = self.act2(self.norm2(self.conv2(x))) + return x +``` + +**Benefits:** +- One class supports many architectural variants +- Easy to experiment: swap LayerNorm, GELU, etc. +- Code reuse without duplication +- Matches PyTorch's own design (e.g., ResNet's `norm_layer` parameter) + +**Pattern:** Accept layer constructors as arguments, not hardcoded classes. + + +### Pattern 4: Proper State Management and `__init__` Structure + +**Core principle:** `__init__` defines the module's structure, `forward` defines computation. + +```python +class WellStructuredModule(nn.Module): + """ + Template for well-structured PyTorch modules. + """ + + def __init__(self, config): + # 1. ALWAYS call super().__init__() first + super().__init__() + + # 2. Store configuration (for reproducibility/serialization) + self.config = config + + # 3. Initialize all submodules (parameters registered automatically) + self._build_layers() + + # 4. Initialize weights AFTER building layers + self.reset_parameters() + + def _build_layers(self): + """ + Separate method for building layers (cleaner __init__). + """ + self.encoder = nn.Sequential( + nn.Linear(self.config.input_dim, self.config.hidden_dim), + nn.ReLU(), + nn.Linear(self.config.hidden_dim, self.config.hidden_dim) + ) + + self.decoder = nn.Linear(self.config.hidden_dim, self.config.output_dim) + + # ✅ Use nn.Identity() for conditional modules + if self.config.use_skip: + self.skip = nn.Linear(self.config.input_dim, self.config.output_dim) + else: + self.skip = nn.Identity() + + def reset_parameters(self): + """ + Custom initialization following PyTorch convention. + + This method can be called to re-initialize the module: + - After creation + - When loading partial checkpoints + - For training experiments + """ + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: # ✅ Check before accessing + nn.init.zeros_(module.bias) + + def forward(self, x): + """ + Forward pass - pure computation, no module construction. + """ + # ❌ NEVER create modules here! + # ❌ NEVER assign self.* here! + + encoded = self.encoder(x) + decoded = self.decoder(encoded) + skip = self.skip(x) + + return decoded + skip +``` + +**Critical rules:** + +1. **Never create modules in `forward()`** + ```python + # ❌ WRONG + def forward(self, x): + self.temp_layer = nn.Linear(10, 10) # ❌ Created during forward! + return self.temp_layer(x) + ``` + **Why:** Parameters not registered, DDP breaks, state dict inconsistent. + +2. **Never use `self.*` for intermediate results** + ```python + # ❌ WRONG + def forward(self, x): + self.intermediate = self.encoder(x) # ❌ Storing as attribute! + return self.decoder(self.intermediate) + ``` + **Why:** Retains computation graph, memory leak, not thread-safe. + +3. **All modules defined in `__init__`** + ```python + # ✅ CORRECT + def __init__(self): + super().__init__() + self.encoder = nn.Linear(10, 10) # ✅ Defined in __init__ + + def forward(self, x): + intermediate = self.encoder(x) # ✅ Local variable + return self.decoder(intermediate) + ``` + + +## Hook Management Best Practices + +### Pattern 5: Forward Hooks for Feature Extraction + +**Problem:** Naive hook usage causes memory leaks and handle management issues. + +```python +# ❌ WRONG: Multiple problems +import torch +import torch.nn as nn + +features = {} # ❌ Global state + +def get_features(name): + def hook(module, input, output): + features[name] = output # ❌ Retains computation graph! + return hook + +model = nn.Sequential(...) +# ❌ No handle stored, can't remove +model[2].register_forward_hook(get_features('layer2')) + +with torch.no_grad(): + output = model(x) + +# features now contains tensors with gradients (even in no_grad context!) +``` + +**Why this breaks:** +1. **Hooks run outside `torch.no_grad()` context**: Hook is called by autograd machinery, not your code +2. **Global state**: Not thread-safe, can't have multiple concurrent extractions +3. **No cleanup**: Hooks persist forever, can't remove +4. **Memory leak**: Retained outputs keep computation graph alive + +**✅ CORRECT: Encapsulated hook handler with proper cleanup** + +```python +class FeatureExtractor: + """ + Proper feature extraction using forward hooks. + + Example: + extractor = FeatureExtractor(model, layers=['layer2', 'layer3']) + with extractor: + output = model(input) + features = extractor.features # Dict of detached tensors + """ + + def __init__(self, model, layers): + self.model = model + self.layers = layers + self.features = {} + self.handles = [] # ✅ Store handles for cleanup + + def _make_hook(self, name): + def hook(module, input, output): + # ✅ CRITICAL: Detach and optionally clone + self.features[name] = output.detach() + # For inputs that might be modified in-place, use .clone(): + # self.features[name] = output.detach().clone() + return hook + + def __enter__(self): + """Register hooks when entering context.""" + self.features.clear() + + for name, module in self.model.named_modules(): + if name in self.layers: + handle = module.register_forward_hook(self._make_hook(name)) + self.handles.append(handle) # ✅ Store handle + + return self + + def __exit__(self, *args): + """Clean up hooks when exiting context.""" + # ✅ CRITICAL: Remove all hooks + for handle in self.handles: + handle.remove() + self.handles.clear() + +# Usage +model = resnet50() +extractor = FeatureExtractor(model, layers=['layer2', 'layer3', 'layer4']) + +with extractor: + output = model(input_tensor) + +# Features extracted and hooks cleaned up +pyramid_features = [ + extractor.features['layer2'], + extractor.features['layer3'], + extractor.features['layer4'] +] +``` + +**Key points:** +- ✅ Encapsulated in class (no global state) +- ✅ `output.detach()` breaks gradient tracking (prevents memory leak) +- ✅ Hook handles stored and removed (no persistent hooks) +- ✅ Context manager ensures cleanup even if error occurs +- ✅ Thread-safe (each extractor has own state) + + +### Pattern 6: When to Detach vs Clone in Hooks + +**Question:** Should hooks detach, clone, or both? + +**Decision framework:** + +```python +def hook(module, input, output): + # Decision tree: + + # 1. Just reading output, no modifications? + self.features[name] = output.detach() # ✅ Sufficient + + # 2. Output might be modified in-place later? + self.features[name] = output.detach().clone() # ✅ Safer + + # 3. Need gradients for analysis (rare)? + self.features[name] = output # ⚠️ Dangerous, ensure short lifetime +``` + +**Example: When clone matters** +```python +# Scenario: In-place operations after hook +class Model(nn.Module): + def forward(self, x): + x = self.layer1(x) # Hook here + x = self.layer2(x) + x += 10 # ❌ In-place modification! + return x + +# ❌ WRONG: Detach without clone +def hook(module, input, output): + features['layer1'] = output.detach() # Still shares memory! + +# After forward pass: +# features['layer1'] has been modified by x += 10! + +# ✅ CORRECT: Clone to get independent copy +def hook(module, input, output): + features['layer1'] = output.detach().clone() # Independent copy +``` + +**Rule of thumb:** +- **Detach only**: Reading features for analysis, no in-place ops +- **Detach + clone**: Features might be modified, or unsure +- **Neither**: Only if you need gradients (rare, risky) + + +### Pattern 7: Backward Hooks for Gradient Inspection + +**Use case:** Debugging gradient flow, detecting vanishing/exploding gradients. + +```python +class GradientInspector: + """ + Inspect gradients during backward pass. + + Example: + inspector = GradientInspector(model, layers=['layer1', 'layer2']) + with inspector: + output = model(input) + loss.backward() + + # Check gradient statistics + for name, stats in inspector.grad_stats.items(): + print(f"{name}: mean={stats['mean']:.4f}, std={stats['std']:.4f}") + """ + + def __init__(self, model, layers): + self.model = model + self.layers = layers + self.grad_stats = {} + self.handles = [] + + def _make_hook(self, name): + def hook(module, grad_input, grad_output): + # grad_output: gradients w.r.t. outputs (from upstream) + # grad_input: gradients w.r.t. inputs (to downstream) + + # Check grad_output (most common) + if grad_output[0] is not None: + grad = grad_output[0].detach() + self.grad_stats[name] = { + 'mean': grad.abs().mean().item(), + 'std': grad.std().item(), + 'max': grad.abs().max().item(), + 'min': grad.abs().min().item(), + } + return hook + + def __enter__(self): + self.grad_stats.clear() + + for name, module in self.model.named_modules(): + if name in self.layers: + handle = module.register_full_backward_hook(self._make_hook(name)) + self.handles.append(handle) + + return self + + def __exit__(self, *args): + for handle in self.handles: + handle.remove() + self.handles.clear() + +# Usage for gradient debugging +model = MyModel() +inspector = GradientInspector(model, layers=['encoder.layer1', 'decoder.layer1']) + +with inspector: + output = model(input) + loss = criterion(output, target) + loss.backward() + +# Check for vanishing/exploding gradients +for name, stats in inspector.grad_stats.items(): + if stats['mean'] < 1e-7: + print(f"⚠️ Vanishing gradient in {name}") + if stats['mean'] > 100: + print(f"⚠️ Exploding gradient in {name}") +``` + +**Critical differences from forward hooks:** +- **Backward hooks run during `.backward()`**: Not during forward pass +- **Receive gradient tensors**: Not activations +- **Used for gradient analysis**: Not feature extraction + + +### Pattern 8: Hook Handle Management Patterns + +**Never do this:** +```python +# ❌ WRONG: No handle stored +model.layer.register_forward_hook(hook_fn) +# Hook persists forever, can't remove! +``` + +**Three patterns for handle management:** + +**Pattern A: Context manager (recommended for temporary hooks)** +```python +class HookManager: + def __init__(self, module, hook_fn): + self.module = module + self.hook_fn = hook_fn + self.handle = None + + def __enter__(self): + self.handle = self.module.register_forward_hook(self.hook_fn) + return self + + def __exit__(self, *args): + if self.handle: + self.handle.remove() + +# Usage +with HookManager(model.layer1, my_hook): + output = model(input) +# Hook automatically removed +``` + +**Pattern B: Explicit cleanup (for long-lived hooks)** +```python +class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + self.hook_handle = self.layer.register_forward_hook(self._debug_hook) + + def _debug_hook(self, module, input, output): + print(f"Output shape: {output.shape}") + + def remove_hooks(self): + """Explicit cleanup method.""" + if self.hook_handle: + self.hook_handle.remove() + self.hook_handle = None + +# Usage +model = Model() +# ... use model ... +model.remove_hooks() # Clean up before saving or finishing +``` + +**Pattern C: List of handles (multiple hooks)** +```python +class Model(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)]) + self.hook_handles = [] + + def register_debug_hooks(self): + """Register hooks on all layers.""" + for i, layer in enumerate(self.layers): + handle = layer.register_forward_hook( + lambda m, inp, out, idx=i: print(f"Layer {idx}: {out.shape}") + ) + self.hook_handles.append(handle) + + def remove_all_hooks(self): + """Remove all registered hooks.""" + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() +``` + +**Critical rule:** Every `register_*_hook()` call MUST have corresponding `handle.remove()`. + + +## Weight Initialization Patterns + +### Pattern 9: The `reset_parameters()` Convention + +**PyTorch convention:** Custom initialization goes in `reset_parameters()`, called from `__init__`. + +```python +# ❌ WRONG: Initialization in __init__ after submodule creation +class CustomModule(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + self.linear1 = nn.Linear(in_dim, out_dim) + self.linear2 = nn.Linear(out_dim, out_dim) + + # ❌ Initializing here is fragile + nn.init.xavier_uniform_(self.linear1.weight) + nn.init.xavier_uniform_(self.linear2.weight) + # What if linear has bias=False? This crashes: + nn.init.zeros_(self.linear1.bias) # ❌ AttributeError if bias=False +``` + +**Problems:** +1. Happens AFTER nn.Linear's own `reset_parameters()` (already initialized) +2. Can't re-initialize later: `model.reset_parameters()` won't work +3. Fragile: assumes bias exists +4. Violates PyTorch convention + +**✅ CORRECT: Define `reset_parameters()` method** + +```python +class CustomModule(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + self.linear1 = nn.Linear(in_dim, out_dim) + self.linear2 = nn.Linear(out_dim, out_dim) + + # ✅ Call reset_parameters at end of __init__ + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize module parameters. + + Following PyTorch convention, this method: + - Can be called to re-initialize the module + - Is called automatically at end of __init__ + - Allows for custom initialization strategies + """ + # ✅ Defensive: check if bias exists + nn.init.xavier_uniform_(self.linear1.weight) + if self.linear1.bias is not None: + nn.init.zeros_(self.linear1.bias) + + nn.init.xavier_uniform_(self.linear2.weight) + if self.linear2.bias is not None: + nn.init.zeros_(self.linear2.bias) + + def forward(self, x): + return self.linear2(F.relu(self.linear1(x))) + +# Benefits: +# 1. Can re-initialize: model.reset_parameters() +# 2. Defensive checks for optional bias +# 3. Follows PyTorch convention +# 4. Clear separation: __init__ defines structure, reset_parameters initializes +``` + + +### Pattern 10: Hierarchical Initialization + +**Pattern:** When modules contain submodules, iterate through hierarchy. + +```python +class ComplexModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.encoder = nn.Sequential( + nn.Linear(config.input_dim, config.hidden_dim), + nn.ReLU(), + nn.Linear(config.hidden_dim, config.hidden_dim) + ) + + self.attention = nn.MultiheadAttention(config.hidden_dim, config.num_heads) + + self.decoder = nn.Linear(config.hidden_dim, config.output_dim) + + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize all submodules hierarchically. + """ + # Method 1: Iterate through all modules + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + # MultiheadAttention has its own reset_parameters() + # Option: Call it or customize + module._reset_parameters() # Call internal reset + + # Method 2: Specific initialization for specific layers + # Override general initialization for decoder + nn.init.xavier_uniform_(self.decoder.weight, gain=0.5) +``` + +**Two strategies:** + +1. **Uniform initialization**: Iterate all modules, apply same rules + ```python + for module in self.modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.kaiming_normal_(module.weight) + ``` + +2. **Layered initialization**: Different rules for different components + ```python + def reset_parameters(self): + # Encoder: Xavier + for module in self.encoder.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + + # Decoder: Xavier with small gain + nn.init.xavier_uniform_(self.decoder.weight, gain=0.5) + ``` + +**Defensive checks:** +```python +def reset_parameters(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + + # ✅ Always check for bias + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.BatchNorm2d): + # BatchNorm has weight and bias, but different semantics + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) +``` + + +### Pattern 11: Initialization with Learnable Parameters + +**Use case:** Custom parameters that need special initialization. + +```python +class AttentionWithTemperature(nn.Module): + def __init__(self, d_model, d_k): + super().__init__() + + self.d_k = d_k + + self.query = nn.Linear(d_model, d_k) + self.key = nn.Linear(d_model, d_k) + self.value = nn.Linear(d_model, d_k) + self.output = nn.Linear(d_k, d_model) + + # ✅ Learnable temperature parameter + # Initialize to 1/sqrt(d_k), but make it learnable + self.temperature = nn.Parameter(torch.ones(1)) + + self.reset_parameters() + + def reset_parameters(self): + """Initialize all parameters.""" + # Standard initialization for linear layers + for linear in [self.query, self.key, self.value]: + nn.init.xavier_uniform_(linear.weight) + if linear.bias is not None: + nn.init.zeros_(linear.bias) + + # Output projection with smaller gain + nn.init.xavier_uniform_(self.output.weight, gain=0.5) + if self.output.bias is not None: + nn.init.zeros_(self.output.bias) + + # ✅ Custom parameter initialization + nn.init.constant_(self.temperature, 1.0 / math.sqrt(self.d_k)) + + def forward(self, x): + q = self.query(x) + k = self.key(x) + v = self.value(x) + + scores = torch.matmul(q, k.transpose(-2, -1)) * self.temperature + attn = torch.softmax(scores, dim=-1) + out = torch.matmul(attn, v) + + return self.output(out) +``` + +**Key points:** +- Custom parameters defined with `nn.Parameter()` +- Initialized in `reset_parameters()` like other parameters +- Can use `nn.init.*` functions on parameters + + +## Common Pitfalls + +### Consolidated Pitfalls Table + +| # | Pitfall | Symptom | Root Cause | Fix | +|---|---------|---------|------------|-----| +| 1 | Using `self.x = None` for conditional modules | State dict inconsistent, DDP fails, can't move to device | None not an nn.Module | Use `nn.Identity()` | +| 2 | Using functional ops when hooks/inspection needed | Can't hook activations, architecture invisible | Functional bypasses module hierarchy | Store as `self.activation = nn.ReLU()` | +| 3 | Hooks retaining computation graphs | Memory leak during feature extraction | Hook doesn't detach outputs | Use `output.detach()` in hook | +| 4 | No hook handle cleanup | Hooks persist, memory leak, unexpected behavior | Handles not stored/removed | Store handles, call `handle.remove()` | +| 5 | Global state in hook closures | Not thread-safe, coupling issues | Mutable global variables | Encapsulate in class | +| 6 | Initialization in `__init__` instead of `reset_parameters()` | Can't re-initialize, fragile timing | Violates PyTorch convention | Define `reset_parameters()` | +| 7 | Accessing bias without checking existence | Crashes with AttributeError | Assumes bias always exists | Check `if module.bias is not None:` | +| 8 | Creating modules in `forward()` | Parameters not registered, DDP breaks | Modules must be in `__init__` | Move to `__init__`, use local vars | +| 9 | Storing intermediate results as `self.*` | Memory leak, not thread-safe | Retains computation graph | Use local variables only | +| 10 | Not using context managers for hooks | Hooks not cleaned up on error | Missing try/finally | Use `__enter__`/`__exit__` pattern | + + +### Pitfall 1: Conditional None Assignment + +```python +# ❌ WRONG +class Block(nn.Module): + def __init__(self, use_skip): + super().__init__() + self.layer = nn.Linear(10, 10) + self.skip = nn.Linear(10, 10) if use_skip else None # ❌ + + def forward(self, x): + out = self.layer(x) + if self.skip is not None: + out = out + self.skip(x) + return out + +# ✅ CORRECT +class Block(nn.Module): + def __init__(self, use_skip): + super().__init__() + self.layer = nn.Linear(10, 10) + self.skip = nn.Linear(10, 10) if use_skip else nn.Identity() # ✅ + + def forward(self, x): + out = self.layer(x) + out = out + self.skip(x) # No conditional needed + return out +``` + +**Symptom:** State dict keys mismatch, DDP synchronization failures +**Fix:** Always use `nn.Identity()` for no-op modules + + +### Pitfall 2: Functional Ops Preventing Hooks + +```python +# ❌ WRONG: Can't hook ReLU +class Encoder(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear(x) + return F.relu(x) # ❌ Can't hook this! + +# Can't do this: +# encoder.relu.register_forward_hook(hook) # AttributeError! + +# ✅ CORRECT: Hookable activation +class Encoder(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + self.relu = nn.ReLU() # ✅ Stored as module + + def forward(self, x): + x = self.linear(x) + return self.relu(x) + +# ✅ Now can hook: +encoder.relu.register_forward_hook(hook) +``` + +**Symptom:** Can't register hooks on operations +**Fix:** Store operations as modules when you need inspection/hooks + + +### Pitfall 3: Hook Memory Leak + +```python +# ❌ WRONG: Hook retains graph +features = {} + +def hook(module, input, output): + features['layer'] = output # ❌ Retains computation graph! + +model.layer.register_forward_hook(hook) + +with torch.no_grad(): + output = model(input) +# features['layer'] STILL has gradients! + +# ✅ CORRECT: Detach in hook +def hook(module, input, output): + features['layer'] = output.detach() # ✅ Breaks graph + +# Even better: Clone if might be modified +def hook(module, input, output): + features['layer'] = output.detach().clone() # ✅ Independent copy +``` + +**Symptom:** Memory grows during feature extraction even with `torch.no_grad()` +**Fix:** Always `.detach()` in hooks (and `.clone()` if needed) + + +### Pitfall 4: Missing Hook Cleanup + +```python +# ❌ WRONG: No handle management +model.layer.register_forward_hook(my_hook) +# Hook persists forever, can't remove! + +# ✅ CORRECT: Store and clean up handle +class HookManager: + def __init__(self): + self.handle = None + + def register(self, module, hook): + self.handle = module.register_forward_hook(hook) + + def cleanup(self): + if self.handle: + self.handle.remove() + +manager = HookManager() +manager.register(model.layer, my_hook) +# ... use model ... +manager.cleanup() # ✅ Remove hook +``` + +**Symptom:** Hooks persist, unexpected behavior, memory leaks +**Fix:** Always store handles and call `.remove()` + + +### Pitfall 5: Initialization Timing + +```python +# ❌ WRONG: Init in __init__ (fragile) +class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) # Already initialized! + + # This works but is fragile: + nn.init.xavier_uniform_(self.linear.weight) # Overwrites default init + +# ✅ CORRECT: Init in reset_parameters() +class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + self.reset_parameters() # ✅ Clear separation + + def reset_parameters(self): + nn.init.xavier_uniform_(self.linear.weight) + if self.linear.bias is not None: # ✅ Defensive + nn.init.zeros_(self.linear.bias) +``` + +**Symptom:** Can't re-initialize, crashes on bias=False +**Fix:** Define `reset_parameters()`, call from `__init__` + + +## Red Flags - Stop and Reconsider + +**If you catch yourself doing ANY of these, STOP and follow patterns:** + +| Red Flag Action | Reality | What to Do Instead | +|-----------------|---------|-------------------| +| "I'll assign None to this module attribute" | Breaks PyTorch's module contract | Use `nn.Identity()` | +| "F.relu() is simpler than nn.ReLU()" | True, but prevents inspection/hooks | Use module if you might need hooks | +| "I'll store hook output directly" | Retains computation graph | Always `.detach()` first | +| "I don't need to store the hook handle" | Can't remove hook later | Always store handles | +| "I'll just initialize in __init__" | Can't re-initialize later | Use `reset_parameters()` | +| "Bias always exists, right?" | No! `bias=False` is common | Check `if bias is not None:` | +| "I'll save intermediate results as self.*" | Memory leak, not thread-safe | Use local variables only | +| "I'll create this module in forward()" | Parameters not registered | All modules in `__init__` | + +**Critical rule:** Follow PyTorch conventions or face subtle bugs in production. + + +## Complete Example: Well-Designed ResNet Block + +```python +import torch +import torch.nn as nn +import math + +class ResNetBlock(nn.Module): + """ + Well-designed ResNet block following all best practices. + + Features: + - Substitutable norm and activation layers + - Proper use of nn.Identity() for skip connections + - Hook-friendly (all operations are modules) + - Correct initialization via reset_parameters() + - Defensive bias checking + """ + + def __init__( + self, + in_channels, + out_channels, + stride=1, + norm_layer=nn.BatchNorm2d, + activation=nn.ReLU, + bias=False # Usually False with BatchNorm + ): + super().__init__() + + # Store config for potential serialization + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + # Main path: conv -> norm -> activation -> conv -> norm + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=bias) + self.norm1 = norm_layer(out_channels) + self.act1 = activation() + + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias) + self.norm2 = norm_layer(out_channels) + + # Skip connection (dimension matching) + # ✅ CRITICAL: Use nn.Identity(), never None + if stride != 1 or in_channels != out_channels: + self.skip = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=bias), + norm_layer(out_channels) + ) + else: + self.skip = nn.Identity() + + # Final activation (applied after residual addition) + self.act2 = activation() + + # ✅ Initialize weights following convention + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize weights using He initialization (good for ReLU). + """ + # Iterate through all conv layers + for module in self.modules(): + if isinstance(module, nn.Conv2d): + # He initialization for ReLU + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + # ✅ Defensive: check bias exists + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.BatchNorm2d): + # BatchNorm standard initialization + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + def forward(self, x): + """ + Forward pass: residual connection with skip path. + + Note: All operations are modules, so can be hooked or modified. + """ + # Main path + out = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) # ✅ Module, not F.relu() + + out = self.conv2(out) + out = self.norm2(out) + + # Skip connection (always works, no conditional) + skip = self.skip(x) # ✅ Identity passes through if no projection needed + + # Residual addition and final activation + out = out + skip + out = self.act2(out) # ✅ Module, not F.relu() + + return out + +# Usage examples: +# Standard ResNet block +block1 = ResNetBlock(64, 128, stride=2) + +# With LayerNorm and GELU (Vision Transformer style) +block2 = ResNetBlock(64, 128, norm_layer=nn.GroupNorm, activation=nn.GELU) + +# Can hook any operation: +handle = block1.act1.register_forward_hook(lambda m, i, o: print(f"ReLU output shape: {o.shape}")) + +# Can re-initialize: +block1.reset_parameters() + +# Can inspect architecture: +for name, module in block1.named_modules(): + print(f"{name}: {module}") +``` + +**Why this design is robust:** +1. ✅ No None assignments (uses `nn.Identity()`) +2. ✅ All operations are modules (hookable) +3. ✅ Substitutable components (norm, activation) +4. ✅ Proper initialization (`reset_parameters()`) +5. ✅ Defensive bias checking +6. ✅ Clear module hierarchy +7. ✅ Configuration stored (reproducibility) +8. ✅ No magic numbers or hardcoded choices + + +## Edge Cases and Advanced Scenarios + +### Edge Case 1: Dynamic Module Lists (nn.ModuleList) + +**Scenario:** Need variable number of layers based on config. + +```python +# ❌ WRONG: Using Python list for modules +class DynamicModel(nn.Module): + def __init__(self, num_layers): + super().__init__() + self.layers = [] # ❌ Python list, parameters not registered! + for i in range(num_layers): + self.layers.append(nn.Linear(10, 10)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +# model.parameters() is empty! DDP breaks! + +# ✅ CORRECT: Use nn.ModuleList +class DynamicModel(nn.Module): + def __init__(self, num_layers): + super().__init__() + self.layers = nn.ModuleList([ # ✅ Registers all parameters + nn.Linear(10, 10) for _ in range(num_layers) + ]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x +``` + +**Rule:** Use `nn.ModuleList` for lists of modules, `nn.ModuleDict` for dicts. + + +### Edge Case 2: Hooks on nn.Sequential + +**Problem:** Hooking specific layers inside nn.Sequential. + +```python +model = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 20), + nn.ReLU(), + nn.Linear(20, 10) +) + +# ❌ WRONG: Can't access by name easily +# model.layer2.register_forward_hook(hook) # AttributeError + +# ✅ CORRECT: Access by index +handle = model[2].register_forward_hook(hook) # Third layer (Linear 20->20) + +# ✅ BETTER: Use named modules +for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + print(f"Hooking {name}") + module.register_forward_hook(hook) +``` + +**Best practice:** For hookable models, use explicit named attributes instead of Sequential: + +```python +class HookableModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(10, 20) + self.act1 = nn.ReLU() + self.layer2 = nn.Linear(20, 20) # ✅ Named, easy to hook + self.act2 = nn.ReLU() + self.layer3 = nn.Linear(20, 10) + + def forward(self, x): + x = self.act1(self.layer1(x)) + x = self.act2(self.layer2(x)) + return self.layer3(x) + +# Easy to hook specific layers: +model.layer2.register_forward_hook(hook) +``` + + +### Edge Case 3: Hooks with In-Place Operations + +**Problem:** In-place operations modify hooked tensors. + +```python +class ModelWithInPlace(nn.Module): + def forward(self, x): + x = self.layer1(x) # Hook here + x += 10 # ❌ In-place modification! + x = self.layer2(x) + return x + +# Hook only using detach(): +def hook(module, input, output): + features['layer1'] = output.detach() # ❌ Still shares memory! + +# After forward pass, features['layer1'] has been modified! + +# ✅ CORRECT: Detach AND clone +def hook(module, input, output): + features['layer1'] = output.detach().clone() # ✅ Independent copy +``` + +**Decision tree for hooks:** + +``` +Is output modified in-place later? +├─ Yes → Use .detach().clone() +└─ No → Use .detach() (sufficient) + +Need gradients for analysis? +├─ Yes → Don't detach (but ensure short lifetime!) +└─ No → Detach (prevents memory leak) +``` + + +### Edge Case 4: Partial State Dict Loading + +**Scenario:** Loading checkpoint with different architecture. + +```python +# Original model +class ModelV1(nn.Module): + def __init__(self): + super().__init__() + self.encoder = nn.Linear(10, 20) + self.decoder = nn.Linear(20, 10) + +# New model with additional layer +class ModelV2(nn.Module): + def __init__(self): + super().__init__() + self.encoder = nn.Linear(10, 20) + self.middle = nn.Linear(20, 20) # New layer! + self.decoder = nn.Linear(20, 10) + + self.reset_parameters() + + def reset_parameters(self): + # ✅ Initialize all layers + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + +# Load V1 checkpoint into V2 model +model_v2 = ModelV2() +checkpoint = torch.load('model_v1.pth') + +# ✅ Use strict=False for partial loading +model_v2.load_state_dict(checkpoint, strict=False) + +# ✅ Re-initialize new layers only +model_v2.middle.reset_parameters() # New layer needs init +``` + +**Pattern:** When loading partial checkpoints: +1. Load with `strict=False` +2. Check which keys are missing/unexpected +3. Re-initialize only new layers (not loaded ones) + + +### Edge Case 5: Hook Removal During Forward Pass + +**Problem:** Removing hooks while iterating causes issues. + +```python +class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + self.hook_handles = [] + + def add_temporary_hook(self): + def hook(module, input, output): + print("Hook called!") + # ❌ WRONG: Removing handle inside hook + for h in self.hook_handles: + h.remove() # Dangerous during iteration! + + handle = self.layer.register_forward_hook(hook) + self.hook_handles.append(handle) + +# ✅ CORRECT: Flag for removal, remove after forward pass +class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + self.hook_handles = [] + self.hooks_to_remove = [] + + def add_temporary_hook(self): + def hook(module, input, output): + print("Hook called!") + # ✅ Flag for removal + self.hooks_to_remove.append(handle) + + handle = self.layer.register_forward_hook(hook) + self.hook_handles.append(handle) + + def cleanup_hooks(self): + """Call after forward pass""" + for handle in self.hooks_to_remove: + handle.remove() + self.hook_handles.remove(handle) + self.hooks_to_remove.clear() +``` + +**Rule:** Never modify hook handles during forward pass. Flag for removal and clean up after. + + +### Edge Case 6: Custom Modules with Buffers + +**Pattern:** Buffers are non-parameter tensors that should be saved/moved with model. + +```python +class RunningStatsModule(nn.Module): + def __init__(self, num_features): + super().__init__() + + # ❌ WRONG: Just store as attribute + self.running_mean = torch.zeros(num_features) # Not registered! + + # ✅ CORRECT: Register as buffer + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + # Parameters (learnable) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + # Update running stats (in training mode) + if self.training: + mean = x.mean(dim=0) + var = x.var(dim=0) + # ✅ In-place update of buffers + self.running_mean.mul_(0.9).add_(mean, alpha=0.1) + self.running_var.mul_(0.9).add_(var, alpha=0.1) + + # Normalize using running stats + normalized = (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5) + return normalized * self.weight + self.bias + +# Buffers are moved with model: +model = RunningStatsModule(10) +model.cuda() # ✅ running_mean and running_var moved to GPU + +# Buffers are saved in state_dict: +torch.save(model.state_dict(), 'model.pth') # ✅ Includes buffers +``` + +**When to use buffers:** +- Running statistics (BatchNorm-style) +- Fixed embeddings (not updated by optimizer) +- Positional encodings (not learned) +- Masks or indices + +**Rule:** Use `register_buffer()` for tensors that aren't parameters but should be saved/moved. + + +## Common Rationalizations (Don't Do These) + +| Excuse | Reality | Correct Approach | +|--------|---------|------------------| +| "User wants quick solution, I'll use None" | Quick becomes slow when DDP breaks | Always use nn.Identity(), same speed | +| "It's just a prototype, proper patterns later" | Prototype becomes production, tech debt compounds | Build correctly from start, no extra time | +| "F.relu() is more Pythonic/simpler" | True, but prevents hooks and modification | Use nn.ReLU() if any chance of needing hooks | +| "I'll fix initialization in training loop" | Defeats purpose of reset_parameters() | Put in reset_parameters(), 5 extra lines | +| "Bias is almost always there" | False! Many models use bias=False | Check if bias is not None, always | +| "Hooks are advanced, user won't use them" | Until they need debugging or feature extraction | Design hookable from start, no cost | +| "I'll clean up hooks manually later" | Later never comes, memory leaks persist | Context manager takes 10 lines, bulletproof | +| "This module is simple, no need for modularity" | Simple modules get extended and reused | Substitutable components from start | +| "State dict loading always matches architecture" | False! Checkpoints get reused across versions | Implement reset_parameters() for partial loads | +| "In-place ops are fine, I'll remember detach+clone" | Won't remember under pressure | Document decision in code, add comment | + +**Critical insight:** "Shortcuts for simplicity" become "bugs in production." Proper patterns take seconds more, prevent hours of debugging. + + +## Decision Frameworks + +### Framework 1: Module vs Functional Operations + +**Question:** Should I use `nn.ReLU()` or `F.relu()`? + +``` +Will you ever need to: +├─ Register hooks on this operation? → Use nn.ReLU() +├─ Inspect architecture (model.named_modules())? → Use nn.ReLU() +├─ Swap activation (ReLU→GELU)? → Use nn.ReLU() +├─ Use quantization? → Use nn.ReLU() +└─ None of above AND performance critical? → F.relu() acceptable +``` + +**Default:** When in doubt, use module version. Performance difference negligible. + + +### Framework 2: Hook Detachment Strategy + +**Question:** In my hook, should I use `detach()`, `detach().clone()`, or neither? + +``` +Do you need gradients for analysis? +├─ Yes → Don't detach (but ensure short lifetime!) +└─ No → Continue... + +Will the output be modified in-place later? +├─ Yes → Use .detach().clone() +├─ Unsure → Use .detach().clone() (safer) +└─ No → Use .detach() (sufficient) +``` + +**Example decision:** +```python +# Scenario: Extract features for visualization (no gradients needed, no in-place) +def hook(module, input, output): + return output.detach() # ✅ Sufficient + +# Scenario: Extract features, model has in-place ops (x += y) +def hook(module, input, output): + return output.detach().clone() # ✅ Necessary + +# Scenario: Gradient analysis (rare!) +def hook(module, input, output): + return output # ⚠️ Keep gradients, but ensure short lifetime +``` + + +### Framework 3: Initialization Strategy Selection + +**Question:** Which initialization should I use? + +``` +Activation function? +├─ ReLU family → Kaiming (He) initialization +├─ Tanh/Sigmoid → Xavier (Glorot) initialization +├─ GELU/Swish → Xavier or Kaiming (experiment) +└─ None/Linear → Xavier + +Layer type? +├─ Conv → Usually Kaiming with mode='fan_out' +├─ Linear → Kaiming or Xavier depending on activation +├─ Embedding → Normal(0, 1) or Xavier +└─ LSTM/GRU → Xavier for gates + +Special considerations? +├─ ResNet-style → Last layer of block: small gain (e.g., 0.5) +├─ Transformer → Xavier uniform, specific scale for embeddings +├─ GAN → Careful initialization critical (see paper) +└─ Pre-trained → Don't re-initialize! Load checkpoint +``` + +**Code example:** +```python +def reset_parameters(self): + for module in self.modules(): + if isinstance(module, nn.Conv2d): + # ReLU activation → Kaiming + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Linear): + # Check what activation follows (from self.config or hardcoded) + if self.activation == 'relu': + nn.init.kaiming_uniform_(module.weight, nonlinearity='relu') + else: + nn.init.xavier_uniform_(module.weight) + + if module.bias is not None: + nn.init.zeros_(module.bias) +``` + + +### Framework 4: When to Use Buffers vs Parameters vs Attributes + +**Decision tree:** + +``` +Is it a tensor that needs to be saved with the model? +└─ No → Regular attribute (self.x = value) +└─ Yes → Continue... + +Should it be updated by optimizer? +└─ Yes → nn.Parameter() +└─ No → Continue... + +Should it move with model (.to(device))? +└─ Yes → register_buffer() +└─ No → Regular attribute + +Examples: +- Model weights → nn.Parameter() +- Running statistics (BatchNorm) → register_buffer() +- Configuration dict → Regular attribute +- Fixed positional encoding → register_buffer() +- Dropout probability → Regular attribute +- Learnable temperature → nn.Parameter() +``` + + +## Pressure Testing Scenarios + +### Scenario 1: Time Pressure + +**User:** "I need this module quickly, just make it work." + +**Agent thought:** "I'll use None and functional ops, faster to write." + +**Reality:** Taking 30 seconds more to use nn.Identity() and nn.ReLU() prevents hours of debugging DDP issues. + +**Correct response:** Apply patterns anyway. They're not slower to write once familiar. + + +### Scenario 2: "Simple" Module + +**User:** "This is a simple block, don't overcomplicate it." + +**Agent thought:** "I'll hardcode ReLU and BatchNorm, it's just a prototype." + +**Reality:** Prototypes become production. Making activation/norm substitutable takes one extra line. + +**Correct response:** Design modularly from the start. "Simple" doesn't mean "brittle." + + +### Scenario 3: Existing Codebase + +**User:** "The existing code uses None for optional modules." + +**Agent thought:** "I should match existing style for consistency." + +**Reality:** Existing code may have bugs. Improving patterns is better than perpetuating anti-patterns. + +**Correct response:** Use correct patterns. Offer to refactor existing code if user wants. + + +### Scenario 4: "Just Getting Started" + +**User:** "I'm just experimenting, I'll clean it up later." + +**Agent thought:** "Proper patterns can wait until it works." + +**Reality:** Later never comes. Or worse, you can't iterate quickly because of accumulated tech debt. + +**Correct response:** Proper patterns don't slow down experimentation. They enable faster iteration. + + +## Red Flags Checklist + +Before writing `__init__` or `forward`, check yourself: + +### Module Definition Red Flags +- [ ] Am I assigning `None` to a module attribute? + - **FIX:** Use `nn.Identity()` +- [ ] Am I using functional ops (F.relu) without considering hooks? + - **ASK:** Will this ever need inspection/modification? +- [ ] Am I hardcoding architecture choices (ReLU, BatchNorm)? + - **FIX:** Make them substitutable parameters +- [ ] Am I creating modules in `forward()`? + - **FIX:** All modules in `__init__` + +### Hook Usage Red Flags +- [ ] Am I storing hook output without detaching? + - **FIX:** Use `.detach()` or `.detach().clone()` +- [ ] Am I registering hooks without storing handles? + - **FIX:** Store handles, clean up in `__exit__` +- [ ] Am I using global variables in hook closures? + - **FIX:** Encapsulate in a class +- [ ] Am I modifying hook handles during forward pass? + - **FIX:** Flag for removal, clean up after + +### Initialization Red Flags +- [ ] Am I initializing weights in `__init__`? + - **FIX:** Define `reset_parameters()`, call from `__init__` +- [ ] Am I accessing `.bias` without checking if it exists? + - **FIX:** Check `if module.bias is not None:` +- [ ] Am I using one initialization for all layers? + - **ASK:** Should different layers have different strategies? + +### State Management Red Flags +- [ ] Am I storing intermediate results as `self.*`? + - **FIX:** Use local variables only +- [ ] Am I using Python list for modules? + - **FIX:** Use `nn.ModuleList` +- [ ] Do I have tensors that should be buffers but aren't? + - **FIX:** Use `register_buffer()` + +**If ANY red flag is true, STOP and apply the pattern before proceeding.** + + +## Quick Reference Cards + +### Card 1: Module Design Checklist +``` +✓ super().__init__() called first +✓ All modules defined in __init__ (not forward) +✓ No None assignments (use nn.Identity()) +✓ Substitutable components (norm_layer, activation args) +✓ reset_parameters() defined and called +✓ Defensive checks (if bias is not None) +✓ Buffers registered (register_buffer()) +✓ No self.* assignments in forward() +``` + +### Card 2: Hook Checklist +``` +✓ Hook detaches output (.detach() or .detach().clone()) +✓ Hook handles stored in list +✓ Context manager for cleanup (__enter__/__exit__) +✓ No global state mutation +✓ Error handling (try/except in hook) +✓ Documented whether hook modifies output +``` + +### Card 3: Initialization Checklist +``` +✓ reset_parameters() method defined +✓ Called from __init__ +✓ Iterates through modules or layers +✓ Checks if bias is not None +✓ Uses appropriate init strategy (Kaiming/Xavier) +✓ Documents why this initialization +✓ Can be called to re-initialize +``` + + +## References + +**PyTorch Documentation:** +- nn.Module: https://pytorch.org/docs/stable/notes/modules.html +- Hooks: https://pytorch.org/docs/stable/notes/modules.html#module-hooks +- Initialization: https://pytorch.org/docs/stable/nn.init.html + +**Related Skills:** +- tensor-operations-and-memory (memory management) +- debugging-techniques (using hooks for debugging) +- distributed-training-strategies (DDP-compatible module design) +- checkpointing-and-reproducibility (state dict best practices) diff --git a/skills/using-pytorch-engineering/performance-profiling.md b/skills/using-pytorch-engineering/performance-profiling.md new file mode 100644 index 0000000..b87b15b --- /dev/null +++ b/skills/using-pytorch-engineering/performance-profiling.md @@ -0,0 +1,1893 @@ + +# Performance Profiling and Bottleneck Analysis + +## Overview + +**Core Principle:** Optimization without measurement is guessing. Profile systematically (whole → component → operation) using the right tools to identify actual bottlenecks before attempting fixes. 90% of runtime is usually in 10% of code - find that 10% with profiling, not intuition. + +Performance issues stem from: data loading bottlenecks (CPU-bound), inefficient operations (GPU-bound), memory bandwidth limits (memory-bound), or I/O bottlenecks. Profiling reveals which category applies. Guessing leads to optimizing the wrong thing, wasting hours on marginal improvements while real bottleneck remains. + +## When to Use + +**Use this skill when:** +- Training or inference slower than expected +- Need to identify performance bottleneck in PyTorch code +- High GPU memory usage, need to understand what's using memory +- Evaluating whether optimization actually improved performance +- Debugging low GPU utilization issues +- Comparing performance of different implementations +- Need to profile specific operations or model components + +**Don't use when:** +- Performance is already acceptable (no problem to solve) +- Architecture design questions (use module-design-patterns) +- Debugging correctness issues (use debugging-techniques) +- Memory leaks (use tensor-operations-and-memory) + +**Symptoms triggering this skill:** +- "Training is slower than expected" +- "Low GPU utilization but training still slow" +- "Which part of my model is the bottleneck?" +- "Does this optimization actually help?" +- "Memory usage is high, what's using it?" +- "First iteration much slower than subsequent ones" + + +## Systematic Profiling Methodology + +### The Four-Phase Framework + +**Phase 1: Establish Baseline** +- Define metric (throughput, latency, memory) +- Measure end-to-end performance +- Set improvement target +- Document measurement conditions + +**Phase 2: Identify Bottleneck Type** +- CPU-bound vs GPU-bound vs I/O-bound vs memory-bound +- Check GPU utilization (nvidia-smi) +- Profile data loading separately from computation +- Determine which component to investigate + +**Phase 3: Narrow to Component** +- Profile at coarse granularity +- Identify which phase is slow (forward/backward/optimizer/data loading) +- Focus profiling on bottleneck component +- Use iterative narrowing + +**Phase 4: Identify Operation** +- Profile bottleneck component in detail +- Examine both table view and trace view +- Find specific operation or pattern +- Measure improvement after fix + +**Critical Rule:** ALWAYS work through phases in order. Don't jump to Phase 4 without Phases 1-3. + + +### Phase 1: Establish Baseline + +**Step 1: Define Performance Metric** + +```python +# Choose the right metric for your use case: + +# Throughput (samples/second) - for training +def measure_throughput(model, dataloader, num_batches=100): + model.train() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + total_samples = 0 + start.record() + + for i, (data, target) in enumerate(dataloader): + if i >= num_batches: + break + data, target = data.cuda(), target.cuda() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + total_samples += data.size(0) + + end.record() + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + + throughput = total_samples / (elapsed_ms / 1000.0) + print(f"Throughput: {throughput:.2f} samples/sec") + print(f"Time per batch: {elapsed_ms / num_batches:.2f} ms") + return throughput + +# Latency (time per sample) - for inference +def measure_latency(model, sample_input, num_iterations=100, warmup=10): + model.eval() + sample_input = sample_input.cuda() + + # Warmup (CRITICAL - don't skip!) + with torch.no_grad(): + for _ in range(warmup): + _ = model(sample_input) + + # Measure + latencies = [] + with torch.no_grad(): + for _ in range(num_iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(sample_input) + end.record() + + torch.cuda.synchronize() + latencies.append(start.elapsed_time(end)) + + # Report statistics (not just average!) + import numpy as np + latencies = np.array(latencies) + print(f"Latency - Mean: {latencies.mean():.2f} ms, " + f"Std: {latencies.std():.2f} ms, " + f"Median: {np.median(latencies):.2f} ms, " + f"P95: {np.percentile(latencies, 95):.2f} ms, " + f"P99: {np.percentile(latencies, 99):.2f} ms") + return latencies + +# Memory usage (peak GB) +def measure_memory(model, sample_batch): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + # Run one iteration + output = model(sample_batch) + loss = criterion(output, target) + loss.backward() + + torch.cuda.synchronize() + peak_memory = torch.cuda.max_memory_allocated() / 1e9 + print(f"Peak memory: {peak_memory:.2f} GB") + return peak_memory +``` + +**Why this matters:** +- Without baseline, can't measure improvement +- Need statistics (mean, std, percentiles), not just average +- Must use CUDA Events for GPU timing (not `time.time()`) +- Warmup critical to exclude JIT compilation overhead + + +**Step 2: Document Measurement Conditions** + +```python +# Record all relevant configuration +profiling_config = { + 'model': model.__class__.__name__, + 'batch_size': 32, + 'input_shape': (3, 224, 224), + 'device': 'cuda:0', + 'dtype': 'float16' if using_amp else 'float32', + 'mode': 'train' or 'eval', + 'num_workers': dataloader.num_workers, + 'cudnn_benchmark': torch.backends.cudnn.benchmark, + 'gpu': torch.cuda.get_device_name(0), +} + +print(json.dumps(profiling_config, indent=2)) +``` + +**Why this matters:** +- Performance changes with configuration +- Need to reproduce results +- Comparing different runs requires same conditions +- Document before optimizing, re-measure after + + +### Phase 2: Identify Bottleneck Type + +**Step 1: Check GPU Utilization** + +```bash +# In terminal, monitor GPU utilization in real-time +nvidia-smi dmon -s u -i 0 -d 1 + +# Or within Python +import subprocess +result = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu,memory.used', + '--format=csv,noheader,nounits'], + capture_output=True, text=True) +gpu_util, mem_used = result.stdout.strip().split(',') +print(f"GPU Utilization: {gpu_util}%, Memory: {mem_used} MB") +``` + +**Interpretation:** + +| GPU Utilization | Likely Bottleneck | Next Step | +|----------------|-------------------|-----------| +| < 70% | CPU-bound (data loading, preprocessing) | Profile data loading | +| > 90% | GPU-bound (computation) | Profile model operations | +| 50-80% | Mixed or memory-bound | Check memory bandwidth | + +**Why this matters:** GPU utilization tells you WHERE to look. If GPU isn't saturated, optimizing GPU operations won't help. + + +**Step 2: Profile Data Loading vs Computation** + +```python +import time + +def profile_dataloader_vs_model(model, dataloader, num_batches=50): + """Separate data loading time from model computation time""" + model.train() + + data_times = [] + compute_times = [] + + batch_iterator = iter(dataloader) + + for i in range(num_batches): + # Time data loading + data_start = time.time() + data, target = next(batch_iterator) + data_end = time.time() + data_times.append(data_end - data_start) + + # Time computation + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + torch.cuda.synchronize() + + compute_start = torch.cuda.Event(enable_timing=True) + compute_end = torch.cuda.Event(enable_timing=True) + + compute_start.record() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + compute_end.record() + + torch.cuda.synchronize() + compute_times.append(compute_start.elapsed_time(compute_end)) + + import numpy as np + avg_data_time = np.mean(data_times) * 1000 # to ms + avg_compute_time = np.mean(compute_times) + + print(f"Avg data loading time: {avg_data_time:.2f} ms") + print(f"Avg computation time: {avg_compute_time:.2f} ms") + print(f"Data loading is {avg_data_time/avg_compute_time:.1f}x " + f"{'slower' if avg_data_time > avg_compute_time else 'faster'} than compute") + + if avg_data_time > avg_compute_time: + print("⚠️ BOTTLENECK: Data loading (CPU-bound)") + print(" Solutions: Increase num_workers, use pin_memory=True, " + "move preprocessing to GPU") + else: + print("✅ Data loading is fast enough. Bottleneck is in model computation.") + + return avg_data_time, avg_compute_time +``` + +**Why this matters:** +- If data loading > computation time, GPU is starved (increase workers) +- If computation > data loading, GPU is bottleneck (optimize model) +- Common mistake: Optimizing model when data loading is the bottleneck + + +**Step 3: Determine Bottleneck Category** + +```python +def diagnose_bottleneck_type(model, dataloader): + """Systematic bottleneck categorization""" + + # 1. Check GPU utilization + print("=== GPU Utilization Check ===") + # Run training for a bit while monitoring GPU + # If GPU util < 70% → CPU-bound + # If GPU util > 90% → GPU-bound + + # 2. Check memory bandwidth + print("\n=== Memory Bandwidth Check ===") + from torch.profiler import profile, ProfilerActivity + + with profile(activities=[ProfilerActivity.CUDA]) as prof: + for i, (data, target) in enumerate(dataloader): + if i >= 5: + break + data, target = data.cuda(), target.cuda() + output = model(data) + loss = criterion(output, target) + loss.backward() + + # Look for high memory-bound ops + events = prof.key_averages() + for evt in events: + if evt.cuda_time_total > 0: + # If many large tensor ops with low FLOPS → memory-bound + pass + + # 3. Profile phases + print("\n=== Phase Profiling ===") + times = profile_training_phases(model, next(iter(dataloader))) + + # Interpret results + print("\n=== Diagnosis ===") + if times['data_loading'] > times['forward'] + times['backward']: + print("BOTTLENECK: CPU-bound (data loading)") + print("Action: Increase num_workers, enable pin_memory, cache data") + elif times['forward'] > times['backward'] * 2: + print("BOTTLENECK: GPU-bound (forward pass)") + print("Action: Profile forward pass operations") + elif times['backward'] > times['forward'] * 2: + print("BOTTLENECK: GPU-bound (backward pass)") + print("Action: Profile backward pass, check gradient checkpointing") + else: + print("BOTTLENECK: Mixed or memory-bound") + print("Action: Deep profiling needed") +``` + + +### Phase 3: Narrow to Component + +**Step 1: Coarse-Grained Profiling** + +```python +from torch.profiler import profile, ProfilerActivity, schedule + +def profile_training_step(model, dataloader, num_steps=10): + """Profile one training step to identify bottleneck phase""" + + # Use schedule to reduce profiling overhead + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=1, warmup=2, active=5, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs'), + record_shapes=True, + profile_memory=True, + with_stack=True + ) as prof: + + for step, (data, target) in enumerate(dataloader): + if step >= num_steps: + break + + data, target = data.cuda(), target.cuda() + + optimizer.zero_grad() + + # Forward + output = model(data) + loss = criterion(output, target) + + # Backward + loss.backward() + + # Optimizer + optimizer.step() + + prof.step() # Notify profiler of step boundary + + # Print summary + print(prof.key_averages().table( + sort_by="cuda_time_total", + row_limit=20, + max_src_column_width=80 + )) + + # Export trace for visualization + print("\n✅ Trace exported to ./profiler_logs") + print(" View in Chrome: chrome://tracing (load trace.json)") + print(" Or TensorBoard: tensorboard --logdir=./profiler_logs") + + return prof +``` + +**Understanding the schedule:** +- `wait=1`: Skip first iteration (cold start) +- `warmup=2`: Next 2 iterations for warmup (no profiling overhead) +- `active=5`: Profile these 5 iterations +- `repeat=1`: Do this cycle once + +**Why this matters:** +- Profiling has overhead - don't profile every iteration +- Schedule controls when profiling is active +- Warmup prevents including JIT compilation in measurements + + +**Step 2: Phase-Level Timing** + +```python +from torch.profiler import record_function + +def profile_training_phases(model, batch, target): + """Time each phase of training separately""" + + data, target = batch.cuda(), target.cuda() + optimizer.zero_grad() + + torch.cuda.synchronize() + + # Profile each phase + phases = {} + + # Forward pass + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + with record_function("forward_pass"): + start.record() + output = model(data) + loss = criterion(output, target) + end.record() + torch.cuda.synchronize() + phases['forward'] = start.elapsed_time(end) + + # Backward pass + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + with record_function("backward_pass"): + start.record() + loss.backward() + end.record() + torch.cuda.synchronize() + phases['backward'] = start.elapsed_time(end) + + # Optimizer step + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + with record_function("optimizer_step"): + start.record() + optimizer.step() + end.record() + torch.cuda.synchronize() + phases['optimizer'] = start.elapsed_time(end) + + # Print breakdown + total = sum(phases.values()) + print("Phase Breakdown:") + for phase, time_ms in phases.items(): + print(f" {phase:15s}: {time_ms:7.2f} ms ({time_ms/total*100:5.1f}%)") + + return phases +``` + +**Why this matters:** +- Identifies which phase is slowest +- Focuses subsequent profiling on bottleneck phase +- Uses `record_function` to add custom markers in trace view + + +**Step 3: Module-Level Profiling** + +```python +def profile_model_modules(model, sample_input): + """Profile time spent in each model module""" + + model.eval() + sample_input = sample_input.cuda() + + # Add hooks to time each module + module_times = {} + + def make_hook(name): + def hook(module, input, output): + if name not in module_times: + module_times[name] = [] + + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + # Forward already happened, this is for next time + end.record() + + torch.cuda.synchronize() + # (Note: This is simplified - real implementation more complex) + + return hook + + # Register hooks + hooks = [] + for name, module in model.named_modules(): + if len(list(module.children())) == 0: # Leaf modules only + hook = module.register_forward_hook(make_hook(name)) + hooks.append(hook) + + # Better approach: Use record_function + class ProfilingModule(torch.nn.Module): + def __init__(self, module, name): + super().__init__() + self.module = module + self.name = name + + def forward(self, *args, **kwargs): + with record_function(f"module_{self.name}"): + return self.module(*args, **kwargs) + + # Or just use torch.profiler with record_shapes=True + # It will automatically show module breakdown + + with profile(activities=[ProfilerActivity.CUDA]) as prof: + with record_function("model_forward"): + output = model(sample_input) + + print(prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total", row_limit=20 + )) + + # Clean up + for hook in hooks: + hook.remove() +``` + +**Why this matters:** +- Identifies which model component is slowest +- Guides optimization efforts to specific layers +- Reveals unexpected bottlenecks (e.g., LayerNorm taking 30% of time) + + +### Phase 4: Identify Operation + +**Step 1: Detailed Operation Profiling** + +```python +def profile_operations_detailed(model, sample_input): + """Get detailed breakdown of all operations""" + + model.eval() + sample_input = sample_input.cuda() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + profile_memory=True + ) as prof: + output = model(sample_input) + + # Group by operation type + print("\n=== Top Operations by CUDA Time ===") + print(prof.key_averages().table( + sort_by="cuda_time_total", + row_limit=30, + max_src_column_width=100 + )) + + print("\n=== Top Operations by Memory ===") + print(prof.key_averages().table( + sort_by="self_cuda_memory_usage", + row_limit=20, + max_src_column_width=100 + )) + + print("\n=== Grouped by Input Shape ===") + print(prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total", + row_limit=20 + )) + + # Export for trace view + prof.export_chrome_trace("detailed_trace.json") + print("\n✅ Exported detailed_trace.json - view in chrome://tracing") + + return prof +``` + + +**Step 2: Reading Profiler Output** + +```python +# Example profiler output: +""" +--------------------------------- ------------ ------------ ------------ ------------ + Name Self CPU % Self CPU CPU total % CPU total +--------------------------------- ------------ ------------ ------------ ------------ + aten::conv2d 0.5% 124ms 45.2% 11.234s + aten::convolution 1.2% 298ms 44.7% 11.110s + aten::_convolution 2.3% 571ms 43.5% 10.812s + aten::cudnn_convolution 40.1% 9.967s 41.2% 10.241s + aten::batch_norm 0.3% 74ms 25.8% 6.412s + aten::_batch_norm 1.1% 273ms 25.5% 6.338s + aten::cudnn_batch_norm 23.2% 5.765s 24.4% 6.065s + aten::relu 8.2% 2.038s 8.2% 2.038s +""" +``` + +**How to interpret:** + +| Column | Meaning | When to Look | +|--------|---------|--------------| +| Name | Operation name | Identify what operation | +| Self CPU % | Time in this op only (no children) | Find leaf operations | +| CPU total % | Time in op + children | Find expensive subtrees | +| Self CUDA time | GPU execution time | Main metric for GPU ops | +| Call count | How many times called | High count = optimization target | + +**Common patterns:** + +```python +# Pattern 1: High aten::copy_ time (40%+) +# → Device transfer issue (CPU ↔ GPU) +# Action: Check device placement, reduce transfers + +# Pattern 2: High cudaLaunchKernel overhead +# → Too many small kernel launches +# Action: Increase batch size, fuse operations + +# Pattern 3: High cudnn_convolution time +# → Convolutions are bottleneck (expected for CNNs) +# Action: Check input dimensions for Tensor Core alignment + +# Pattern 4: High CPU time, low CUDA time +# → CPU bottleneck (data loading, preprocessing) +# Action: Increase num_workers, move ops to GPU + +# Pattern 5: Many small operations +# → Operation fusion opportunity +# Action: Use torch.compile or fuse manually +``` + + +**Step 3: Trace View Analysis** + +```python +# After exporting trace: prof.export_chrome_trace("trace.json") +# Open in chrome://tracing + +""" +Trace view shows: +1. Timeline of GPU kernels +2. CPU → GPU synchronization points +3. Parallel vs sequential execution +4. GPU idle time (gaps between kernels) + +What to look for: +- Large gaps between GPU kernels → GPU underutilized +- Many thin bars → Too many small operations +- Thick bars → Few large operations (good for GPU) +- Yellow/red bars → CPU activity (should be minimal during GPU work) +- Overlapping bars → Concurrent execution (good) +""" +``` + +**Reading trace view:** + +``` +GPU Stream 0: ████░░░░████░░░░████ ← Gaps = idle GPU (bad) +GPU Stream 0: ███████████████████ ← Continuous = good utilization +CPU: ░░░░████░░░░████░░░░ ← CPU peaks = data loading + +Timeline: +[Data Load]──→[GPU Forward]──→[Data Load]──→[GPU Forward] + ↑ Gap here = GPU waiting for data +``` + + +## Memory Profiling + +### Memory Tracking Methodology + +**Step 1: Basic Memory Tracking** + +```python +import torch + +def track_memory(stage_name): + """Print current memory usage""" + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + print(f"{stage_name:30s} - Allocated: {allocated:6.2f} GB, " + f"Reserved: {reserved:6.2f} GB") + +# Track at each training phase +track_memory("Start") + +data, target = next(iter(dataloader)) +data, target = data.cuda(), target.cuda() +track_memory("After data to GPU") + +output = model(data) +track_memory("After forward") + +loss = criterion(output, target) +track_memory("After loss") + +loss.backward() +track_memory("After backward") + +optimizer.step() +track_memory("After optimizer step") + +optimizer.zero_grad() +track_memory("After zero_grad") +``` + +**Output interpretation:** + +``` +Start - Allocated: 2.50 GB, Reserved: 2.75 GB +After data to GPU - Allocated: 2.62 GB, Reserved: 2.75 GB +After forward - Allocated: 4.80 GB, Reserved: 5.00 GB ← Activations +After loss - Allocated: 4.81 GB, Reserved: 5.00 GB +After backward - Allocated: 7.20 GB, Reserved: 7.50 GB ← Gradients +After optimizer step - Allocated: 7.20 GB, Reserved: 7.50 GB +After zero_grad - Allocated: 4.70 GB, Reserved: 7.50 GB ← Gradients freed +``` + +**Key insights:** +- Allocated = actual memory used +- Reserved = memory held by allocator (may be > allocated due to caching) +- Large jump after forward = activations (consider gradient checkpointing) +- Large jump after backward = gradients (same size as parameters) +- Reserved stays high = memory fragmentation or caching + + +**Step 2: Peak Memory Analysis** + +```python +def analyze_peak_memory(model, batch, target): + """Find peak memory usage and what causes it""" + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + # Run one iteration + data, target = batch.cuda(), target.cuda() + + output = model(data) + forward_peak = torch.cuda.max_memory_allocated() / 1e9 + + loss = criterion(output, target) + loss_peak = torch.cuda.max_memory_allocated() / 1e9 + + loss.backward() + backward_peak = torch.cuda.max_memory_allocated() / 1e9 + + optimizer.step() + optimizer_peak = torch.cuda.max_memory_allocated() / 1e9 + + optimizer.zero_grad() + final_peak = torch.cuda.max_memory_allocated() / 1e9 + + print(f"Peak after forward: {forward_peak:.2f} GB") + print(f"Peak after loss: {loss_peak:.2f} GB") + print(f"Peak after backward: {backward_peak:.2f} GB") + print(f"Peak after optimizer: {optimizer_peak:.2f} GB") + print(f"Overall peak: {final_peak:.2f} GB") + + # Identify bottleneck + if forward_peak > backward_peak * 0.8: + print("\n⚠️ Activations dominate memory usage") + print(" Consider: Gradient checkpointing, smaller batch size") + elif backward_peak > forward_peak * 1.5: + print("\n⚠️ Gradients dominate memory usage") + print(" Consider: Gradient accumulation, mixed precision") + else: + print("\n✅ Memory usage balanced across phases") + + return { + 'forward': forward_peak, + 'backward': backward_peak, + 'optimizer': optimizer_peak, + 'peak': final_peak + } +``` + + +**Step 3: Detailed Memory Summary** + +```python +def print_memory_summary(): + """Print detailed memory breakdown""" + print(torch.cuda.memory_summary()) + +""" +Example output: +|===========================================================================| +| PyTorch CUDA memory summary | +|---------------------------------------------------------------------------| +| CUDA OOMs: 0 | +| Metric | Cur Usage | Peak Usage | Alloc Retries | # Allocs | +|----------------|------------|------------|---------------|---------------| +| Allocated | 4.50 GB | 7.20 GB | 0 | 15234 | +| Reserved | 7.50 GB | 7.50 GB | 0 | 1523 | +| Active | 4.50 GB | 7.20 GB | | | +| Inactive | 3.00 GB | 0.30 GB | | | +|===========================================================================| + +Allocated memory: 4.50 GB ← Actual tensors +Reserved memory: 7.50 GB ← Memory held by allocator +Active allocations: 4.50 GB ← Currently in use +Inactive allocations: 3.00 GB ← Cached for reuse (fragmentation) +""" + +# If Inactive >> 0, memory fragmentation is occurring +# Periodic torch.cuda.empty_cache() may help +``` + + +**Step 4: Memory Snapshot (PyTorch 2.0+)** + +```python +import pickle +import torch.cuda + +def capture_memory_snapshot(filename="memory_snapshot.pickle"): + """Capture detailed memory snapshot for analysis""" + + # Enable memory history tracking + torch.cuda.memory._record_memory_history(max_entries=100000) + + try: + # Run your training code here + for i, (data, target) in enumerate(dataloader): + if i >= 5: + break + data, target = data.cuda(), target.cuda() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Capture snapshot + torch.cuda.memory._dump_snapshot(filename) + print(f"✅ Memory snapshot saved to {filename}") + + finally: + # Disable tracking + torch.cuda.memory._record_memory_history(enabled=None) + + print(f"\nAnalyze with:") + print(f" python -m torch.cuda._memory_viz trace_plot {filename}") + print(f" # Opens interactive visualization in browser") +``` + +**Memory snapshot visualization shows:** +- Allocation timeline +- Stack traces for each allocation +- Memory leaks (allocations never freed) +- Fragmentation patterns +- Peak memory events + + +## GPU Timing Best Practices + +### CUDA Synchronization and Events + +**❌ WRONG: Using time.time() for GPU operations** + +```python +import time + +# WRONG - measures CPU time, not GPU time! +start = time.time() +output = model(data) # Kernel launches, returns immediately +end = time.time() +print(f"Time: {end - start:.4f}s") # ❌ This is kernel launch overhead (~microseconds) + +# Problem: CUDA operations are asynchronous +# time.time() measures CPU time (when kernel was launched) +# Not GPU time (when kernel actually executed) +``` + +**Why this is wrong:** +- CUDA kernel launches are asynchronous (return immediately to CPU) +- `time.time()` measures CPU wall-clock time +- Actual GPU execution happens later, in parallel with CPU +- Measured time is kernel launch overhead (microseconds), not execution time + + +**✅ CORRECT: Using CUDA Events** + +```python +# CORRECT - measures actual GPU execution time +start_event = torch.cuda.Event(enable_timing=True) +end_event = torch.cuda.Event(enable_timing=True) + +start_event.record() +output = model(data) +end_event.record() + +# Wait for GPU to finish +torch.cuda.synchronize() + +# Get elapsed time in milliseconds +elapsed_time_ms = start_event.elapsed_time(end_event) +print(f"GPU Time: {elapsed_time_ms:.2f} ms") +``` + +**Why this is correct:** +- CUDA Events are GPU-native timing +- `record()` inserts timing markers into GPU stream +- `synchronize()` waits for GPU to complete +- `elapsed_time()` returns actual GPU execution time + + +**Alternative: Using torch.profiler** + +```python +# For comprehensive profiling, use torch.profiler instead of manual timing +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CUDA]) as prof: + output = model(data) + +print(prof.key_averages().table(sort_by="cuda_time_total")) +# This automatically handles synchronization and provides detailed breakdown +``` + + +### Warmup Iterations + +**Why warmup is critical:** + +```python +# First iteration includes: +# 1. CUDA kernel JIT compilation +# 2. cuDNN algorithm selection (benchmark mode) +# 3. Memory pool allocation +# 4. CPU→GPU transfer of model weights (first time) + +# Example timing without warmup: +model.eval() +with torch.no_grad(): + for i in range(10): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(sample_input) + end.record() + torch.cuda.synchronize() + + print(f"Iteration {i}: {start.elapsed_time(end):.2f} ms") + +""" +Output: +Iteration 0: 1234.56 ms ← JIT compilation, cuDNN benchmarking +Iteration 1: 987.43 ms ← Still some overhead +Iteration 2: 102.34 ms ← Stabilized +Iteration 3: 101.89 ms ← Stable +Iteration 4: 102.12 ms ← Stable +... +""" +``` + +**✅ Correct warmup methodology:** + +```python +def benchmark_with_warmup(model, sample_input, warmup=5, iterations=100): + """Proper benchmarking with warmup""" + + model.eval() + sample_input = sample_input.cuda() + + # Warmup iterations (CRITICAL!) + with torch.no_grad(): + for _ in range(warmup): + _ = model(sample_input) + + # Ensure warmup completed + torch.cuda.synchronize() + + # Actual measurement + times = [] + with torch.no_grad(): + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(sample_input) + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + # Report statistics + import numpy as np + times = np.array(times) + + print(f"Mean: {times.mean():.2f} ms") + print(f"Std: {times.std():.2f} ms") + print(f"Median: {np.median(times):.2f} ms") + print(f"Min: {times.min():.2f} ms") + print(f"Max: {times.max():.2f} ms") + print(f"P95: {np.percentile(times, 95):.2f} ms") + print(f"P99: {np.percentile(times, 99):.2f} ms") + + return times +``` + +**Warmup rules:** +- Minimum 3 iterations, recommend 5-10 +- More complex models need more warmup +- Dynamic control flow needs extra warmup +- Report statistics (mean, std, percentiles), not just average + + +## Bottleneck Identification Patterns + +### CPU-Bound Bottlenecks + +**Symptoms:** +- Low GPU utilization (<70%) +- High CPU usage +- Data loading time > computation time +- `nvidia-smi` shows low GPU usage + +**Diagnostic code:** + +```python +def diagnose_cpu_bottleneck(model, dataloader): + """Check if training is CPU-bound""" + + # Check GPU utilization + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'], + capture_output=True, text=True + ) + gpu_util = int(result.stdout.strip()) + + print(f"GPU Utilization: {gpu_util}%") + + if gpu_util < 70: + print("⚠️ LOW GPU UTILIZATION - likely CPU-bound") + + # Profile data loading vs compute + data_time, compute_time = profile_dataloader_vs_model(model, dataloader) + + if data_time > compute_time: + print("\n🎯 BOTTLENECK: Data loading") + print("Solutions:") + print(" 1. Increase num_workers in DataLoader") + print(f" Current: {dataloader.num_workers}, try: {os.cpu_count()}") + print(" 2. Enable pin_memory=True") + print(" 3. Move data augmentation to GPU (use kornia)") + print(" 4. Cache preprocessed data if dataset is small") + print(" 5. Use faster storage (SSD instead of HDD)") + + else: + print("\n🎯 BOTTLENECK: CPU preprocessing") + print("Solutions:") + print(" 1. Move preprocessing to GPU") + print(" 2. Reduce preprocessing complexity") + print(" 3. Batch preprocessing operations") + + else: + print("✅ GPU utilization is healthy") + print(" Bottleneck is likely in GPU computation") + + return gpu_util +``` + +**Common solutions:** + +```python +# Solution 1: Increase num_workers +dataloader = DataLoader( + dataset, + batch_size=32, + num_workers=8, # Increase from default 0 + pin_memory=True, # Enable for faster GPU transfer + persistent_workers=True # Keep workers alive between epochs +) + +# Solution 2: Move augmentation to GPU +import kornia + +class GPUAugmentation(nn.Module): + def __init__(self): + super().__init__() + self.augment = nn.Sequential( + kornia.augmentation.RandomHorizontalFlip(p=0.5), + kornia.augmentation.ColorJitter(0.2, 0.2, 0.2, 0.1), + kornia.augmentation.RandomResizedCrop((224, 224)), + ) + + def forward(self, x): + return self.augment(x) + +# Apply on GPU +gpu_augment = GPUAugmentation().cuda() +for data, target in dataloader: + data = data.cuda() + data = gpu_augment(data) # Augment on GPU + output = model(data) +``` + + +### GPU-Bound Bottlenecks + +**Symptoms:** +- High GPU utilization (>90%) +- Computation time > data loading time +- High CUDA time in profiler + +**Diagnostic code:** + +```python +def diagnose_gpu_bottleneck(model, sample_input): + """Identify GPU bottleneck operations""" + + with profile(activities=[ProfilerActivity.CUDA]) as prof: + output = model(sample_input) + loss = criterion(output, target) + loss.backward() + + # Find top GPU operations + events = prof.key_averages() + cuda_events = [(evt.key, evt.cuda_time_total) for evt in events + if evt.cuda_time_total > 0] + cuda_events.sort(key=lambda x: x[1], reverse=True) + + print("Top 10 GPU operations:") + total_time = sum(time for _, time in cuda_events) + for i, (name, time) in enumerate(cuda_events[:10], 1): + percentage = (time / total_time) * 100 + print(f"{i:2d}. {name:40s} {time/1000:8.2f} ms ({percentage:5.1f}%)") + + # Check for optimization opportunities + top_op = cuda_events[0][0] + if 'conv' in top_op.lower(): + print("\n🎯 Bottleneck: Convolution operations") + print("Solutions:") + print(" 1. Check input dimensions for Tensor Core alignment (multiples of 8)") + print(" 2. Use mixed precision (torch.cuda.amp)") + print(" 3. Consider depthwise separable convolutions") + print(" 4. Profile with different batch sizes") + + elif 'mm' in top_op.lower() or 'matmul' in top_op.lower(): + print("\n🎯 Bottleneck: Matrix multiplication") + print("Solutions:") + print(" 1. Ensure dimensions are multiples of 8 (FP16) or 16 (BF16)") + print(" 2. Use mixed precision") + print(" 3. Check for unnecessary transposes") + + elif 'copy' in top_op.lower(): + print("\n🎯 Bottleneck: Memory copies") + print("Solutions:") + print(" 1. Check device placement (CPU ↔ GPU transfers)") + print(" 2. Ensure tensors are contiguous") + print(" 3. Reduce explicit .cuda() or .cpu() calls") + + return cuda_events +``` + +**Common solutions:** + +```python +# Solution 1: Mixed precision +from torch.cuda.amp import autocast + +with autocast(): + output = model(data) + loss = criterion(output, target) +# 2-3x speedup for large models + +# Solution 2: Tensor Core alignment +# Ensure dimensions are multiples of 8 (FP16) or 16 (BF16) +# BAD: (batch=31, seq_len=127, hidden=509) +# GOOD: (batch=32, seq_len=128, hidden=512) + +# Solution 3: torch.compile (PyTorch 2.0+) +model = torch.compile(model) +# Automatic kernel fusion and optimization +``` + + +### Memory-Bound Bottlenecks + +**Symptoms:** +- Low GPU utilization despite high memory usage +- Large tensor operations dominating time +- Memory bandwidth saturated + +**Diagnostic code:** + +```python +def diagnose_memory_bottleneck(model, sample_input): + """Check if operations are memory-bandwidth limited""" + + # Profile memory and compute + with profile( + activities=[ProfilerActivity.CUDA], + profile_memory=True + ) as prof: + output = model(sample_input) + + # Analyze operations + for evt in prof.key_averages(): + if evt.cuda_time_total > 0 and evt.self_cuda_memory_usage > 0: + # Rough FLOP/s estimate + # Memory-bound: low FLOP/s despite high memory usage + # Compute-bound: high FLOP/s + + memory_gb = evt.self_cuda_memory_usage / 1e9 + time_s = evt.cuda_time_total / 1e6 # µs to s + + if memory_gb > 1.0 and time_s > 0.01: + bandwidth = memory_gb / time_s # GB/s + print(f"{evt.key:40s}: {bandwidth:.1f} GB/s") + + print("\nIf bandwidth < 500 GB/s, likely memory-bound") + print("Solutions:") + print(" 1. Reduce intermediate tensor sizes") + print(" 2. Use in-place operations where safe") + print(" 3. Tile large operations") + print(" 4. Increase arithmetic intensity (more compute per byte)") +``` + + +### I/O-Bound Bottlenecks + +**Symptoms:** +- Low CPU and GPU utilization +- Long pauses between batches +- Slow disk I/O + +**Solutions:** + +```python +# Solution 1: Cache dataset in RAM +class CachedDataset(Dataset): + def __init__(self, dataset): + self.cache = [dataset[i] for i in range(len(dataset))] + + def __getitem__(self, idx): + return self.cache[idx] + + def __len__(self): + return len(self.cache) + +# Solution 2: Use SSD storage or RAM disk +# Solution 3: Prefetch data +dataloader = DataLoader( + dataset, + num_workers=8, + prefetch_factor=4, # Prefetch 4 batches per worker + pin_memory=True +) +``` + + +## Common Profiling Mistakes + +### Mistake 1: Profiling Too Many Iterations + +**❌ WRONG:** + +```python +# Profiling 100 epochs - output is massive, unusable +with profile(activities=[ProfilerActivity.CUDA]) as prof: + for epoch in range(100): + for batch in dataloader: + # ... training ... + pass +``` + +**✅ CORRECT:** + +```python +# Profile just a few iterations +with profile( + activities=[ProfilerActivity.CUDA], + schedule=schedule(wait=1, warmup=2, active=3, repeat=1) +) as prof: + for epoch in range(1): + for step, batch in enumerate(dataloader): + if step >= 10: + break + # ... training ... + prof.step() +``` + + +### Mistake 2: No Warmup Before Timing + +**❌ WRONG:** + +```python +# Including JIT compilation in timing +start = time.time() +output = model(data) # First call - includes JIT overhead +end = time.time() +``` + +**✅ CORRECT:** + +```python +# Warmup first +for _ in range(5): + _ = model(data) + +torch.cuda.synchronize() + +# Now measure +start = torch.cuda.Event(enable_timing=True) +end = torch.cuda.Event(enable_timing=True) +start.record() +output = model(data) +end.record() +torch.cuda.synchronize() +``` + + +### Mistake 3: Synchronization Overhead + +**❌ WRONG:** + +```python +# Synchronizing in training loop - kills performance! +for batch in dataloader: + output = model(batch) + torch.cuda.synchronize() # ❌ Breaks pipelining! + loss = criterion(output, target) + torch.cuda.synchronize() # ❌ Unnecessary! + loss.backward() +``` + +**✅ CORRECT:** + +```python +# Only synchronize for timing/profiling, not in production +for batch in dataloader: + output = model(batch) + loss = criterion(output, target) + loss.backward() + # No synchronization - let GPU pipeline work +``` + +**When to synchronize:** +- Profiling/timing measurements +- Before memory measurements +- Debugging CUDA errors +- NEVER in production training loop + + +### Mistake 4: Wrong Profiling Granularity + +**❌ WRONG:** + +```python +# Profiling entire model - too coarse, can't identify bottleneck +with profile() as prof: + output = model(data) +# "Model takes 100ms" - not actionable! +``` + +**✅ CORRECT:** + +```python +# Iterative narrowing: +# 1. Profile whole step +# 2. Identify slow phase (forward, backward, optimizer) +# 3. Profile that phase in detail +# 4. Identify specific operation + +# Phase 1: Coarse +with record_function("forward"): + output = model(data) +with record_function("backward"): + loss.backward() + +# Phase 2: Found forward is slow, profile in detail +with profile() as prof: + output = model(data) +# Now see which layer is slow + +# Phase 3: Found layer X is slow, profile that layer +with profile() as prof: + output = model.layer_x(data) +# Now see which operation in layer X is slow +``` + + +### Mistake 5: Ignoring Memory While Profiling Compute + +**❌ WRONG:** + +```python +# Only looking at time, ignoring memory +with profile(activities=[ProfilerActivity.CUDA]) as prof: + output = model(data) +``` + +**✅ CORRECT:** + +```python +# Profile both compute AND memory +with profile( + activities=[ProfilerActivity.CUDA], + profile_memory=True +) as prof: + output = model(data) + +# Check both time and memory +print(prof.key_averages().table(sort_by="cuda_time_total")) +print(prof.key_averages().table(sort_by="self_cuda_memory_usage")) +``` + + +### Mistake 6: Profiling in Wrong Mode + +**❌ WRONG:** + +```python +# Profiling in eval mode when you care about training speed +model.eval() +with torch.no_grad(): + with profile() as prof: + output = model(data) +# ❌ This doesn't include backward pass! +``` + +**✅ CORRECT:** + +```python +# Profile in the mode you actually use +model.train() +with profile() as prof: + output = model(data) + loss = criterion(output, target) + loss.backward() # ✅ Include backward if profiling training +``` + + +## Red Flags - Stop and Profile Systematically + +**If you catch yourself thinking ANY of these, STOP and follow methodology:** + +| Red Flag Thought | Reality | What to Do Instead | +|------------------|---------|-------------------| +| "I can see the bottleneck" | 90% of the time your guess is wrong | Profile to confirm, don't guess | +| "User says X is slow, so X is the bottleneck" | User might be wrong about the cause | Verify with profiling | +| "This loop looks inefficient" | Intuition about performance often wrong | Measure it, don't assume | +| "Profiling takes too long" | Profiling saves hours of guessing | 10 minutes of profiling > hours of guessing | +| "Let me just try this optimization" | Premature optimization wastes time | Measure first, optimize second | +| "It's obviously a GPU problem" | Could be CPU, data loading, or I/O | Check GPU utilization first | +| "I'll reduce batch size" | Doesn't address root cause | Diagnose memory bottleneck first | +| "Skip warmup, it's just one iteration" | First iterations have 10-100x overhead | Always warmup, no exceptions | + +**Critical rules:** +1. NEVER optimize before profiling +2. ALWAYS use warmup iterations +3. ALWAYS check GPU utilization before assuming GPU bottleneck +4. ALWAYS profile data loading separately from computation +5. ALWAYS report statistics (mean, std, percentiles), not just average +6. ALWAYS use CUDA Events for GPU timing, never `time.time()` + + +## Common Rationalizations (Don't Do These) + +| Excuse | What Really Happens | Correct Approach | +|--------|-------------------|------------------| +| "User seems rushed, skip profiling" | Guessing wastes MORE time than profiling | 10 min profiling saves hours | +| "I already profiled once" | Might have used wrong tool or granularity | Re-profile with systematic methodology | +| "Profiling overhead will skew results" | Use schedule to minimize overhead | `schedule(wait=1, warmup=2, active=3)` | +| "This worked on another model" | Different models have different bottlenecks | Profile THIS model, not assumptions | +| "Documentation says X is slow" | Depends on context, hardware, data | Verify with profiling on YOUR setup | +| "Just trust the profiler output" | Must interpret correctly | Understand what metrics mean | +| "The model is the bottleneck" | Often it's data loading | Always check data loading vs compute | + + +## Complete Profiling Example + +```python +import torch +import torch.nn as nn +from torch.profiler import profile, ProfilerActivity, schedule +from torch.utils.data import DataLoader +import numpy as np + +def comprehensive_profiling(model, dataloader, device='cuda'): + """ + Complete profiling workflow following systematic methodology + """ + + print("=" * 80) + print("PHASE 1: ESTABLISH BASELINE") + print("=" * 80) + + # Step 1: Measure baseline performance + model = model.to(device) + model.train() + + # Warmup + print("\nWarming up (5 iterations)...") + for i, (data, target) in enumerate(dataloader): + if i >= 5: + break + data, target = data.to(device), target.to(device) + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + torch.cuda.synchronize() + + # Measure baseline + print("\nMeasuring baseline (10 iterations)...") + times = [] + for i, (data, target) in enumerate(dataloader): + if i >= 10: + break + + data, target = data.to(device), target.to(device) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + times = np.array(times) + print(f"\nBaseline Performance:") + print(f" Mean: {times.mean():.2f} ms/iteration") + print(f" Std: {times.std():.2f} ms") + print(f" Median: {np.median(times):.2f} ms") + print(f" P95: {np.percentile(times, 95):.2f} ms") + + # ------------------------------------------------------------------------- + print("\n" + "=" * 80) + print("PHASE 2: IDENTIFY BOTTLENECK TYPE") + print("=" * 80) + + # Check GPU utilization + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=utilization.gpu,memory.used', + '--format=csv,noheader,nounits'], + capture_output=True, text=True + ) + gpu_util, mem_used = result.stdout.strip().split(',') + print(f"\nGPU Utilization: {gpu_util}%") + print(f"GPU Memory Used: {mem_used} MB") + + if int(gpu_util) < 70: + print("⚠️ LOW GPU UTILIZATION - likely CPU-bound") + else: + print("✅ GPU utilization healthy - likely GPU-bound") + + # Profile data loading vs computation + print("\nProfiling data loading vs computation...") + data_times = [] + compute_times = [] + + batch_iter = iter(dataloader) + for i in range(20): + import time + + # Data loading time + data_start = time.time() + data, target = next(batch_iter) + data_time = time.time() - data_start + data_times.append(data_time * 1000) # to ms + + # Computation time + data, target = data.to(device), target.to(device) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + end.record() + + torch.cuda.synchronize() + compute_times.append(start.elapsed_time(end)) + + avg_data = np.mean(data_times) + avg_compute = np.mean(compute_times) + + print(f"\nData loading: {avg_data:.2f} ms") + print(f"Computation: {avg_compute:.2f} ms") + + if avg_data > avg_compute: + print("🎯 BOTTLENECK: Data loading (CPU-bound)") + else: + print("🎯 BOTTLENECK: Model computation (GPU-bound)") + + # ------------------------------------------------------------------------- + print("\n" + "=" * 80) + print("PHASE 3: NARROW TO COMPONENT") + print("=" * 80) + + # Profile training phases + print("\nProfiling training phases...") + + data, target = next(iter(dataloader)) + data, target = data.to(device), target.to(device) + + # Forward + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = model(data) + loss = criterion(output, target) + end.record() + torch.cuda.synchronize() + forward_time = start.elapsed_time(end) + + # Backward + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + loss.backward() + end.record() + torch.cuda.synchronize() + backward_time = start.elapsed_time(end) + + # Optimizer + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + optimizer.step() + optimizer.zero_grad() + end.record() + torch.cuda.synchronize() + optimizer_time = start.elapsed_time(end) + + total = forward_time + backward_time + optimizer_time + + print(f"\nPhase breakdown:") + print(f" Forward: {forward_time:7.2f} ms ({forward_time/total*100:5.1f}%)") + print(f" Backward: {backward_time:7.2f} ms ({backward_time/total*100:5.1f}%)") + print(f" Optimizer: {optimizer_time:7.2f} ms ({optimizer_time/total*100:5.1f}%)") + + # ------------------------------------------------------------------------- + print("\n" + "=" * 80) + print("PHASE 4: IDENTIFY OPERATION") + print("=" * 80) + + # Detailed profiling + print("\nRunning detailed profiler...") + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=1, warmup=2, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs'), + record_shapes=True, + profile_memory=True, + with_stack=True + ) as prof: + for step, (data, target) in enumerate(dataloader): + if step >= 10: + break + + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + prof.step() + + # Print summary + print("\nTop operations by CUDA time:") + print(prof.key_averages().table( + sort_by="cuda_time_total", + row_limit=15, + max_src_column_width=60 + )) + + print("\nTop operations by memory:") + print(prof.key_averages().table( + sort_by="self_cuda_memory_usage", + row_limit=10, + max_src_column_width=60 + )) + + prof.export_chrome_trace("detailed_trace.json") + + print("\n" + "=" * 80) + print("PROFILING COMPLETE") + print("=" * 80) + print("\nNext steps:") + print(" 1. Review top operations in table above") + print(" 2. Open chrome://tracing and load detailed_trace.json") + print(" 3. Or view in TensorBoard: tensorboard --logdir=./profiler_logs") + print(" 4. Focus optimization on identified bottleneck") + print(" 5. Re-run this profiling after optimization to verify improvement") + + return { + 'baseline_ms': times.mean(), + 'gpu_utilization': int(gpu_util), + 'data_loading_ms': avg_data, + 'computation_ms': avg_compute, + 'forward_ms': forward_time, + 'backward_ms': backward_time, + 'optimizer_ms': optimizer_time, + } +``` + + +## Memory Profiling Complete Example + +```python +def profile_memory_usage(model, sample_batch, sample_target): + """Comprehensive memory profiling""" + + print("=" * 80) + print("MEMORY PROFILING") + print("=" * 80) + + device = next(model.parameters()).device + + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + def print_memory(stage): + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + peak = torch.cuda.max_memory_allocated() / 1e9 + print(f"{stage:30s} | Alloc: {allocated:5.2f} GB | " + f"Reserved: {reserved:5.2f} GB | Peak: {peak:5.2f} GB") + + print("\nMemory tracking:") + print("-" * 80) + + print_memory("Initial") + + # Move data to GPU + data = sample_batch.to(device) + target = sample_target.to(device) + print_memory("After data to GPU") + + # Forward pass + output = model(data) + print_memory("After forward") + + # Loss + loss = criterion(output, target) + print_memory("After loss") + + # Backward + loss.backward() + print_memory("After backward") + + # Optimizer + optimizer.step() + print_memory("After optimizer step") + + # Zero grad + optimizer.zero_grad() + print_memory("After zero_grad") + + # Final peak + peak_memory = torch.cuda.max_memory_allocated() / 1e9 + print("-" * 80) + print(f"\nPeak memory usage: {peak_memory:.2f} GB") + + # Detailed summary + print("\n" + "=" * 80) + print("DETAILED MEMORY SUMMARY") + print("=" * 80) + print(torch.cuda.memory_summary()) + + # Memory breakdown + print("\n" + "=" * 80) + print("MEMORY OPTIMIZATION SUGGESTIONS") + print("=" * 80) + + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + + if reserved > allocated * 1.5: + print("⚠️ Memory fragmentation detected") + print(f" Reserved: {reserved:.2f} GB, Allocated: {allocated:.2f} GB") + print(" Suggestion: Call torch.cuda.empty_cache() periodically") + + # Estimate memory components + param_memory = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9 + print(f"\nModel parameters: {param_memory:.2f} GB") + + # Estimate gradients (same size as parameters) + print(f"Gradients (estimate): {param_memory:.2f} GB") + + # Optimizer states (Adam: 2x parameters) + if isinstance(optimizer, torch.optim.Adam): + optimizer_memory = param_memory * 2 + print(f"Optimizer states (Adam): {optimizer_memory:.2f} GB") + + # Activations (peak - parameters - gradients - optimizer) + activation_memory = peak_memory - param_memory - param_memory - optimizer_memory + if activation_memory > 0: + print(f"Activations (estimate): {activation_memory:.2f} GB") + + if activation_memory > peak_memory * 0.5: + print("\n🎯 Activations dominate memory usage") + print(" Suggestions:") + print(" 1. Use gradient checkpointing") + print(" 2. Reduce batch size") + print(" 3. Use mixed precision (FP16/BF16)") + + return { + 'peak_gb': peak_memory, + 'parameters_gb': param_memory, + 'fragmentation_ratio': reserved / allocated if allocated > 0 else 1.0 + } +``` + + +## Profiling Checklist + +Before claiming you've profiled the code, verify: + +- [ ] **Baseline established** + - [ ] Defined performance metric (throughput/latency/memory) + - [ ] Measured with CUDA Events (not time.time()) + - [ ] Used 5+ warmup iterations + - [ ] Reported statistics (mean, std, percentiles) + - [ ] Documented measurement conditions + +- [ ] **Bottleneck type identified** + - [ ] Checked GPU utilization (nvidia-smi) + - [ ] Profiled data loading vs computation separately + - [ ] Categorized as CPU-bound, GPU-bound, memory-bound, or I/O-bound + - [ ] Verified category with profiling data (not guessing) + +- [ ] **Component identified** + - [ ] Profiled training phases (forward/backward/optimizer) + - [ ] Identified which phase is slowest + - [ ] Used iterative narrowing approach + - [ ] Examined both table and trace view + +- [ ] **Operation identified** + - [ ] Profiled bottleneck component in detail + - [ ] Found specific operation or pattern + - [ ] Understand WHY it's slow (not just WHAT is slow) + - [ ] Have actionable optimization target + +- [ ] **Verification ready** + - [ ] Saved baseline measurements + - [ ] Know how to re-run profiling after optimization + - [ ] Can verify if optimization actually helped + - [ ] Have profiling artifacts (traces, summaries) + + +## References + +**PyTorch Profiling Documentation:** +- torch.profiler: https://pytorch.org/docs/stable/profiler.html +- Profiling recipe: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html +- Performance tuning guide: https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html + +**Related Skills:** +- tensor-operations-and-memory (memory leak debugging, operation optimization) +- mixed-precision-and-optimization (AMP profiling, Tensor Core utilization) +- distributed-training-strategies (multi-GPU profiling) + +**Tools:** +- Chrome tracing: chrome://tracing +- TensorBoard profiler: tensorboard --logdir= +- NVIDIA Nsight Systems: nsys profile python train.py +- PyTorch Memory Visualizer: python -m torch.cuda._memory_viz diff --git a/skills/using-pytorch-engineering/tensor-operations-and-memory.md b/skills/using-pytorch-engineering/tensor-operations-and-memory.md new file mode 100644 index 0000000..8ed40d2 --- /dev/null +++ b/skills/using-pytorch-engineering/tensor-operations-and-memory.md @@ -0,0 +1,1029 @@ + +# Tensor Operations and Memory Management + +## Overview + +**Core Principle:** PyTorch memory issues stem from understanding tensor lifecycle, operation efficiency, and device management. Fix at the operation level, not by adding more RAM. + +Memory leaks, OOM errors, and slow operations are symptoms. Root causes are: gradient retention, inefficient operations, device inconsistency, or Python reference cycles. Systematic diagnosis beats guessing. + +## When to Use + +**Use this skill when:** +- "CUDA out of memory" or GPU memory growing over time +- Training/inference slower than expected on GPU +- "CUDA error: device-side assert triggered" +- Memory usage doesn't decrease after batch/epoch +- Tensor operations causing performance bottlenecks +- Mixed precision training causing crashes + +**Don't use when:** +- Model architecture design (use neural-architectures) +- Training convergence issues (use training-optimization) +- Multi-GPU distributed training strategy (use distributed-training-strategies) + +**Symptoms triggering this skill:** +- "Memory keeps growing each epoch" +- "GPU utilization low but training slow" +- "Random CUDA crashes" +- "Tensor operations taking too long" + + +## Memory Leak Diagnosis Methodology + +### Systematic Debugging Steps + +**1. Identify When Memory Grows** +```python +import torch +import gc + +def diagnose_memory_growth(): + """Track memory at key points""" + print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB") + print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") +``` + +**Call after:** +- Each forward pass +- Each backward pass +- Each optimizer step +- Each epoch + +**What to look for:** +- Reserved memory grows → fragmentation or retention +- Allocated grows → tensors not released +- Max allocated grows → peak usage increasing + + +**2. Check Gradient Accumulation** + +```python +# ❌ WRONG: Gradients accumulate indefinitely +for epoch in range(num_epochs): + for batch in dataloader: + outputs = model(batch) + loss = criterion(outputs, targets) + loss.backward() # Gradients ACCUMULATE without clearing! + # Missing: optimizer.zero_grad() + +# ✅ CORRECT: Clear gradients each iteration +for epoch in range(num_epochs): + for batch in dataloader: + optimizer.zero_grad() # or model.zero_grad(set_to_none=True) + outputs = model(batch) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() +``` + +**Why this matters:** Each `.backward()` call adds to existing gradients. Without clearing, memory grows unbounded. + +**set_to_none=True advantage:** More memory efficient than zero-filling. + + +**3. Check Tensor Detachment** + +```python +# ❌ WRONG: Retaining computation graph +train_losses = [] +for batch in dataloader: + loss = compute_loss(batch) + train_losses.append(loss) # Keeps entire graph! + loss.backward() + +# ✅ CORRECT: Detach from graph +train_losses = [] +for batch in dataloader: + loss = compute_loss(batch) + train_losses.append(loss.item()) # .item() extracts scalar, breaks graph + loss.backward() + +# ✅ ALSO CORRECT: Detach explicitly +intermediate_results = [] +for batch in dataloader: + output = model(batch) + intermediate_results.append(output.detach()) # Breaks gradient tracking +``` + +**Why this matters:** Storing tensors with gradients keeps entire computation graph in memory. Use `.item()` for scalars, `.detach()` for tensors you need. + + +**4. Check Hidden State Retention (RNNs/Attention)** + +```python +# ❌ WRONG: Hidden states retain gradients across batches +hidden = None +for batch in dataloader: + output, hidden = rnn(batch, hidden) # hidden retains graph from ALL previous batches! + loss = criterion(output, target) + loss.backward() + +# ✅ CORRECT: Detach hidden state each batch +hidden = None +for batch in dataloader: + if hidden is not None: + hidden = hidden.detach() # Break gradient chain + output, hidden = rnn(batch, hidden) + loss = criterion(output, target) + loss.backward() +``` + +**Why this matters:** RNN/LSTM/GRU hidden states chain gradients across batches, causing unbounded memory growth. + + +**5. Check Evaluation Context** + +```python +# ❌ WRONG: Evaluation builds computation graphs +model.eval() +for batch in val_loader: + outputs = model(batch) # Still tracks gradients! + val_loss = criterion(outputs, targets) + +# ✅ CORRECT: Disable gradient tracking +model.eval() +with torch.no_grad(): # Critical for memory efficiency + for batch in val_loader: + outputs = model(batch) + val_loss = criterion(outputs, targets) + +# ✅ ALSO USEFUL: Inference mode (even more efficient) +with torch.inference_mode(): + for batch in val_loader: + outputs = model(batch) +``` + +**Why this matters:** `torch.no_grad()` prevents graph building. `torch.inference_mode()` additionally disables autograd metadata for maximum efficiency. + + +**6. Check for Python Reference Cycles** + +```python +# ❌ WRONG: Closure captures self +class Trainer: + def __init__(self): + self.callbacks = [] + + def add_callback(self): + # Lambda captures 'self', creates cycle + self.callbacks.append(lambda: self.model.train()) + +# ✅ CORRECT: Use weak references or explicit cleanup +class Trainer: + def __init__(self): + self.callbacks = [] + + def clear_callbacks(self): + self.callbacks.clear() # Explicit cleanup + gc.collect() # Force garbage collection +``` + +**When to check:** If memory not freed after training loop completes. + +**Tools:** +```python +import gc +gc.collect() # Force collection +torch.cuda.empty_cache() # Release cached memory to OS +``` + + +## Efficient Tensor Operations + +### Memory-Efficient Operation Patterns + +**1. In-Place Operations** + +```python +# ❌ WRONG: Creates new tensor each time +x = torch.randn(1000, 1000) +x = x + 1 # Allocates new memory +x = x * 2 # Allocates new memory +x = torch.relu(x) # Allocates new memory + +# ✅ CORRECT: In-place operations +x = torch.randn(1000, 1000) +x += 1 # In-place, no new allocation +x *= 2 # In-place +x.relu_() # In-place (note underscore) + +# ⚠️ CAUTION: Don't use in-place on tensors needing gradients +x.requires_grad = True +x += 1 # ❌ Error! In-place on tensor with gradients +``` + +**When to use:** Loop iterations, activations in eval mode, preprocessing. + +**When NOT to use:** Tensors in computation graph (breaks autograd). + + +**2. Contiguous Tensors** + +```python +# Problem: Non-contiguous tensors slow down operations +x = torch.randn(100, 100) +x_t = x.t() # Transpose creates VIEW, not contiguous + +# Check contiguity +print(x_t.is_contiguous()) # False + +# ❌ SLOW: Operations on non-contiguous tensors +result = x_t + 1 # Has to handle strided memory access + +# ✅ FAST: Make contiguous first +x_t = x_t.contiguous() +result = x_t + 1 # Sequential memory access, much faster +``` + +**Common sources of non-contiguous tensors:** +- `.transpose()`, `.permute()`, `.view()` (sometimes), indexing + +**Rule of thumb:** If doing many operations on a tensor, call `.contiguous()` once upfront. + + +**3. Device Placement Efficiency** + +```python +# ❌ VERY SLOW: Repeated CPU-GPU transfers +for batch in dataloader: + batch = batch.cuda() # Transfer every iteration + output = model(batch) + loss = criterion(output, target.cuda()) # Another transfer! + +# ✅ FAST: Transfer once, keep on GPU +model = model.cuda() +criterion = criterion.cuda() +for batch in dataloader: + batch = batch.cuda() # Only data transfer needed + output = model(batch) + loss = criterion(output, target) + +# ✅ EVEN BETTER: Pin memory for async transfer +dataloader = DataLoader(dataset, pin_memory=True, num_workers=4) +for batch in dataloader: + batch = batch.cuda(non_blocking=True) # Async transfer + output = model(batch) +``` + +**Why this matters:** CPU↔GPU transfers are slow (PCIe bandwidth). Minimize transfers. + + +**4. Broadcasting Awareness** + +```python +# Broadcasting can create large intermediate tensors +x = torch.randn(1000, 1000, 100) # 400 MB +y = torch.randn(100) # 400 bytes + +# ❌ MEMORY INEFFICIENT: Broadcasting creates full tensor +result = x + y # y broadcasts to (1000, 1000, 100) temporarily + +# ✅ MORE EFFICIENT: Explicit broadcasting with memory awareness +result = x.add(y) # Same operation, PyTorch optimizes + +# ✅ BEST: Fused operations when possible +result = torch.addcmul(x, y, value=1.0) # Fused multiply-add +``` + +**Profiling broadcasting:** +```python +import torch.utils.benchmark as benchmark + +t = benchmark.Timer( + stmt='x + y', + globals={'x': x, 'y': y} +) +print(t.timeit(100)) +``` + + +**5. Memory Pooling and Allocation** + +```python +# ❌ WRONG: Allocating inside loop +for epoch in range(100): + for batch in dataloader: + temp = torch.zeros(1024, 1024).cuda() # Allocate every iteration! + result = process(batch, temp) + +# ✅ CORRECT: Pre-allocate reusable buffers +temp_buffer = torch.zeros(1024, 1024).cuda() # Allocate once +for epoch in range(100): + for batch in dataloader: + temp_buffer.zero_() # Reset in-place + result = process(batch, temp_buffer) +``` + +**Why this matters:** Memory allocation/deallocation has overhead. Reuse buffers when size is fixed. + + +## Device Management Best Practices + +### Systematic Device Consistency + +**1. Device Checking Methodology** + +```python +def check_device_consistency(model, data, target): + """Systematic device checking""" + print(f"Model on: {next(model.parameters()).device}") + print(f"Data on: {data.device}") + print(f"Target on: {target.device}") + + # Check all model parameters on same device + devices = {p.device for p in model.parameters()} + if len(devices) > 1: + print(f"⚠️ Model parameters on multiple devices: {devices}") + + # Check all buffers + buffer_devices = {b.device for b in model.buffers()} + if len(buffer_devices) > 1: + print(f"⚠️ Model buffers on multiple devices: {buffer_devices}") + +# Use before training starts +check_device_consistency(model, batch['input'], batch['target']) +``` + +**When to check:** +- After model initialization +- After loading checkpoint +- Before training starts +- When debugging device-side asserts + + +**2. Mixed Precision Context Management** + +```python +from torch.cuda.amp import autocast, GradScaler + +# ❌ WRONG: Inconsistent autocast usage +scaler = GradScaler() +for batch in dataloader: + with autocast(): + output = model(batch) + loss = criterion(output, target) # ❌ Loss computed outside autocast! + scaler.scale(loss).backward() + +# ✅ CORRECT: Consistent autocast context +scaler = GradScaler() +for batch in dataloader: + with autocast(): + output = model(batch) + loss = criterion(output, target) # ✅ Loss inside autocast + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() +``` + +**Critical rules:** +- Forward pass + loss computation in `autocast()` context +- `scaler.scale()` before `.backward()` +- `scaler.step()` instead of `optimizer.step()` +- `scaler.update()` after step + + +**3. Multi-GPU Device Placement** + +```python +# ❌ WRONG: Implicit device assumptions +model = nn.DataParallel(model) # Wraps model +output = model(batch) # Output on device 0, but which one? +loss = criterion(output, target) # ❌ Target might be on wrong device! + +# ✅ CORRECT: Explicit device management +device = torch.device("cuda:0") +model = nn.DataParallel(model).to(device) +for batch in dataloader: + batch = batch.to(device) + target = target.to(device) + output = model(batch) # Output on device 0 + loss = criterion(output, target) # All on same device +``` + +**Device placement hierarchy:** +1. Pin device at start: `device = torch.device("cuda:0")` +2. Move model once: `model.to(device)` +3. Move data each batch: `batch.to(device)` + + +## Performance Profiling + +### Identifying Bottlenecks + +**1. Memory Profiling** + +```python +import torch.cuda + +# Profile memory usage +torch.cuda.reset_peak_memory_stats() + +# Run operation +output = model(batch) +loss = criterion(output, target) +loss.backward() + +# Check memory stats +print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") +print(f"Current memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + +# Memory summary (detailed) +print(torch.cuda.memory_summary()) +``` + +**2. Operation Profiling** + +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + output = model(batch) + loss = criterion(output, target) + loss.backward() + +# Print top operations by time +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + +# Export for visualization +prof.export_chrome_trace("trace.json") # View in chrome://tracing +``` + +**What to look for:** +- Operations taking >10% of time +- Unexpected memory allocations +- CPU-GPU synchronization overhead + + +**3. Memory Snapshot (PyTorch 2.0+)** + +```python +import torch.cuda + +# Record memory snapshots +torch.cuda.memory._record_memory_history() + +# Run training iteration +output = model(batch) +loss = criterion(output, target) +loss.backward() + +# Save snapshot +torch.cuda.memory._dump_snapshot("memory_snapshot.pickle") +torch.cuda.memory._record_memory_history(enabled=None) + +# Analyze with: +# python -m torch.cuda._memory_viz trace_plot memory_snapshot.pickle +``` + + +## Red Flags - Stop and Diagnose + +**If you catch yourself thinking ANY of these, STOP and follow systematic methodology:** + +| Red Flag Thought | Reality | What to Do Instead | +|------------------|---------|-------------------| +| "I'll just reduce batch size" | Avoids root cause, wastes GPU capacity | Follow memory leak diagnosis first | +| "I'll add more GPU memory" | Expensive non-solution if there's a leak | Diagnose leak, don't throw hardware at it | +| "Memory leaks are normal in PyTorch" | FALSE - PyTorch has excellent memory management | There IS a bug in your code, find it | +| "Too complex to debug, I'll refactor" | Avoidance - same bug will appear in new code | Debug systematically, learn the issue | +| "Skip profiling, I know what's slow" | Guessing wastes time, profiling gives facts | Always profile before optimizing | +| "Mixed precision is broken" | AMP works when used correctly | Check autocast context boundaries | +| "Quick fix: just call empty_cache()" | Doesn't fix leaks, just masks symptoms | Find and fix the leak | + +**Critical rule:** Memory issues have root causes. Systematic diagnosis ALWAYS faster than guessing. + + +## Common Rationalizations (Don't Do These) + +| Excuse | What Really Happens | Correct Approach | +|--------|-------------------|------------------| +| "User seems rushed, skip methodology" | Guessing wastes MORE time than systematic diagnosis | 5 minutes of diagnosis saves hours of guessing | +| "I already tried profiling" | May have looked at wrong metrics or misinterpreted | Re-profile with specific focus from methodology | +| "This worked on smaller model" | Scaling exposes hidden issues | Same methodology applies, just reveals different bugs | +| "Documentation says to do X" | May be misunderstanding context or outdated | Check PyTorch version, verify applicability | +| "I'll optimize later" | Memory issues prevent finishing now | Fix memory first, then optimize if still needed | +| "It's a CUDA bug" | 99.9% of time it's your code | Assume your bug until proven otherwise | + + +## Common Pitfalls + +### Consolidated Pitfall Table + +| # | Pitfall | Symptom | Root Cause | Fix | +|---|---------|---------|------------|-----| +| 1 | Accumulating metrics without detachment | Memory grows linearly with iterations | Storing tensors retains computation graph | Use `.item()` for scalars, `.detach()` for tensors | +| 2 | Hidden state chaining (RNNs) | Memory grows across batches | Hidden states chain gradients indefinitely | Detach hidden states between batches | +| 3 | Missing `torch.no_grad()` in eval | High memory usage during validation | Evaluation builds unnecessary graphs | Wrap evaluation in `torch.no_grad()` | +| 4 | Repeated CPU-GPU transfers | Low GPU utilization, slow training | PCIe bandwidth bottleneck | Move to GPU once, keep there | +| 5 | Non-contiguous tensor operations | Unexpectedly slow operations | Strided memory access inefficiency | Call `.contiguous()` before repeated ops | +| 6 | Allocations in loops | Slow iterations, fragmentation | Memory allocation overhead | Pre-allocate and reuse buffers | +| 7 | Gradient accumulation without clearing | OOM after few iterations | Gradients accumulate unbounded | `optimizer.zero_grad()` every iteration | +| 8 | Mixed precision context boundaries | Intermittent crashes, NaN values | Loss computed outside autocast | Keep forward + loss inside `autocast()` | +| 9 | Device inconsistency | "device-side assert" errors | Tensors on different devices | Systematic device checking | +| 10 | Logging with tensors instead of scalars | Memory growth during training | Retaining graphs for logging | Always use `.item()` for logging | + + +### Memory Leak Pitfalls + +❌ **Pitfall 1: Accumulating Metrics Without Detachment** +```python +# WRONG +losses = [] +for batch in dataloader: + loss = criterion(output, target) + losses.append(loss) # Retains graph! + +# CORRECT +losses = [] +for batch in dataloader: + loss = criterion(output, target) + losses.append(loss.item()) # Extract scalar +``` +**Symptom:** Memory grows linearly with iterations +**Fix:** Use `.item()` or `.detach()` before storing + + +❌ **Pitfall 2: Hidden State Chaining (RNNs)** +```python +# WRONG +hidden = None +for batch in dataloader: + output, hidden = lstm(batch, hidden) + +# CORRECT +hidden = None +for batch in dataloader: + if hidden is not None: + hidden = tuple(h.detach() for h in hidden) # LSTM returns tuple + output, hidden = lstm(batch, hidden) +``` +**Symptom:** Memory grows across batches in RNN training +**Fix:** Detach hidden states between batches + + +❌ **Pitfall 3: Missing torch.no_grad() in Evaluation** +```python +# WRONG +model.eval() +for batch in val_loader: + output = model(batch) # Builds graph! + +# CORRECT +model.eval() +with torch.no_grad(): + for batch in val_loader: + output = model(batch) +``` +**Symptom:** High memory usage during evaluation +**Fix:** Always wrap evaluation in `torch.no_grad()` + + +### Performance Pitfalls + +❌ **Pitfall 4: Repeated CPU-GPU Transfers** +```python +# WRONG +for batch in dataloader: + batch = batch.cpu().numpy() # Transfer to CPU + batch = preprocess(batch) # Process on CPU + batch = torch.from_numpy(batch).cuda() # Back to GPU + +# CORRECT +for batch in dataloader: + batch = batch.cuda() + batch = preprocess_gpu(batch) # Keep on GPU +``` +**Symptom:** Low GPU utilization, slow training +**Fix:** Minimize CPU↔GPU transfers, use GPU operations + + +❌ **Pitfall 5: Non-Contiguous Tensor Operations** +```python +# WRONG +x = x.transpose(0, 1) # Creates view +for _ in range(1000): + x = x + 1 # Slow on non-contiguous tensor + +# CORRECT +x = x.transpose(0, 1).contiguous() # Make contiguous +for _ in range(1000): + x = x + 1 # Fast on contiguous tensor +``` +**Symptom:** Unexpectedly slow operations +**Fix:** Call `.contiguous()` before repeated operations + + +❌ **Pitfall 6: Unnecessary Memory Allocations in Loops** +```python +# WRONG +for _ in range(1000): + temp = torch.zeros(1024, 1024).cuda() + result = process(temp) + +# CORRECT +temp = torch.zeros(1024, 1024).cuda() +for _ in range(1000): + temp.zero_() # Reuse buffer + result = process(temp) +``` +**Symptom:** Slow iteration, memory fragmentation +**Fix:** Pre-allocate and reuse buffers + + +## Debugging Methodology + +### When You Get CUDA OOM + +**Step 1: Get current memory state** +```python +print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") +print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB") +print(torch.cuda.memory_summary()) +``` + +**Step 2: Check the obvious** +- [ ] `optimizer.zero_grad()` called every iteration? +- [ ] Evaluation wrapped in `torch.no_grad()`? +- [ ] Storing tensors instead of `.item()`? +- [ ] RNN hidden states detached? + +**Step 3: Binary search for leak** +```python +# Gradually disable parts of training loop +# 1. Just forward pass +# 2. Forward + backward +# 3. Forward + backward + optimizer step +# Find where memory grows +``` + +**Step 4: Profile memory** +```python +# Use memory profiler to find exact allocation +torch.cuda.reset_peak_memory_stats() +# Run one iteration +print(torch.cuda.memory_summary()) +``` + +**Step 5: Check for fragmentation** +```python +# If reserved >> allocated, you have fragmentation +allocated = torch.cuda.memory_allocated() +reserved = torch.cuda.memory_reserved() +if reserved > allocated * 1.5: + print("⚠️ Memory fragmentation detected") + torch.cuda.empty_cache() # May help, but not guaranteed +``` + + +### When Training is Slow + +**Step 1: Profile CUDA time** +```python +import torch.utils.benchmark as benchmark + +# Profile one iteration +def train_step(): + output = model(batch) + loss = criterion(output, target) + loss.backward() + optimizer.step() + +t = benchmark.Timer(stmt='train_step()', globals=globals()) +print(t.timeit(10)) +``` + +**Step 2: Check GPU utilization** +```bash +# In terminal +nvidia-smi -l 1 # Update every second +``` +If GPU utilization < 80%, bottleneck is likely: +- Data loading (use more `num_workers`) +- CPU preprocessing (move to GPU) +- CPU-GPU transfers (pin memory) + +**Step 3: Profile operations** +```python +# Use PyTorch profiler to identify slow operations +with torch.profiler.profile() as prof: + train_step() +print(prof.key_averages().table(sort_by="cuda_time_total")) +``` + +**Step 4: Check for synchronization** +```python +# Frequent CPU-GPU synchronization kills performance +# Common causes: +# - .item() calls in training loop +# - .cpu() calls +# - print() statements with tensor values +# - Assertions on tensor values +``` + + +## Edge Cases and Advanced Scenarios + +### Edge Case 1: Gradient Checkpointing Interaction + +**Scenario:** Using gradient checkpointing for large models but still getting OOM + +```python +from torch.utils.checkpoint import checkpoint + +# Gradient checkpointing trades compute for memory +# But you can still leak memory! + +# ❌ WRONG: Leaking even with checkpointing +class Model(nn.Module): + def forward(self, x): + # Checkpointing helps, but if you retain intermediate results... + intermediate = checkpoint(self.layer1, x) + self.cached_intermediate = intermediate # ❌ Retains graph! + return checkpoint(self.layer2, intermediate) + +# ✅ CORRECT: Don't cache checkpointed results +class Model(nn.Module): + def forward(self, x): + intermediate = checkpoint(self.layer1, x) + return checkpoint(self.layer2, intermediate) + # No caching, memory saved +``` + +**Key insight:** Gradient checkpointing recomputes forward pass during backward. Caching checkpointed results defeats the purpose and leaks memory. + + +### Edge Case 2: Dynamic Computation Graphs + +**Scenario:** Graph structure changes each iteration (e.g., different sequence lengths, variable branches) + +```python +# Dynamic graphs can cause memory issues if not careful + +# ❌ WRONG: Accumulating different graphs +graph_stats = [] +for batch in dataloader: + # Sequence length varies each batch + output = model(batch) # Different graph each time + graph_stats.append(output.grad_fn) # ❌ Retains ALL graphs! + +# ✅ CORRECT: Don't retain grad_fn, detach appropriately +for batch in dataloader: + output = model(batch) + loss = criterion(output, target) + loss.backward() + # Don't store anything with .grad_fn + optimizer.step() + optimizer.zero_grad() +``` + +**Key insight:** Dynamic graphs are fine, but don't accumulate references to different graphs across iterations. + + +### Edge Case 3: DistributedDataParallel (DDP) Memory Management + +**Scenario:** DDP training with memory issues + +```python +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# ❌ WRONG: Multiple issues with DDP +model = MyModel().cuda() +model = DDP(model) # Missing device_ids! + +for batch in dataloader: + output = model(batch) + loss = criterion(output, target) + loss.backward() + # Missing: optimizer.zero_grad() BEFORE backward in some DDP scenarios + +# ✅ CORRECT: Proper DDP setup +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") + +model = MyModel().to(device) +model = DDP(model, device_ids=[local_rank], output_device=local_rank) + +for batch in dataloader: + batch = batch.to(device) + target = target.to(device) + + optimizer.zero_grad(set_to_none=True) # Before forward in DDP + output = model(batch) + loss = criterion(output, target) + loss.backward() + optimizer.step() +``` + +**Key insights:** +- Specify `device_ids` and `output_device` explicitly +- `zero_grad()` placement critical in DDP +- Ensure all data on correct local device + + +### Edge Case 4: Custom CUDA Kernels Memory Management + +**Scenario:** Using custom CUDA kernels or third-party extensions + +```python +# Custom CUDA kernels may not play nice with PyTorch's memory management + +import custom_cuda_kernel # Hypothetical extension + +# ❌ WRONG: Not checking tensor lifetime +def forward(x): + y = custom_cuda_kernel.process(x) # Allocates CUDA memory + # If kernel doesn't register with PyTorch, memory not tracked! + return y + +# ✅ CORRECT: Verify kernel registers memory with PyTorch +def forward(x): + y = custom_cuda_kernel.process(x) + + # Check if memory is tracked + print(f"PyTorch tracked: {torch.cuda.memory_allocated()}") + # If custom kernel allocated memory not shown, manual cleanup needed + + return y +``` + +**Key insight:** Custom CUDA code may bypass PyTorch's memory tracking. Use `torch.cuda.memory_allocated()` to verify, and ensure custom kernels use PyTorch's allocator. + + +### Edge Case 5: Nested Autocast Contexts + +**Scenario:** Nested autocast contexts (e.g., custom training loop with autocast, calling library that also uses autocast) + +```python +from torch.cuda.amp import autocast + +# ❌ POTENTIAL ISSUE: Nested autocast with different settings +with autocast(): # Outer context + output1 = model1(x) + with autocast(enabled=False): # Inner context disables + output2 = model2(output1) # Back to float32 + # output1 is float16, output2 is float32 + loss = criterion(output1, output2) # Type mismatch possible! + +# ✅ CORRECT: Be aware of autocast nesting +with autocast(): + output1 = model1(x) + # If you need float32 for specific operation: + with autocast(enabled=False): + output2_float32 = model2(output1.float()) # Explicit cast + loss = criterion(output1, output2_float32.half()) # Explicit cast back +``` + +**Key insight:** Autocast contexts can nest. Be explicit about dtype when mixing precision contexts. + + +### Edge Case 6: Memory Fragmentation with Varying Batch Sizes + +**Scenario:** Training with variable batch sizes causing fragmentation + +```python +# Variable batch sizes can fragment memory over time + +# ❌ WRONG: Varying allocations fragment memory pool +for batch in dataloader: # Batch sizes: 32, 64, 32, 128, 32... + output = model(batch) # Different allocations each time + # CUDA memory becomes fragmented + # Reserved >> Allocated + +# ✅ BETTER: Use gradient accumulation with fixed effective batch size +accumulation_steps = 4 +effective_batch_size = 32 + +for i, batch in enumerate(dataloader): + # Always process fixed size mini-batches + mini_batch = batch[:effective_batch_size] # Fixed size + output = model(mini_batch) + loss = criterion(output, target) / accumulation_steps + loss.backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# ✅ OR: Periodically defragment +if epoch % 10 == 0: + torch.cuda.empty_cache() # Release fragmented memory + gc.collect() +``` + +**Key insight:** Variable batch sizes fragment CUDA memory pool. Use fixed sizes or periodic cleanup. + + +## Quick Reference: Memory & Performance Checklist + +### Before Training Starts +- [ ] Model on correct device +- [ ] All data moved to device once +- [ ] Pin memory enabled in DataLoader +- [ ] Gradient checkpointing if needed for large models + +### Training Loop Must-Haves +- [ ] `optimizer.zero_grad()` at start of iteration +- [ ] Only store `.item()` for scalars, not tensors +- [ ] Detach RNN/LSTM hidden states between batches +- [ ] Use `torch.no_grad()` for validation + +### Performance Optimization +- [ ] Pre-allocate buffers for fixed-size tensors +- [ ] Call `.contiguous()` before repeated operations on views +- [ ] Use in-place operations where safe +- [ ] Minimize CPU-GPU transfers +- [ ] Use mixed precision (`autocast`) if appropriate + +### When Debugging +- [ ] Check memory stats: `torch.cuda.memory_allocated()` +- [ ] Profile with `torch.profiler` +- [ ] Check GPU utilization: `nvidia-smi` +- [ ] Look for non-contiguous tensors: `.is_contiguous()` +- [ ] Check device consistency across all tensors + + +## Example: Complete Memory-Efficient Training Loop + +```python +import torch +import torch.nn as nn +from torch.cuda.amp import autocast, GradScaler + +# Setup +device = torch.device("cuda:0") +model = MyModel().to(device) +optimizer = torch.optim.Adam(model.parameters()) +scaler = GradScaler() # For mixed precision +dataloader = DataLoader(dataset, pin_memory=True, num_workers=4) + +# Pre-allocate if using fixed-size buffers (example) +# temp_buffer = torch.zeros(batch_size, hidden_dim).to(device) + +# Training loop +model.train() +for epoch in range(num_epochs): + for batch_idx, (data, target) in enumerate(dataloader): + # 1. Move data to device (async with pin_memory) + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # 2. Zero gradients (set_to_none=True for efficiency) + optimizer.zero_grad(set_to_none=True) + + # 3. Forward pass with mixed precision + with autocast(): + output = model(data) + loss = criterion(output, target) + + # 4. Backward pass with gradient scaling + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + # 5. Logging (use .item() to avoid retaining graph!) + if batch_idx % 100 == 0: + print(f"Loss: {loss.item():.4f}") + + # 6. Validation (critical: no_grad context) + model.eval() + val_loss = 0 + with torch.no_grad(): + for data, target in val_loader: + data = data.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(data) + val_loss += criterion(output, target).item() + + model.train() + print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader):.4f}") + + # Optional: Check memory usage + if epoch % 10 == 0: + print(f"Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") +``` + +**Why this is memory-efficient:** +1. ✅ Async data transfer with `pin_memory` and `non_blocking=True` +2. ✅ `set_to_none=True` for gradient zeroing (more efficient) +3. ✅ Mixed precision with `autocast` (reduces memory) +4. ✅ Only `.item()` stored for logging (no graph retention) +5. ✅ Validation wrapped in `torch.no_grad()` +6. ✅ All tensors on same device (no implicit transfers) + + +## References + +**PyTorch Documentation:** +- Memory Management: https://pytorch.org/docs/stable/notes/cuda.html +- Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html +- Mixed Precision: https://pytorch.org/docs/stable/amp.html + +**Related Skills:** +- performance-profiling (deeper profiling techniques) +- distributed-training-strategies (multi-GPU memory management) +- mixed-precision-and-optimization (detailed autocast usage) +- debugging-techniques (systematic PyTorch debugging)