From c1039300101902481e12c5a4d637d0bfd1823078 Mon Sep 17 00:00:00 2001 From: Zhongwei Li Date: Sun, 30 Nov 2025 09:00:11 +0800 Subject: [PATCH] Initial commit --- .claude-plugin/plugin.json | 12 + README.md | 3 + plugin.lock.json | 85 + skills/using-training-optimization/SKILL.md | 493 +++ .../batch-size-and-memory-tradeoffs.md | 1651 ++++++++++ .../data-augmentation-strategies.md | 1483 +++++++++ .../experiment-tracking.md | 1942 ++++++++++++ .../gradient-management.md | 2442 +++++++++++++++ .../hyperparameter-tuning.md | 1635 ++++++++++ .../learning-rate-scheduling.md | 2723 +++++++++++++++++ .../loss-functions-and-objectives.md | 2138 +++++++++++++ .../optimization-algorithms.md | 1832 +++++++++++ .../overfitting-prevention.md | 1464 +++++++++ .../training-loop-architecture.md | 882 ++++++ 14 files changed, 18785 insertions(+) create mode 100644 .claude-plugin/plugin.json create mode 100644 README.md create mode 100644 plugin.lock.json create mode 100644 skills/using-training-optimization/SKILL.md create mode 100644 skills/using-training-optimization/batch-size-and-memory-tradeoffs.md create mode 100644 skills/using-training-optimization/data-augmentation-strategies.md create mode 100644 skills/using-training-optimization/experiment-tracking.md create mode 100644 skills/using-training-optimization/gradient-management.md create mode 100644 skills/using-training-optimization/hyperparameter-tuning.md create mode 100644 skills/using-training-optimization/learning-rate-scheduling.md create mode 100644 skills/using-training-optimization/loss-functions-and-objectives.md create mode 100644 skills/using-training-optimization/optimization-algorithms.md create mode 100644 skills/using-training-optimization/overfitting-prevention.md create mode 100644 skills/using-training-optimization/training-loop-architecture.md diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json new file mode 100644 index 0000000..2722ab8 --- /dev/null +++ b/.claude-plugin/plugin.json @@ -0,0 +1,12 @@ +{ + "name": "yzmir-training-optimization", + "description": "Training stability - optimizers, learning rates, convergence, debugging - 11 skills", + "version": "1.0.1", + "author": { + "name": "tachyon-beep", + "url": "https://github.com/tachyon-beep" + }, + "skills": [ + "./skills" + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..40c6a3c --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# yzmir-training-optimization + +Training stability - optimizers, learning rates, convergence, debugging - 11 skills diff --git a/plugin.lock.json b/plugin.lock.json new file mode 100644 index 0000000..2d0b216 --- /dev/null +++ b/plugin.lock.json @@ -0,0 +1,85 @@ +{ + "$schema": "internal://schemas/plugin.lock.v1.json", + "pluginId": "gh:tachyon-beep/skillpacks:plugins/yzmir-training-optimization", + "normalized": { + "repo": null, + "ref": "refs/tags/v20251128.0", + "commit": "1a6b6c440fc2093056207eab655a4e9dd2ec4c4a", + "treeHash": "9d0b9b47ce83e27e7b927eb0c701097a264827013ffe91f0f3bc10b7e63a5904", + "generatedAt": "2025-11-28T10:28:35.048054Z", + "toolVersion": "publish_plugins.py@0.2.0" + }, + "origin": { + "remote": "git@github.com:zhongweili/42plugin-data.git", + "branch": "master", + "commit": "aa1497ed0949fd50e99e70d6324a29c5b34f9390", + "repoRoot": "/Users/zhongweili/projects/openmind/42plugin-data" + }, + "manifest": { + "name": "yzmir-training-optimization", + "description": "Training stability - optimizers, learning rates, convergence, debugging - 11 skills", + "version": "1.0.1" + }, + "content": { + "files": [ + { + "path": "README.md", + "sha256": "62fadbf1552b3992391990d6217a2bdaacb0ae8ef43496cfe45b6ef294d445f7" + }, + { + "path": ".claude-plugin/plugin.json", + "sha256": "ab9e676b0f5af41d7375a19ab91f6b20b485c83a63871b346d74ff9b34105bc9" + }, + { + "path": "skills/using-training-optimization/gradient-management.md", + "sha256": "067d4b2e5a007fc0b9d002e8c7b250a910bdd0044872bf5e0894cc301ed52147" + }, + { + "path": "skills/using-training-optimization/hyperparameter-tuning.md", + "sha256": "bc6e5972545fb3f60f64732f3e326c8f4e28e490d46f3258fbf8aa09dca04e5c" + }, + { + "path": "skills/using-training-optimization/optimization-algorithms.md", + "sha256": "4fd6a5c71e5b5d429c47ae3e40fa212aff4dffb8fa4fffa6811030feba186ba3" + }, + { + "path": "skills/using-training-optimization/overfitting-prevention.md", + "sha256": "eaebe6e390799bdbd1023a73ba60df98a421cb3cd909e3047b3dcfde6b5643db" + }, + { + "path": "skills/using-training-optimization/learning-rate-scheduling.md", + "sha256": "dc940fcd9516036cab41bac67731ae220e3741055d7684702987bf1b23e2db98" + }, + { + "path": "skills/using-training-optimization/data-augmentation-strategies.md", + "sha256": "b585581efe16abdaa359d8136c4e546d3283635857b8b031fa863e58ad079b6e" + }, + { + "path": "skills/using-training-optimization/batch-size-and-memory-tradeoffs.md", + "sha256": "feeb5b24feb77c4c816ab6442ce69e35fe0352ab6f209bda9fd82c278765f5f2" + }, + { + "path": "skills/using-training-optimization/loss-functions-and-objectives.md", + "sha256": "6e953263309eed649fbf8e4aad5cbcc5161507158e20f8d67532b576837deb7b" + }, + { + "path": "skills/using-training-optimization/SKILL.md", + "sha256": "e171600f2a1c9c126b1c4677118ef566e17a3181227993fac4ee901996f72964" + }, + { + "path": "skills/using-training-optimization/training-loop-architecture.md", + "sha256": "ea96698c4a8f21d77dc559728b669a78a1d9a04fd73c5a73ad7fa4ad97259f62" + }, + { + "path": "skills/using-training-optimization/experiment-tracking.md", + "sha256": "90c1626f4fff358148da196d1a0d7da5c915de5e79b795611b9b19906358a049" + } + ], + "dirSha256": "9d0b9b47ce83e27e7b927eb0c701097a264827013ffe91f0f3bc10b7e63a5904" + }, + "security": { + "scannedAt": null, + "scannerVersion": null, + "flags": [] + } +} \ No newline at end of file diff --git a/skills/using-training-optimization/SKILL.md b/skills/using-training-optimization/SKILL.md new file mode 100644 index 0000000..669e658 --- /dev/null +++ b/skills/using-training-optimization/SKILL.md @@ -0,0 +1,493 @@ +--- +name: using-training-optimization +description: Router to training optimization skills based on symptoms and training problems +mode: true +--- + +# Using Training Optimization + +## Overview + +This meta-skill routes you to the right training optimization specialist based on symptoms. Training issues often have multiple potential causes—this skill helps diagnose symptoms and route to the appropriate specialist. Load this skill when you encounter training problems but aren't sure which specific technique to apply. + +**Core Principle**: Diagnose before routing. Training issues often have multiple causes. Ask clarifying questions to understand symptoms before routing to specific skills. Wrong diagnosis wastes time—systematic routing saves it. + +## When to Use + +Load this skill when: +- Model not learning (loss stuck, not decreasing, poor accuracy) +- Training instability (loss spikes, NaN values, divergence) +- Overfitting (large train/val gap, poor generalization) +- Training too slow (throughput issues, time constraints) +- Hyperparameter selection (optimizer, learning rate, batch size, regularization) +- Experiment management (tracking runs, comparing configurations) +- Convergence issues (slow learning, plateaus, local minima) +- Setting up new training pipeline + +**Don't use for**: PyTorch implementation bugs (use pytorch-engineering), model architecture selection (use neural-architectures), production deployment (use ml-production), RL-specific training (use deep-rl), LLM fine-tuning specifics (use llm-specialist) + +--- + +## Routing by Primary Symptom + +### Symptom: "Model Not Learning" / "Loss Not Decreasing" + +**Keywords**: stuck, flat loss, not improving, not learning, accuracy not increasing + +**Diagnostic questions (ask BEFORE routing):** +1. "Is loss completely flat from the start, decreasing very slowly, or was it learning then stopped?" +2. "Any NaN or Inf values in your loss?" +3. "What optimizer and learning rate are you using?" +4. "What does your loss curve look like over time?" + +**Route based on answers:** + +| Loss Behavior | Likely Cause | Route To | Why | +|---------------|--------------|----------|-----| +| Flat from epoch 0 | LR too low OR wrong optimizer OR inappropriate loss | **learning-rate-scheduling** + **optimization-algorithms** | Need to diagnose starting conditions | +| Was learning, then plateaued | Local minima OR LR too high OR overfitting | **learning-rate-scheduling** (scheduler) + check validation loss for overfitting | Adaptation needed during training | +| Oscillating wildly | LR too high OR gradient instability | **learning-rate-scheduling** + **gradient-management** | Instability issues | +| NaN or Inf | Gradient explosion OR numerical instability | **gradient-management** (PRIMARY) + **loss-functions-and-objectives** | Stability critical | +| Loss function doesn't match task | Wrong objective | **loss-functions-and-objectives** | Fundamental mismatch | + +**Multi-skill routing**: Often need **optimization-algorithms** (choose optimizer) + **learning-rate-scheduling** (choose LR/schedule) together for "not learning" issues. + +--- + +### Symptom: "Training Unstable" / "Loss Spikes" / "NaN Values" + +**Keywords**: NaN, Inf, exploding, diverging, unstable, spikes + +**Diagnostic questions:** +1. "When do NaN values appear - immediately at start, or after N epochs?" +2. "What's your learning rate and schedule?" +3. "Are you using mixed precision training?" +4. "What loss function are you using?" + +**Route to (in priority order):** + +1. **gradient-management** (PRIMARY) + - When: Gradient explosion, NaN gradients + - Why: Gradient issues cause instability. Must stabilize gradients first. + - Techniques: Gradient clipping, gradient scaling, NaN debugging + +2. **learning-rate-scheduling** (SECONDARY) + - When: LR too high causes instability + - Why: High LR can cause divergence even with clipped gradients + - Check: If NaN appears later in training, LR schedule might be increasing too much + +3. **loss-functions-and-objectives** (if numerical issues) + - When: Loss computation has numerical instability (log(0), division by zero) + - Why: Numerical instability in loss propagates to gradients + - Check: Custom loss functions especially prone to this + +**Cross-pack note**: If using mixed precision (AMP), also check pytorch-engineering (mixed-precision-and-optimization) for gradient scaling issues. + +--- + +### Symptom: "Model Overfits" / "Train/Val Gap Large" + +**Keywords**: overfitting, train/val gap, poor generalization, memorizing training data + +**Diagnostic questions:** +1. "How large is your dataset (number of training examples)?" +2. "What's the train accuracy vs validation accuracy?" +3. "What regularization are you currently using?" +4. "What's your model size (parameters) relative to dataset size?" + +**Route to (multi-skill approach):** + +1. **overfitting-prevention** (PRIMARY) + - Techniques: Dropout, weight decay, early stopping, L1/L2 regularization + - When: Always the first stop for overfitting + - Why: Comprehensive regularization strategy + +2. **data-augmentation-strategies** (HIGHLY RECOMMENDED) + - When: Dataset is small (< 10K examples) or moderate (10K-100K) + - Why: Increases effective dataset size, teaches invariances + - Priority: Higher priority for smaller datasets + +3. **hyperparameter-tuning** + - When: Need to find optimal regularization strength + - Why: Balance between underfitting and overfitting + - Use: After implementing regularization techniques + +**Decision factors:** +- Small dataset (< 1K): data-augmentation is CRITICAL + overfitting-prevention +- Medium dataset (1K-10K): overfitting-prevention + data-augmentation +- Large dataset (> 10K) but still overfitting: overfitting-prevention + check model capacity +- Model too large: Consider neural-architectures for model size discussion + +--- + +### Symptom: "Training Too Slow" / "Low Throughput" + +**Keywords**: slow training, low GPU utilization, takes too long, time per epoch + +**CRITICAL: Diagnose bottleneck before routing** + +**Diagnostic questions:** +1. "Is it slow per-step (low throughput) or just need many steps?" +2. "What's your GPU utilization percentage?" +3. "Are you using data augmentation? How heavy?" +4. "What's your current batch size?" + +**Route based on bottleneck:** + +| GPU Utilization | Likely Cause | Route To | Why | +|-----------------|--------------|----------|-----| +| < 50% consistently | Data loading bottleneck OR CPU preprocessing | **pytorch-engineering** (data loading, profiling) | Not compute-bound, infrastructure issue | +| High (> 80%) but still slow | Batch size too small OR need distributed training | **batch-size-and-memory-tradeoffs** + possibly pytorch-engineering (distributed) | Compute-bound, need scaling | +| High + heavy augmentation | Augmentation overhead | **data-augmentation-strategies** (optimization) + pytorch-engineering (profiling) | Augmentation CPU cost | +| Memory-limited batch size | Can't increase batch size due to OOM | **batch-size-and-memory-tradeoffs** (gradient accumulation) + pytorch-engineering (memory) | Memory constraints limiting throughput | + +**Cross-pack boundaries:** +- Data loading issues → **pytorch-engineering** (DataLoader, prefetching, num_workers) +- Distributed training setup → **pytorch-engineering** (distributed-training-strategies) +- Batch size optimization for speed/memory → **training-optimization** (batch-size-and-memory-tradeoffs) +- Profiling to identify bottleneck → **pytorch-engineering** (performance-profiling) + +**Key principle**: Profile FIRST before optimizing. Low GPU utilization = wrong optimization target. + +--- + +### Symptom: "Which X Should I Use?" (Direct Questions) + +**Direct hyperparameter questions route to specific skills:** + +| Question | Route To | Examples | +|----------|----------|----------| +| "Which optimizer?" | **optimization-algorithms** | SGD vs Adam vs AdamW, momentum, betas | +| "Which learning rate?" | **learning-rate-scheduling** | Initial LR, warmup, schedule type | +| "Which batch size?" | **batch-size-and-memory-tradeoffs** | Batch size effects on convergence and speed | +| "Which loss function?" | **loss-functions-and-objectives** | Cross-entropy vs focal vs custom | +| "How to prevent overfitting?" | **overfitting-prevention** | Dropout, weight decay, early stopping | +| "Which augmentation?" | **data-augmentation-strategies** | Type and strength of augmentation | +| "How to tune hyperparameters?" | **hyperparameter-tuning** | Search strategies, AutoML | +| "How to track experiments?" | **experiment-tracking** | MLflow, W&B, TensorBoard | + +**For new project setup, route to MULTIPLE in sequence:** +1. **optimization-algorithms** - Choose optimizer +2. **learning-rate-scheduling** - Choose initial LR and schedule +3. **batch-size-and-memory-tradeoffs** - Choose batch size +4. **experiment-tracking** - Set up tracking +5. **training-loop-architecture** - Design training loop + +--- + +### Symptom: "Need to Track Experiments" / "Compare Configurations" + +**Keywords**: experiment tracking, MLflow, wandb, tensorboard, compare runs, log metrics + +**Route to:** + +1. **experiment-tracking** (PRIMARY) + - Tools: MLflow, Weights & Biases, TensorBoard, Neptune + - When: Setting up tracking, comparing runs, organizing experiments + - Why: Systematic experiment management + +2. **hyperparameter-tuning** (if systematic search) + - When: Running many configurations systematically + - Why: Automated hyperparameter search integrates with tracking + - Combined: Hyperparameter search + experiment tracking workflow + +3. **training-loop-architecture** (for integration) + - When: Need to integrate tracking into training loop + - Why: Proper callback and logging design + - Combined: Training loop + tracking integration + +--- + +## Cross-Cutting Multi-Skill Scenarios + +### Scenario: New Training Setup (First Time) + +**Route to (in order):** +1. **optimization-algorithms** - Select optimizer for your task +2. **learning-rate-scheduling** - Choose LR and warmup strategy +3. **batch-size-and-memory-tradeoffs** - Determine batch size +4. **loss-functions-and-objectives** - Verify loss function appropriate for task +5. **experiment-tracking** - Set up experiment logging +6. **training-loop-architecture** - Design training loop with checkpointing + +**Why this order**: Foundation (optimizer/LR/batch) → Objective (loss) → Infrastructure (tracking/loop) + +--- + +### Scenario: Convergence Issues + +**Route to (diagnose first, then in order):** +1. **gradient-management** - Check gradients not vanishing/exploding (use gradient checking) +2. **learning-rate-scheduling** - Adjust LR schedule (might be too high/low/wrong schedule) +3. **optimization-algorithms** - Consider different optimizer if current one unsuitable + +**Why this order**: Stability (gradients) → Adaptation (LR) → Algorithm (optimizer) + +**Common mistakes:** +- Jumping to change optimizer without checking gradients +- Blaming optimizer when LR is the issue +- Not using gradient monitoring to diagnose + +--- + +### Scenario: Overfitting Issues + +**Route to (multi-pronged approach):** +1. **overfitting-prevention** - Implement regularization (dropout, weight decay, early stopping) +2. **data-augmentation-strategies** - Increase effective dataset size +3. **hyperparameter-tuning** - Find optimal regularization strength + +**Why all three**: Overfitting needs comprehensive strategy, not single technique. + +**Prioritization:** +- Small dataset (< 1K): data-augmentation is MOST critical +- Medium dataset (1K-10K): overfitting-prevention + augmentation balanced +- Large dataset but still overfitting: overfitting-prevention + check model size + +--- + +### Scenario: Training Speed + Memory Constraints + +**Route to:** +1. **batch-size-and-memory-tradeoffs** (PRIMARY) - Gradient accumulation to simulate larger batch +2. **pytorch-engineering/tensor-operations-and-memory** - Memory profiling and optimization +3. **data-augmentation-strategies** - Reduce augmentation overhead if bottleneck + +**Why**: Speed and memory often coupled—need to balance batch size, memory, and throughput. + +--- + +### Scenario: Multi-Task Learning or Custom Loss + +**Route to:** +1. **loss-functions-and-objectives** (PRIMARY) - Multi-task loss design, uncertainty weighting +2. **gradient-management** - Check gradients per task, gradient balancing +3. **hyperparameter-tuning** - Tune task weights, loss coefficients + +**Why**: Custom losses need careful design + gradient analysis + tuning. + +--- + +## Ambiguous Queries - Clarification Protocol + +When symptom unclear, ASK ONE diagnostic question before routing: + +| Vague Query | Clarifying Question | Why | +|-------------|---------------------|-----| +| "Fix my training" | "What specific issue? Not learning? Unstable? Overfitting? Too slow?" | 4+ different routing paths | +| "Improve model" | "Improve what? Training speed? Accuracy? Generalization?" | Different optimization targets | +| "Training not working well" | "What's 'not working'? Loss behavior? Accuracy? Convergence speed?" | Need specific symptoms | +| "Optimize hyperparameters" | "Which hyperparameters? All of them? Specific ones like LR?" | Specific vs broad search | +| "Model performs poorly" | "Training accuracy poor or validation accuracy poor or both?" | Underfitting vs overfitting | + +**Never guess when ambiguous. Ask once, route accurately.** + +--- + +## Common Routing Mistakes + +| Symptom | Wrong Route | Correct Route | Why | +|---------|-------------|---------------|-----| +| "Training slow" | batch-size-and-memory | ASK: Check GPU utilization first | Might be data loading, not compute | +| "Not learning" | optimization-algorithms | ASK: Diagnose loss behavior | Could be LR, gradients, loss function | +| "Loss NaN" | learning-rate-scheduling | gradient-management FIRST | Gradient explosion most common cause | +| "Overfitting" | overfitting-prevention only | overfitting-prevention + data-augmentation | Need multi-pronged approach | +| "Need to speed up training" | optimization-algorithms | Profile first (pytorch-engineering) | Don't optimize without measuring | +| "Which optimizer for transformer" | neural-architectures | optimization-algorithms | Optimizer choice, not architecture | + +**Key principle**: Diagnosis before solutions, clarification before routing, multi-skill for complex issues. + +--- + +## When NOT to Use Training-Optimization Pack + +**Skip training-optimization when:** + +| Symptom | Wrong Pack | Correct Pack | Why | +|---------|------------|--------------|-----| +| "CUDA out of memory" | training-optimization | pytorch-engineering | Infrastructure issue, not training algorithm | +| "DDP not working" | training-optimization | pytorch-engineering | Distributed setup, not hyperparameters | +| "Which architecture to use" | training-optimization | neural-architectures | Architecture choice precedes training | +| "Model won't load" | training-optimization | pytorch-engineering | Checkpointing/serialization issue | +| "Inference too slow" | training-optimization | ml-production | Production optimization, not training | +| "How to deploy model" | training-optimization | ml-production | Deployment concern | +| "RL exploration issues" | training-optimization | deep-rl | RL-specific training concern | +| "RLHF for LLM" | training-optimization | llm-specialist | LLM-specific technique | + +**Training-optimization pack is for**: Framework-agnostic training algorithms, hyperparameters, optimization techniques, and training strategies that apply across architectures. + +**Boundaries:** +- PyTorch implementation/infrastructure → **pytorch-engineering** +- Architecture selection → **neural-architectures** +- Production/inference → **ml-production** +- RL-specific algorithms → **deep-rl** +- LLM-specific techniques → **llm-specialist** + +--- + +## Red Flags - Stop and Reconsider + +If you catch yourself about to: +- ❌ Suggest specific optimizer without routing → Route to **optimization-algorithms** +- ❌ Suggest learning rate value without diagnosis → Route to **learning-rate-scheduling** +- ❌ Say "add dropout" without comprehensive strategy → Route to **overfitting-prevention** +- ❌ Suggest "reduce batch size" without profiling → Check if memory or speed issue, route appropriately +- ❌ Give trial-and-error fixes for NaN → Route to **gradient-management** for systematic debugging +- ❌ Provide generic training advice → Identify specific symptom and route to specialist +- ❌ Accept user's self-diagnosis without verification → Ask diagnostic questions first + +**All of these mean: You're about to give incomplete advice. Route to specialist instead.** + +--- + +## Common Rationalizations (Don't Do These) + +| Rationalization | Reality | What To Do Instead | +|-----------------|---------|-------------------| +| "User is rushed, skip diagnostic questions" | Diagnosis takes 30 seconds, wrong route wastes 10+ minutes | Ask ONE quick diagnostic question: "Is loss flat, oscillating, or NaN?" | +| "Symptoms are obvious, route immediately" | Symptoms often have multiple causes | Ask clarifying question to eliminate ambiguity | +| "User suggested optimizer change" | User self-diagnosis can be wrong | "What loss behavior are you seeing?" to verify root cause | +| "Expert user doesn't need routing" | Expert users benefit from specialist skills too | Route based on symptoms, not user sophistication | +| "Just a quick question" | Quick questions deserve correct answers | Route to specialist—they have quick diagnostics too | +| "Single solution will fix it" | Training issues often multi-causal | Consider multi-skill routing for complex symptoms | +| "Time pressure means guess quickly" | Wrong guess wastes MORE time | Fast systematic diagnosis faster than trial-and-error | +| "They already tried X" | Maybe tried X wrong or X wasn't the issue | Route to specialist to verify X was done correctly | +| "Too complex to route" | Complex issues need specialists MORE | Use multi-skill routing for complex scenarios | +| "Direct answer is helpful" | Wrong direct answer wastes time | Routing IS the helpful answer | + +**If you catch yourself thinking ANY of these, STOP and route to specialist or ask diagnostic question.** + +--- + +## Pressure Resistance - Critical Discipline + +### Time/Emergency Pressure + +| Pressure | Wrong Response | Correct Response | +|----------|----------------|------------------| +| "Demo tomorrow, need quick fix" | Give untested suggestions | "Fast systematic diagnosis ensures demo success: [question]" | +| "Production training failing" | Panic and guess | "Quick clarification prevents longer outage: [question]" | +| "Just tell me which optimizer" | "Use Adam" | "30-second clarification ensures right choice: [question]" | + +**Emergency protocol**: Fast clarification (30 sec) → Correct routing (60 sec) → Specialist handles efficiently + +--- + +### Authority/Hierarchy Pressure + +| Pressure | Wrong Response | Correct Response | +|----------|----------------|------------------| +| "Senior said use SGD" | Accept without verification | "To apply SGD effectively, let me verify the symptoms: [question]" | +| "PM wants optimizer change" | Change without diagnosis | "Let's diagnose to ensure optimizer is the issue: [question]" | + +**Authority protocol**: Acknowledge → Verify symptoms → Route based on evidence, not opinion + +--- + +### User Self-Diagnosis Pressure + +| Pressure | Wrong Response | Correct Response | +|----------|----------------|------------------| +| "I think it's the optimizer" | Discuss optimizer choice | "What loss behavior makes you think optimizer? [diagnostic]" | +| "Obviously need to clip gradients" | Implement clipping | "What symptoms suggest gradient issues? [verify]" | + +**Verification protocol**: User attribution is hypothesis, not diagnosis. Verify with symptoms. + +--- + +## Red Flags Checklist - Self-Check Before Routing + +Before giving ANY training advice or routing, ask yourself: + +1. ❓ **Did I identify specific symptoms?** + - If no → Ask clarifying question + - If yes → Proceed + +2. ❓ **Is this symptom in my routing table?** + - If yes → Route to specialist skill + - If no → Ask diagnostic question + +3. ❓ **Am I about to give direct advice?** + - If yes → STOP. Why am I not routing? + - Check rationalization table—am I making excuses? + +4. ❓ **Could this symptom have multiple causes?** + - If yes → Ask diagnostic question to narrow down + - If no → Route confidently + +5. ❓ **Is this training-optimization or another pack?** + - PyTorch errors → pytorch-engineering + - Architecture choice → neural-architectures + - Deployment → ml-production + +6. ❓ **Am I feeling pressure to skip routing?** + - Time pressure → Route anyway (faster overall) + - Authority pressure → Verify symptoms anyway + - User self-diagnosis → Confirm with questions anyway + - Expert user → Route anyway (specialists help experts too) + +**If you failed ANY check, do NOT give direct advice. Route to specialist or ask clarifying question.** + +--- + +## Training Optimization Specialist Skills + +After routing, load the appropriate specialist skill for detailed guidance: + +1. [optimization-algorithms.md](optimization-algorithms.md) - Optimizer selection (SGD, Adam, AdamW, momentum), hyperparameter tuning, optimizer comparison +2. [learning-rate-scheduling.md](learning-rate-scheduling.md) - LR schedulers (step, cosine, exponential), warmup strategies, cyclical learning rates +3. [loss-functions-and-objectives.md](loss-functions-and-objectives.md) - Custom losses, multi-task learning, weighted objectives, numerical stability +4. [gradient-management.md](gradient-management.md) - Gradient clipping, accumulation, scaling, vanishing/exploding gradient diagnosis +5. [batch-size-and-memory-tradeoffs.md](batch-size-and-memory-tradeoffs.md) - Batch size effects on convergence, gradient accumulation, memory optimization +6. [data-augmentation-strategies.md](data-augmentation-strategies.md) - Augmentation techniques (geometric, color, mixing), policy design, AutoAugment +7. [overfitting-prevention.md](overfitting-prevention.md) - Regularization (L1/L2, dropout, weight decay), early stopping, generalization techniques +8. [training-loop-architecture.md](training-loop-architecture.md) - Training loop design, monitoring, logging, checkpointing integration, callbacks +9. [hyperparameter-tuning.md](hyperparameter-tuning.md) - Search strategies (grid, random, Bayesian), AutoML, Optuna, Ray Tune +10. [experiment-tracking.md](experiment-tracking.md) - MLflow, Weights & Biases, TensorBoard, Neptune, run comparison + +--- + +## Quick Reference: Symptom → Skills + +| Symptom | Primary Skill | Secondary Skills | Diagnostic Question | +|---------|---------------|------------------|---------------------| +| Loss flat/stuck | learning-rate-scheduling | optimization-algorithms | "Flat from start or plateaued later?" | +| Loss NaN/Inf | gradient-management | learning-rate-scheduling, loss-functions | "When does NaN appear?" | +| Overfitting | overfitting-prevention | data-augmentation, hyperparameter-tuning | "Dataset size? Current regularization?" | +| Training slow | batch-size-and-memory OR pytorch-engineering | data-augmentation | "GPU utilization percentage?" | +| Oscillating loss | learning-rate-scheduling | gradient-management | "What's your current LR?" | +| Which optimizer | optimization-algorithms | learning-rate-scheduling | Task type, architecture, dataset | +| Which LR | learning-rate-scheduling | optimization-algorithms | Optimizer, task, current symptoms | +| Track experiments | experiment-tracking | hyperparameter-tuning, training-loop | Tools preference, scale of experiments | +| Poor generalization | overfitting-prevention | data-augmentation | Train vs val accuracy gap | +| Convergence issues | gradient-management | learning-rate-scheduling, optimization-algorithms | Gradient norms, loss curve | + +--- + +## Integration Notes + +**Phase 1 - Standalone**: Training-optimization skills are self-contained and framework-agnostic. + +**Cross-references with other packs:** +- **pytorch-engineering**: Infrastructure, implementation, profiling, distributed training +- **neural-architectures**: Architecture selection (precedes training optimization) +- **deep-rl**: RL-specific training algorithms (policy gradients, Q-learning) +- **llm-specialist**: LLM-specific techniques (RLHF, LoRA) +- **ml-production**: Production/inference optimization (post-training) + +**Current focus**: Route within training-optimization for framework-agnostic training concerns. Other packs handle implementation, architecture, and domain-specific issues. + +--- + +## Summary + +**This meta-skill's job**: +1. Identify symptoms through diagnostic questions +2. Map symptoms to appropriate specialist skills +3. Route to one or multiple skills based on diagnosis +4. Resist pressure to give direct advice +5. Maintain clear boundaries with other packs + +**Remember**: Diagnose before routing. Training issues often have multiple causes. Clarify symptoms, ask diagnostic questions, then route to specialist skills. Wrong routing wastes more time than asking one clarifying question. + +**Route based on symptoms, not guesses. Let specialists do their job.** diff --git a/skills/using-training-optimization/batch-size-and-memory-tradeoffs.md b/skills/using-training-optimization/batch-size-and-memory-tradeoffs.md new file mode 100644 index 0000000..eccd56c --- /dev/null +++ b/skills/using-training-optimization/batch-size-and-memory-tradeoffs.md @@ -0,0 +1,1651 @@ + +# Batch Size and Memory Tradeoffs + +## Overview + +Batch size is one of the most misunderstood hyperparameters. Most engineers think: "larger batch = faster training = better". Wrong. Batch size affects convergence speed, generalization, memory usage, and actual wall-clock training time in complex ways. **Larger batch size is NOT always better.** + +**Core principle**: Batch size selection is a system optimization problem, not a memory constraint problem. Choose batch size based on computational speed, convergence requirements, and generalization targets - not just what fits in memory. + + +## When to Use This Skill + +**Use this skill when:** +- Choosing batch size for new training +- Training is slow and considering larger batches +- Out-of-memory errors during training +- Learning rate needs adjustment after batch size change +- Distributed training needs batch size scaling +- Gradient accumulation considerations +- User asks "what batch size should I use?" +- Training accuracy varies widely between batch sizes +- Convergence takes too long or is unstable +- Memory per sample calculation needed +- Comparing training speed: iterations vs epochs vs wall-clock time +- Fine-tuning with different batch sizes than pre-training + +**Symptoms you need this skill:** +- "I have memory, what's the maximum batch size?" (wrong question) +- "Larger batches train faster, so use 512?" (incomplete) +- "Batch size doesn't affect accuracy, only speed?" (false) +- "Gradient accumulation is a workaround for small memory?" (misconception) +- "Just scale learning rate by 2x when doubling batch size?" (incomplete) +- "We get OOM at batch 256, so use 128 forever" (not optimized) + +**Don't use when:** +- User has pure memory/infrastructure questions (use pytorch-engineering) +- User asks about optimizer selection (use optimizer-selection-framework) +- User asks about learning rate scheduling (use learning-rate-scheduling) +- User has general training failure (not batch-size specific) + + +## Core Patterns + +### Pattern 1: The Batch Size Tradeoff Space + +**The critical insight**: Batch size affects FOUR independent dimensions simultaneously. Optimize one = impact others. + +**The four dimensions:** + +``` +1. TRAINING SPEED (iterations to converge) + ├─ Larger batch → fewer iterations to convergence ✓ + ├─ BUT: Gradient variance decreases (noisier gradients are better) + └─ Result: Mixed - can't just maximize batch + +2. COMPUTATIONAL EFFICIENCY (wall-clock time) + ├─ Larger batch → amortize overhead per sample ✓ + ├─ BUT: Larger batch → need larger LR (unstable) + ├─ AND: Gradient accumulation = repeated backward (slow) + └─ Result: Optimal ≠ Maximum + +3. GENERALIZATION (test accuracy) + ├─ Smaller batch → noisier gradients → better regularization ✓ + ├─ Larger batch → cleaner gradient → overfit risk ✗ + ├─ BUT: Can compensate with stronger regularization + └─ Result: Batch size ↔ regularization coupling + +4. MEMORY USAGE (GPU memory required) + ├─ Larger batch → linear increase in activation memory + ├─ Parameters constant regardless of batch + ├─ Optimizer state constant regardless of batch + └─ Result: Memory ∝ batch size (linear only for activations) +``` + +**The mental model:** +``` +LARGER BATCH: + ✓ Fewer iterations to convergence + ✓ Better computational efficiency (up to point) + ✗ Worse generalization (harder to regularize) + ✗ Requires larger learning rate (instability risk) + ✗ Higher memory usage + +SMALLER BATCH: + ✗ More iterations to convergence + ✗ Worse computational efficiency + ✓ Better generalization (noise helps) + ✓ Smaller learning rates are stable + ✓ Lower memory usage +``` + +**Finding the sweet spot:** +- Start with batch size that uses ~80% GPU memory +- Adjust learning rate using linear scaling rule +- Monitor validation accuracy +- If validation accuracy drops → batch too large, reduce or regularize +- If training is slow → may need gradient accumulation, not larger batch + + +### Pattern 2: Linear Learning Rate Scaling Rule + +**The rule that changes everything:** + +If you increase batch size by factor K, increase learning rate by factor K. + +``` +New LR = Old LR × (New Batch Size / Old Batch Size) +``` + +**Why this works (the math):** + +``` +Gradient Descent Update: param = param - lr * gradient + +With Batch Size B, gradient is average of B samples: + gradient_B = (1/B) * sum(gradients from B samples) + update_B = lr * gradient_B + +With Batch Size 2B, gradient is average of 2B samples: + gradient_2B = (1/(2B)) * sum(gradients from 2B samples) + +Variance drops by 2x when averaging 2x more samples. +If variance drops 2x, gradient magnitude is √2x smaller. +To keep update magnitude constant: lr should increase by 2x. + +Empirically validated: Goyal et al. (2017) "Accurate, Large Batch Training" +``` + +**Implementation:** + +```python +# Pattern 1: Direct scaling +original_lr = 0.001 +original_batch_size = 32 +new_batch_size = 128 + +scaling_factor = new_batch_size / original_batch_size # 4x +new_lr = original_lr * scaling_factor # 0.004 + +# Pattern 2: When changing both batch AND learning rate +def compute_scaled_lr(base_lr, base_batch_size, current_batch_size): + """ + Compute learning rate for new batch size using linear scaling rule. + + Args: + base_lr: Learning rate at reference batch size + base_batch_size: Batch size where base_lr was tuned (usually 32 or 256) + current_batch_size: New batch size + + Returns: + Scaled learning rate + + WHY: Linear scaling rule keeps update magnitude constant + """ + scale_factor = current_batch_size / base_batch_size + return base_lr * scale_factor + +# Example: ResNet-50 training (ImageNet baseline) +# Reference: batch=256, lr=0.1 +# Now training at: batch=1024 +scaled_lr = compute_scaled_lr(0.1, 256, 1024) # 0.4 +print(f"Batch 256 with lr=0.1 → Batch 1024 with lr={scaled_lr}") +``` + +**When linear scaling works:** + +```python +# CASE 1: Scaling works well +# Batch: 32 → 256 (8x increase) +# Learning rate: 0.001 → 0.008 (8x) +# Training: ✓ Converges normally, same final accuracy +# Wall-clock: ✓ Faster (fewer iterations, better hardware utilization) + +# CASE 2: Scaling doesn't work +# Batch: 32 → 1024 (32x increase!) +# Learning rate: 0.001 → 0.032 (32x) +# Problem: Learning rate too large, training diverges +# Solution: Need warmup phase +``` + +**The Critical Caveat: WARMUP IS REQUIRED** + +```python +# WRONG: Apply full scaled LR immediately +optimizer = torch.optim.SGD(model.parameters(), lr=0.032) # Too large! +for epoch in range(100): + for batch in train_loader: + loss = criterion(model(batch), targets) + loss.backward() + optimizer.step() # Loss diverges on first iteration! + +# CORRECT: Warmup phase before scaled LR +def warmup_lr_schedule(base_lr, current_batch_size, reference_batch_size, + current_step, warmup_steps): + """ + Linear warmup from 0 to scaled LR. + + WHY: Large LR jumps can cause divergence. + Gradual warmup lets model adapt to larger updates. + """ + scaled_lr = base_lr * (current_batch_size / reference_batch_size) + + if current_step < warmup_steps: + # Linear warmup: ramp from 0 to scaled_lr + return scaled_lr * (current_step / warmup_steps) + else: + # Full scaled LR after warmup + return scaled_lr + +# Implementation with PyTorch scheduler +from torch.optim.lr_scheduler import LambdaLR + +def get_warmup_scheduler(optimizer, warmup_steps): + base_lrs = [param_group['lr'] for param_group in optimizer.param_groups] + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda) + +# Training loop +optimizer = torch.optim.SGD(model.parameters(), lr=0.032) +scheduler = get_warmup_scheduler(optimizer, warmup_steps=1000) + +for epoch in range(100): + for step, batch in enumerate(train_loader): + loss = criterion(model(batch), targets) + loss.backward() + optimizer.step() + scheduler.step() # Gradually increase LR +``` + +**Practical guidelines:** + +``` +BATCH SIZE INCREASE LEARNING RATE SCALE WARMUP NEEDED? WHY +2x (64→128) 2x (0.001→0.002) No Safe, gradual +4x (64→256) 4x (0.001→0.004) Maybe Starting to matter +8x (64→512) 8x (0.001→0.008) YES Risky without warmup +16x+ (64→1024) 16x+ (0.001→0.016) CRITICAL Risk of divergence +``` + + +### Pattern 3: Gradient Accumulation - The Alternative to Large Batches + +**What gradient accumulation does:** + +Gradient accumulation simulates large batch size without large GPU memory. Instead of 1 forward+backward of batch 256, do 8 forward+backwardsof batch 32. Same effective batch, 1/8th memory. + +**How it works:** + +```python +# SIMPLE APPROACH (without accumulation) +batch_size = 256 +effective_batch_size = 256 # Process full batch at once +memory_required = HIGH # Can't fit in GPU + +for batch in train_loader: # batch.size() = 256 + output = model(batch) + loss = criterion(output, target) + loss.backward() + optimizer.step() + +# GRADIENT ACCUMULATION APPROACH +batch_size = 32 +accumulation_steps = 8 +effective_batch_size = 32 * 8 = 256 # Same as above! +memory_required = LOW # Only batch 32 in memory at once + +optimizer.zero_grad() +for accumulation_step in range(accumulation_steps): + batch = next(iter(train_loader)) # batch.size() = 32 + output = model(batch) + loss = criterion(output, target) + loss.backward() # Accumulate gradients (don't zero!) + # Don't call optimizer.step() yet! + +optimizer.step() # Update weights after accumulation complete +# Effect: Updated weights as if we processed batch_size=256 +``` + +**When to use gradient accumulation:** + +```python +# CASE 1: Model too large to fit large batch +# Model: GPT-2 (124M parameters) +# Available GPU: 24GB +# Desired batch: 512 per GPU +# Fits in memory: No, only 32 fits +# Solution: Accumulate 16 steps of batch 32 = effective 512 + +model_params = 124_000_000 # 124M +param_memory = model_params * 4 # bytes (FP32) +optimizer_memory = model_params * 8 # Adam state (8x parameters) +batch_size = 32 +sequence_length = 512 +activation_memory_per_sample = param_memory / 10 # Rough estimate +total_memory = param_memory + optimizer_memory + (batch_size * activation_memory_per_sample) +# ~2GB memory per step +# 16 accumulation steps still << 24GB + +# CASE 2: Distributed training across 8 GPUs +# Per-GPU batch: 32 +# Number of GPUs: 8 +# Local accumulation: 4 steps +# Total effective: 32 * 8 * 4 = 1024 (synchronized across 8 GPUs) + +# Accumulation enables large total batch without massive per-GPU batch +``` + +**The memory math:** + +``` +Memory with Gradient Accumulation: + +Without accumulation (batch_size = 256): + - Parameters: Fixed + - Optimizer state: Fixed (8x params for Adam) + - Activations: O(batch_size) = O(256) + - Gradients: O(batch_size) = O(256) + - Total ≈ 1.0x baseline memory + +With accumulation (batch_size = 32, steps = 8): + - Parameters: Fixed (same) + - Optimizer state: Fixed (same) + - Activations: O(batch_size) = O(32) = 8x SMALLER + - Gradients: O(batch_size) = O(32) = 8x SMALLER + - Total ≈ 0.15x baseline memory (for activations+gradients) + +Savings: ~85% memory reduction! +Cost: 8x longer (8 backward passes instead of 1) +Net wall-clock: ~1.5-2x slower (overhead, synchronization) +``` + +**Implementation patterns:** + +```python +# Pattern 1: Manual gradient accumulation +num_accumulation_steps = 8 +optimizer.zero_grad() + +for step, (batch, target) in enumerate(train_loader): + output = model(batch) + loss = criterion(output, target) + + # Scale loss by accumulation steps + # WHY: Otherwise gradient magnitudes stack up across steps + loss = loss / num_accumulation_steps + + loss.backward() # Accumulate gradients + + if (step + 1) % num_accumulation_steps == 0: + optimizer.step() # Update after accumulation complete + optimizer.zero_grad() + +# Pattern 2: With learning rate adjustment +# IMPORTANT: Don't adjust learning rate just because of accumulation! +# Accumulation is transparent to optimizer. +# Scale is: effective_batch = batch_size * num_accumulation_steps +# So LR should match effective_batch, NOT per-GPU batch + +original_lr = 0.1 # Tuned for batch_size = 32 +num_accumulation_steps = 8 +effective_batch = 32 * 8 # 256 + +# Linear scaling rule based on effective batch: +# Batch 32 → 256 is 8x increase +# So LR: 0.1 → 0.8 (8x) +new_lr = original_lr * 8 # 0.8 +optimizer = torch.optim.SGD(model.parameters(), lr=new_lr) + +# Pattern 3: Distributed training with gradient accumulation +# Per-GPU batch: 32 +# Number of GPUs: 8 +# Accumulation steps: 4 +# Effective batch: 32 * 8 * 4 = 1024 + +from torch.nn.parallel import DistributedDataParallel as DDP + +model = DDP(model) + +num_accumulation_steps = 4 +optimizer.zero_grad() + +for step, (batch, target) in enumerate(train_loader): + output = model(batch) + loss = criterion(output, target) + loss = loss / num_accumulation_steps + + loss.backward() + + if (step + 1) % num_accumulation_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + +# Pattern 4: With synchronization (distributed) +class GradientAccumulator: + def __init__(self, model, num_accumulation_steps, sync_gradients_every=1): + self.model = model + self.num_steps = num_accumulation_steps + self.sync_every = sync_gradients_every + self.step_count = 0 + + def should_sync_gradients(self): + """ + In DDP, only sync gradients when we're about to do optimizer.step(). + This reduces communication overhead. + """ + return (self.step_count + 1) % self.sync_every == 0 + + def backward(self, loss): + loss = loss / self.num_steps + + # Only sync if we're about to step + if self.should_sync_gradients(): + loss.backward() + else: + with self.model.no_sync(): # Skip gradient sync in DDP + loss.backward() + + self.step_count += 1 +``` + +**Gradient accumulation vs large batch - when to choose:** + +```python +# When gradient accumulation is GOOD choice: +# 1. Memory-constrained (can't fit large batch) +# 2. Need large effective batch for convergence +# 3. Can tolerate ~1.5-2x slowdown +# 4. Training wall-clock time not critical + +# When gradient accumulation is BAD choice: +# 1. Can fit desired batch size in memory +# 2. Training speed is critical (wall-clock matters) +# 3. Already have good convergence with smaller batches +# 4. Reduced gradient noise is important for task + +# Comparison table: +# LARGE BATCH GRADIENT ACCUMULATION +# Memory High Low (1/accumulation) +# Wall-clock time Fast ~1.5-2x slower +# Convergence speed Good Same (effective batch same) +# Implementation Simple Requires manual loop +# Memory savings None ~85% (with 8x accumulation) +# When to use When memory OK When memory constrained +``` + + +### Pattern 4: Memory Estimation and Optimization + +**Understanding memory components:** + +``` +Total GPU Memory = Parameters + Optimizer State + Activations + Gradients + +Example: Training BERT-base (110M params) with batch_size=32, seq_len=512 + +1. PARAMETERS (Fixed) + - BERT: 110M × 4 bytes (FP32) = 440 MB + - Or 110M × 2 bytes (FP16) = 220 MB + +2. OPTIMIZER STATE (Fixed) + - SGD: No extra state = 0 MB + - Adam: m + v = 2 × params = 880 MB (FP32) or 440 MB (FP16) + - AdamW: Same as Adam + +3. ACTIVATIONS (Linear in batch_size, seq_len) + - Stored during forward pass (for backward) + - BERT layer: ~batch × seq_len × hidden_dim × 4 + - = 32 × 512 × 768 × 4 bytes + - = ~320 MB per layer + - × 12 layers = ~3.8 GB + +4. GRADIENTS (Linear in batch_size) + - Stored after backward, until optimizer.step() + - Same size as parameters = 440 MB + +TOTAL MEMORY = 440 + 880 + 3800 + 440 = ~5.6 GB +Typical budget: Use ~80% GPU = 19 GB with 24GB GPU +Room for more: Can increase batch from 32 → 128 safely +``` + +**Memory calculation framework:** + +```python +def estimate_memory_usage( + num_params: int, + batch_size: int, + seq_length: int, + hidden_dim: int, + num_layers: int, + dtype_bytes: int = 4, # 4 for FP32, 2 for FP16 + optimizer: str = "adam", # or "sgd" + use_gradient_checkpointing: bool = False, +): + """ + Estimate memory for training a transformer model. + + Args: + num_params: Total parameters + batch_size: Batch size + seq_length: Sequence length + hidden_dim: Hidden dimension (for activation estimation) + num_layers: Number of transformer layers + dtype_bytes: 4 for FP32, 2 for FP16, 1 for INT8 + optimizer: "sgd" (no state), "adam" (8x params) + use_gradient_checkpointing: If True, reduce activation memory + + Returns: + Memory in GB + + WHY: Helps choose batch size without trial-and-error OOM + """ + + # 1. Parameter memory + param_memory = num_params * dtype_bytes + + # 2. Optimizer state + if optimizer.lower() == "adam": + opt_memory = 2 * num_params * dtype_bytes # m + v + elif optimizer.lower() == "adamw": + opt_memory = 2 * num_params * dtype_bytes # m + v + else: # SGD + opt_memory = 0 + + # 3. Activation memory (transformer-specific) + # Activations = hidden states + attention weights stored during forward + # Per layer: batch × seq_len × hidden_dim × 4 bytes + # × num_layers + activation_memory_per_layer = batch_size * seq_length * hidden_dim * dtype_bytes + total_activation_memory = activation_memory_per_layer * num_layers + + if use_gradient_checkpointing: + # With checkpointing: only save activations for last layer + # (recompute others during backward) + total_activation_memory = activation_memory_per_layer # Only 1 layer + + # 4. Gradient memory (same as parameter memory) + gradient_memory = num_params * dtype_bytes + + # Total + total_bytes = param_memory + opt_memory + total_activation_memory + gradient_memory + total_gb = total_bytes / (1024**3) + + return total_gb + +# Example: BERT training +memory_gb = estimate_memory_usage( + num_params=110_000_000, # BERT-base + batch_size=32, + seq_length=512, + hidden_dim=768, + num_layers=12, + dtype_bytes=4, # FP32 + optimizer="adam", + use_gradient_checkpointing=False, +) +print(f"Memory: {memory_gb:.1f} GB") # ~5.6 GB + +# Optimize by reducing batch +memory_gb_batch16 = estimate_memory_usage( + num_params=110_000_000, + batch_size=16, # 2x smaller + seq_length=512, + hidden_dim=768, + num_layers=12, + dtype_bytes=4, + optimizer="adam", + use_gradient_checkpointing=False, +) +print(f"Memory with batch 16: {memory_gb_batch16:.1f} GB") # ~3.8 GB + +# Optimize by mixed precision +memory_gb_fp16 = estimate_memory_usage( + num_params=110_000_000, + batch_size=32, + seq_length=512, + hidden_dim=768, + num_layers=12, + dtype_bytes=2, # FP16 instead of FP32 + optimizer="adam", + use_gradient_checkpointing=False, +) +print(f"Memory with FP16: {memory_gb_fp16:.1f} GB") # ~2.8 GB + +# Optimize with checkpointing +memory_gb_ckpt = estimate_memory_usage( + num_params=110_000_000, + batch_size=32, + seq_length=512, + hidden_dim=768, + num_layers=12, + dtype_bytes=4, + optimizer="adam", + use_gradient_checkpointing=True, # Save only last layer activations +) +print(f"Memory with checkpointing: {memory_gb_ckpt:.1f} GB") # ~1.0 GB +``` + +**Memory optimization techniques:** + +```python +# Technique 1: Gradient Checkpointing +# Recompute activations instead of storing them +# Memory: O(sqrt(num_layers)) instead of O(num_layers) +# Cost: ~30% slower training (recompute activations during backward) + +from torch.utils.checkpoint import checkpoint + +class TransformerBlock(nn.Module): + def forward(self, x): + # Forward: compute and store activations + # Backward: recompute activations during backward + return checkpoint(self._forward, x, use_reentrant=False) + + def _forward(self, x): + x = self.attention(x) + x = self.feedforward(x) + return x + +# Technique 2: Mixed Precision (FP16) +# Use FP16 for forward+backward (2x memory) +# Use FP32 for weights (don't accumulate errors) +# Memory: ~50% reduction +# Speed: 1.3-2x faster on modern GPUs + +from torch.cuda.amp import autocast, GradScaler + +scaler = GradScaler() + +for batch, target in train_loader: + optimizer.zero_grad() + + with autocast(): # Automatic FP16 casting + output = model(batch) + loss = criterion(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + +# Technique 3: Quantization-Aware Training +# Store weights in INT8 or FP8 +# Requires special hardware support +# Memory: 75-90% reduction +# Speed: 2-4x faster + +# Technique 4: Batch Size Scheduling +# Start with small batch, increase during training +# Reason: Large batch early = poor generalization +# Large batch late = good generalization +# Memory: Gradually increases as needed + +def get_adaptive_batch_size(epoch, total_epochs): + """Increase batch size as training progresses""" + base_batch = 32 + max_batch = 256 + + # Linear increase: start small, end large + scale_factor = base_batch + (max_batch - base_batch) * (epoch / total_epochs) + return int(scale_factor) +``` + + +### Pattern 5: Batch Size Effects on Convergence and Generalization + +**The generalization gap - why bigger batch = worse accuracy:** + +``` +Generalization Gap = Test Accuracy (Large Batch) - Test Accuracy (Small Batch) + +Why small batch generalizes better: +1. Gradient Noise: Small batch = noisy gradients + - Noise acts as regularization + - Forces model to find robust minima + - Larger noise → larger generalization margin + +2. Loss Landscape: SGD with noise explores landscape differently + - Large batch: Gradient descent (exact gradient) + - Small batch: Stochastic gradient (noisy) + - Noise helps escape sharp minima (bad generalization) + - Leads to flat minima (good generalization) + +3. Batch Normalization Interaction: + - BN computes statistics per batch + - Larger batch → more stable statistics + - More stable → less regularization effect + - Less regularization → worse generalization + +Real numbers (ResNet-50 on ImageNet): +- Batch 256: 76.0% accuracy +- Batch 1024: 74.8% accuracy (1.2% gap!) +- Batch 4096: 72.0% accuracy (4% gap!) +``` + +**The sharp minima problem:** + +``` +SMALL BATCH (32): + Loss landscape: Finds FLAT minima + - Small change in weights → loss increases slowly + - Generalizes well (robust to input variations) + - Test accuracy ≈ Train accuracy + - Variance: Higher (gradient noise) + +LARGE BATCH (1024): + Loss landscape: Finds SHARP minima + - Small change in weights → loss increases quickly + - Generalizes poorly (sensitive to input variations) + - Test accuracy << Train accuracy (overfitting) + - Variance: Lower (stable gradients) + +SOLUTION: Add regularization to large batch training +- L2 regularization (weight decay) +- Dropout +- Data augmentation +- Label smoothing +``` + +**Batch size effects on different architectures:** + +```python +# Architecture 1: ResNets (well-studied) +# Batch 256: 76.0% top-1 accuracy (ImageNet) +# Batch 1024: 74.8% (-1.2%) +# Batch 4096: 72.0% (-4%) +# Conclusion: Batch size matters, gap grows exponentially + +# Architecture 2: Vision Transformers +# Batch 512: 82.0% accuracy +# Batch 1024: 81.8% (-0.2%) +# Batch 4096: 81.0% (-1%) +# Conclusion: Less sensitive to batch size (more robust) + +# Architecture 3: BERT (Language) +# Batch 128: 89.0% GLUE score +# Batch 256: 88.8% (-0.2%) +# Batch 512: 88.2% (-0.8%) +# Conclusion: Moderate sensitivity + +# WHY THE DIFFERENCES? +# - ResNets: Simple architecture, sharp minima +# - Vision Transformers: Attention provides regularization +# - BERT: Pre-training + fine-tuning, already regularized +``` + +**Empirical guidelines for batch size vs generalization:** + +```python +# Rule 1: Start with batch 128-256 +# Most tasks achieve good accuracy at this range +# Memory reasonable on modern GPUs +# Generalization gap minimal + +# Rule 2: If increasing batch size - add regularization +def add_regularization_for_large_batch(batch_size, base_batch=256): + """Adjust regularization strength for larger batch size""" + + # Start from base: batch 256, weight_decay 0.0001 + # Double batch → increase regularization + scale_factor = batch_size / base_batch + + weight_decay = 0.0001 * (scale_factor ** 0.5) # sqrt scale + dropout = 0.1 # Add dropout + label_smoothing = 0.1 # Label smoothing helps + + return { + 'weight_decay': weight_decay, + 'dropout': dropout, + 'label_smoothing': label_smoothing, + } + +# Rule 3: Validate on validation set +# Don't assume scaling rule works for accuracy +# Larger batch might need different epochs/learning rate schedule + +# Rule 4: Gradient accumulation doesn't help generalization +# Accumulation ≠ large batch for gradient statistics +# Gradient accumulation has same gradient per parameter +# Just takes longer (multiple backward passes) +# Generalization benefit same as if you had memory for full batch +``` + + +### Pattern 6: Finding Optimal Batch Size (Not Just Maximum) + +**The batch size selection framework:** + +``` +Step 1: Calculate memory budget + → Max memory available (e.g., 24GB GPU) + → Estimate parameters + optimizer state + → Available for batch = Total - (params + opt state) + +Step 2: Estimate per-sample memory + → Run small batch (8), measure memory + → Divide by 8 to get per-sample + → Max batch = Available Memory / per-sample + +Step 3: Find memory-safe batch + → Use 80% of max (leaves margin) + → This is maximum batch that's safe + +Step 4: Check convergence at maximum batch + → Train model with maximum safe batch + → Compare accuracy to smaller batches + → If >2% accuracy drop: reduce batch or add regularization + +Step 5: Optimize for wall-clock time + → Profile training time at different batch sizes + → Wall-clock = (iterations) × (time per iteration) + → Iterations = (samples / batch) × epochs + → Find batch that minimizes wall-clock time + → Often NOT the maximum batch! + +Step 6: Select based on task requirements + → If convergence matters more: smaller batch + → If speed matters more: larger batch + → If memory constrained: gradient accumulation + → If fine-tuning: smaller batch (preserve pre-training) +``` + +**Implementation:** + +```python +def find_optimal_batch_size( + model, + train_loader, + criterion, + device, + target_accuracy=None, + time_budget_seconds=None, +): + """ + Find optimal batch size by profiling at different sizes. + + Args: + model: PyTorch model to profile + train_loader: DataLoader + criterion: Loss function + device: torch.device + target_accuracy: If set, find batch that achieves this + time_budget_seconds: If set, find fastest batch within budget + + Returns: + Optimal batch size, profiling results + + WHY: Maximum batch ≠ optimal batch + """ + + batch_sizes = [32, 64, 128, 256, 512] + results = {} + + for batch_size in batch_sizes: + # Measure memory for this batch size + try: + batch, target = next(iter(train_loader)) + batch = batch[:batch_size].to(device) + target = target[:batch_size].to(device) + + torch.cuda.reset_peak_memory_stats(device) + with torch.cuda.device(device): + output = model(batch) + loss = criterion(output, target) + loss.backward() + + memory_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2) + + # Measure iteration time + import time + start = time.time() + for _ in range(10): + output = model(batch) + loss = criterion(output, target) + loss.backward() + iteration_time = (time.time() - start) / 10 + + # Calculate total training time + # Assume 100 epochs, 50k samples + iterations_per_epoch = 50000 // batch_size + total_iterations = iterations_per_epoch * 100 + total_time = total_iterations * iteration_time + + results[batch_size] = { + 'memory_mb': memory_mb, + 'iteration_time_ms': iteration_time * 1000, + 'total_time_hours': total_time / 3600, + } + + except RuntimeError as e: + results[batch_size] = {'error': str(e)} + + # Find optimal based on criteria + if target_accuracy is not None: + # Choose smallest batch that achieves target accuracy + return min(results.keys()) + elif time_budget_seconds is not None: + # Choose largest batch within time budget + valid = {bs: r for bs, r in results.items() + if 'error' not in r and r['total_time_hours'] * 3600 < time_budget_seconds} + return max(valid.keys()) if valid else None + else: + # Default: choose largest batch within 80% memory limit + memory_limit = 0.8 * torch.cuda.get_device_properties(device).total_memory / (1024**2) + valid = {bs: r for bs, r in results.items() + if 'error' not in r and r['memory_mb'] < memory_limit} + return max(valid.keys()) if valid else None + +# Batch size discovery loop +def discover_optimal_batch_size(model, train_loader, criterion, device): + """ + Progressive batch size search starting from small. + + Pattern: Double batch size until OOM, then back off. + """ + batch_size = 8 + + while True: + try: + # Try current batch size + batch, target = next(iter(train_loader)) + batch = batch[:batch_size].to(device) + target = target[:batch_size].to(device) + + output = model(batch) + loss = criterion(output, target) + loss.backward() + + print(f"✓ Batch {batch_size} works") + + # Try 2x + prev_batch = batch_size + batch_size *= 2 + + except RuntimeError as e: + if "out of memory" in str(e).lower(): + # OOM: go back to last working batch + optimal_batch = prev_batch + print(f"✗ Batch {batch_size} OOM") + print(f"→ Use batch size {optimal_batch} (safe margin)") + + # But check if we can use 1.5x + test_batch = int(optimal_batch * 1.5) + try: + batch = batch[:test_batch].to(device) + output = model(batch) + loss = criterion(output, target) + loss.backward() + print(f"✓ Batch {test_batch} also works, use this") + return test_batch + except: + return optimal_batch + else: + raise +``` + +**Batch size selection by use case:** + +```python +# Use Case 1: Maximum accuracy matters (research, publication) +# → Choose smaller batch (128-256) +# → More gradient noise = better generalization +# → Willing to train longer if accuracy is better + +optimal_batch_size = 128 + +# Use Case 2: Training speed matters (prototyping, iteration) +# → Choose larger batch (512-1024) +# → Trade some accuracy for wall-clock speed +# → Need to add regularization to reduce generalization gap + +optimal_batch_size = 512 +regularization_strength = 'strong' # weight_decay, dropout + +# Use Case 3: Memory severely constrained (mobile, edge) +# → Choose small batch (16-32) +# → Use gradient accumulation to simulate larger batch +# → Accept lower accuracy if necessary + +optimal_batch_size = 16 +accumulation_steps = 8 # Simulate batch 128 + +# Use Case 4: Fine-tuning small dataset +# → Choose small batch (16-32) +# → Preserve pre-training (smaller updates) +# → Larger batch risks forgetting pre-trained knowledge + +optimal_batch_size = 16 + +# Use Case 5: Large model, large dataset +# → Choose medium-large batch (256-512) +# → Gradient accumulation for effective larger batch +# → Mixed precision for memory savings + +optimal_batch_size = 256 +use_mixed_precision = True +use_gradient_accumulation = False # Fits with mixed precision + +# Use Case 6: Distributed training (multiple GPUs/TPUs) +# → Per-GPU batch: 32-64 +# → Accumulation: 4-8 steps +# → Total effective: per_gpu * num_gpus * accumulation +# → Large total effective batch, small per-GPU batch + +per_gpu_batch = 64 +num_gpus = 8 +accumulation_steps = 4 +effective_batch = 64 * 8 * 4 # 2048 +``` + + +## Common Pitfalls + +❌ **Pitfall 1: Confusing Maximum Batch with Optimal Batch** + +→ **Symptom**: "I have 24GB memory, so I should use the largest batch that fits" +→ **Why it breaks**: Larger batch = worse generalization. Maximum batch might achieve 2-3% lower accuracy. +→ **Fix**: Use 80% of maximum batch size, validate accuracy, adjust if needed. + +```python +# WRONG +max_batch = find_max_batch_that_fits(model, memory=24_000_000_000) +train(model, batch_size=max_batch) # Likely overfit + +# CORRECT +safe_batch = int(max_batch * 0.8) # 80% of maximum +train(model, batch_size=safe_batch) +validate_accuracy(model) # Check if acceptable +if accuracy_drop > 2%: + reduce_batch_size(safe_batch * 0.8) + add_regularization() +``` + + +❌ **Pitfall 2: Ignoring Learning Rate Scaling** + +→ **Symptom**: "I doubled my batch size, training diverges now" +→ **Why it breaks**: Gradient magnitudes decrease with larger batch, so learning rate must increase proportionally. +→ **Fix**: Use linear scaling rule: new_lr = old_lr × (new_batch / old_batch) + +```python +# WRONG +batch_size = 64 +learning_rate = 0.001 + +# Increase batch without adjusting LR +batch_size = 256 +# Learning rate still 0.001 - too small! +# Gradient updates too conservative, very slow convergence + +# CORRECT +batch_size = 64 +learning_rate = 0.001 + +batch_size = 256 +learning_rate = 0.001 * (256 / 64) # Scale by 4x +# = 0.004 +``` + + +❌ **Pitfall 3: Using Huge Learning Rate Without Warmup** + +→ **Symptom**: "I scaled my learning rate by 10x and now training diverges immediately" +→ **Why it breaks**: Very large learning rate jumps cause instability. Model can't adapt. +→ **Fix**: Add linear warmup phase: gradually increase LR from 0 to scaled value. + +```python +# WRONG +scaled_lr = 0.001 * 10 # 0.01 +optimizer = SGD(model, lr=0.01) +for epoch in range(100): + for batch in train_loader: + loss = criterion(model(batch), target) + loss.backward() + optimizer.step() # Diverges on first iteration! + +# CORRECT +base_lr = 0.001 +scaled_lr = 0.001 * 10 # 0.01 +warmup_steps = 1000 + +def lr_lambda(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) * 10 # 0 to 10x over warmup + return 1.0 # 10x after warmup + +optimizer = SGD(model, lr=base_lr) +scheduler = LambdaLR(optimizer, lr_lambda) + +for epoch in range(100): + for batch in train_loader: + loss = criterion(model(batch), target) + loss.backward() + optimizer.step() + scheduler.step() +``` + + +❌ **Pitfall 4: Gradient Accumulation Without LR Adjustment** + +→ **Symptom**: "I added gradient accumulation but training is much slower to converge" +→ **Why it breaks**: Accumulation itself doesn't require LR change, but if effective batch increased, LR should too. +→ **Fix**: Adjust LR based on effective batch size, not per-GPU batch size. + +```python +# WRONG +batch_size = 32 # Per-GPU +num_accumulation = 8 +# Learning rate still tuned for batch 32 + +# Effective batch = 32 × 8 = 256 +# But LR not scaled for batch 256 +# Convergence slower because LR too conservative + +# CORRECT +batch_size = 32 +num_accumulation = 8 +effective_batch = batch_size * num_accumulation # 256 + +# Get LR for batch 32 +base_lr_batch32 = 0.001 + +# Scale for batch 256 +lr_batch256 = base_lr_batch32 * (256 / 32) # 0.008 +optimizer = SGD(model, lr=lr_batch256) +``` + + +❌ **Pitfall 5: Assuming Batch Size Doesn't Affect Accuracy** + +→ **Symptom**: "Batch size only affects speed, not accuracy" +→ **Why it breaks**: Batch size strongly affects generalization (1-4% gap is common). +→ **Fix**: Always validate final accuracy at different batch sizes. Larger batch might need different hyperparameters. + +```python +# WRONG - assume accuracy independent of batch +batch_sizes = [64, 256, 1024] +for batch_size in batch_sizes: + model = train(learning_rate=0.001) # Same LR for all! + accuracy = evaluate(model) + # Accuracy will differ significantly! + +# CORRECT - adjust hyperparameters per batch +for batch_size in batch_sizes: + lr = 0.001 * (batch_size / 64) # Scale LR + weight_decay = 0.0001 * (batch_size / 64) ** 0.5 # Increase regularization + model = train(learning_rate=lr, weight_decay=weight_decay) + accuracy = evaluate(model) + # More consistent accuracy across batch sizes +``` + + +❌ **Pitfall 6: Not Considering Synchronous vs Asynchronous Batch Norm** + +→ **Symptom**: "My distributed training accuracy is much worse than single-GPU" +→ **Why it breaks**: Batch norm computes statistics per batch. Distributed training with small per-GPU batch = incorrect statistics. +→ **Fix**: Use SyncBatchNorm for correct statistics across all GPUs. + +```python +# WRONG - Synchronous data parallel, asynchronous BN +from torch.nn.parallel import DataParallel + +model = DataParallel(model, device_ids=[0, 1, 2, 3]) +# Each GPU has batch_size=32 +# BN computes stats from only its 32 samples +# Stats unstable, training broken + +# CORRECT - Synchronous batch norm +from torch.nn.modules.batchnorm import SyncBatchNorm + +model = SyncBatchNorm.convert_sync_batchnorm(model) +model = DistributedDataParallel(model, find_unused_parameters=False) +# Each GPU: batch 32, but BN aggregates across all 4 GPUs = 128 +# Stats computed from all 128 samples, stable +``` + + +❌ **Pitfall 7: Gradient Accumulation Too Large (>16x)** + +→ **Symptom**: "I'm accumulating gradients over 32 steps but training diverges" +→ **Why it breaks**: Large accumulation means many iterations of gradient computation before update. Gradients become stale, divergence risk. +→ **Fix**: Keep accumulation ≤ 16x. Use distributed training for larger effective batches. + +```python +# WRONG - excessive accumulation +batch_size = 4 +accumulation_steps = 32 # 128x effective batch! +# Gradients from step 1 are way out of date by step 32 +# Large variance in gradient estimates, divergence + +# CORRECT - reasonable accumulation +batch_size = 32 +accumulation_steps = 8 # 256x effective batch, acceptable +# Gradients only 8 iterations old by update time +# Variance manageable + +# OR use distributed training instead +per_gpu_batch = 32 +num_gpus = 8 +effective_batch = 32 * 8 = 256 # Same as above, but no accumulation +# Better convergence properties +``` + + +❌ **Pitfall 8: Mixing Gradient Accumulation with Exponential Moving Average (EMA)** + +→ **Symptom**: "I'm using gradient accumulation with learning rate scheduler and EMA, but training is unstable" +→ **Why it breaks**: EMA expects one update per step. With accumulation, multiple backward passes → stale momentum terms. +→ **Fix**: Update EMA only when you call optimizer.step(), not every backward pass. + +```python +# WRONG - updating EMA every backward pass +ema_model = ExponentialMovingAverage(model.parameters(), decay=0.999) + +for step, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + loss.backward() + + ema_model.update(model.parameters()) # Called every iteration! + + if (step + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# CORRECT - update EMA only on optimizer.step() +for step, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + loss.backward() + + if (step + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + ema_model.update(model.parameters()) # Only here! +``` + + +❌ **Pitfall 9: Batch Size Doubling Without Validation** + +→ **Symptom**: "I increased batch from 64 to 128 based on linear scaling rule, but accuracy dropped 2%" +→ **Why it breaks**: Linear scaling rule gives convergence rate, not accuracy guarantee. Generalization gap widens. +→ **Fix**: Always validate on holdout set when changing batch size. Accept accuracy drop or add regularization. + +```python +# WRONG - assume linear scaling guarantees accuracy +original_batch = 64 +original_lr = 0.001 +original_accuracy = 0.85 + +new_batch = 128 +new_lr = 0.001 * (128 / 64) # 0.002 +new_accuracy = 0.83 # Dropped 2%! Should have validated first + +# CORRECT - validate and adjust regularization if needed +new_batch = 128 +new_lr = 0.001 * (128 / 64) +model = train(lr=new_lr, batch=new_batch) +val_accuracy = validate(model) + +if val_accuracy < 0.84: # Acceptable drop? + # Add regularization for larger batch + model = train( + lr=new_lr, + batch=new_batch, + weight_decay=0.0002, # Increase + dropout=0.2, # Add/increase + ) + val_accuracy = validate(model) +``` + + +❌ **Pitfall 10: Using Maximum Batch in Fine-tuning** + +→ **Symptom**: "I fine-tuned with large batch size and catastrophically forgot pre-training" +→ **Why it breaks**: Large batch = large updates. Pre-trained weights overwritten too quickly. +→ **Fix**: Fine-tuning requires SMALLER batch size (32-64) and smaller learning rate than pre-training. + +```python +# WRONG - fine-tuning with large batch +pretrained_model = load_pretrained_bert() +batch_size = 512 # Large! +learning_rate = 0.001 # Too large! + +model = fine_tune(pretrained_model, batch_size=512, lr=0.001) +# Overfit to task, forget pre-trained knowledge +# Pre-training lost! + +# CORRECT - conservative fine-tuning +pretrained_model = load_pretrained_bert() +batch_size = 32 # Small, conservative +learning_rate = 0.00001 # Tiny, preserve pre-training + +model = fine_tune( + pretrained_model, + batch_size=batch_size, + lr=learning_rate, + weight_decay=0.001, # Strong L2 regularization +) +# Preserves pre-training knowledge, adapts carefully +``` + + +## Practical Decision Framework + +### Quick Batch Size Decision Tree + +``` +1. How much GPU memory do you have? + ├─ < 8 GB: Start with batch 16-32 + ├─ 8-16 GB: Start with batch 32-64 + ├─ 16-24 GB: Start with batch 64-128 + └─ 24+ GB: Start with batch 128-256 + +2. Can you fit your target batch in memory? + ├─ Yes: Use it (with LR scaling) + ├─ No, by <2x: Use gradient accumulation + └─ No, by >2x: Use smaller batch + stronger regularization + +3. Is accuracy your priority or speed? + ├─ Accuracy: Use smaller batch (32-128) + ├─ Speed: Use larger batch (256-1024) + └─ Both: Gradient accumulation + mixed precision + +4. Are you fine-tuning or training from scratch? + ├─ Fine-tuning: Use small batch (16-32), small LR + └─ From scratch: Use medium batch (64-256), scale LR + +5. Are you using distributed training? + ├─ Yes: Per-GPU batch 32-64, accumulate for effective 256-512 + └─ No: Single GPU batch 64-256 +``` + + +## Red Flags - Stop and Clarify + +| Excuse | Reality | What To Do | +|--------|---------|-----------| +| "Just use the maximum batch that fits" | Worse generalization likely. Need to validate accuracy. | Measure accuracy at 80% of max, validate trade-offs. | +| "Linear scaling rule means I don't need to validate" | Rule gives convergence rate, not accuracy guarantee. Generalization gap exists. | Always validate final accuracy with new batch size. | +| "Gradient accumulation is just for memory-constrained settings" | It's a legitimate technique with trade-offs (slowness) worth understanding. | Use when memory constrained; understand slowdown cost. | +| "Batch size only affects speed, not accuracy" | Incorrect. Batch size strongly affects final accuracy (1-4% typical gap). | Always measure accuracy, expect gap, add regularization. | +| "I'll use the batch size from a paper, it should work" | Different model, data, hardware - need to validate. | Use paper as starting point, but validate and adjust. | +| "Larger batch = faster training" | Depends on what you measure (iterations vs epochs vs wall-clock). | Measure actual wall-clock time at different batch sizes. | +| "Just double the learning rate when doubling batch" | Linear scaling rule requires warmup for large increases. | Add warmup phase, measure convergence. | +| "Fine-tuning works same as pre-training, just different data" | Fine-tuning needs much smaller batch and LR (preserve pre-training). | Use batch 16-32, LR 10-100x smaller than pre-training. | + + +## Advanced Patterns: Batch Size Optimization in Production + +### Pattern 7: Batch Size Scheduling During Training + +**Increasing batch size as training progresses - when and why:** + +```python +# Intuition: Start with small batch (good generalization), +# increase later (finish training faster) + +def get_scheduled_batch_size(epoch, total_epochs, base_batch=32, max_batch=256): + """ + Increase batch size linearly with epochs. + + WHY: Start small for generalization, increase for speed later. + Research shows this works well for long training. + """ + # Linear increase: 0 → 100% over training + scale = epoch / total_epochs + return int(base_batch + (max_batch - base_batch) * scale) + +# Usage in training loop +for epoch in range(total_epochs): + batch_size = get_scheduled_batch_size(epoch, total_epochs) + + for batch, target in get_data_loader(batch_size=batch_size): + # Adjust learning rate dynamically + lr = 0.001 * (batch_size / 32) # Scale with batch + update_learning_rate(optimizer, lr) + + output = model(batch) + loss = criterion(output, target) + loss.backward() + optimizer.step() + +# Alternative: exponential schedule +def get_exponential_batch_schedule(epoch, base_batch=32, max_batch=256): + """Exponential increase instead of linear""" + scale = (epoch / total_epochs) + return int(base_batch * (max_batch / base_batch) ** scale) +``` + +**When batch size scheduling is valuable:** + +``` +GOOD FIT: +- Long training (100+ epochs) +- Starting generalization is important +- Speed only matters at end +- Example: ResNet on ImageNet + +NOT NEEDED: +- Short training (10-20 epochs) +- Already regularized enough (BERT fine-tuning) +- Batch size well-chosen from start +``` + + +### Pattern 8: Batch Size vs Other Hyperparameters + +**Understanding interactions with other hyperparameters:** + +```python +# Interaction 1: Batch size ↔ Learning rate +# Already covered: linear scaling rule + +# Interaction 2: Batch size ↔ Weight decay +# Larger batch → worse generalization +# Solution: Increase weight decay when increasing batch +# Typical: weight_decay ~ sqrt(batch_size) + +def adjust_weight_decay(base_wd=0.0001, base_batch=256, new_batch=512): + """Scale weight decay with batch size""" + return base_wd * (new_batch / base_batch) ** 0.5 + +# Example +wd_batch_256 = 0.0001 +wd_batch_512 = adjust_weight_decay(wd_batch_256, 256, 512) # 0.000141 + +# Interaction 3: Batch size ↔ Dropout +# Larger batch → add/increase dropout +# Dropout magnitude depends on layer, typically 0.1-0.5 + +def adjust_dropout(base_dropout=0.1, base_batch=256, new_batch=512): + """Increase dropout for larger batches""" + # Dropout strength ~ sqrt(batch_size) + scale = (new_batch / base_batch) ** 0.5 + return min(base_dropout * scale, 0.5) # Cap at 0.5 + +# Interaction 4: Batch size ↔ Number of epochs +# Larger batch → more epochs needed to converge +# Typical: iterations constant ≈ samples/batch × epochs +# If batch 4x → epochs 1.5-2x to match convergence + +base_batch = 64 +base_epochs = 100 +base_iterations = (50000 / base_batch) * base_epochs # Total iterations + +new_batch = 256 +# To maintain same iterations: +new_epochs = base_iterations / (50000 / new_batch) # ~25 epochs +# Wall-clock faster (fewer iterations) but need fewer epochs + +# Interaction 5: Batch size ↔ Optimizer choice +# SGD: works well at all batch sizes +# Momentum: accumulates larger steps, works best with smaller batch +# Adam: adaptive, less sensitive to batch size +# RMSprop: similar to Adam + +# Recommendation: +# - Small batch (32-128): SGD with momentum or Adam +# - Large batch (512+): Adam (more stable) or SGD with warmup + large LR + +# Interaction 6: Batch size ↔ Normalization technique +# Batch Norm: statistics from batch, larger batch = better stats +# Layer Norm: independent of batch size +# Group Norm: middle ground, works well with any batch size + +# If using BatchNorm with small batch (< 16): +# → Use SyncBatchNorm across devices +# → Or use GroupNorm instead + +# If using BatchNorm with large batch (> 1024): +# → Standard BatchNorm fine +# → May want to reduce BN momentum (accumulate stats slower) +``` + + +## Rationalization Table: Common Excuses About Batch Size + +| Rationalization | Why It's Wrong | Correct Approach | +|---|---|---| +| "Larger batch is always better for speed" | Wall-clock time depends on iterations AND time-per-iteration. Larger batch may have lower throughput. | Profile wall-clock time at different batch sizes, choose fastest. | +| "I'll tune batch size last, it's not important" | Batch size affects convergence rate, generalization, and stability early. Tuning last wastes time. | Choose good batch size early (based on memory), validate accuracy. | +| "Maximum batch that fits = optimal batch" | Generalization gap widens with batch size (1-4% typical). Maximum might hit accuracy target. | Use 80% of max, validate on validation set, adjust if needed. | +| "Linear scaling rule means I don't validate" | Scaling rule gives convergence rate. Accuracy still varies with batch size due to generalization gap. | Always validate test/validation accuracy with new batch. | +| "Gradient accumulation is slow, don't use it" | True, it's slower (1.5-2x). But if memory is bottleneck, only alternative. Choose based on constraints. | Use when memory constrained. Accept slowdown. Don't use if memory OK. | +| "I don't need warmup, I'll just use scaled LR" | Large LR jumps cause divergence. Warmup prevents this. | Add linear warmup phase for scaled LR. | +| "My paper used batch X, I'll use that" | Different model, data, hardware converge differently. Paper batch might not be optimal for you. | Use paper as starting point. Validate and adjust for your setup. | +| "Fine-tuning uses same batch as pre-training" | Fine-tuning needs much smaller batch (preserve knowledge). Using pre-training batch erases pre-training. | Use batch 10-20x smaller than pre-training. Use tiny LR. | +| "Batch size only affects speed, not accuracy" | Batch size strongly affects generalization (1-4% gap common). Different final accuracy with different batch. | Expect accuracy variation with batch. Validate at each batch size. | +| "I increased batch, why is training slower?" | Fewer iterations (good) but longer per-iteration (bad). Total wall-clock = iterations × time-per-iteration. | Profile actual wall-clock time. May need gradient accumulation. | +| "I'll start with large batch to save memory" | Large batch → bad generalization early → harder to recover later. Start small, increase if needed. | Start with batch 32-64, increase during training if memory allows. | + + +## Comprehensive Example: Training a Vision Transformer + +Let's put it all together with a real example: + +```python +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.optim import AdamW +from torchvision import models, datasets, transforms + +def train_vision_transformer_optimized(): + """ + Complete example: training Vision Transformer with batch size optimization. + """ + + # Step 1: Model and data + device = torch.device("cuda:0") + model = models.vit_b_16(pretrained=False).to(device) + criterion = torch.nn.CrossEntropyLoss() + + # Dataset (ImageNet-scale) + transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ]) + + # Step 2: Determine batch size + # ViT-Base: 86M parameters + # GPU: 40GB A100 + # Memory estimate: params (344MB) + optimizer (688MB) + activations + # Can fit batch 256-512 + + base_batch = 256 + num_accumulation_steps = 1 # Can fit directly + effective_batch = base_batch + + # Step 3: Initialize optimizer with scaled LR + # Base LR tuned for batch 256 + base_lr = 1e-4 + scaled_lr = base_lr * (effective_batch / 256) # 1e-4 (no scaling needed) + + optimizer = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.05) + + # Step 4: Warmup scheduler + warmup_steps = 1000 + total_steps = 100 * len(dataset) // effective_batch + + def warmup_cosine_schedule(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + return 0.5 * (1.0 + torch.cos( + torch.tensor(3.14159) * + (step - warmup_steps) / (total_steps - warmup_steps) + )).item() + + scheduler = LambdaLR(optimizer, warmup_cosine_schedule) + + # Step 5: Training loop with gradient accumulation (even though not needed) + # Good practice for larger models + model.train() + optimizer.zero_grad() + + for epoch in range(100): + for step, (images, labels) in enumerate(train_loader): + images = images.to(device) + labels = labels.to(device) + + # Forward + backward + logits = model(images) + loss = criterion(logits, labels) + loss = loss / num_accumulation_steps + loss.backward() + + # Update on accumulation step + if (step + 1) % num_accumulation_steps == 0: + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + if step % 100 == 0: + print(f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, " + f"lr {optimizer.param_groups[0]['lr']:.2e}") + + # Validate every epoch + val_accuracy = validate(model, device) + print(f"Epoch {epoch} validation accuracy: {val_accuracy:.2%}") + + return model + +# Key patterns demonstrated: +# 1. Batch size chosen based on memory (80% of max) +# 2. Learning rate scaled for batch size +# 3. Warmup phase for gradual LR increase +# 4. Cosine annealing for LR decay +# 5. Gradient accumulation structure (even if not needed) +# 6. Gradient clipping for stability +# 7. Regular validation to monitor accuracy +``` + + +## Summary: Batch Size and Memory Decision Making + +**The core principle:** Batch size is a system design choice affecting convergence, generalization, speed, and memory simultaneously. There is no universal "right" batch size - it depends on your constraints and priorities. + +**The decision process:** + +1. **Memory constraint**: Start with 80% of maximum batch +2. **Convergence**: Scale learning rate 1:1 with batch increase (with warmup) +3. **Generalization**: Validate accuracy, reduce if gap >2% (or add regularization) +4. **Performance**: Profile wall-clock time at different batch sizes +5. **Architecture**: Different models have different optimal batches + +**The key insights:** + +- Larger batch = faster iterations but worse generalization +- Linear scaling rule requires warmup for large increases +- Gradient accumulation is a legitimate technique (understand slowdown cost) +- Fine-tuning requires smaller batch than pre-training +- Distributed training needs care with batch norm and gradient updates +- Always measure, validate, and adjust - don't assume rules apply to your case + +**The testing approach:** + +When pressure-tested, this skill should: +- Explain why maximum batch ≠ optimal batch (generalization gap) +- Provide concrete examples of linear scaling rule with warmup +- Address gradient accumulation systematically (when, why, cost) +- Discuss memory estimation and optimization techniques +- Help select batch size based on constraints AND priorities +- Resist rationalizations and always recommend validation + + +## References and Further Reading + +**Key papers:** +- Goyal et al. (2017) "Accurate, Large Batch Training" - Linear scaling rule +- You et al. (2019) "Large Batch Optimization for Deep Learning" - Theory +- Smith et al. (2017) "Don't Decay the Learning Rate" - Learning rate schedules + +**Techniques mentioned:** +- Batch Normalization: Ioffe & Szegedy (2015) +- Layer Normalization: Ba et al. (2016) +- Mixed Precision Training: Micikevicius et al. (2017) +- Gradient Checkpointing: Chen et al. (2016) + +**Related Yzmir Skills:** +- `learning-rate-scheduling` - LR schedule choices beyond linear scaling +- `gradient-management` - Gradient clipping and accumulation for stability +- `optimization-algorithms` - Optimizer selection and hyperparameter tuning + diff --git a/skills/using-training-optimization/data-augmentation-strategies.md b/skills/using-training-optimization/data-augmentation-strategies.md new file mode 100644 index 0000000..91bed2d --- /dev/null +++ b/skills/using-training-optimization/data-augmentation-strategies.md @@ -0,0 +1,1483 @@ + +# Data Augmentation Strategies + +## Overview + +Data augmentation artificially increases training data diversity by applying transformations that preserve labels. This is one of the most cost-effective ways to improve model robustness and reduce overfitting, but it requires domain knowledge and careful strength tuning. + +**Core Principle**: Augmentation is NOT a universal technique. The right augmentations depend on your domain, task, data distribution, and model capacity. Wrong augmentations can hurt more than help. + +**Critical Rule**: Augment ONLY training data. Validation and test data must remain unaugmented to provide accurate performance estimates. + +**Why Augmentation Matters**: +- Creates label-preserving variations, teaching invariance +- Reduces overfitting by preventing memorization +- Improves robustness to distribution shift +- Essentially "free" data—no labeling cost +- Can outperform adding more labeled data in some domains + + +## When to Use This Skill + +Load this skill when: +- Training on limited dataset (< 10,000 examples) and seeing overfitting +- Addressing distribution shift or robustness concerns +- Selecting augmentations for vision, NLP, audio, or tabular tasks +- Designing augmentation pipelines and strength tuning +- Troubleshooting training issues (accuracy drop with augmentation) +- Implementing test-time augmentation (TTA) or augmentation policies +- Choosing between weak augmentation (100% prob) vs strong (lower prob) + +**Don't use for**: General training debugging (use using-training-optimization), optimization algorithm selection (use optimization-algorithms), regularization without domain context (augmentation is domain-specific) + + +## Part 1: Augmentation Decision Framework + +### The Core Question: "When should I augment?" + +**WRONG ANSWER**: "Use augmentation for all datasets." + +**RIGHT APPROACH**: Use this decision framework. + +### Clarifying Questions + +1. **"How much training data do you have?"** + - < 1,000 examples → Strong augmentation needed + - 1,000-10,000 examples → Medium augmentation + - 10,000-100,000 examples → Light augmentation often sufficient + - > 100,000 examples → Augmentation helps but not critical + - Rule: Smaller dataset = more aggressive augmentation + +2. **"What's your train/validation accuracy gap?"** + - Train 90%, val 70% (20% gap) → Overfitting, augmentation will help + - Train 85%, val 83% (2% gap) → Well-regularized, augmentation optional + - Train 60%, val 58% (2% gap) → Underfitting, augmentation won't help (need more capacity) + - Rule: Large gap indicates augmentation will help + +3. **"How much distribution shift is expected at test time?"** + - Same domain, clean images → Light augmentation (rotation ±15°, crop 90%, brightness ±10%) + - Real-world conditions → Medium augmentation (rotation ±30°, crop 75%, brightness ±20%) + - Extreme conditions (weather, blur) → Strong augmentation + robust architectures + - Rule: Augment for expected shift, not beyond + +4. **"What's your domain?"** + - Vision → Rich augmentation toolkit available + - NLP → Limited augmentations (preserve syntax/semantics) + - Audio → Time/frequency domain transforms + - Tabular → SMOTE, feature dropout, noise injection + - Rule: Domain determines augmentation types + +5. **"Do you have compute budget for increased training time?"** + - Yes → Stronger augmentation possible + - No → Lighter augmentation to save training time + - Rule: Online augmentation adds ~10-20% training time + +### Decision Tree + +``` +START: Should I augment? + +├─ Is your training data < 10,000 examples? +│ ├─ YES → Augmentation will likely help. Go to Part 2 (domain selection). +│ │ +│ └─ NO → Check train/validation gap... + +├─ Is your train-validation accuracy gap > 10%? +│ ├─ YES → Augmentation will likely help. Go to Part 2. +│ │ +│ └─ NO → Continue... + +├─ Are you in a domain where distribution shift is expected? +│ │ (medical imaging varies by scanner, autonomous driving weather varies, +│ │ satellite imagery has seasonal changes, etc.) +│ ├─ YES → Augmentation will help. Go to Part 2. +│ │ +│ └─ NO → Continue... + +├─ Do you have compute budget for 10-20% extra training time? +│ ├─ YES, but data is ample → Optional: light augmentation helps margins +│ │ May improve generalization even with large data. +│ │ +│ └─ NO → Skip augmentation or use very light augmentation. + +└─ DEFAULT: Apply light-to-medium augmentation for target domain. + Start with conservative parameters. + Measure impact before increasing strength. +``` + + +## Part 2: Domain-Specific Augmentation Catalogs + +### Vision Augmentations (Image Classification, Detection, Segmentation) + +**Key Principle**: Preserve semantic content while varying appearance and geometry. + +#### Geometric Transforms (Preserve Class) + +**Rotation**: +```python +from torchvision import transforms +transform = transforms.RandomRotation(degrees=15) +# ±15° for most tasks (natural objects rotate ±15°) +# ±30° for synthetic/manufactured objects +# ±45° for symmetric objects (digits, logos) +# Avoid: ±180° (completely unrecognizable) +``` + +**When to use**: All vision tasks. Rotation-invariance is common. + +**Strength tuning**: +- Light: ±5° to ±15° (most conservative) +- Medium: ±15° to ±30° +- Strong: ±30° to ±45° (only for symmetric classes) +- Never: ±180° (makes label ambiguous) + +**Domain exceptions**: +- Medical imaging: ±10° maximum (anatomy is not rotation-invariant) +- Satellite: ±5° maximum (geographic north is meaningful) +- Handwriting: ±15° okay (natural variation) +- OCR: ±10° maximum (upside-down is different class) + + +**Crop (Random Crop + Resize)**: +```python +transform = transforms.RandomResizedCrop(224, scale=(0.8, 1.0)) +# Crops 80-100% of original, resizes to 224x224 +# Teaches invariance to framing and zoom +``` + +**When to use**: Classification, detection (with care), segmentation. + +**Strength tuning**: +- Light: scale=(0.9, 1.0) - crop 90-100% +- Medium: scale=(0.8, 1.0) - crop 80-100% +- Strong: scale=(0.5, 1.0) - crop 50-100% (can lose important features) + +**Domain considerations**: +- Detection: Minimum scale should keep objects ≥50px +- Segmentation: Crops must preserve mask validity +- Medical: Center-biased crops (avoid cutting off pathology) + + +**Horizontal Flip**: +```python +transform = transforms.RandomHorizontalFlip(p=0.5) +# Mirrors image left-right +``` + +**When to use**: Most vision tasks WHERE LEFT-RIGHT SYMMETRY IS NATURAL. + +**CRITICAL EXCEPTION**: +- ❌ Medical imaging (L/R markers mean something) +- ❌ Text/documents (flipped text is unreadable) +- ❌ Objects with semantic left/right (cars facing direction) +- ❌ Faces (though some datasets use it) + +**Safe domains**: +- ✅ Natural scene classification +- ✅ Animal classification (except directional animals) +- ✅ Generic object detection (not vehicles) + + +**Vertical Flip** (Use Rarely): +```python +transform = transforms.RandomVerticalFlip(p=0.5) +``` + +**VERY LIMITED USE**: Most natural objects are not up-down symmetric. +- ❌ Most natural images (horizon has direction) +- ❌ Medical imaging (anatomical direction matters) +- ✅ Texture classification (some textures rotationally symmetric) + + +**Perspective Transform (Affine)**: +```python +transform = transforms.RandomAffine( + degrees=0, + translate=(0.1, 0.1), # ±10% translation + scale=(0.9, 1.1), # ±10% scaling + shear=(-15, 15) # ±15° shear +) +``` + +**When to use**: Scene understanding, 3D object detection, autonomous driving. + +**Caution**: Shear and extreme perspective can make images unrecognizable. Use conservatively. + + +#### Color and Brightness Transforms (Appearance Variance) + +**Color Jitter**: +```python +transform = transforms.ColorJitter( + brightness=0.2, # ±20% brightness + contrast=0.2, # ±20% contrast + saturation=0.2, # ±20% saturation + hue=0.1 # ±10% hue shift +) +``` + +**When to use**: All vision tasks (teaches color-invariance). + +**Strength tuning**: +- Light: brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05 +- Medium: brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1 +- Strong: brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3 + +**Domain exceptions**: +- Medical imaging: brightness/contrast only (color is artificial) +- Satellite: All channels safe (handles weather/season) +- Thermal imaging: Only brightness meaningful + + +**Gaussian Blur**: +```python +from torchvision.transforms.functional import gaussian_blur +transform = transforms.GaussianBlur(kernel_size=(3, 7), sigma=(0.1, 2.0)) +``` + +**When to use**: Makes model robust to soft focus, mimics unfocused camera. + +**Strength tuning**: +- Light: sigma=(0.1, 0.5) +- Medium: sigma=(0.1, 1.0) +- Strong: sigma=(0.5, 2.0) + +**Domain consideration**: Don't blur medical/satellite (loses diagnostic/geographic detail). + + +**Grayscale**: +```python +transform = transforms.Grayscale(p=0.2) # 20% probability +``` + +**When to use**: When color information is redundant or unreliable. + +**Domain exceptions**: +- Medical imaging: Apply selectively (preserve when color is diagnostic) +- Satellite: Don't apply (multi-spectral bands are essential) +- Natural scene: Safe to apply + + +#### Mixing Augmentations (Mixup, Cutmix, Cutout) + +**Mixup**: Linear interpolation of images and labels + +```python +def mixup(x, y, alpha=1.0): + """Mixup augmentation: blend two images and labels.""" + batch_size = x.size(0) + index = torch.randperm(batch_size) + + lam = np.random.beta(alpha, alpha) # Sample mixing ratio + mixed_x = lam * x + (1 - lam) * x[index] + y_a, y_b = y, y[index] + + return mixed_x, y_a, y_b, lam + +# Use with soft labels during training: +# loss = lam * loss_fn(pred, y_a) + (1-lam) * loss_fn(pred, y_b) +``` + +**When to use**: All image classification tasks. + +**Strength tuning**: +- Light: alpha=2.0 (blends close to original) +- Medium: alpha=1.0 (uniform blending) +- Strong: alpha=0.2 (extreme blends) + +**Effectiveness**: One of the best modern augmentations, ~1-2% accuracy improvement typical. + + +**Cutmix**: Replace rectangular region with another image + +```python +def cutmix(x, y, alpha=1.0): + """CutMix augmentation: replace rectangular patch.""" + batch_size = x.size(0) + index = torch.randperm(batch_size) + + lam = np.random.beta(alpha, alpha) + height, width = x.size(2), x.size(3) + + # Sample patch coordinates + cut_ratio = np.sqrt(1.0 - lam) + cut_h = int(height * cut_ratio) + cut_w = int(width * cut_ratio) + + cx = np.random.randint(0, width) + cy = np.random.randint(0, height) + + bbx1 = np.clip(cx - cut_w // 2, 0, width) + bby1 = np.clip(cy - cut_h // 2, 0, height) + bbx2 = np.clip(cx + cut_w // 2, 0, width) + bby2 = np.clip(cy + cut_h // 2, 0, height) + + x[index, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2] + + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1)) / (height * width) + + return x, y, y[index], lam +``` + +**When to use**: Image classification (especially effective). + +**Advantage over Mixup**: Preserves spatial structure better, more realistic. + +**Typical improvement**: 1-3% accuracy increase. + + +**Cutout**: Remove rectangular patch (fill with zero/mean) + +```python +def cutout(x, patch_size=32, p=0.5): + """Cutout: remove rectangular region.""" + if np.random.rand() > p: + return x + + batch_size, _, height, width = x.size() + + for i in range(batch_size): + cx = np.random.randint(0, width) + cy = np.random.randint(0, height) + + x1 = np.clip(cx - patch_size // 2, 0, width) + y1 = np.clip(cy - patch_size // 2, 0, height) + x2 = np.clip(cx + patch_size // 2, 0, width) + y2 = np.clip(cy + patch_size // 2, 0, height) + + x[i, :, y1:y2, x1:x2] = 0 + + return x +``` + +**When to use**: Regularization effect, teaches local invariance. + +**Typical improvement**: 0.5-1% accuracy increase. + + +#### AutoAugment and Learned Policies + +**RandAugment**: Random selection from augmentation space + +```python +from torchvision.transforms import RandAugment + +transform = RandAugment(num_ops=2, magnitude=9) +# Apply 2 random augmentations from 14 operation space +# Magnitude 0-30 controls strength +``` + +**When to use**: When unsure about augmentation selection. + +**Advantage**: Removes manual hyperparameter tuning. + +**Typical improvement**: 1-2% accuracy compared to manual selection. + + +**AutoAugment**: Data-dependent learned policy + +```python +from torchvision.transforms import AutoAugment, AutoAugmentPolicy + +transform = AutoAugment(AutoAugmentPolicy.IMAGENET) +# Predefined policy for ImageNet-like tasks +# Policies: IMAGENET, CIFAR10, SVHN +``` + +**Pre-trained policies**: +- IMAGENET: General-purpose, vision tasks +- CIFAR10: Smaller images (32x32), high regularization +- SVHN: Street view house numbers + +**Typical improvement**: 0.5-1% accuracy. + + +### NLP Augmentations (Text Classification, QA, Generation) + +**Key Principle**: Preserve meaning while varying surface form. Syntax and semantics must be preserved. + +#### Rule-Based Augmentations + +**Back-Translation**: +```python +def back_translate(text: str, src_lang='en', inter_lang='fr') -> str: + """Translate to intermediate language and back to create paraphrase.""" + # English -> French -> English + # Example: "The cat sat on mat" -> "Le chat s'assit sur le tapis" -> "The cat sat on the mat" + + # Use library like transformers or marian-mt + from transformers import MarianMTModel, MarianTokenizer + + # Translate en->fr + model_name = f"Helsinki-NLP/Opus-MT-{src_lang}-{inter_lang}" + tokenizer = MarianTokenizer.from_pretrained(model_name) + model = MarianMTModel.from_pretrained(model_name) + + inputs = tokenizer(text, return_tensors="pt") + outputs = model.generate(**inputs) + intermediate = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + + # Translate fr->en + model_name_back = f"Helsinki-NLP/Opus-MT-{inter_lang}-{src_lang}" + tokenizer_back = MarianTokenizer.from_pretrained(model_name_back) + model_back = MarianMTModel.from_pretrained(model_name_back) + + inputs_back = tokenizer_back(intermediate, return_tensors="pt") + outputs_back = model_back.generate(**inputs_back) + result = tokenizer_back.batch_decode(outputs_back, skip_special_tokens=True)[0] + + return result +``` + +**When to use**: Text classification, sentiment analysis, intent detection. + +**Strength tuning**: +- Use 1-2 intermediate languages +- Probability 0.3-0.5 (paraphrases, not all data) + +**Advantage**: Creates natural paraphrases. + +**Disadvantage**: Slow (requires neural translation model). + + +**Synonym Replacement (EDA)**: +```python +import nltk +from nltk.corpus import wordnet + +def synonym_replacement(text: str, n=2): + """Replace n random words with synonyms.""" + words = text.split() + new_words = words.copy() + + random_word_list = list(set([word for word in words if word.isalnum()])) + random.shuffle(random_word_list) + + num_replaced = 0 + for random_word in random_word_list: + synonyms = get_synonyms(random_word) + if len(synonyms) > 0: + synonym = random.choice(synonyms) + new_words = [synonym if word == random_word else word for word in new_words] + num_replaced += 1 + if num_replaced >= n: + break + + return ' '.join(new_words) + +def get_synonyms(word): + """Find synonyms using WordNet.""" + synonyms = set() + for syn in wordnet.synsets(word): + for lemma in syn.lemmas(): + synonyms.add(lemma.name()) + return list(synonyms - {word}) +``` + +**When to use**: Text classification, low-resource languages. + +**Strength tuning**: +- n=1-3 synonyms per sentence +- Probability 0.5 (replace in half of training data) + +**Typical improvement**: 1-2% for small datasets. + + +**Random Insertion**: +```python +def random_insertion(text: str, n=2): + """Insert n random synonyms of random words.""" + words = text.split() + new_words = words.copy() + + for _ in range(n): + add_word(new_words) + + return ' '.join(new_words) + +def add_word(new_words): + synonyms = [] + counter = 0 + while len(synonyms) < 1: + if counter >= 10: + return + random_word = new_words[random.randint(0, len(new_words)-1)] + synonyms = get_synonyms(random_word) + counter += 1 + + random_synonym = synonyms[random.randint(0, len(synonyms)-1)] + random_idx = random.randint(0, len(new_words)-1) + new_words.insert(random_idx, random_synonym) +``` + +**When to use**: Text classification, paraphrase detection. + + +**Random Swap**: +```python +def random_swap(text: str, n=2): + """Randomly swap positions of n word pairs.""" + words = text.split() + new_words = words.copy() + + for _ in range(n): + new_words = swap_word(new_words) + + return ' '.join(new_words) + +def swap_word(new_words): + random_idx_1 = random.randint(0, len(new_words)-1) + random_idx_2 = random_idx_1 + + counter = 0 + while random_idx_2 == random_idx_1: + random_idx_2 = random.randint(0, len(new_words)-1) + counter += 1 + if counter > 3: + return new_words + + new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] + return new_words +``` + +**When to use**: Robustness to word order variations. + + +**Random Deletion**: +```python +def random_deletion(text: str, p=0.2): + """Randomly delete words with probability p.""" + if len(text.split()) == 1: + return text + + words = text.split() + new_words = [word for word in words if random.uniform(0, 1) > p] + + if len(new_words) == 0: + return random.choice(words) + + return ' '.join(new_words) +``` + +**When to use**: Robustness to missing/incomplete input. + + +#### Sentence-Level Augmentations + +**Paraphrase Generation**: +```python +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +def paraphrase(text: str): + """Generate paraphrase using pretrained model.""" + model_name = "Vamsi/T5_Paraphrase_Paws" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + input_ids = tokenizer.encode(text, return_tensors="pt") + outputs = model.generate(input_ids) + paraphrase = tokenizer.decode(outputs[0], skip_special_tokens=True) + + return paraphrase +``` + +**When to use**: Text classification with limited data. + +**Advantage**: High-quality semantic paraphrases. + +**Disadvantage**: Model-dependent, can be slow. + + +### Audio Augmentations (Speech Recognition, Music) + +**Key Principle**: Preserve content while varying acoustic conditions. + +**Pitch Shift**: +```python +import librosa +import numpy as np + +def pitch_shift(waveform: np.ndarray, sr: int, steps: int): + """Shift pitch without changing speed.""" + # Shift by ±2-4 semitones typical + return librosa.effects.pitch_shift(waveform, sr=sr, n_steps=steps) + +# Usage: +audio, sr = librosa.load('audio.wav') +augmented = pitch_shift(audio, sr, steps=np.random.randint(-4, 5)) +``` + +**When to use**: Speech recognition (speaker variation). + +**Strength tuning**: +- Light: ±2 semitones +- Medium: ±4 semitones +- Strong: ±8 semitones (avoid, changes phone identity) + + +**Time Stretching**: +```python +def time_stretch(waveform: np.ndarray, rate: float): + """Speed up/slow down without changing pitch.""" + return librosa.effects.time_stretch(waveform, rate=rate) + +# Usage: +augmented = time_stretch(audio, rate=np.random.uniform(0.9, 1.1)) # ±10% speed +``` + +**When to use**: Speech recognition (speech rate variation). + +**Strength tuning**: +- Light: 0.95-1.05 (±5% speed) +- Medium: 0.9-1.1 (±10% speed) +- Strong: 0.8-1.2 (±20% speed, too aggressive) + + +**Background Noise Injection**: +```python +def add_background_noise(waveform: np.ndarray, noise: np.ndarray, snr_db: float): + """Add noise at specified SNR (signal-to-noise ratio).""" + signal_power = np.mean(waveform ** 2) + snr_linear = 10 ** (snr_db / 10) + noise_power = signal_power / snr_linear + + noise_scaled = noise * np.sqrt(noise_power / np.mean(noise ** 2)) + + # Mix only first len(waveform) samples of noise + augmented = waveform + noise_scaled[:len(waveform)] + return np.clip(augmented, -1, 1) # Prevent clipping + +# Usage: +noise, _ = librosa.load('background_noise.wav', sr=sr) +augmented = add_background_noise(audio, noise, snr_db=np.random.uniform(15, 30)) +``` + +**When to use**: Speech recognition, robustness to noisy environments. + +**Strength tuning**: +- Light: SNR 30-40 dB (minimal noise) +- Medium: SNR 20-30 dB (moderate noise) +- Strong: SNR 10-20 dB (very noisy, challenging) + + +**SpecAugment**: Augmentation in spectrogram space + +```python +def spec_augment(mel_spec: np.ndarray, freq_mask_width: int, time_mask_width: int): + """Apply frequency and time masking to mel-spectrogram.""" + freq_axis_size = mel_spec.shape[0] + time_axis_size = mel_spec.shape[1] + + # Frequency masking + f0 = np.random.randint(0, freq_axis_size - freq_mask_width) + mel_spec[f0:f0+freq_mask_width, :] = 0 + + # Time masking + t0 = np.random.randint(0, time_axis_size - time_mask_width) + mel_spec[:, t0:t0+time_mask_width] = 0 + + return mel_spec + +# Usage: +mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr) +augmented = spec_augment(mel_spec, freq_mask_width=30, time_mask_width=40) +``` + +**When to use**: Speech recognition (standard for ASR). + + +### Tabular Augmentations (Regression, Classification on Structured Data) + +**Key Principle**: Preserve relationships between features while adding noise/variation. + +**SMOTE (Synthetic Minority Over-sampling)**: +```python +from imblearn.over_sampling import SMOTE + +# Balance imbalanced classification +X_train = your_features # shape: (n_samples, n_features) +y_train = your_labels + +smote = SMOTE(random_state=42) +X_resampled, y_resampled = smote.fit_resample(X_train, y_train) + +# Now X_resampled has balanced classes with synthetic minority examples +``` + +**When to use**: Imbalanced classification (rare class oversampling). + +**Advantage**: Addresses class imbalance by creating synthetic examples. + + +**Feature-wise Noise Injection**: +```python +def add_noise_to_features(X: np.ndarray, noise_std: float): + """Add Gaussian noise to features (percentage of feature std).""" + noise = np.random.normal(0, noise_std, X.shape) + # Scale noise to percentage of feature std + feature_stds = np.std(X, axis=0) + scaled_noise = noise * (feature_stds * noise_std) + return X + scaled_noise +``` + +**When to use**: Robustness to measurement noise. + +**Strength tuning**: +- Light: noise_std=0.01 (1% of feature std) +- Medium: noise_std=0.05 (5% of feature std) +- Strong: noise_std=0.1 (10% of feature std) + + +**Feature Dropout**: +```python +def feature_dropout(X: np.ndarray, p: float): + """Randomly set features to zero.""" + mask = np.random.binomial(1, 1-p, X.shape) + return X * mask +``` + +**When to use**: Robustness to missing/unavailable features. + +**Strength tuning**: +- p=0.1 (drop 10% of features) +- p=0.2 (drop 20%) +- Avoid p>0.3 (too much information loss) + + +**Mixup for Tabular Data**: +```python +def mixup_tabular(X: np.ndarray, y: np.ndarray, alpha: float = 1.0): + """Apply mixup to tabular features.""" + batch_size = X.shape[0] + index = np.random.permutation(batch_size) + lam = np.random.beta(alpha, alpha) + + X_mixed = lam * X + (1 - lam) * X[index] + y_a, y_b = y, y[index] + + return X_mixed, y_a, y_b, lam +``` + +**When to use**: Regression and classification on tabular data. + + +## Part 3: Augmentation Strength Tuning + +### Conservative vs Aggressive Augmentation + +**Principle**: Start conservative, increase gradually. Test impact. + +#### Weak Augmentation (100% probability) + +Apply light augmentation to ALL training data, EVERY epoch. + +```python +weak_augmentation = transforms.Compose([ + transforms.RandomRotation(degrees=10), + transforms.ColorJitter(brightness=0.1, contrast=0.1), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), +]) +``` + +**Typical improvement**: +1-2% accuracy. + +**Pros**: +- Consistent, no randomness in augmentation strength +- Easier to reproduce +- Less prone to catastrophic augmentation + +**Cons**: +- Each image same number of times +- Less diversity per image + + +#### Strong Augmentation (Lower Probability) + +Apply strong augmentations with 30-50% probability. + +```python +strong_augmentation = transforms.Compose([ + transforms.RandomRotation(degrees=45), + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), + transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), shear=(15, 15)), + transforms.RandomPerspective(distortion_scale=0.3), +]) + +class StrongAugmentationWrapper: + def __init__(self, transform, p=0.3): + self.transform = transform + self.p = p + + def __call__(self, x): + if np.random.rand() < self.p: + return self.transform(x) + return x + +aug_wrapper = StrongAugmentationWrapper(strong_augmentation, p=0.3) +``` + +**Typical improvement**: +2-3% accuracy. + +**Pros**: +- More diversity +- Better robustness to extreme conditions + +**Cons**: +- Risk of too-aggressive augmentation +- Requires careful strength tuning + + +### Finding Optimal Strength + +**Algorithm**: + +1. Start with weak augmentation (parameters at 50% of expected range) +2. Train for 1 epoch, measure validation accuracy +3. Keep weak augmentation for full training +4. Increase strength by 25% and retrain +5. Compare final accuracies +6. If accuracy improved, increase further; if hurt, decrease +7. Stop when accuracy plateaus or decreases + +**Example**: + +```python +# Start: rotation ±10°, brightness ±0.1 +# After test 1: accuracy improves, try rotation ±15°, brightness ±0.15 +# After test 2: accuracy improves, try rotation ±20°, brightness ±0.2 +# After test 3: accuracy decreases, revert to rotation ±15°, brightness ±0.15 +``` + + +## Part 4: Test-Time Augmentation (TTA) + +**Definition**: Apply augmentation at inference time, average predictions. + +```python +def predict_with_tta(model, image, num_augmentations=8): + """Make predictions with test-time augmentation.""" + predictions = [] + + for _ in range(num_augmentations): + # Apply light augmentation + augmented = augmentation(image) + with torch.no_grad(): + pred = model(augmented.unsqueeze(0)) + predictions.append(pred.softmax(dim=1)) + + # Average predictions + final_pred = torch.stack(predictions).mean(dim=0) + return final_pred +``` + +**When to use**: +- Final evaluation (test set submission) +- Robustness testing +- Post-training calibration + +**Don't use for**: +- Validation (metrics must reflect single-pass performance) +- Production inference (too slow, accuracy not worth inference latency) + +**Typical improvement**: +0.5-1% accuracy. + +**Computational cost**: 8-10x slower inference. + + +## Part 5: Common Pitfalls and Rationalization + +### Pitfall 1: Augmenting Validation/Test Data + +**Symptom**: Validation accuracy inflated, test performance poor. + +**User Says**: "More diversity helps, so augment everywhere" + +**Why It Fails**: Validation measures true performance on ORIGINAL data, not augmented. + +**Fix**: +```python +# WRONG: +val_transform = transforms.Compose([ + transforms.RandomRotation(20), + transforms.ToTensor(), +]) + +# RIGHT: +val_transform = transforms.Compose([ + transforms.ToTensor(), +]) +``` + + +### Pitfall 2: Over-Augmentation (Unrecognizable Images) + +**Symptom**: Training loss doesn't decrease, accuracy worse with augmentation. + +**User Says**: "More augmentation = more robustness" + +**Why It Fails**: If image unrecognizable, model cannot learn the class. + +**Fix**: Start conservative. Test incrementally. + + +### Pitfall 3: Wrong Domain Augmentations + +**Symptom**: Accuracy drops with augmentation. + +**User Says**: "These augmentations work for images, why not text?" + +**Why It Fails**: Flipped text is unreadable. Domain-specific invariances differ. + +**Fix**: Use augmentations designed for your domain. + + +### Pitfall 4: Augmentation Inconsistency Across Train/Val + +**Symptom**: Model overfits, ignores augmentation benefit. + +**User Says**: "I normalize images, so different augmentation pipelines okay" + +**Why It Fails**: Train augmentation must be intentional; val must not have it. + +**Fix**: Explicitly separate training and validation transforms. + + +### Pitfall 5: Ignoring Label Semantics + +**Symptom**: Model predicts wrong class after augmentation. + +**User Says**: "The label is preserved, so any transformation okay" + +**Why It Fails**: Extreme transformations obscure discriminative features. + +**Example**: Medical image rotated 180° may have artifacts that change diagnosis. + +**Fix**: Consider label semantics, not just label preservation. + + +### Pitfall 6: No Augmentation on Small Dataset + +**Symptom**: Severe overfitting, poor generalization. + +**User Says**: "My data is unique, standard augmentations won't help" + +**Why It Fails**: Overfitting still happens, augmentation reduces it. + +**Fix**: Use domain-appropriate augmentations even on small datasets. + + +### Pitfall 7: Augmentation Not Reproducible + +**Symptom**: Different training runs give different results. + +**User Says**: "Random augmentation is fine, natural variation" + +**Why It Fails**: Makes debugging impossible, non-reproducible research. + +**Fix**: Set random seeds for reproducible augmentation. + +```python +import random +import numpy as np +import torch + +random.seed(42) +np.random.seed(42) +torch.manual_seed(42) +``` + + +### Pitfall 8: Using One Augmentation Policy for All Tasks + +**Symptom**: Augmentation works for classification, hurts for detection. + +**User Says**: "Augmentation is general, works everywhere" + +**Why It Fails**: Detection needs different augmentations (preserve boxes). + +**Fix**: Domain AND task-specific augmentation selection. + + +### Pitfall 9: Augmentation Overhead Too High + +**Symptom**: Training 2x slower, minimal accuracy improvement. + +**User Says**: "Augmentation is worth the overhead" + +**Why It Fails**: Sometimes it is, sometimes not. Measure impact. + +**Fix**: Profile training time. Balance overhead vs accuracy gain. + + +### Pitfall 10: Mixing Incompatible Augmentations + +**Symptom**: Unexpected behavior, degraded performance. + +**User Says**: "Combining augmentations = better diversity" + +**Why It Fails**: Some augmentations conflict or overlap. + +**Example**: CutMix + random crop can create strange patches. + +**Fix**: Design augmentation pipelines carefully, test combinations. + + +## Part 6: Augmentation Policy Design + +### Step-by-Step Augmentation Design + +**Step 1: Identify invariances in your domain** + +What transformations preserve the class label? + +- Vision: Rotation ±15° (natural), flip (depends), color jitter (yes) +- Text: Synonym replacement (yes), flip sentence (no) +- Audio: Pitch shift ±4 semitones (yes), time stretch ±20% (yes) +- Tabular: Feature noise (yes), feature permutation (no) + +**Step 2: Select weak augmentations** + +Choose conservative parameters. + +```python +weak_aug = transforms.Compose([ + transforms.RandomRotation(degrees=15), + transforms.ColorJitter(brightness=0.1), +]) +``` + +**Step 3: Measure impact** + +Train with/without augmentation, compare validation accuracy. + +```python +# Without augmentation +model_no_aug = train(no_aug_transforms, epochs=10) +val_acc_no_aug = evaluate(model_no_aug, val_loader) + +# With weak augmentation +model_weak_aug = train(weak_aug, epochs=10) +val_acc_weak_aug = evaluate(model_weak_aug, val_loader) + +print(f"Without augmentation: {val_acc_no_aug}") +print(f"With weak augmentation: {val_acc_weak_aug}") +``` + +**Step 4: Increase gradually if beneficial** + +If augmentation helped, increase strength 25%. + +```python +medium_aug = transforms.Compose([ + transforms.RandomRotation(degrees=20), # ±20° vs ±15° + transforms.ColorJitter(brightness=0.15), # 0.15 vs 0.1 +]) + +model_medium = train(medium_aug, epochs=10) +val_acc_medium = evaluate(model_medium, val_loader) +``` + +**Step 5: Stop when improvement plateaus** + +When accuracy no longer improves, use previous best parameters. + + +### Augmentation for Different Dataset Sizes + +**< 1,000 examples**: Heavy augmentation needed +```python +heavy_aug = transforms.Compose([ + transforms.RandomRotation(degrees=30), + transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), + transforms.ColorJitter(brightness=0.3, contrast=0.3), + transforms.RandomAffine(degrees=0, shear=15), + transforms.RandomHorizontalFlip(p=0.5), +]) +``` + +**1,000-10,000 examples**: Medium augmentation +```python +medium_aug = transforms.Compose([ + transforms.RandomRotation(degrees=15), + transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + transforms.RandomHorizontalFlip(p=0.5), +]) +``` + +**10,000-100,000 examples**: Light augmentation +```python +light_aug = transforms.Compose([ + transforms.RandomRotation(degrees=10), + transforms.ColorJitter(brightness=0.1), + transforms.RandomHorizontalFlip(p=0.3), +]) +``` + +**> 100,000 examples**: Minimal augmentation (optional) +```python +minimal_aug = transforms.Compose([ + transforms.ColorJitter(brightness=0.05), +]) +``` + + +## Part 7: Augmentation Composition Strategies + +### Sequential vs Compound Augmentation + +**Sequential** (Apply transforms in sequence, each has independent probability): + +```python +# Sequential: each transform independent +sequential = transforms.Compose([ + transforms.RandomRotation(degrees=15), # 100% probability + transforms.ColorJitter(brightness=0.2), # 100% probability + transforms.RandomHorizontalFlip(p=0.5), # 50% probability +]) +# Result: Always rotate and color jitter, sometimes flip +# Most common approach +``` + +**Compound** (Random selection of augmentation combinations): + +```python +# Compound: choose one from alternatives +def compound_augmentation(image): + choice = np.random.choice(['light', 'medium', 'heavy']) + + if choice == 'light': + return light_aug(image) + elif choice == 'medium': + return medium_aug(image) + else: + return heavy_aug(image) +``` + +**When to use compound**: +- When augmentations conflict +- When you want balanced diversity +- When computational resources limited + + +### Augmentation Order Matters + +Some augmentations should be applied in specific order: + +**Optimal order**: +1. Geometric transforms first (rotation, shear, perspective) +2. Cropping (RandomResizedCrop) +3. Flipping (horizontal, vertical) +4. Color/intensity transforms (brightness, contrast, hue) +5. Final normalization + +```python +optimal_order = transforms.Compose([ + transforms.RandomRotation(15), + transforms.RandomAffine(degrees=0, shear=10), + transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), +]) +``` + +**Why**: Geometric first (operate on pixel coordinates), then color (invariant to coordinate changes). + + +### Probability-Based Augmentation Control + +**Weak augmentation** (apply to all data): + +```python +# Weak: always apply +weak = transforms.Compose([ + transforms.RandomRotation(degrees=10), + transforms.ColorJitter(brightness=0.1), + transforms.RandomHorizontalFlip(p=0.5), +]) + +# Apply to every training image +for epoch in range(epochs): + for images, labels in train_loader: + images = weak(images) + # ... train +``` + +**Strong augmentation with probability**: + +```python +class ProbabilisticAugmentation: + def __init__(self, transform, p: float): + self.transform = transform + self.p = p + + def __call__(self, x): + if np.random.rand() < self.p: + return self.transform(x) + return x + +# Use strong augmentation with 30% probability +strong = transforms.Compose([ + transforms.RandomRotation(degrees=45), + transforms.ColorJitter(brightness=0.4), +]) +probabilistic = ProbabilisticAugmentation(strong, p=0.3) + +# Each image: 70% unaugmented (training signal), 30% strongly augmented +``` + + +## Part 8: Augmentation for Specific Tasks + +### Augmentation for Object Detection + +**Challenge**: Must preserve bounding boxes after augmentation. + +**Strategy**: Use augmentations that preserve geometry or can remap boxes. + +```python +from albumentations import ( + HorizontalFlip, VerticalFlip, Rotate, ColorJitter, Resize, Compose +) + +# Albumentations handles box remapping automatically +detection_augmentation = Compose([ + HorizontalFlip(p=0.5), + Rotate(limit=15, p=0.5), + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.5), +], bbox_params=BboxParams(format='pascal_voc', label_fields=['labels'])) + +# Usage: +image, boxes, labels = detection_sample +augmented = detection_augmentation( + image=image, + bboxes=boxes, + labels=labels +) +``` + +**Safe augmentations**: +- ✅ Horizontal flip (adjust box x-coordinates) +- ✅ Crop (clip boxes to cropped region) +- ✅ Rotate ±15° (remaps box corners) +- ✅ Color jitter (no box changes) + +**Avoid**: +- ❌ Vertical flip (semantic meaning changes for many objects) +- ❌ Perspective distortion (complex box remapping) +- ❌ Large rotation (hard to remap boxes) + + +### Augmentation for Semantic Segmentation + +**Challenge**: Masks must be transformed identically to images. + +**Strategy**: Apply same transform to both image and mask. + +```python +from albumentations import ( + HorizontalFlip, RandomCrop, Rotate, ColorJitter, Compose +) + +segmentation_augmentation = Compose([ + HorizontalFlip(p=0.5), + Rotate(limit=15, p=0.5), + RandomCrop(height=256, width=256), + ColorJitter(brightness=0.2, contrast=0.2, p=0.5), +], keypoint_params=KeypointParams(format='xy')) + +# Usage: +image, mask = segmentation_sample +augmented = segmentation_augmentation(image=image, mask=mask) +image_aug, mask_aug = augmented['image'], augmented['mask'] +``` + +**Key requirement**: Image and mask transformed identically. + + +### Augmentation for Fine-Grained Classification + +**Challenges**: Small objects, subtle differences between classes. + +**Strategy**: Use conservative geometric transforms, aggressive color/texture. + +```python +# Fine-grained: preserve structure, vary appearance +fine_grained = transforms.Compose([ + transforms.RandomRotation(degrees=5), # Conservative rotation + transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), # Minimal crop + transforms.ColorJitter(brightness=0.3, contrast=0.3), # Aggressive color + transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), +]) +``` + +**Avoid**: +- Large crops (lose discriminative details) +- Extreme rotations (change object orientation) +- Perspective distortion (distorts fine structures) + + +### Augmentation for Medical Imaging + +**Critical requirements**: Domain-specific, label-preserving, anatomically valid. + +```python +# Medical imaging augmentation (conservative) +medical_aug = transforms.Compose([ + transforms.RandomRotation(degrees=10), # Max ±10° + transforms.ColorJitter(brightness=0.1, contrast=0.1), + # Avoid: vertical flip (anatomical direction), excessive crop +]) + +# Never apply: +# - Vertical flip (anatomy has direction) +# - Random crops cutting off pathology +# - Extreme color transforms (diagnostic colors matter) +# - Perspective distortion (can distort anatomy) +``` + +**Domain-specific augmentations for medical**: +- ✅ Elastic deformation (models anatomical variation) +- ✅ Rotation ±10° (patient positioning variation) +- ✅ Small brightness/contrast (scanner variation) +- ✅ Gaussian blur (image quality variation) + + +### Augmentation for Time Series / Sequences + +**For 1D sequences** (signal processing, ECG, EEG): + +```python +def jitter(x: np.ndarray, std: float = 0.01): + """Add small random noise to sequence.""" + return x + np.random.normal(0, std, x.shape) + +def scaling(x: np.ndarray, scale: float = 0.1): + """Scale magnitude of sequence.""" + return x * np.random.uniform(1 - scale, 1 + scale) + +def rotation(x: np.ndarray): + """Rotate in 2D space (for multivariate sequences).""" + theta = np.random.uniform(-np.pi/4, np.pi/4) + rotation_matrix = np.array([ + [np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)] + ]) + return x @ rotation_matrix.T + +def magnitude_warping(x: np.ndarray, sigma: float = 0.2): + """Apply smooth scaling variations.""" + knots = np.linspace(0, len(x), 5) + values = np.random.normal(1, sigma, len(knots)) + from scipy.interpolate import interp1d + smooth_scale = interp1d(knots, values, kind='cubic')(np.arange(len(x))) + return x * smooth_scale[:, np.newaxis] + +def window_slicing(x: np.ndarray, window_ratio: float = 0.1): + """Reduce window size, then scale back to original length.""" + window_size = int(len(x) * window_ratio) + start = np.random.randint(0, len(x) - window_size) + x_sliced = x[start:start + window_size] + # Interpolate back to original length + from scipy.interpolate import interp1d + f = interp1d(np.arange(len(x_sliced)), x_sliced, axis=0, kind='linear', + fill_value='extrapolate') + return f(np.linspace(0, len(x_sliced)-1, len(x))) +``` + + +## Part 9: Augmentation Red Flags and Troubleshooting + +### Red Flags: When Augmentation Is Hurting + +1. **Validation accuracy DECREASES with augmentation** + - Likely: Too aggressive augmentation + - Solution: Reduce augmentation strength by 50%, retrain + +2. **Training loss doesn't decrease** + - Likely: Images too distorted to learn + - Solution: Visualize augmented images, check if recognizable + +3. **Test accuracy much worse than validation** + - Likely: Validation data accidentally augmented + - Solution: Check transform pipelines, ensure validation/test unaugmented + +4. **High variance in results across runs** + - Likely: Augmentation randomness not seeded + - Solution: Set random seeds for reproducibility + +5. **Specific class performance drops with augmentation** + - Likely: Augmentation inappropriate for that class + - Solution: Design class-specific augmentation (or disable for that class) + +6. **Memory usage doubled** + - Likely: Applying augmentation twice (in data loader and training) + - Solution: Remove duplicate augmentation pipeline + +7. **Model never converges to baseline** + - Likely: Augmentation too strong, label semantics lost + - Solution: Use weak augmentation first, increase gradually + +8. **Overfitting still severe despite augmentation** + - Likely: Augmentation too weak or wrong type + - Solution: Increase strength, try different augmentations, use regularization too + + +### Troubleshooting Checklist + +Before concluding augmentation doesn't help: + +- [ ] Validation transform pipeline has NO augmentations +- [ ] Training transform pipeline has only desired augmentations +- [ ] Random seed set for reproducibility +- [ ] Augmented images are visually recognizable (not noise) +- [ ] Augmentation applied consistently across epochs +- [ ] Baseline training tested (no augmentation) for comparison +- [ ] Accuracy impact measured on same hardware/compute +- [ ] Computational cost justified by accuracy improvement + + +## Part 10: Rationalization Table (What Users Say vs Reality) + +| User Statement | Reality | Evidence | Fix | +|----------------|---------|----------|-----| +| "Augmentation is overhead, skip it" | Augmentation prevents overfitting on small data | +5-10% accuracy on <5K examples | Enable augmentation, measure impact | +| "Use augmentation on validation too" | Validation measures true performance on original data | Metrics misleading if augmented | Remove augmentation from val transforms | +| "More augmentation always better" | Extreme augmentation creates label noise | Accuracy drops with too-aggressive transforms | Start conservative, increase gradually | +| "Same augmentation for all domains" | Each domain has different invariances | Text upside-down ≠ same class | Use domain-specific augmentations | +| "Augmentation takes too long" | ~10-20% training overhead, usually worth it | Depends on accuracy gain vs compute cost | Profile: measure accuracy/time tradeoff | +| "Flip works for everything" | Vertical flip changes anatomy/semantics | Medical imaging, some objects not symmetric | Know when flip is appropriate | +| "Random augmentation same as fixed" | Randomness prevents memorization, fixed is repetitive | Stochastic variation teaches invariance | Use random, not fixed transforms | +| "My data is too unique for standard augmentations" | Even unique data benefits from domain-appropriate augmentation | Overfitting still happens with small unique datasets | Adapt augmentations to your domain | +| "Augmentation is regularization" | Augmentation and regularization different; both help together | Dropout+BatchNorm+Augmentation > any single one | Use augmentation AND regularization | +| "TTA means augment validation" | TTA is optional post-training, not validation practice | TTA averaged over multiple forward passes | Use TTA only at final inference | + + +## Summary: Quick Reference + +| Domain | Light Augmentations | Medium Augmentations | Strong Augmentations | +|--------|-------------------|----------------------|----------------------| +| Vision | ±10° rotation, ±10% brightness, 0.5 H-flip | ±20° rotation, ±20% brightness, CutMix | ±45° rotation, ±30% jitter, strong perspective | +| NLP | Synonym replacement (1 word) | Back-translation, EDA | Multiple paraphrases, sentence reordering | +| Audio | Pitch ±2 semitones, noise SNR 30dB | Pitch ±4, noise SNR 20dB | Pitch ±8, noise SNR 10dB | +| Tabular | Feature noise 1%, SMOTE | Feature noise 5%, feature dropout | Feature noise 10%, heavy SMOTE | + + +## Critical Rules + +1. **Augment training data ONLY**. Validation and test data must be unaugmented. +2. **Start conservative, increase gradually**. Measure impact at each step. +3. **Domain matters**. No universal augmentation strategy exists. +4. **Preserve labels**. Do not apply transformations that change the class. +5. **Test incrementally**. Add one augmentation at a time, measure impact. +6. **Reproducibility**. Set random seeds for ablation studies. +7. **Avoid extremes**. If images/text unrecognizable, augmentation too strong. +8. **Know your domain**. Understand what invariances matter for your task. +9. **Measure impact**. Profile training time and accuracy improvement. +10. **Combine with regularization**. Augmentation works best with dropout, batch norm, weight decay. + diff --git a/skills/using-training-optimization/experiment-tracking.md b/skills/using-training-optimization/experiment-tracking.md new file mode 100644 index 0000000..244f7ac --- /dev/null +++ b/skills/using-training-optimization/experiment-tracking.md @@ -0,0 +1,1942 @@ + +# Experiment Tracking Skill + +## When to Use This Skill + +Use this skill when: +- User starts training a model and asks "should I track this experiment?" +- User wants to reproduce a previous result but doesn't remember settings +- Training runs overnight and user needs persistent logs +- User asks "which tool should I use: TensorBoard, W&B, or MLflow?" +- Multiple experiments running and user can't compare results +- User wants to share results with teammates or collaborators +- Model checkpoints accumulating with no organization or versioning +- User asks "what should I track?" or "how do I make experiments reproducible?" +- Debugging training issues and needs historical data (metrics, gradients) +- User wants to visualize training curves or compare hyperparameters +- Working on a research project that requires tracking many experiments +- User lost their best result and can't reproduce it + +Do NOT use when: +- User is doing quick prototyping with throwaway code (<5 minutes) +- Only running inference on pre-trained models (no training) +- Single experiment that's already tracked and working +- User is asking about hyperparameter tuning strategy (not tracking) +- Discussing model architecture design (not experiment management) + + +## Core Principles + +### 1. Track Before You Need It (Can't Add Retroactively) + +The BIGGEST mistake: waiting to track until results are worth saving. + +**The Reality**: +- The best result is ALWAYS the one you didn't track +- Can't add tracking after the experiment completes +- Human memory fails within hours (let alone days/weeks) +- Print statements disappear when terminal closes +- Code changes between experiments (git state matters) + +**When Tracking Matters**: +``` +Experiment value curve: + ^ + | ╱─ Peak result (untracked = lost forever) + | ╱ + | ╱ + | ╱ + | ╱ + | ╱ + | ╱ + | ╱ + |____╱________________________________> + Start Time + +If you wait to track "important" experiments, you've already lost them. +``` + +**Track From Day 1**: +- First experiment (even if "just testing") +- Every hyperparameter change +- Every model architecture variation +- Every data preprocessing change + +**Decision Rule**: If you're running `python train.py`, you should be tracking. No exceptions. + + +### 2. Complete Tracking = Hyperparameters + Metrics + Artifacts + Environment + +Reproducibility requires tracking EVERYTHING that affects the result. + +**The Five Categories**: + +``` +┌─────────────────────────────────────────────────────────┐ +│ 1. HYPERPARAMETERS (what you're tuning) │ +├─────────────────────────────────────────────────────────┤ +│ • Learning rate, batch size, optimizer type │ +│ • Model architecture (width, depth, activation) │ +│ • Regularization (weight decay, dropout) │ +│ • Training length (epochs, steps) │ +│ • Data augmentation settings │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ 2. METRICS (how you're doing) │ +├─────────────────────────────────────────────────────────┤ +│ • Training loss (every step or epoch) │ +│ • Validation loss (every epoch) │ +│ • Evaluation metrics (accuracy, F1, mAP, etc.) │ +│ • Learning rate schedule (actual LR each step) │ +│ • Gradient norms (for debugging) │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ 3. ARTIFACTS (what you're saving) │ +├─────────────────────────────────────────────────────────┤ +│ • Model checkpoints (with epoch/step metadata) │ +│ • Training plots (loss curves, confusion matrices) │ +│ • Predictions on validation set │ +│ • Logs (stdout, stderr) │ +│ • Config files (for reproducibility) │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ 4. CODE VERSION (what you're running) │ +├─────────────────────────────────────────────────────────┤ +│ • Git commit hash │ +│ • Git branch name │ +│ • Dirty status (uncommitted changes) │ +│ • Code diff (if uncommitted) │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ 5. ENVIRONMENT (where you're running) │ +├─────────────────────────────────────────────────────────┤ +│ • Python version, PyTorch version │ +│ • CUDA version, GPU type │ +│ • Random seeds (Python, NumPy, PyTorch, CUDA) │ +│ • Data version (if dataset changes) │ +│ • Hardware (CPU, RAM, GPU count) │ +└─────────────────────────────────────────────────────────┘ +``` + +**Reproducibility Test**: +> Can someone else (or future you) reproduce the result with ONLY the tracked information? + +If NO, you're not tracking enough. + + +### 3. Tool Selection: Local vs Team vs Production + +Different tools for different use cases. Choose based on your needs. + +**Tool Comparison**: + +| Feature | TensorBoard | Weights & Biases | MLflow | Custom | +|---------|-------------|------------------|--------|--------| +| **Setup Complexity** | Low | Low | Medium | High | +| **Local Only** | Yes | No (cloud) | Yes | Yes | +| **Team Collaboration** | Limited | Excellent | Good | Custom | +| **Cost** | Free | Free tier + paid | Free | Free | +| **Scalability** | Medium | High | High | Low | +| **Visualization** | Good | Excellent | Good | Custom | +| **Integration** | PyTorch, TF | Everything | Everything | Manual | +| **Best For** | Solo projects | Team research | Production | Specific needs | + +**Decision Tree**: +``` +Do you need team collaboration? +├─ YES → Need to share results with teammates? +│ ├─ YES → Weights & Biases (best team features) +│ └─ NO → MLflow (self-hosted, more control) +│ +└─ NO → Solo project? + ├─ YES → TensorBoard (simplest, local) + └─ NO → MLflow (scales to production) + +Budget constraints? +├─ FREE only → TensorBoard or MLflow +└─ Can pay → W&B (worth it for teams) + +Production deployment? +├─ YES → MLflow (production-ready) +└─ NO → TensorBoard or W&B (research) +``` + +**Recommendation**: +- **Starting out / learning**: TensorBoard (easiest, free, local) +- **Research team / collaboration**: Weights & Biases (best UX, sharing) +- **Production ML / enterprise**: MLflow (self-hosted, model registry) +- **Specific needs / customization**: Custom logging (CSV + Git) + + +### 4. Minimal Overhead, Maximum Value + +Tracking should cost 1-5% overhead, not 50%. + +**What to Track at Different Frequencies**: + +```python +# Every step (high frequency, small data): +log_every_step = { + "train_loss": loss.item(), + "learning_rate": optimizer.param_groups[0]['lr'], + "step": global_step, +} + +# Every epoch (medium frequency, medium data): +log_every_epoch = { + "train_loss_avg": train_losses.mean(), + "val_loss": val_loss, + "val_accuracy": val_acc, + "epoch": epoch, +} + +# Once per experiment (low frequency, large data): +log_once = { + "hyperparameters": config, + "git_commit": get_git_hash(), + "environment": { + "python_version": sys.version, + "torch_version": torch.__version__, + "cuda_version": torch.version.cuda, + }, +} + +# Only on improvement (conditional): +if val_loss < best_val_loss: + save_checkpoint(model, optimizer, epoch, val_loss) + log_artifact("best_model.pt") +``` + +**Overhead Guidelines**: +- Logging scalars (loss, accuracy): <0.1% overhead (always do) +- Logging images/plots: 1-2% overhead (do every epoch) +- Logging checkpoints: 5-10% overhead (do only on improvement) +- Logging gradients: 10-20% overhead (do only for debugging) + +**Don't Track**: +- Raw training data (too large, use data versioning instead) +- Every intermediate activation (use profiling tools instead) +- Full model weights every step (only on improvement) + + +### 5. Experiment Organization: Naming, Tagging, Grouping + +With 100+ experiments, organization is survival. + +**Naming Convention**: +```python +# GOOD: Descriptive, sortable, parseable +experiment_name = f"{model}_{dataset}_{timestamp}_{hyperparams}" +# Examples: +# "resnet18_cifar10_20241030_lr0.01_bs128" +# "bert_squad_20241030_lr3e-5_warmup1000" +# "gpt2_wikitext_20241030_ctx512_layers12" + +# BAD: Uninformative +experiment_name = "test" +experiment_name = "final" +experiment_name = "model_v2" +experiment_name = "test_again_actually_final" +``` + +**Tagging Strategy**: +```python +# Tags for filtering and grouping +tags = { + "model": "resnet18", + "dataset": "cifar10", + "experiment_type": "hyperparameter_search", + "status": "completed", + "goal": "beat_baseline", + "author": "john", +} + +# Can filter later: +# - Show me all "hyperparameter_search" experiments +# - Show me all "resnet18" on "cifar10" +# - Show me experiments by "john" +``` + +**Grouping Related Experiments**: +```python +# Group by goal/project +project = "cifar10_sota" +group = "learning_rate_search" +experiment_name = f"{project}/{group}/lr_{lr}" + +# Hierarchy: +# cifar10_sota/ +# ├─ learning_rate_search/ +# │ ├─ lr_0.001 +# │ ├─ lr_0.01 +# │ └─ lr_0.1 +# ├─ architecture_search/ +# │ ├─ resnet18 +# │ ├─ resnet34 +# │ └─ resnet50 +# └─ regularization_search/ +# ├─ dropout_0.1 +# ├─ dropout_0.3 +# └─ dropout_0.5 +``` + + +## Tool-Specific Integration + +### TensorBoard (Local, Simple) + +**Setup**: +```python +from torch.utils.tensorboard import SummaryWriter + +# Create writer +writer = SummaryWriter(f"runs/{experiment_name}") + +# Log hyperparameters +hparams = { + "learning_rate": 0.01, + "batch_size": 128, + "optimizer": "adam", +} +metrics = { + "best_val_acc": 0.0, +} +writer.add_hparams(hparams, metrics) +``` + +**During Training**: +```python +for epoch in range(num_epochs): + for batch_idx, (data, target) in enumerate(train_loader): + # Training step + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Log every N steps + global_step = epoch * len(train_loader) + batch_idx + if global_step % log_interval == 0: + writer.add_scalar("train/loss", loss.item(), global_step) + writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step) + + # Validation + val_loss, val_acc = evaluate(model, val_loader) + writer.add_scalar("val/loss", val_loss, epoch) + writer.add_scalar("val/accuracy", val_acc, epoch) + + # Log images (confusion matrix, etc.) + if epoch % 10 == 0: + fig = plot_confusion_matrix(model, val_loader) + writer.add_figure("val/confusion_matrix", fig, epoch) + +writer.close() +``` + +**View Results**: +```bash +tensorboard --logdir=runs +# Opens web UI at http://localhost:6006 +``` + +**Pros**: +- Simple setup (2 lines of code) +- Local (no cloud dependency) +- Good visualizations (scalars, images, graphs) +- Integrated with PyTorch + +**Cons**: +- No hyperparameter comparison table +- Limited team collaboration +- No artifact storage (checkpoints) +- Manual experiment management + + +### Weights & Biases (Team, Cloud) + +**Setup**: +```python +import wandb + +# Initialize experiment +wandb.init( + project="cifar10-sota", + name=experiment_name, + config={ + "learning_rate": 0.01, + "batch_size": 128, + "optimizer": "adam", + "model": "resnet18", + "dataset": "cifar10", + }, + tags=["hyperparameter_search", "resnet"], +) + +# Config is automatically tracked +config = wandb.config +``` + +**During Training**: +```python +for epoch in range(num_epochs): + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Log metrics + wandb.log({ + "train/loss": loss.item(), + "train/lr": optimizer.param_groups[0]['lr'], + "epoch": epoch, + }) + + # Validation + val_loss, val_acc = evaluate(model, val_loader) + wandb.log({ + "val/loss": val_loss, + "val/accuracy": val_acc, + "epoch": epoch, + }) + + # Save checkpoint + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save(model.state_dict(), "best_model.pt") + wandb.save("best_model.pt") # Upload to cloud + +# Log final results +wandb.log({"best_val_accuracy": best_val_acc}) +wandb.finish() +``` + +**Advanced Features**: +```python +# Log images +wandb.log({"examples": [wandb.Image(img, caption=f"Pred: {pred}") for img, pred in samples]}) + +# Log plots +fig = plot_confusion_matrix(model, val_loader) +wandb.log({"confusion_matrix": wandb.Image(fig)}) + +# Log tables (for result analysis) +table = wandb.Table(columns=["epoch", "train_loss", "val_loss", "val_acc"]) +for epoch, tl, vl, va in zip(epochs, train_losses, val_losses, val_accs): + table.add_data(epoch, tl, vl, va) +wandb.log({"results": table}) + +# Log model architecture +wandb.watch(model, log="all", log_freq=100) # Logs gradients + weights +``` + +**View Results**: +- Web UI: https://wandb.ai/your-username/cifar10-sota +- Compare experiments side-by-side +- Share links with teammates +- Filter by tags, hyperparameters + +**Pros**: +- Excellent team collaboration (share links) +- Beautiful visualizations +- Hyperparameter comparison (parallel coordinates) +- Artifact versioning (models, data) +- Integration with everything (PyTorch, TF, JAX) + +**Cons**: +- Cloud-based (requires internet) +- Free tier limits (100GB storage) +- Data leaves your machine (privacy concern) + + +### MLflow (Production, Self-Hosted) + +**Setup**: +```python +import mlflow +import mlflow.pytorch + +# Start experiment +mlflow.set_experiment("cifar10-sota") + +# Start run +with mlflow.start_run(run_name=experiment_name): + # Log hyperparameters + mlflow.log_param("learning_rate", 0.01) + mlflow.log_param("batch_size", 128) + mlflow.log_param("optimizer", "adam") + mlflow.log_param("model", "resnet18") + + # Training loop + for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer) + val_loss, val_acc = evaluate(model, val_loader) + + # Log metrics + mlflow.log_metric("train_loss", train_loss, step=epoch) + mlflow.log_metric("val_loss", val_loss, step=epoch) + mlflow.log_metric("val_accuracy", val_acc, step=epoch) + + # Log final metrics + mlflow.log_metric("best_val_accuracy", best_val_acc) + + # Log model + mlflow.pytorch.log_model(model, "model") + + # Log artifacts + mlflow.log_artifact("config.yaml") + mlflow.log_artifact("best_model.pt") +``` + +**View Results**: +```bash +mlflow ui +# Opens web UI at http://localhost:5000 +``` + +**Model Registry** (for production): +```python +# Register model +model_uri = f"runs:/{run_id}/model" +mlflow.register_model(model_uri, "cifar10-resnet18") + +# Load registered model +model = mlflow.pytorch.load_model("models:/cifar10-resnet18/production") +``` + +**Pros**: +- Self-hosted (full control, privacy) +- Model registry (production deployment) +- Scales to large teams +- Integration with deployment tools + +**Cons**: +- More complex setup (need server) +- Visualization not as good as W&B +- Less intuitive UI + + +## Reproducibility Patterns + +### 1. Seed Everything + +```python +import random +import numpy as np +import torch + +def set_seed(seed): + """Set all random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Deterministic operations (slower but reproducible) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +# At start of training +set_seed(42) + +# Log seed +config = {"seed": 42} +``` + +**Warning**: Deterministic mode can be 10-20% slower. Trade-off between speed and reproducibility. + + +### 2. Capture Git State + +```python +import subprocess + +def get_git_info(): + """Capture current git state.""" + try: + commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() + + # Check for uncommitted changes + status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip() + is_dirty = len(status) > 0 + + # Get diff if dirty + diff = None + if is_dirty: + diff = subprocess.check_output(['git', 'diff']).decode('ascii') + + return { + "commit": commit, + "branch": branch, + "is_dirty": is_dirty, + "diff": diff, + } + except Exception as e: + return {"error": str(e)} + +# Log git info +git_info = get_git_info() +if git_info.get("is_dirty"): + print("WARNING: Uncommitted changes detected!") + print("Experiment may not be reproducible without the diff.") +``` + + +### 3. Environment Capture + +```python +import sys +import torch + +def get_environment_info(): + """Capture environment details.""" + return { + "python_version": sys.version, + "torch_version": torch.__version__, + "cuda_version": torch.version.cuda, + "cudnn_version": torch.backends.cudnn.version(), + "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, + "gpu_count": torch.cuda.device_count(), + } + +# Save requirements.txt +# pip freeze > requirements.txt + +# Or use pip-tools +# pip-compile requirements.in +``` + + +### 4. Config Files for Reproducibility + +```python +# config.yaml +model: + name: resnet18 + num_classes: 10 + +training: + learning_rate: 0.01 + batch_size: 128 + num_epochs: 100 + optimizer: adam + weight_decay: 0.0001 + +data: + dataset: cifar10 + augmentation: true + normalize: true + +# Load config +import yaml +with open("config.yaml") as f: + config = yaml.safe_load(f) + +# Save config alongside results +import shutil +shutil.copy("config.yaml", f"results/{experiment_name}/config.yaml") +``` + + +## Experiment Comparison + +### 1. Comparing Metrics + +```python +# TensorBoard: Compare multiple runs +# tensorboard --logdir=runs --port=6006 +# Select multiple runs in UI + +# W&B: Filter and compare +# Go to project page, select runs, click "Compare" + +# MLflow: Query experiments +import mlflow + +# Get all runs from an experiment +experiment = mlflow.get_experiment_by_name("cifar10-sota") +runs = mlflow.search_runs(experiment_ids=[experiment.experiment_id]) + +# Filter by metric +best_runs = runs[runs["metrics.val_accuracy"] > 0.85] + +# Sort by metric +best_runs = runs.sort_values("metrics.val_accuracy", ascending=False) + +# Analyze hyperparameter impact +import pandas as pd +import seaborn as sns + +# Plot learning rate vs accuracy +sns.scatterplot(data=runs, x="params.learning_rate", y="metrics.val_accuracy") +``` + + +### 2. Hyperparameter Analysis + +```python +# W&B: Parallel coordinates plot +# Shows which hyperparameter combinations lead to best results +# UI: Click "Parallel Coordinates" in project view + +# MLflow: Custom analysis +import matplotlib.pyplot as plt + +# Group by hyperparameter +for lr in [0.001, 0.01, 0.1]: + lr_runs = runs[runs["params.learning_rate"] == str(lr)] + accuracies = lr_runs["metrics.val_accuracy"] + plt.scatter([lr] * len(accuracies), accuracies, alpha=0.5, label=f"LR={lr}") + +plt.xlabel("Learning Rate") +plt.ylabel("Validation Accuracy") +plt.xscale("log") +plt.legend() +plt.title("Learning Rate vs Accuracy") +plt.show() +``` + + +### 3. Comparing Artifacts + +```python +# Compare model checkpoints +from torchvision.models import resnet18 + +# Load two models +model_a = resnet18() +model_a.load_state_dict(torch.load("experiments/exp_a/best_model.pt")) + +model_b = resnet18() +model_b.load_state_dict(torch.load("experiments/exp_b/best_model.pt")) + +# Compare on validation set +acc_a = evaluate(model_a, val_loader) +acc_b = evaluate(model_b, val_loader) + +print(f"Model A: {acc_a:.2%}") +print(f"Model B: {acc_b:.2%}") + +# Compare predictions +preds_a = model_a(val_data) +preds_b = model_b(val_data) +agreement = (preds_a.argmax(1) == preds_b.argmax(1)).float().mean() +print(f"Prediction agreement: {agreement:.2%}") +``` + + +## Collaboration Workflows + +### 1. Sharing Results (W&B) + +```python +# Share experiment link +# https://wandb.ai/your-username/cifar10-sota/runs/run-id + +# Create report +# W&B UI: Click "Create Report" → Add charts, text, code + +# Export results +# W&B UI: Click "Export" → CSV, JSON, or API + +# API access for programmatic sharing +import wandb +api = wandb.Api() +runs = api.runs("your-username/cifar10-sota") + +for run in runs: + print(f"{run.name}: {run.summary['val_accuracy']}") +``` + + +### 2. Team Experiment Dashboard + +```python +# MLflow: Shared tracking server +# Server machine: +mlflow server --host 0.0.0.0 --port 5000 + +# Team members: +import mlflow +mlflow.set_tracking_uri("http://shared-server:5000") + +# Everyone logs to same server +with mlflow.start_run(): + mlflow.log_metric("val_accuracy", 0.87) +``` + + +### 3. Experiment Handoff + +```python +# Package experiment for reproducibility +experiment_package = { + "code": "git_commit_hash", + "config": "config.yaml", + "model": "best_model.pt", + "results": "results.csv", + "logs": "training.log", + "environment": "requirements.txt", +} + +# Create reproducibility script +# reproduce.sh +""" +#!/bin/bash +git checkout +pip install -r requirements.txt +python train.py --config config.yaml +""" +``` + + +## Complete Tracking Example + +Here's a production-ready tracking setup: + +```python +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +import wandb +import yaml +import subprocess +from pathlib import Path +from datetime import datetime + +class ExperimentTracker: + """Complete experiment tracking wrapper.""" + + def __init__(self, config, experiment_name=None, use_wandb=True, use_tensorboard=True): + self.config = config + self.use_wandb = use_wandb + self.use_tensorboard = use_tensorboard + + # Generate experiment name + if experiment_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + experiment_name = f"{config['model']}_{config['dataset']}_{timestamp}" + self.experiment_name = experiment_name + + # Create experiment directory + self.exp_dir = Path(f"experiments/{experiment_name}") + self.exp_dir.mkdir(parents=True, exist_ok=True) + + # Initialize tracking tools + if self.use_tensorboard: + self.tb_writer = SummaryWriter(self.exp_dir / "tensorboard") + + if self.use_wandb: + wandb.init( + project=config.get("project", "default"), + name=experiment_name, + config=config, + dir=self.exp_dir, + ) + + # Save config + with open(self.exp_dir / "config.yaml", "w") as f: + yaml.dump(config, f) + + # Capture environment + self._log_environment() + + # Capture git state + self._log_git_state() + + # Setup logging + self._setup_logging() + + self.global_step = 0 + self.best_metric = float('-inf') + + def _log_environment(self): + """Log environment information.""" + import sys + env_info = { + "python_version": sys.version, + "torch_version": torch.__version__, + "cuda_version": torch.version.cuda if torch.cuda.is_available() else None, + "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, + "gpu_count": torch.cuda.device_count(), + } + + # Save to file + with open(self.exp_dir / "environment.yaml", "w") as f: + yaml.dump(env_info, f) + + # Log to W&B + if self.use_wandb: + wandb.config.update({"environment": env_info}) + + def _log_git_state(self): + """Log git commit and status.""" + try: + commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() + status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip() + is_dirty = len(status) > 0 + + git_info = { + "commit": commit, + "branch": branch, + "is_dirty": is_dirty, + } + + # Save to file + with open(self.exp_dir / "git_info.yaml", "w") as f: + yaml.dump(git_info, f) + + # Save diff if dirty + if is_dirty: + diff = subprocess.check_output(['git', 'diff']).decode('ascii') + with open(self.exp_dir / "git_diff.patch", "w") as f: + f.write(diff) + print("WARNING: Uncommitted changes detected! Saved to git_diff.patch") + + # Log to W&B + if self.use_wandb: + wandb.config.update({"git": git_info}) + + except Exception as e: + print(f"Failed to capture git state: {e}") + + def _setup_logging(self): + """Setup file logging.""" + import logging + self.logger = logging.getLogger(self.experiment_name) + self.logger.setLevel(logging.INFO) + + # File handler + fh = logging.FileHandler(self.exp_dir / "training.log") + fh.setLevel(logging.INFO) + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + + # Formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + fh.setFormatter(formatter) + ch.setFormatter(formatter) + + self.logger.addHandler(fh) + self.logger.addHandler(ch) + + def log_metrics(self, metrics, step=None): + """Log metrics to all tracking backends.""" + if step is None: + step = self.global_step + + # TensorBoard + if self.use_tensorboard: + for key, value in metrics.items(): + if isinstance(value, (int, float)): + self.tb_writer.add_scalar(key, value, step) + + # W&B + if self.use_wandb: + wandb.log(metrics, step=step) + + # File + self.logger.info(f"Step {step}: {metrics}") + + self.global_step = step + 1 + + def save_checkpoint(self, model, optimizer, epoch, metric_value, metric_name="val_accuracy"): + """Save model checkpoint with metadata.""" + checkpoint = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + metric_name: metric_value, + "config": self.config, + } + + # Save latest checkpoint + checkpoint_path = self.exp_dir / "checkpoints" / f"checkpoint_epoch_{epoch}.pt" + checkpoint_path.parent.mkdir(exist_ok=True) + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if metric_value > self.best_metric: + self.best_metric = metric_value + best_path = self.exp_dir / "checkpoints" / "best_model.pt" + torch.save(checkpoint, best_path) + + self.logger.info(f"New best model saved: {metric_name}={metric_value:.4f}") + + # Log to W&B + if self.use_wandb: + wandb.log({f"best_{metric_name}": metric_value}) + wandb.save(str(best_path)) + + return checkpoint_path + + def log_figure(self, name, figure, step=None): + """Log matplotlib figure.""" + if step is None: + step = self.global_step + + # TensorBoard + if self.use_tensorboard: + self.tb_writer.add_figure(name, figure, step) + + # W&B + if self.use_wandb: + wandb.log({name: wandb.Image(figure)}, step=step) + + # Save to disk + fig_path = self.exp_dir / "figures" / f"{name}_step_{step}.png" + fig_path.parent.mkdir(exist_ok=True) + figure.savefig(fig_path) + + def finish(self): + """Clean up and close tracking backends.""" + if self.use_tensorboard: + self.tb_writer.close() + + if self.use_wandb: + wandb.finish() + + self.logger.info("Experiment tracking finished.") + + +# Usage example +if __name__ == "__main__": + config = { + "project": "cifar10-sota", + "model": "resnet18", + "dataset": "cifar10", + "learning_rate": 0.01, + "batch_size": 128, + "num_epochs": 100, + "optimizer": "adam", + "seed": 42, + } + + # Initialize tracker + tracker = ExperimentTracker(config) + + # Training loop + model = create_model(config) + optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"]) + + for epoch in range(config["num_epochs"]): + train_loss = train_epoch(model, train_loader, optimizer) + val_loss, val_acc = evaluate(model, val_loader) + + # Log metrics + tracker.log_metrics({ + "train/loss": train_loss, + "val/loss": val_loss, + "val/accuracy": val_acc, + "epoch": epoch, + }) + + # Save checkpoint + tracker.save_checkpoint(model, optimizer, epoch, val_acc) + + # Log figure (every 10 epochs) + if epoch % 10 == 0: + fig = plot_confusion_matrix(model, val_loader) + tracker.log_figure("confusion_matrix", fig) + + # Finish + tracker.finish() +``` + + +## Pitfalls and Anti-Patterns + +### Pitfall 1: Tracking Metrics But Not Config + +**Symptom**: Have CSV with 50 experiments' metrics, but no idea what hyperparameters produced them. + +**Why It Happens**: +- User focuses on "what matters" (the metric) +- Assumes they'll remember settings +- Doesn't realize metrics without context are useless + +**Fix**: +```python +# WRONG: Only metrics +with open("results.csv", "a") as f: + f.write(f"{epoch},{train_loss},{val_loss}\n") + +# RIGHT: Metrics + config +experiment_id = f"exp_{timestamp}" +with open(f"{experiment_id}_config.yaml", "w") as f: + yaml.dump(config, f) +with open(f"{experiment_id}_results.csv", "w") as f: + f.write(f"{epoch},{train_loss},{val_loss}\n") +``` + + +### Pitfall 2: Overwriting Checkpoints Without Versioning + +**Symptom**: Always saving to "best_model.pt", can't recover earlier checkpoints. + +**Why It Happens**: +- Disk space concerns (misguided) +- Only care about "best" model +- Don't anticipate evaluation bugs + +**Fix**: +```python +# WRONG: Overwriting +torch.save(model.state_dict(), "best_model.pt") + +# RIGHT: Versioned checkpoints +torch.save(model.state_dict(), f"checkpoints/model_epoch_{epoch}.pt") +torch.save(model.state_dict(), f"checkpoints/best_model_val_acc_{val_acc:.4f}.pt") +``` + + +### Pitfall 3: Using Print Instead of Logging + +**Symptom**: Training crashes, all print output lost, can't debug. + +**Why It Happens**: +- Print is simpler than logging +- Works for short scripts +- Doesn't anticipate crashes + +**Fix**: +```python +# WRONG: Print statements +print(f"Epoch {epoch}: loss={loss}") + +# RIGHT: Proper logging +import logging +logging.basicConfig(filename="training.log", level=logging.INFO) +logging.info(f"Epoch {epoch}: loss={loss}") +``` + + +### Pitfall 4: No Git Tracking for Code Changes + +**Symptom**: Can't reproduce result because code changed between experiments. + +**Why It Happens**: +- Rapid iteration (uncommitted changes) +- "I'll commit later" +- Don't realize code version matters + +**Fix**: +```python +# Log git commit at start of training +git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() +config["git_commit"] = git_hash + +# Better: Require clean git state +status = subprocess.check_output(['git', 'status', '--porcelain']).decode('ascii').strip() +if status: + print("ERROR: Uncommitted changes detected!") + print("Commit your changes before running experiments.") + sys.exit(1) +``` + + +### Pitfall 5: Not Tracking Random Seeds + +**Symptom**: Same code, same hyperparameters, different results every time. + +**Why It Happens**: +- Forget to set seed +- Set seed in one place but not others (PyTorch, NumPy, CUDA) +- Don't log seed value + +**Fix**: +```python +# Set all seeds +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +# Use seed from config +set_seed(config["seed"]) + +# Log seed +tracker.log_metrics({"seed": config["seed"]}) +``` + + +### Pitfall 6: Tracking Too Much Data (Storage Bloat) + +**Symptom**: 100GB of logs for 50 experiments, can't store more. + +**Why It Happens**: +- Logging every step (not just epoch) +- Saving all checkpoints (not just best) +- Logging high-resolution images + +**Fix**: +```python +# Log at appropriate frequency +if global_step % 100 == 0: # Every 100 steps, not every step + tracker.log_metrics({"train/loss": loss}) + +# Save only best checkpoints +if val_acc > best_val_acc: # Only when improving + tracker.save_checkpoint(model, optimizer, epoch, val_acc) + +# Downsample images +img_low_res = F.interpolate(img, size=(64, 64)) # Don't log 224x224 +``` + + +### Pitfall 7: No Experiment Naming Convention + +**Symptom**: experiments/test, experiments/test2, experiments/final, experiments/final_final + +**Why It Happens**: +- No planning for multiple experiments +- Naming feels unimportant +- "I'll organize later" + +**Fix**: +```python +# Good naming convention +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +experiment_name = f"{config['model']}_{config['dataset']}_{timestamp}_lr{config['lr']}" +# Example: "resnet18_cifar10_20241030_120000_lr0.01" +``` + + +### Pitfall 8: Not Tracking Evaluation Metrics + +**Symptom**: Saved best model by training loss, but validation loss was actually increasing (overfitting). + +**Why It Happens**: +- Only tracking training metrics +- Assuming training loss = model quality +- Not validating frequently enough + +**Fix**: +```python +# Track both training and validation +tracker.log_metrics({ + "train/loss": train_loss, + "val/loss": val_loss, # Don't forget validation! + "val/accuracy": val_acc, +}) + +# Save best model by validation metric, not training +if val_acc > best_val_acc: + tracker.save_checkpoint(model, optimizer, epoch, val_acc) +``` + + +### Pitfall 9: Local-Only Tracking for Team Projects + +**Symptom**: Team members can't see each other's experiments, duplicate work. + +**Why It Happens**: +- TensorBoard is local by default +- Don't realize collaboration tools exist +- Privacy concerns (unfounded) + +**Fix**: +```python +# Use team-friendly tool +wandb.init(project="team-project") # Everyone can see + +# Or: Share TensorBoard logs +# scp -r runs/ shared-server:/path/ +# tensorboard --logdir=/path/runs --host=0.0.0.0 +``` + + +### Pitfall 10: No Tracking Until "Important" Experiment + +**Symptom**: First 20 experiments untracked, realize they had valuable insights. + +**Why It Happens**: +- "Just testing" mentality +- Tracking feels like overhead +- Don't realize importance until later + +**Fix**: +```python +# Track from experiment 1 +# Even if "just testing", it takes 30 seconds to set up tracking +tracker = ExperimentTracker(config) + +# Future you will thank past you +``` + + +## Rationalization vs Reality Table + +| User Rationalization | Reality | Recommendation | +|----------------------|---------|----------------| +| "I'll remember what I tried" | You won't (memory fails in hours) | Track from day 1, always | +| "Print statements are enough" | Lost on crash or terminal close | Use proper logging to file | +| "Only track final metrics" | Can't debug without intermediate data | Track every epoch minimum | +| "Just save best model" | Need checkpoints for analysis | Version all important checkpoints | +| "Tracking adds too much overhead" | <1% overhead for scalars | Log metrics, not raw data | +| "I only need the model file" | Need hyperparameters to understand it | Save config + model + metrics | +| "TensorBoard is too complex" | 2 lines of code to set up | Start simple, expand later | +| "I'll organize experiments later" | Never happens, chaos ensues | Use naming convention from start | +| "Git commits slow me down" | Uncommitted code = irreproducible | Commit before experiments | +| "Cloud tracking costs money" | Free tiers are generous | W&B free: 100GB, unlimited experiments | +| "I don't need reproducibility" | Your future self will | Track environment + seed + git | +| "Tracking is for production, not research" | Research needs it more (exploration) | Research = more experiments = more tracking | + + +## Red Flags (Likely to Fail) + +1. **"I'll track it later"** + - Reality: Later = never; best results are always untracked + - Action: Track from experiment 1 + +2. **"Just using print statements"** + - Reality: Lost on crash/close; can't analyze later + - Action: Use logging framework or tracking tool + +3. **"Only tracking the final metric"** + - Reality: Can't debug convergence issues; no training curves + - Action: Track every epoch at minimum + +4. **"Saving to best_model.pt (overwriting)"** + - Reality: Can't recover earlier checkpoints; evaluation bugs = disaster + - Action: Version checkpoints with epoch/metric + +5. **"Don't need to track hyperparameters"** + - Reality: Metrics without config are meaningless + - Action: Log config alongside metrics + +6. **"Not tracking git commit"** + - Reality: Code changes = irreproducible + - Action: Log git hash, check for uncommitted changes + +7. **"Random seed doesn't matter"** + - Reality: Can cause 5%+ variance in results + - Action: Set and log all seeds + +8. **"TensorBoard/W&B is overkill for me"** + - Reality: Setup takes 2 minutes, saves hours later + - Action: Use simplest tool (TensorBoard), expand if needed + +9. **"I'm just testing, don't need tracking"** + - Reality: Best results come from "tests" + - Action: Track everything, including tests + +10. **"Team doesn't need to see my experiments"** + - Reality: Collaboration requires transparency + - Action: Use shared tracking (W&B, MLflow server) + + +## When This Skill Applies + +**Strong Signals** (definitely use): +- Starting a new ML project (even "quick prototype") +- User asks "should I track this?" +- User lost their best result and can't reproduce +- Multiple experiments running (need comparison) +- Team collaboration (need to share results) +- User asks about TensorBoard, W&B, or MLflow +- Training crashes and user needs debugging data + +**Weak Signals** (maybe use): +- User has tracking but it's incomplete +- Asking about reproducibility +- Discussing hyperparameter tuning (needs tracking) +- Long-running training (overnight, multi-day) + +**Not Applicable**: +- Pure inference (no training) +- Single experiment already tracked +- Discussing model architecture only +- Data preprocessing questions (pre-training) + + +## Success Criteria + +You've successfully applied this skill when: + +1. **Complete Tracking**: Hyperparameters + metrics + artifacts + git + environment all logged +2. **Reproducibility**: Someone else (or future you) can reproduce the result from tracked info +3. **Tool Choice**: Selected appropriate tool (TensorBoard, W&B, MLflow) for use case +4. **Organization**: Experiments have clear naming, tagging, grouping +5. **Comparison**: Can compare experiments side-by-side, analyze hyperparameter impact +6. **Collaboration**: Team can see and discuss results (if team project) +7. **Minimal Overhead**: Tracking adds <5% runtime overhead +8. **Persistence**: Logs survive crashes, terminal closes, reboots +9. **Historical Analysis**: Can go back to any experiment and understand what was done +10. **Best Practices**: Git commits before experiments, seeds set, evaluation bugs impossible + +**Final Test**: Can you reproduce the best result from 6 months ago using only the tracked information? + +If YES: Excellent tracking! If NO: Gaps remain. + + +## Advanced Tracking Patterns + +### 1. Multi-Run Experiments (Hyperparameter Sweeps) + +When running many experiments systematically: + +```python +# W&B Sweeps +sweep_config = { + "method": "random", + "metric": {"name": "val_accuracy", "goal": "maximize"}, + "parameters": { + "learning_rate": {"values": [0.001, 0.01, 0.1]}, + "batch_size": {"values": [32, 64, 128]}, + "optimizer": {"values": ["adam", "sgd"]}, + }, +} + +sweep_id = wandb.sweep(sweep_config, project="cifar10-sweep") + +def train(): + run = wandb.init() + config = wandb.config + + model = create_model(config) + # ... training code ... + wandb.log({"val_accuracy": val_acc}) + +wandb.agent(sweep_id, train, count=10) + +# MLflow with Optuna +import optuna +import mlflow + +def objective(trial): + with mlflow.start_run(nested=True): + lr = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1) + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) + + mlflow.log_params({"learning_rate": lr, "batch_size": batch_size}) + + val_acc = train_and_evaluate(lr, batch_size) + mlflow.log_metric("val_accuracy", val_acc) + + return val_acc + +with mlflow.start_run(): + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=20) + + mlflow.log_param("best_params", study.best_params) + mlflow.log_metric("best_accuracy", study.best_value) +``` + + +### 2. Distributed Training Tracking + +When training on multiple GPUs or machines: + +```python +import torch.distributed as dist + +def setup_distributed_tracking(rank, world_size): + """Setup tracking for distributed training.""" + + # Only rank 0 logs to avoid duplicates + if rank == 0: + tracker = ExperimentTracker(config) + else: + tracker = None + + return tracker + +def train_distributed(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + tracker = setup_distributed_tracking(rank, world_size) + + model = DistributedDataParallel(model, device_ids=[rank]) + + for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer) + + # Gather metrics from all ranks + train_loss_tensor = torch.tensor(train_loss).cuda() + dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM) + avg_train_loss = train_loss_tensor.item() / world_size + + # Only rank 0 logs + if rank == 0 and tracker: + tracker.log_metrics({ + "train/loss": avg_train_loss, + "epoch": epoch, + }) + + if rank == 0 and tracker: + tracker.finish() + + dist.destroy_process_group() +``` + + +### 3. Experiment Resumption + +Tracking setup for resumable experiments: + +```python +class ResumableExperimentTracker(ExperimentTracker): + """Experiment tracker with resume support.""" + + def __init__(self, config, checkpoint_path=None): + super().__init__(config) + + self.checkpoint_path = checkpoint_path + + if checkpoint_path and os.path.exists(checkpoint_path): + self.resume_from_checkpoint() + + def resume_from_checkpoint(self): + """Resume tracking from saved checkpoint.""" + checkpoint = torch.load(self.checkpoint_path) + + self.global_step = checkpoint.get("global_step", 0) + self.best_metric = checkpoint.get("best_metric", float('-inf')) + + self.logger.info(f"Resumed from checkpoint: step={self.global_step}") + + def save_checkpoint(self, model, optimizer, epoch, metric_value, metric_name="val_accuracy"): + """Save checkpoint with tracker state.""" + checkpoint = { + "epoch": epoch, + "global_step": self.global_step, + "best_metric": self.best_metric, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + metric_name: metric_value, + "config": self.config, + } + + checkpoint_path = self.exp_dir / "checkpoints" / "latest.pt" + checkpoint_path.parent.mkdir(exist_ok=True) + torch.save(checkpoint, checkpoint_path) + + # Also save best + if metric_value > self.best_metric: + self.best_metric = metric_value + best_path = self.exp_dir / "checkpoints" / "best.pt" + torch.save(checkpoint, best_path) + + return checkpoint_path + +# Usage +tracker = ResumableExperimentTracker(config, checkpoint_path="checkpoints/latest.pt") + +# Training continues from where it left off +for epoch in range(start_epoch, num_epochs): + # ... training ... + tracker.save_checkpoint(model, optimizer, epoch, val_acc) +``` + + +### 4. Experiment Comparison and Analysis + +Programmatic experiment analysis: + +```python +def analyze_experiments(project_name): + """Analyze all experiments in a project.""" + + # W&B + import wandb + api = wandb.Api() + runs = api.runs(project_name) + + # Extract data + data = [] + for run in runs: + data.append({ + "name": run.name, + "learning_rate": run.config.get("learning_rate"), + "batch_size": run.config.get("batch_size"), + "val_accuracy": run.summary.get("val_accuracy"), + "train_time": run.summary.get("_runtime"), + }) + + df = pd.DataFrame(data) + + # Analysis + print("Top 5 experiments by accuracy:") + print(df.nlargest(5, "val_accuracy")) + + # Hyperparameter impact + print("\nAverage accuracy by learning rate:") + print(df.groupby("learning_rate")["val_accuracy"].mean()) + + # Visualization + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 2, figsize=(12, 4)) + + # Learning rate vs accuracy + axes[0].scatter(df["learning_rate"], df["val_accuracy"]) + axes[0].set_xlabel("Learning Rate") + axes[0].set_ylabel("Validation Accuracy") + axes[0].set_xscale("log") + + # Batch size vs accuracy + axes[1].scatter(df["batch_size"], df["val_accuracy"]) + axes[1].set_xlabel("Batch Size") + axes[1].set_ylabel("Validation Accuracy") + + plt.tight_layout() + plt.savefig("experiment_analysis.png") + + return df + +# Run analysis +df = analyze_experiments("team/cifar10-sota") +``` + + +### 5. Data Versioning Integration + +Tracking data versions alongside experiments: + +```python +import hashlib + +def hash_dataset(dataset_path): + """Compute hash of dataset for versioning.""" + hasher = hashlib.sha256() + + # Hash dataset files + for file in sorted(Path(dataset_path).rglob("*")): + if file.is_file(): + with open(file, "rb") as f: + hasher.update(f.read()) + + return hasher.hexdigest() + +# Track data version +data_version = hash_dataset("data/cifar10") +config["data_version"] = data_version + +tracker = ExperimentTracker(config) + +# Or use DVC +""" +# Initialize DVC +dvc init + +# Track data +dvc add data/cifar10 +git add data/cifar10.dvc + +# Log DVC hash in experiment +with open("data/cifar10.dvc") as f: + dvc_config = yaml.safe_load(f) + data_hash = dvc_config["outs"][0]["md5"] + config["data_hash"] = data_hash +""" +``` + + +### 6. Artifact Management Best Practices + +Organizing and managing experiment artifacts: + +```python +class ArtifactManager: + """Manages experiment artifacts (models, plots, logs).""" + + def __init__(self, experiment_dir): + self.exp_dir = Path(experiment_dir) + + # Create subdirectories + self.checkpoints_dir = self.exp_dir / "checkpoints" + self.figures_dir = self.exp_dir / "figures" + self.logs_dir = self.exp_dir / "logs" + + for d in [self.checkpoints_dir, self.figures_dir, self.logs_dir]: + d.mkdir(parents=True, exist_ok=True) + + def save_checkpoint(self, checkpoint, name): + """Save checkpoint with automatic cleanup.""" + path = self.checkpoints_dir / f"{name}.pt" + torch.save(checkpoint, path) + + # Keep only last N checkpoints (except best) + self._cleanup_checkpoints(keep_n=5) + + return path + + def _cleanup_checkpoints(self, keep_n=5): + """Keep only recent checkpoints to save space.""" + checkpoints = sorted( + self.checkpoints_dir.glob("checkpoint_epoch_*.pt"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + + # Delete old checkpoints (keep best + last N) + for ckpt in checkpoints[keep_n:]: + if "best" not in ckpt.name: + ckpt.unlink() + + def save_figure(self, fig, name, step=None): + """Save matplotlib figure with metadata.""" + if step is not None: + filename = f"{name}_step_{step}.png" + else: + filename = f"{name}.png" + + path = self.figures_dir / filename + fig.savefig(path, dpi=150, bbox_inches="tight") + + return path + + def get_artifact_summary(self): + """Get summary of stored artifacts.""" + summary = { + "num_checkpoints": len(list(self.checkpoints_dir.glob("*.pt"))), + "num_figures": len(list(self.figures_dir.glob("*.png"))), + "total_size_mb": sum( + f.stat().st_size for f in self.exp_dir.rglob("*") if f.is_file() + ) / (1024 * 1024), + } + return summary + +# Usage +artifacts = ArtifactManager(experiment_dir) +artifacts.save_checkpoint(checkpoint, f"checkpoint_epoch_{epoch}") +artifacts.save_figure(fig, "training_curve") +print(artifacts.get_artifact_summary()) +``` + + +### 7. Real-Time Monitoring and Alerts + +Setup alerts for experiment issues: + +```python +# W&B Alerts +import wandb + +wandb.init(project="cifar10") + +for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer) + + wandb.log({"train/loss": train_loss, "epoch": epoch}) + + # Alert on divergence + if math.isnan(train_loss) or train_loss > 10: + wandb.alert( + title="Training Diverged", + text=f"Loss is {train_loss} at epoch {epoch}", + level=wandb.AlertLevel.ERROR, + ) + break + + # Alert on milestone + if val_acc > 0.90: + wandb.alert( + title="90% Accuracy Reached!", + text=f"Validation accuracy: {val_acc:.2%}", + level=wandb.AlertLevel.INFO, + ) + +# Slack integration +def send_slack_alert(message, webhook_url): + """Send alert to Slack.""" + import requests + requests.post(webhook_url, json={"text": message}) + +# Email alerts +def send_email_alert(subject, body, to_email): + """Send email alert.""" + import smtplib + from email.message import EmailMessage + + msg = EmailMessage() + msg["Subject"] = subject + msg["To"] = to_email + msg.set_content(body) + + # Send via SMTP + with smtplib.SMTP("localhost") as s: + s.send_message(msg) +``` + + +## Common Integration Patterns + +### Pattern 1: Training Script with Complete Tracking + +```python +#!/usr/bin/env python3 +""" +Complete training script with experiment tracking. +""" + +import argparse +import yaml +from pathlib import Path +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from experiment_tracker import ExperimentTracker +from models import create_model +from data import load_dataset + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Config file") + parser.add_argument("--resume", type=str, help="Resume from checkpoint") + parser.add_argument("--name", type=str, help="Experiment name") + return parser.parse_args() + +def main(): + args = parse_args() + + # Load config + with open(args.config) as f: + config = yaml.safe_load(f) + + # Initialize tracking + tracker = ExperimentTracker( + config=config, + experiment_name=args.name, + use_wandb=True, + use_tensorboard=True, + ) + + # Setup training + model = create_model(config) + optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"]) + criterion = nn.CrossEntropyLoss() + + train_loader, val_loader = load_dataset(config) + + # Resume if checkpoint provided + start_epoch = 0 + if args.resume: + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + start_epoch = checkpoint["epoch"] + 1 + tracker.logger.info(f"Resumed from epoch {start_epoch}") + + # Training loop + best_val_acc = 0.0 + for epoch in range(start_epoch, config["num_epochs"]): + # Train + model.train() + train_losses = [] + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + train_losses.append(loss.item()) + + # Log every N batches + if batch_idx % config.get("log_interval", 100) == 0: + tracker.log_metrics({ + "train/loss": loss.item(), + "train/lr": optimizer.param_groups[0]['lr'], + }, step=epoch * len(train_loader) + batch_idx) + + # Validate + model.eval() + val_losses = [] + correct = 0 + total = 0 + with torch.no_grad(): + for data, target in val_loader: + output = model(data) + loss = criterion(output, target) + val_losses.append(loss.item()) + + pred = output.argmax(dim=1) + correct += (pred == target).sum().item() + total += target.size(0) + + train_loss = sum(train_losses) / len(train_losses) + val_loss = sum(val_losses) / len(val_losses) + val_acc = correct / total + + # Log epoch metrics + tracker.log_metrics({ + "train/loss_epoch": train_loss, + "val/loss": val_loss, + "val/accuracy": val_acc, + "epoch": epoch, + }) + + # Save checkpoint + tracker.save_checkpoint(model, optimizer, epoch, val_acc) + + # Update best + if val_acc > best_val_acc: + best_val_acc = val_acc + tracker.logger.info(f"New best accuracy: {val_acc:.4f}") + + # Early stopping + if epoch > 50 and val_acc < 0.5: + tracker.logger.warning("Model not improving, stopping early") + break + + # Log final results + tracker.log_metrics({"best_val_accuracy": best_val_acc}) + tracker.logger.info(f"Training completed. Best accuracy: {best_val_acc:.4f}") + + # Cleanup + tracker.finish() + +if __name__ == "__main__": + main() +``` + +**Usage**: +```bash +# Train new model +python train.py --config configs/resnet18.yaml --name resnet18_baseline + +# Resume training +python train.py --config configs/resnet18.yaml --resume experiments/resnet18_baseline/checkpoints/latest.pt +``` + + +## Further Reading + +- **Papers**: + - "Hidden Technical Debt in Machine Learning Systems" (Sculley et al., 2015) + - "Reproducibility in Machine Learning" (Pineau et al., 2020) + - "A Step Toward Quantifying Independently Reproducible ML Research" (Dodge et al., 2019) + +- **Tool Documentation**: + - TensorBoard: https://www.tensorflow.org/tensorboard + - Weights & Biases: https://docs.wandb.ai/ + - MLflow: https://mlflow.org/docs/latest/index.html + - DVC (Data Version Control): https://dvc.org/doc + - Hydra (Config Management): https://hydra.cc/docs/intro/ + +- **Best Practices**: + - Papers With Code (Reproducibility): https://paperswithcode.com/ + - ML Code Completeness Checklist: https://github.com/paperswithcode/releasing-research-code + - Experiment Management Guide: https://neptune.ai/blog/experiment-management + +- **Books**: + - "Designing Machine Learning Systems" by Chip Huyen (Chapter on Experiment Tracking) + - "Machine Learning Engineering" by Andriy Burkov (Chapter on MLOps) + + +**Remember**: Experiment tracking is insurance. It costs 1% overhead but saves 100% when disaster strikes. Track from day 1, track everything, and your future self will thank you. diff --git a/skills/using-training-optimization/gradient-management.md b/skills/using-training-optimization/gradient-management.md new file mode 100644 index 0000000..e4d9c0b --- /dev/null +++ b/skills/using-training-optimization/gradient-management.md @@ -0,0 +1,2442 @@ + +# Gradient Management Skill + +## When to Use This Skill + +Use this skill when: +- Loss becomes NaN or Inf during training +- Training is unstable with loss spikes +- User asks about gradient clipping +- User wants larger batch size but has OOM issues +- User mentions "exploding gradients" or "vanishing gradients" +- Gradients are very large (>100) or very small (<1e-8) +- Implementing gradient accumulation +- Using mixed precision (AMP) with gradient clipping +- User asks "why is my training unstable?" +- Training Transformers, RNNs, or very deep networks +- User implements gradient accumulation without loss scaling (RED FLAG) +- User clips gradients after optimizer.step() (RED FLAG) +- User doesn't unscale before clipping with AMP (RED FLAG) +- Reinforcement learning (policy gradients often explode) +- Distributed training with gradient synchronization questions +- User says "just lower learning rate" for NaN loss (may need clipping) + +Do NOT use when: +- Training is stable with no gradient issues +- User has architecture questions unrelated to gradients +- User only asks about learning rate (use learning-rate-scheduling skill) +- User asks about data issues (different problem space) + + +## Core Principles + +### 1. The Critical Importance of Gradient Management + +**Gradients are the foundation of neural network training:** +- Backpropagation computes gradients of loss w.r.t. parameters +- Optimizer uses gradients to update parameters +- Gradient magnitude determines update size +- Gradient stability determines training stability +- Wrong gradient handling → training failure (NaN, no convergence) + +**Common Impact:** +- Gradient clipping: Difference between training and NaN loss +- Gradient accumulation: Train with 8x larger effective batch size on same hardware +- Proper diagnosis: 1-2 hours to fix vs days of confusion +- Mixed precision integration: 2x speedup without breaking training + +**This is NOT optional:** +- Every Transformer paper uses gradient clipping +- Gradient accumulation is standard for large models +- Production training code always monitors gradients +- Ignoring gradients → fragile, unreliable training + + +### 2. Gradient Flow in Training + +**Understanding the training loop gradient flow:** + +```python +# Step 1: Zero gradients from previous iteration +optimizer.zero_grad() + +# Step 2: Forward pass (compute loss) +output = model(input) +loss = criterion(output, target) + +# Step 3: Backward pass (compute gradients) +# This computes: param.grad = ∂loss/∂param for all parameters +loss.backward() + +# Step 4: [OPTIONAL] Modify gradients (clipping, scaling, etc.) +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# Step 5: Optimizer step (update parameters using gradients) +# This does: param = param - lr * param.grad (simplified) +optimizer.step() +``` + +**Critical ordering:** +1. Gradients are computed by `backward()` +2. Gradients can be modified between `backward()` and `step()` +3. Gradients are consumed by `step()` to update parameters +4. Gradient modifications MUST happen after `backward()`, before `step()` + +**Mental model:** +- `backward()` produces gradients +- Your code can inspect/modify gradients +- `step()` consumes gradients to update parameters +- Modifications after `step()` are useless (gradients already consumed) +- Modifications before `backward()` are useless (gradients don't exist yet) + + +## Gradient Clipping + +### Why Gradient Clipping Matters + +**The exploding gradients problem:** +- Deep networks multiply gradients through chain rule +- Each layer multiplies gradient by weights and activation derivatives +- If these multiplications are >1, gradients grow exponentially +- Large gradients → large parameter updates → training instability +- Extremely large gradients → NaN or Inf loss + +**Real-world symptoms:** +- Loss suddenly jumps to NaN after normal training +- Loss oscillates wildly between iterations +- Training is stable initially, then diverges +- Parameters become NaN or Inf +- Gradient norms >100 or >1000 + +**Why it happens:** +- Transformers: Attention mechanism can amplify gradients +- RNNs: Backpropagation through time multiplies gradients across timesteps +- Very deep networks: Many layers multiply gradients +- Poor initialization: Large initial weights amplify gradients +- High learning rates: Amplify already-large gradients + +### Norm-Based Gradient Clipping (Primary Method) + +**The standard solution:** + +```python +# Clip gradients by global norm +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# Complete training loop: +optimizer.zero_grad() +loss.backward() +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +optimizer.step() +``` + +**What it does:** +1. Computes total gradient norm: `total_norm = sqrt(sum(g^2 for g in all gradients))` +2. If `total_norm > max_norm`: + - Scaling factor = `max_norm / total_norm` + - All gradients multiplied by this factor +3. Result: Gradient direction preserved, magnitude limited + +**Why this is good:** +- Preserves gradient direction (doesn't distort signal) +- Only scales when needed (if total_norm ≤ max_norm, no change) +- Global view (considers all parameters together) +- Mathematically elegant (scales gradient vector to unit ball) + +**Typical values for max_norm:** + +```python +# Transformers (BERT, GPT, T5) +max_norm = 1.0 # Most common +max_norm = 5.0 # Sometimes used for very large models + +# RNNs/LSTMs +max_norm = 0.5 # More aggressive clipping +max_norm = 1.0 # Also common + +# Reinforcement Learning (policy gradients) +max_norm = 0.5 # RL gradients are particularly unstable + +# CNNs (ResNets, etc.) +# Usually DON'T clip - residual connections provide stability +# Only clip if you observe instability + +# Very deep networks (>100 layers) +max_norm = 1.0 # Helps with stability +``` + +**When to use norm-based clipping:** +✅ Training Transformers (almost always needed) +✅ Training RNNs/LSTMs (essential for long sequences) +✅ Reinforcement learning (policy gradients) +✅ Any time you see loss → NaN +✅ Loss spikes or wild oscillations +✅ Very deep networks (>50 layers) + +**When NOT to use:** +❌ Stable CNN training (ResNet on ImageNet) +❌ Training is already stable with no issues +❌ As a preemptive measure without evidence of need + +### Value-Based Gradient Clipping (Rare) + +**Clips each gradient element individually:** + +```python +# Clip each gradient value to [-clip_value, +clip_value] +torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5) + +# What it does: +for param in model.parameters(): + if param.grad is not None: + param.grad.clamp_(-clip_value, clip_value) +``` + +**Difference from norm-based:** +- Norm-based: Scales entire gradient vector to limit total magnitude +- Value-based: Clamps each gradient element independently +- Value-based is MORE aggressive (can change gradient direction) +- Value-based treats all parameters equally (ignores scale differences) + +**When to use value-based clipping:** +- Debugging: Identify which specific parameters have large gradients +- Extreme outliers: Some parameters have huge gradients while others are normal +- Legacy code: Some old papers use this + +**Usually prefer norm-based:** +- Norm-based is standard in modern deep learning +- Preserves gradient direction +- Better theoretical properties +- Used in all major Transformer implementations + +### Complete Clipping Implementation + +```python +import torch +import torch.nn as nn + +# Model and optimizer +model = TransformerModel() +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +# Training loop with gradient clipping +for epoch in range(num_epochs): + for batch in train_loader: + # 1. Zero gradients + optimizer.zero_grad() + + # 2. Forward pass + output = model(batch['input']) + loss = criterion(output, batch['target']) + + # 3. Backward pass (compute gradients) + loss.backward() + + # 4. Clip gradients (CRITICAL: after backward, before step) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # 5. Optimizer step (update parameters) + optimizer.step() +``` + +**Common mistakes - WRONG ORDER:** + +```python +# WRONG: Clipping after optimizer.step() +loss.backward() +optimizer.step() +clip_grad_norm_(model.parameters(), 1.0) # ❌ Too late! Already updated. + +# WRONG: Clipping before backward() +optimizer.zero_grad() +clip_grad_norm_(model.parameters(), 1.0) # ❌ No gradients exist yet! +loss.backward() +optimizer.step() + +# RIGHT: Clipping after backward(), before step() +loss.backward() # Compute gradients +clip_grad_norm_(model.parameters(), 1.0) # Modify gradients +optimizer.step() # Use modified gradients +``` + +### How to Choose max_norm Value + +**Start with standard values:** + +```python +# Default starting point for Transformers +max_norm = 1.0 + +# If still unstable (loss spikes) +max_norm = 0.5 # More aggressive clipping + +# If training seems too constrained (slow convergence) +max_norm = 2.0 # Less aggressive clipping +``` + +**Systematic tuning:** + +1. **Monitor gradient norms WITHOUT clipping:** + ```python + # Check typical gradient magnitudes + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + print(f"Gradient norm: {total_norm:.4f}") + ``` + +2. **Set max_norm based on typical norms:** + - If typical norms are 0.5-2.0, set max_norm=2.0 or 3.0 + - If typical norms are 5-10, set max_norm=5.0 or 10.0 + - Goal: Clip outliers without affecting normal gradients + +3. **Verify clipping is helping:** + ```python + # Log how often clipping activates + grad_norm_before = compute_grad_norm(model) + clip_grad_norm_(model.parameters(), max_norm=1.0) + grad_norm_after = compute_grad_norm(model) + + if grad_norm_before > max_norm: + print(f"Clipped: {grad_norm_before:.4f} -> {grad_norm_after:.4f}") + ``` + +**Signs you need clipping:** +- Gradient norms occasionally >10 or >100 +- Loss occasionally spikes or becomes NaN +- Training is initially stable then diverges +- Gradient norms grow over time + +**Signs your max_norm is too low:** +- Clipping activates on EVERY iteration +- Training converges very slowly +- Gradient norm is always exactly max_norm (always clipping) + +**Signs your max_norm is too high:** +- Still getting NaN or loss spikes +- Clipping never activates +- Not solving the stability problem + + +## Gradient Accumulation + +### Why Gradient Accumulation Matters + +**The memory vs batch size problem:** +- Larger batch sizes often improve training (more stable gradients) +- Larger batches require more GPU memory +- Memory is limited (GPU VRAM) +- Can't always fit desired batch size in memory + +**Example scenario:** +- Want batch size 256 for stable training +- Only fit batch size 32 in GPU memory +- Can't afford bigger GPU +- Solution: Gradient accumulation + +**What gradient accumulation does:** +- Accumulate gradients over multiple small batches +- Update parameters once with accumulated gradients +- Equivalent to training with one large batch +- Same results, but fits in memory + +**Real-world impact:** +- Train models 4-8x larger batch size on same hardware +- Standard technique in production training +- Used in all large model training (GPT, BERT, etc.) +- Essential for competitive performance on limited hardware + +### Correct Gradient Accumulation Implementation + +**The critical implementation:** + +```python +# Want effective batch size 256, but can only fit 64 in memory +# Solution: Accumulate over 4 steps (256 = 64 * 4) + +accumulation_steps = 4 +optimizer.zero_grad() + +for i, (data, target) in enumerate(train_loader): + # Forward pass + output = model(data) + loss = criterion(output, target) + + # Backward pass with CRITICAL loss scaling + # MUST divide loss by accumulation_steps! + (loss / accumulation_steps).backward() + + # Update weights every accumulation_steps + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() +``` + +**Why scale loss by accumulation_steps?** + +```python +# Without scaling: +loss.backward() # Adds gradients: param.grad += ∂loss/∂param + +# After 4 accumulation steps: +# param.grad = ∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param +# This is 4x larger than single batch! + +# With scaling: +(loss / 4).backward() # Adds: param.grad += (∂loss/∂param) / 4 + +# After 4 accumulation steps: +# param.grad = (∂loss1/∂param + ∂loss2/∂param + ∂loss3/∂param + ∂loss4/∂param) / 4 +# This is the AVERAGE gradient - equivalent to single large batch! +``` + +**Mathematical equivalence:** +- Large batch loss: `L = (l1 + l2 + l3 + l4) / 4` (mean over samples) +- Large batch gradient: `∂L/∂param = (∂l1/∂param + ∂l2/∂param + ∂l3/∂param + ∂l4/∂param) / 4` +- Accumulated gradient: Same result! + +### Common Gradient Accumulation Mistakes + +**WRONG: Not scaling loss** + +```python +# ❌ WRONG - Gradients are accumulation_steps times too large! +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + loss.backward() # ❌ Not scaled! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# Result: Equivalent to learning_rate * accumulation_steps +# Acts like LR is 4x too high → unstable training +``` + +**WRONG: Scaling gradients instead of loss** + +```python +# ❌ WRONG - Inefficient and error-prone! +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + loss.backward() + + # Manually scale gradients + for param in model.parameters(): + if param.grad is not None: + param.grad /= accumulation_steps # ❌ Inefficient! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# Why wrong: +# - More code, more error-prone +# - Less efficient (iterates all parameters) +# - Easy to forget or do incorrectly +# - Scaling loss is cleaner and standard +``` + +**WRONG: Forgetting to zero_grad() after update** + +```python +# ❌ WRONG - Gradients keep accumulating forever! +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + # ❌ Missing optimizer.zero_grad()! + # Next accumulation will add to these gradients! + +# Result: Gradients never reset, accumulate across updates +# Acts like accumulation_steps grows over time +``` + +**WRONG: Zeroing gradients inside accumulation loop** + +```python +# ❌ WRONG - Clears gradients before accumulating! +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + optimizer.zero_grad() # ❌ Clears previous accumulation! + + loss = criterion(model(batch), target) + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + +# Result: Only last batch's gradients are used (no accumulation!) +``` + +### Complete Gradient Accumulation Implementation + +```python +import torch +import torch.nn as nn + +# Configuration +batch_size_per_step = 64 # What fits in memory +accumulation_steps = 4 # Accumulate over 4 steps +effective_batch_size = batch_size_per_step * accumulation_steps # = 256 + +# DataLoader with smaller batch size +train_loader = DataLoader(dataset, batch_size=batch_size_per_step) + +# Model and optimizer +model = TransformerModel() +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +# Training loop +optimizer.zero_grad() # Zero once before accumulation loop + +for epoch in range(num_epochs): + for i, (data, target) in enumerate(train_loader): + # Forward pass + output = model(data) + loss = criterion(output, target) + + # Backward pass with scaled loss + (loss / accumulation_steps).backward() + + # Update every accumulation_steps + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + + # Handle remaining batches at end of epoch + # (if total batches not divisible by accumulation_steps) + if len(train_loader) % accumulation_steps != 0: + optimizer.step() + optimizer.zero_grad() +``` + +### Gradient Accumulation with Gradient Clipping + +**Correct order:** + +```python +accumulation_steps = 4 +optimizer.zero_grad() + +for i, (data, target) in enumerate(train_loader): + output = model(data) + loss = criterion(output, target) + + # Scale loss and backward + (loss / accumulation_steps).backward() + + # Update every accumulation_steps + if (i + 1) % accumulation_steps == 0: + # Clip BEFORE optimizer step (on accumulated gradients) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + optimizer.step() + optimizer.zero_grad() +``` + +**Why this order?** +- Gradients accumulate over `accumulation_steps` iterations +- After accumulation, gradients are ready for clipping +- Clip once on the full accumulated gradients +- Then update parameters with clipped gradients + +**WRONG: Clipping on each accumulation step:** + +```python +# ❌ WRONG - Clips partial gradients! +for i, (data, target) in enumerate(train_loader): + (loss / accumulation_steps).backward() + + # ❌ Clipping partial gradients! + clip_grad_norm_(model.parameters(), max_norm=1.0) + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# Why wrong: +# - Clipping partial gradients distorts accumulation +# - Each partial gradient is ~1/4 of final gradient +# - Clipping these small gradients has wrong threshold +# - Clip ONCE on final accumulated gradient +``` + +### Gradient Accumulation with Learning Rate Scheduling + +**Correct implementation:** + +```python +accumulation_steps = 4 +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) + +optimizer.zero_grad() + +for i, (data, target) in enumerate(train_loader): + output = model(data) + loss = criterion(output, target) + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + + # Step scheduler AFTER optimizer step (once per update) + scheduler.step() +``` + +**Key points:** +- Scheduler steps once per parameter update (not per batch) +- Matches the effective batch size timing +- Scheduler sees `num_batches / accumulation_steps` total steps + + +## Gradient Diagnosis + +### Why Diagnosis Matters + +**Don't guess - measure:** +- "Training isn't working" could be many issues +- Gradient issues have specific symptoms +- Measuring gradients identifies the problem +- Diagnosis guides the solution + +**What to diagnose:** +1. Gradient magnitudes (too large? too small?) +2. Gradient distribution across layers (vanishing in early layers?) +3. NaN or Inf gradients (numerical issues?) +4. Gradient patterns over time (growing? shrinking?) + +### Checking Gradient Magnitudes + +**Basic gradient checking:** + +```python +def check_gradients(model): + """Check gradient magnitudes for all parameters""" + total_norm = 0.0 + param_norms = {} + + for name, param in model.named_parameters(): + if param.grad is not None: + # Compute gradient norm for this parameter + param_norm = param.grad.data.norm(2).item() + param_norms[name] = param_norm + total_norm += param_norm ** 2 + + total_norm = total_norm ** 0.5 + + print(f"Total gradient norm: {total_norm:.4f}") + + # Show top 5 largest gradients + print("\nLargest gradients:") + for name, norm in sorted(param_norms.items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" {name}: {norm:.4f}") + + # Show top 5 smallest gradients + print("\nSmallest gradients:") + for name, norm in sorted(param_norms.items(), key=lambda x: x[1])[:5]: + print(f" {name}: {norm:.4e}") + + return total_norm + +# Usage in training loop: +loss.backward() +grad_norm = check_gradients(model) +optimizer.step() +``` + +**What to look for:** + +```python +# Healthy gradients: +# Total norm: 0.1 to 10 +# Layer norms: Similar order of magnitude across layers +# No NaN or Inf values + +# Exploding gradients: +# Total norm: >100 or >1000 +# Some layers have huge gradients (>10) +# → Solution: Gradient clipping + +# Vanishing gradients: +# Total norm: <1e-6 +# Early layers have much smaller gradients than late layers +# → Solution: Better activation/initialization/architecture + +# NaN gradients: +# Any gradient is NaN or Inf +# → Solution: Check for numerical instability in loss or model +``` + +### Comprehensive Gradient Diagnostics + +```python +def diagnose_gradients(model, threshold_low=1e-8, threshold_high=100): + """ + Comprehensive gradient diagnostics with automatic issue detection + + Args: + model: PyTorch model + threshold_low: Threshold for vanishing gradients + threshold_high: Threshold for exploding gradients + + Returns: + dict with diagnostic information + """ + diagnostics = { + 'total_norm': 0.0, + 'param_norms': {}, + 'has_nan': False, + 'has_inf': False, + 'vanishing': [], + 'exploding': [], + } + + total_norm = 0.0 + + for name, param in model.named_parameters(): + if param.grad is not None: + grad = param.grad.data + + # Check for NaN or Inf + if torch.isnan(grad).any(): + diagnostics['has_nan'] = True + print(f"⚠️ NaN gradient detected in {name}") + + if torch.isinf(grad).any(): + diagnostics['has_inf'] = True + print(f"⚠️ Inf gradient detected in {name}") + + # Compute norm + param_norm = grad.norm(2).item() + diagnostics['param_norms'][name] = param_norm + total_norm += param_norm ** 2 + + # Check for vanishing + if param_norm < threshold_low: + diagnostics['vanishing'].append((name, param_norm)) + + # Check for exploding + if param_norm > threshold_high: + diagnostics['exploding'].append((name, param_norm)) + + total_norm = total_norm ** 0.5 + diagnostics['total_norm'] = total_norm + + # Print diagnosis + print(f"\n{'='*60}") + print(f"GRADIENT DIAGNOSTICS") + print(f"{'='*60}") + print(f"Total gradient norm: {total_norm:.4f}") + + if diagnostics['has_nan']: + print("\n🚨 CRITICAL: NaN gradients detected!") + print(" Possible causes:") + print(" - Division by zero in loss or model") + print(" - Log of zero or negative number") + print(" - Numerical overflow") + print(" - Already-NaN parameters or inputs") + + if diagnostics['has_inf']: + print("\n🚨 CRITICAL: Inf gradients detected!") + print(" Possible causes:") + print(" - Numerical overflow (very large values)") + print(" - Division by very small number") + print(" - Exponential of very large number") + + if total_norm > threshold_high: + print(f"\n⚠️ EXPLODING GRADIENTS: Total norm {total_norm:.2f} > {threshold_high}") + print(" Solution: Add gradient clipping") + print(f" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm={threshold_high/10:.1f})") + if diagnostics['exploding']: + print(f"\n Top exploding layers:") + for name, norm in sorted(diagnostics['exploding'], key=lambda x: x[1], reverse=True)[:5]: + print(f" - {name}: {norm:.2f}") + + if total_norm < threshold_low: + print(f"\n⚠️ VANISHING GRADIENTS: Total norm {total_norm:.2e} < {threshold_low}") + print(" Possible solutions:") + print(" - Use ReLU/GELU instead of sigmoid/tanh") + print(" - Check weight initialization (use He/Xavier)") + print(" - Add batch normalization") + print(" - Add residual connections") + print(" - Increase learning rate (after other fixes)") + if diagnostics['vanishing']: + print(f"\n Layers with vanishing gradients:") + for name, norm in sorted(diagnostics['vanishing'], key=lambda x: x[1])[:5]: + print(f" - {name}: {norm:.2e}") + + print(f"{'='*60}\n") + + return diagnostics + +# Usage: +loss.backward() +diagnostics = diagnose_gradients(model) + +if diagnostics['has_nan'] or diagnostics['has_inf']: + # Stop training, fix the issue + raise RuntimeError("NaN or Inf gradients detected!") +``` + +### Gradient Monitoring and Logging + +**Log gradient statistics during training:** + +```python +import wandb # or tensorboard + +def log_gradient_stats(model, logger, step): + """Log gradient statistics for monitoring""" + + total_norm = 0.0 + layer_norms = {} + + for name, param in model.named_parameters(): + if param.grad is not None: + # Gradient norm + grad_norm = param.grad.data.norm(2).item() + layer_norms[name] = grad_norm + total_norm += grad_norm ** 2 + + # Parameter norm (for ratio calculation) + param_norm = param.data.norm(2).item() + + # Log individual layer stats + logger.log({ + f"gradients/{name}/norm": grad_norm, + f"gradients/{name}/mean": param.grad.data.mean().item(), + f"gradients/{name}/std": param.grad.data.std().item(), + f"gradients/{name}/max": param.grad.data.abs().max().item(), + }, step=step) + + # Log ratio of gradient norm to parameter norm + # Healthy ratio is typically 0.001 to 0.01 + if param_norm > 0: + ratio = grad_norm / param_norm + logger.log({f"gradients/{name}/ratio": ratio}, step=step) + + total_norm = total_norm ** 0.5 + + # Log total gradient norm + logger.log({"gradients/total_norm": total_norm}, step=step) + + return total_norm + +# Usage in training loop: +for step, batch in enumerate(train_loader): + optimizer.zero_grad() + loss = model(batch) + loss.backward() + + # Log gradients (before clipping to see true magnitudes) + grad_norm = log_gradient_stats(model, wandb, step) + + # Clip and update + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() +``` + +**What to watch in gradient logs:** + +```python +# Healthy training: +# - Total gradient norm: Relatively stable (0.1 to 10) +# - Layer norms: Similar across layers (no huge disparities) +# - Ratios: ~0.001 (gradients much smaller than parameters) +# - No sudden spikes or drops to zero + +# Warning signs: +# - Total norm suddenly spikes (>100) → exploding gradients +# - Total norm gradually decreases to near-zero → vanishing gradients +# - Early layers have much smaller norms than late layers → vanishing +# - Ratios > 0.1 → updates are too large relative to parameters +# - Sudden drop to zero → dead neurons or broken gradient flow +``` + + +## Vanishing Gradients + +### Recognizing Vanishing Gradients + +**Symptoms:** +1. Training loss decreases very slowly or not at all +2. Validation metrics don't improve +3. Gradient norms are extremely small (<1e-6) +4. Early layers have much smaller gradients than later layers +5. Training seems "stuck" after initialization + +**How to confirm:** + +```python +# Check gradient magnitudes by layer depth +loss.backward() + +print("Layer-wise gradient norms:") +for name, param in model.named_parameters(): + if param.grad is not None: + norm = param.grad.norm(2).item() + print(f"{name}: {norm:.2e}") + +# Example output showing vanishing gradients: +# layer1.weight: 1.23e-02 ← Early layer +# layer5.weight: 3.45e-04 +# layer10.weight: 8.91e-06 +# layer15.weight: 2.34e-07 +# layer20.weight: 5.67e-09 ← Late layer + +# Pattern: Gradients shrink exponentially with depth +# This is vanishing gradients! +``` + +### Causes of Vanishing Gradients + +**1. Too many layers (very deep networks):** +- Each layer multiplies gradient by weights during backprop +- If multiplication factor <1, gradients shrink exponentially +- More layers = more multiplication = smaller gradients + +**2. Saturating activation functions:** +- Sigmoid: `σ'(x) ≈ 0` when `|x|` is large (saturates) +- Tanh: `tanh'(x) ≈ 0` when `|x|` is large +- Gradient flows through: `grad = grad * activation'(x)` +- If `activation'(x) ≈ 0`, gradient vanishes + +**3. Poor weight initialization:** +- Weights too small → activations too small → gradients too small +- Weights initialized uniformly → improper scaling across layers + +**4. Learning rate too low:** +- Not a root cause, but can make problem worse +- Tiny gradients * tiny LR = no learning + +### Solutions for Vanishing Gradients + +**Solution 1: Use Better Activation Functions** + +```python +# AVOID: Sigmoid and Tanh (saturate easily) +class BadModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(100, 100), + nn.Sigmoid(), # ❌ Saturates, kills gradients + nn.Linear(100, 100), + nn.Sigmoid(), # ❌ Even worse with depth + nn.Linear(100, 10) + ) + +# PREFER: ReLU, GELU, or other non-saturating activations +class GoodModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(100, 100), + nn.ReLU(), # ✅ Doesn't saturate (for x>0) + nn.Linear(100, 100), + nn.GELU(), # ✅ Smooth, non-saturating + nn.Linear(100, 10) + ) + +# Why it helps: +# ReLU: grad = 1 for x>0, doesn't shrink gradient +# GELU: Smooth version of ReLU, widely used in Transformers +# Both avoid saturation that kills gradients +``` + +**Solution 2: Proper Weight Initialization** + +```python +# Use He initialization for ReLU networks +def init_weights(m): + if isinstance(m, nn.Linear): + # He initialization: optimal for ReLU + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + +model = GoodModel() +model.apply(init_weights) + +# Use Xavier initialization for Tanh/Sigmoid (if you must use them) +def init_weights_xavier(m): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + +# Why it helps: +# Proper initialization ensures gradients have appropriate scale +# He init accounts for ReLU's effect on variance +# Xavier init maintains variance across layers for symmetric activations +``` + +**Solution 3: Batch Normalization** + +```python +# Add BatchNorm between layers +class ModelWithBatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(100, 100), + nn.BatchNorm1d(100), # ✅ Normalizes activations + nn.ReLU(), + nn.Linear(100, 100), + nn.BatchNorm1d(100), # ✅ Helps gradient flow + nn.ReLU(), + nn.Linear(100, 10) + ) + +# For CNNs: +class CNNWithBatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1), + nn.BatchNorm2d(64), # ✅ After conv, before activation + nn.ReLU(), + nn.Conv2d(64, 128, 3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(), + ) + +# Why it helps: +# BatchNorm normalizes activations to have mean=0, std=1 +# Prevents activations from getting too small or too large +# Helps maintain gradient scale through network +# Widely used in modern architectures +``` + +**Solution 4: Residual Connections (Skip Connections)** + +```python +# Add skip connections (ResNet-style) +class ResidualBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(dim, dim), + nn.ReLU(), + nn.Linear(dim, dim) + ) + + def forward(self, x): + # Skip connection: add input to output + return x + self.layers(x) # ✅ Gradient flows through skip connection + +class ResidualNetwork(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.Sequential( + ResidualBlock(100), + ResidualBlock(100), + ResidualBlock(100), + # Can stack many blocks without vanishing gradients! + ) + self.output = nn.Linear(100, 10) + + def forward(self, x): + x = self.blocks(x) + return self.output(x) + +# Why it helps: +# Gradients can flow directly through skip connections +# Backprop path: grad flows through addition (no multiplication) +# Allows training very deep networks (ResNet-152, ResNet-200) +# Essential for modern deep architectures +``` + +**Solution 5: Layer Normalization (for Transformers)** + +```python +# Transformers use Layer Normalization +class TransformerBlock(nn.Module): + def __init__(self, d_model): + super().__init__() + self.attention = MultiHeadAttention(d_model) + self.norm1 = nn.LayerNorm(d_model) # ✅ Layer norm + self.ffn = FeedForward(d_model) + self.norm2 = nn.LayerNorm(d_model) # ✅ Layer norm + + def forward(self, x): + # Pre-norm architecture (modern standard) + x = x + self.attention(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + +# Why Layer Norm: +# BatchNorm doesn't work well for sequences (different lengths) +# LayerNorm normalizes across features (not batch) +# Standard in Transformers (BERT, GPT, etc.) +``` + +**Solution 6: Gradient Checkpointing (if memory-constrained)** + +```python +# Trade computation for memory (from pytorch-engineering pack) +from torch.utils.checkpoint import checkpoint + +class DeepModel(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([ + ResidualBlock(100) for _ in range(50) # Very deep! + ]) + + def forward(self, x): + for block in self.blocks: + # Use checkpointing to save memory + x = checkpoint(block, x, use_reentrant=False) + return x + +# Why it helps: +# Allows training deeper networks in same memory +# Doesn't directly solve vanishing gradients +# But removes memory constraint that prevents using deeper models +# Compatible with all other solutions (BN, residuals, etc.) +``` + +### Systematic Approach to Vanishing Gradients + +**Step 1: Confirm diagnosis** +```python +# Check gradient magnitudes +loss.backward() +for name, param in model.named_parameters(): + if param.grad is not None: + print(f"{name}: {param.grad.norm(2).item():.2e}") + +# Look for: Early layers << Late layers +``` + +**Step 2: Apply architectural fixes (priority order)** +1. Switch to ReLU/GELU activations (highest impact) +2. Add proper weight initialization (He/Xavier) +3. Add BatchNorm or LayerNorm +4. Add residual connections if very deep (>20 layers) + +**Step 3: Verify improvement** +```python +# After fixes, check gradients again +# Should see more uniform gradient magnitudes across layers +``` + +**Step 4: Adjust learning rate if needed** +```python +# Only AFTER architectural fixes +# May need slightly higher LR with better gradient flow +``` + +**IMPORTANT NOTE: When Small Gradients Are Actually OK** + +Don't blindly "fix" small gradients if training is working well: + +```python +# Scenario: Gradients are small (1e-7) but training is progressing +# Epoch 1: Loss 2.34, Grad norm: 3.45e-07 +# Epoch 2: Loss 1.89, Grad norm: 2.91e-07 ← Loss decreasing! +# Epoch 3: Loss 1.52, Grad norm: 2.34e-07 ← Still improving! + +# This is OK! Don't fix what isn't broken. +``` + +**Healthy small gradients:** +- Training progressing (loss decreasing, metrics improving) ✓ +- Gradients relatively uniform across layers +- Gradients stable over time + +**Unhealthy vanishing gradients:** +- Training stuck (loss not decreasing) +- Early layers << late layers (1000x difference) +- Gradients decreasing over time + +**Key insight:** Absolute gradient magnitude depends on parameter scale, loss scale, and learning rate. What matters is: **Is the model learning?** + +```python +# Better diagnostic: Check relative gradients across layers +grad_norms = {} +for name, param in model.named_parameters(): + if param.grad is not None: + grad_norms[name] = param.grad.norm(2).item() + +# Check ratio: Are early layers much smaller than late layers? +early_layers = [v for k, v in grad_norms.items() if 'layer0' in k or 'layer1' in k] +late_layers = [v for k, v in grad_norms.items() if 'layer19' in k or 'layer20' in k] + +if early_layers and late_layers: + ratio = np.mean(late_layers) / np.mean(early_layers) + if ratio > 1000: + print(f"⚠️ Vanishing gradients: late/early ratio = {ratio:.0f}") + else: + print(f"✅ Gradient flow OK: late/early ratio = {ratio:.0f}") +``` + +**Decision rule:** +- Training working well + gradients stable → No action needed +- Training stuck + early << late → Apply architectural fixes +- Training working + improving over time → Monitor but don't change + + +## Exploding Gradients + +### Recognizing Exploding Gradients + +**Symptoms:** +1. Loss suddenly becomes NaN or Inf during training +2. Loss oscillates wildly (jumps up and down) +3. Parameters become very large or NaN +4. Gradient norms >100 or >1000 +5. Training is stable initially then suddenly diverges + +**How to confirm:** + +```python +# Check gradient magnitudes +loss.backward() + +total_norm = 0.0 +for param in model.parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + +total_norm = total_norm ** 0.5 +print(f"Total gradient norm: {total_norm:.4f}") + +# If total_norm > 100: Exploding gradients! +# If any parameter grad norm > 100: Exploding gradients! +``` + +### Causes of Exploding Gradients + +**1. Learning rate too high:** +- Large gradients * large LR = huge parameter updates +- Updates overshoot optimal values +- Can cause oscillation or divergence + +**2. Poor weight initialization:** +- Weights too large → activations too large → gradients too large +- Random initialization without proper scaling + +**3. Lack of gradient clipping:** +- Occasional gradient spikes are normal in some architectures +- Without clipping, one spike can break training + +**4. Numerical instability in model:** +- Division by very small numbers +- Exponential of large numbers +- Log of numbers close to zero + +**5. Architecture-specific issues:** +- Transformers: Attention mechanism can amplify gradients +- RNNs: Backprop through time multiplies gradients across timesteps +- Very deep networks: Many layers multiply gradients + +### Solutions for Exploding Gradients + +**Solution 1: Gradient Clipping (Primary Solution)** + +```python +# Add gradient clipping - THE solution for exploding gradients +optimizer.zero_grad() +loss.backward() + +# Clip gradients to maximum norm +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +optimizer.step() + +# Why this works: +# Limits gradient magnitude while preserving direction +# Prevents huge parameter updates +# Standard practice for Transformers, RNNs, RL +``` + +**Solution 2: Lower Learning Rate** + +```python +# If gradients are consistently large, try lower LR +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Was 1e-3 + +# But NOTE: +# Gradient clipping is usually BETTER than just lowering LR +# Clipping handles occasional spikes without limiting normal gradients +# Lowering LR slows down ALL learning, even when gradients are normal + +# Best approach: Use both +# - Gradient clipping for stability (handles spikes) +# - Reasonable learning rate for speed (not too high or too low) +``` + +**Solution 3: Better Weight Initialization** + +```python +# Use proper initialization +def init_weights(m): + if isinstance(m, nn.Linear): + # He initialization for ReLU + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.zeros_(m.bias) + +model.apply(init_weights) + +# Why it helps: +# Proper initialization ensures weights are appropriate scale +# Prevents initial gradients from being too large +# Particularly important for very deep networks +``` + +**Solution 4: Batch Normalization** + +```python +# Add BatchNorm to stabilize training +class StableModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(100, 100), + nn.BatchNorm1d(100), # ✅ Stabilizes gradients + nn.ReLU(), + nn.Linear(100, 100), + nn.BatchNorm1d(100), + nn.ReLU(), + nn.Linear(100, 10) + ) + +# Why it helps: +# Normalizes activations, which stabilizes gradients +# Reduces internal covariate shift +# Makes training more robust to hyperparameter choices +``` + +**Solution 5: Check for Numerical Issues** + +```python +# AVOID: Operations that can cause numerical instability + +# ❌ Division by small numbers +loss = 1.0 / (predictions + eps) # If predictions ≈ 0, loss explodes + +# ✅ Add epsilon for stability +eps = 1e-8 +loss = 1.0 / (predictions + eps) + +# ❌ Log of values close to zero +loss = -torch.log(predictions) # If predictions ≈ 0, loss → -∞ + +# ✅ Add epsilon +loss = -torch.log(predictions + eps) + +# ❌ Exp of large values +loss = torch.exp(logits) # If logits are large, exp explodes + +# ✅ Use log-sum-exp trick or built-in stable functions +loss = F.cross_entropy(logits, targets) # Handles numerics internally + +# ❌ Custom loss without stability +def unstable_loss(pred, target): + return ((pred - target) / pred).pow(2).mean() # Division can explode + +# ✅ Add stability +def stable_loss(pred, target): + return ((pred - target) / (pred.abs() + eps)).pow(2).mean() +``` + +**Solution 6: Use Residual Connections** + +```python +# Residual connections help stability +class ResidualBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(dim, dim), + nn.ReLU(), + nn.Linear(dim, dim) + ) + + def forward(self, x): + return x + self.layers(x) # ✅ Skip connection provides stable path + +# Why it helps: +# Gradients can flow through skip connections +# Prevents gradients from exploding through many layers +# Used in all modern deep architectures (ResNet, Transformer, etc.) +``` + +### Systematic Approach to Exploding Gradients + +**Step 1: Confirm diagnosis** +```python +# Monitor gradient norms +loss.backward() +total_norm = sum(p.grad.data.norm(2).item() ** 2 + for p in model.parameters() if p.grad is not None) ** 0.5 +print(f"Gradient norm: {total_norm:.4f}") + +# If norm > 100 or training diverges: Exploding gradients +``` + +**Step 2: Apply fixes (priority order)** +1. **Add gradient clipping** (highest priority, most effective) + ```python + clip_grad_norm_(model.parameters(), max_norm=1.0) + ``` + +2. **Check learning rate** (if still unstable after clipping) + ```python + optimizer = Adam(model.parameters(), lr=1e-4) # Try lower + ``` + +3. **Verify initialization** (if problems from start of training) + ```python + model.apply(init_weights) # Use He/Xavier init + ``` + +4. **Check for numerical issues** (if NaN appears) + ```python + # Add epsilon to divisions, logs, etc. + ``` + +**Step 3: Verify improvement** +```python +# Monitor gradient norms during training +# Should stay in reasonable range (0.1 to 10) +# No sudden spikes to >100 +# No NaN or Inf +``` + +### When Clipping Doesn't Fix NaN + +**If you've added gradient clipping but still get NaN loss:** + +The problem may be in your loss function, not gradients. Diagnose systematically: + +```python +# Step 1: Check if loss is NaN BEFORE backward() +optimizer.zero_grad() +output = model(batch) +loss = custom_loss(output, target) + +# Check loss BEFORE backward +if torch.isnan(loss): + print("❌ Loss is NaN BEFORE backward - problem is in loss function!") + print(f" Output range: {output.min():.4f} to {output.max():.4f}") + print(f" Target range: {target.min():.4f} to {target.max():.4f}") + # Don't proceed with backward - fix loss function first +else: + print("✅ Loss is valid before backward") + loss.backward() + + # Check gradients after backward + for name, param in model.named_parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + print(f"❌ NaN gradient in {name} - gradient issue") +``` + +**Common loss function numerical issues:** + +```python +# ❌ UNSTABLE: Log of zero or negative +def bad_loss(pred, target): + return -torch.log(pred).mean() # NaN if pred <= 0! + +# ✅ STABLE: Add epsilon +def good_loss(pred, target): + eps = 1e-8 + return -torch.log(pred + eps).mean() + + +# ❌ UNSTABLE: Division by zero or very small number +def bad_loss2(pred, target): + return (target / pred).mean() # Explodes if pred ≈ 0 + +# ✅ STABLE: Add epsilon +def good_loss2(pred, target): + eps = 1e-8 + return (target / (pred + eps)).mean() + + +# ❌ UNSTABLE: Sqrt of negative (can happen with numerical errors) +def bad_loss3(pred, target): + diff = pred - target + return torch.sqrt(diff ** 2).mean() # Can get negative from rounding + +# ✅ STABLE: Use abs or clamp +def good_loss3(pred, target): + diff = pred - target + return torch.sqrt(torch.clamp(diff ** 2, min=0)).mean() + + +# ❌ UNSTABLE: Exp of large values +def bad_loss4(logits): + return torch.exp(logits).sum() # Explodes if logits > 100 + +# ✅ STABLE: Use built-in stable functions +def good_loss4(logits, targets): + return F.cross_entropy(logits, targets) # Handles log-sum-exp internally +``` + +**Diagnostic order when NaN appears:** + +1. **Check loss before backward()**: `if torch.isnan(loss): ...` + - If NaN here → fix loss function (add epsilon, clamp, use stable functions) + - If not NaN → gradient issue + +2. **Check gradients after backward()**: + - If gradients are NaN → clipping placement correct? Unscaling (AMP)? + - If gradients OK → parameters NaN from previous update? + +3. **Check parameters**: + ```python + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"❌ NaN in parameter {name} - previous update caused NaN") + ``` + +**Summary decision tree:** + +``` +Loss becomes NaN +│ +├─ Check: Is loss NaN before backward()? +│ │ +│ ├─ YES → Problem in loss function +│ │ • Add epsilon to divisions +│ │ • Add epsilon to logs +│ │ • Clamp inputs to sqrt +│ │ • Use stable built-in functions +│ │ +│ └─ NO → Problem in backward/gradients +│ • Check gradient clipping is correctly placed +│ • Check unscaling if using AMP +│ • Check for numerical instability in model +│ • Verify proper initialization +``` + + +## Mixed Precision Training Integration + +### Gradient Clipping with AMP + +**The critical interaction:** + +```python +from torch.cuda.amp import autocast, GradScaler + +scaler = GradScaler() +model = TransformerModel().cuda() +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +for batch in train_loader: + optimizer.zero_grad() + + # Forward pass with autocast (mixed precision) + with autocast(): + output = model(batch['input']) + loss = criterion(output, batch['target']) + + # Backward pass (gradients are SCALED) + scaler.scale(loss).backward() + + # CRITICAL: Unscale before clipping! + scaler.unscale_(optimizer) + + # Now clip (on unscaled gradients) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Optimizer step (scaler handles it) + scaler.step(optimizer) + scaler.update() +``` + +**Why unscale before clipping?** + +```python +# Understanding the problem: + +# GradScaler multiplies gradients by large factor (e.g., 2^16 = 65536) +# This prevents underflow in fp16 gradients +# But clipping should happen on TRUE gradient values, not scaled values + +# WITHOUT unscaling: +scaler.scale(loss).backward() # Gradients are now 65536x larger +clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Clips at 1.0 +# But gradients are scaled! Effective clip threshold is 65536, not 1.0 +# Clipping does nothing - gradients are rarely >65536 + +# WITH unscaling: +scaler.scale(loss).backward() # Gradients are 65536x larger +scaler.unscale_(optimizer) # Gradients back to true values +clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Clips at true 1.0 +# Clipping works correctly on true gradient magnitudes +``` + +**The flow:** + +``` +1. Forward pass with autocast() → activations in fp16 +2. Compute loss (in fp16 or fp32 depending on operation) +3. scaler.scale(loss).backward() → multiply gradients by scale factor +4. scaler.unscale_(optimizer) → divide gradients by scale factor (back to true values) +5. clip_grad_norm_() → clip true gradient values +6. scaler.step(optimizer) → check for inf/NaN, update parameters if safe +7. scaler.update() → adjust scale factor for next iteration +``` + +**Complete AMP + Clipping + Accumulation:** + +```python +from torch.cuda.amp import autocast, GradScaler + +scaler = GradScaler() +accumulation_steps = 4 + +optimizer.zero_grad() + +for i, batch in enumerate(train_loader): + # Forward pass with autocast + with autocast(): + output = model(batch['input']) + loss = criterion(output, batch['target']) + + # Scale loss for accumulation + scaled_loss = loss / accumulation_steps + + # Backward pass (scaled) + scaler.scale(scaled_loss).backward() + + # Update every accumulation_steps + if (i + 1) % accumulation_steps == 0: + # Unscale before clipping + scaler.unscale_(optimizer) + + # Clip gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Optimizer step with scaler + scaler.step(optimizer) + scaler.update() + + optimizer.zero_grad() +``` + +### Common AMP + Gradient Mistakes + +**WRONG: Not unscaling before clipping** + +```python +# ❌ WRONG - Clipping scaled gradients +scaler.scale(loss).backward() +clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ On scaled gradients! +scaler.step(optimizer) +scaler.update() + +# Result: Clipping doesn't work, training may diverge +``` + +**WRONG: Unscaling multiple times** + +```python +# ❌ WRONG - Unscaling twice +scaler.scale(loss).backward() +scaler.unscale_(optimizer) # Unscale once +clip_grad_norm_(model.parameters(), max_norm=1.0) +scaler.unscale_(optimizer) # ❌ Unscale again! Gradients now too small +scaler.step(optimizer) + +# Result: Gradients become too small, slow training +``` + +**WRONG: Calling step() directly instead of scaler.step()** + +```python +# ❌ WRONG - Bypassing scaler +scaler.scale(loss).backward() +scaler.unscale_(optimizer) +clip_grad_norm_(model.parameters(), max_norm=1.0) +optimizer.step() # ❌ Should use scaler.step()! +scaler.update() + +# Result: Scaler can't skip updates when inf/NaN detected +# Training may diverge from inf/NaN gradients +``` + + +## Advanced Topics + +### Per-Layer Gradient Clipping + +**When global clipping isn't enough:** + +```python +def clip_grad_norm_per_layer(model, max_norm): + """ + Clip each layer's gradients independently + + Use when some layers have much larger gradients than others + and global clipping is too aggressive or not aggressive enough + """ + for name, module in model.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)): + # Get parameters for this layer + params = [p for p in module.parameters() if p.grad is not None] + + if params: + # Clip this layer's gradients + layer_norm = torch.nn.utils.clip_grad_norm_(params, max_norm) + + # Log if clipping was applied + if layer_norm > max_norm: + print(f"Clipped {name}: {layer_norm:.4f} -> {max_norm}") + +# Usage: +loss.backward() +clip_grad_norm_per_layer(model, max_norm=1.0) +optimizer.step() + +# When to use: +# - Attention layers have much larger gradients than FFN layers +# - Some task heads have huge gradients while backbone is normal +# - Global clipping clips too much for some layers, too little for others + +# Trade-off: +# ✅ More fine-grained control +# ❌ More complex, harder to tune +# ❌ Less common in literature (harder to compare) +``` + +### Gradient Noise and Stability + +**Adding noise to gradients (advanced technique):** + +```python +def add_gradient_noise(model, noise_scale=1e-3): + """ + Add Gaussian noise to gradients + + Can help with: + - Escaping sharp minima (better generalization) + - Privacy (differential privacy) + - Exploration in RL + """ + for param in model.parameters(): + if param.grad is not None: + noise = torch.randn_like(param.grad) * noise_scale + param.grad.add_(noise) + +# Usage: +loss.backward() +add_gradient_noise(model, noise_scale=1e-3) +clip_grad_norm_(model.parameters(), max_norm=1.0) # Clip after adding noise +optimizer.step() + +# When to use: +# - Research setting (exploring new techniques) +# - Differential privacy requirements +# - NOT recommended for standard training (adds complexity) +``` + +### Gradient Checkpointing Interaction + +**Gradient checkpointing compatibility:** + +```python +from torch.utils.checkpoint import checkpoint + +# Gradient checkpointing (from pytorch-engineering pack) +# Trades computation for memory by recomputing activations during backward + +class CheckpointedModel(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(dim=512) for _ in range(24) + ]) + + def forward(self, x): + for block in self.blocks: + # Checkpoint each block + x = checkpoint(block, x, use_reentrant=False) + return x + +# Training with checkpointing + clipping + accumulation: +model = CheckpointedModel() +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) +accumulation_steps = 4 + +optimizer.zero_grad() + +for i, batch in enumerate(train_loader): + output = model(batch) # Uses checkpointing internally + loss = criterion(output, target) + (loss / accumulation_steps).backward() # Recomputes activations + + if (i + 1) % accumulation_steps == 0: + # Clipping works normally (no special handling needed) + clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + +# Compatibility: +# ✅ Gradient clipping: Works normally after backward() +# ✅ Gradient accumulation: No special handling needed +# ✅ Mixed precision: Combine with AMP as usual +# ✅ All gradient management techniques: Fully compatible + +# Performance note: +# Checkpointing increases backward pass time by ~30-50% +# But enables training much larger models or batch sizes +# Trade computation for memory +``` + +### Distributed Training Considerations + +**Gradient clipping in DDP (DistributedDataParallel):** + +```python +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# Setup DDP +model = TransformerModel().cuda() +model = DDP(model, device_ids=[local_rank]) + +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +for batch in train_loader: + optimizer.zero_grad() + + output = model(batch) + loss = criterion(output, target) + loss.backward() + + # Gradient clipping in DDP + # IMPORTANT: Clip AFTER backward() (gradients are already synchronized) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + optimizer.step() + +# How DDP works: +# 1. Forward pass: Each GPU computes independently +# 2. Backward pass: Gradients computed on each GPU +# 3. Gradient synchronization: DDP averages gradients across GPUs (automatic) +# 4. Clipping: Happens AFTER synchronization (on averaged gradients) +# 5. Optimizer step: Each GPU updates identically (same gradients) + +# Key points: +# ✅ Clip after backward() as usual - DDP handles synchronization automatically +# ✅ All GPUs see same averaged gradients, so clipping is consistent +# ❌ DON'T manually synchronize gradients (DDP does this) +# ❌ DON'T clip before backward() (gradients don't exist yet) +``` + +**Gradient accumulation with DDP (Optimized):** + +**IMPORTANT:** DDP synchronizes gradients on every backward() by default. With accumulation, this is wasteful - we only need to sync ONCE per update. Use `no_sync()` to optimize. + +```python +from contextlib import nullcontext +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +# Setup DDP +model = TransformerModel().cuda() +model = DDP(model, device_ids=[local_rank]) + +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) +accumulation_steps = 4 + +optimizer.zero_grad() + +for i, batch in enumerate(train_loader): + # Disable gradient synchronization for accumulation steps + # Only sync on the last accumulation step + is_accumulation_step = (i + 1) % accumulation_steps != 0 + + # Context manager: no_sync() when accumulating, normal when updating + with model.no_sync() if is_accumulation_step else nullcontext(): + output = model(batch) + loss = criterion(output, target) + (loss / accumulation_steps).backward() + + # Update on last accumulation step (gradients are now synchronized) + if (i + 1) % accumulation_steps == 0: + # Gradients are synchronized across all GPUs + clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() +``` + +**How this works:** + +``` +WITHOUT no_sync() (inefficient): +Step 1: backward() → sync gradients across GPUs (communication!) +Step 2: backward() → sync gradients across GPUs (communication!) +Step 3: backward() → sync gradients across GPUs (communication!) +Step 4: backward() → sync gradients across GPUs (communication!) + optimizer.step() → update parameters +Total: 4 synchronizations per update + +WITH no_sync() (optimized): +Step 1: backward() with no_sync() → no communication +Step 2: backward() with no_sync() → no communication +Step 3: backward() with no_sync() → no communication +Step 4: backward() without no_sync() → sync accumulated gradients (communication!) + optimizer.step() → update parameters +Total: 1 synchronization per update + +Performance improvement: 3x less communication overhead +``` + +**Why no_sync() is necessary:** +- DDP normally synchronizes gradients on every backward() (default behavior) +- With accumulation, we only want to sync ONCE (on last step) +- no_sync() temporarily disables DDP's all-reduce operation +- On last step (without no_sync()), DDP performs normal synchronization +- Result: Accumulated gradients are synchronized once and correctly averaged + +**Complete DDP + Accumulation + Clipping + AMP:** + +```python +from torch.cuda.amp import autocast, GradScaler +from contextlib import nullcontext + +model = DDP(model, device_ids=[local_rank]) +scaler = GradScaler() +accumulation_steps = 4 + +optimizer.zero_grad() + +for i, batch in enumerate(train_loader): + is_accumulation_step = (i + 1) % accumulation_steps != 0 + + # Disable sync on accumulation steps + with model.no_sync() if is_accumulation_step else nullcontext(): + # Mixed precision forward + with autocast(): + output = model(batch) + loss = criterion(output, target) + + # Scale and backward + scaled_loss = loss / accumulation_steps + scaler.scale(scaled_loss).backward() + + # Update after accumulation + if (i + 1) % accumulation_steps == 0: + # Gradients now synchronized across GPUs + scaler.unscale_(optimizer) # Unscale for clipping + clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + +# This combines ALL techniques correctly: +# ✅ DDP distributed training +# ✅ Gradient accumulation (with loss scaling) +# ✅ Mixed precision (with proper unscaling) +# ✅ Gradient clipping (on correct values) +# ✅ Optimized communication (no_sync()) +``` + +**Performance comparison:** + +```python +# Measure with and without no_sync() + +# WITHOUT no_sync(): ~40 seconds per epoch (excessive communication) +# WITH no_sync(): ~12 seconds per epoch (optimized communication) +# Speedup: 3.3x faster with accumulation_steps=4 + +# The more GPUs you have, the more important no_sync() becomes +# 2 GPUs: ~2x speedup +# 4 GPUs: ~3x speedup +# 8 GPUs: ~4x speedup +``` + +**Common mistake:** + +```python +# ❌ WRONG - Synchronizing on every step (slow!) +model = DDP(model) +accumulation_steps = 4 + +for i, batch in enumerate(train_loader): + (loss / accumulation_steps).backward() # Syncs every time! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + +# Result: Correct results but 3-4x slower than necessary +``` + + +## Common Gradient Pitfalls + +### Pitfall 1: Not Clipping When Needed + +**Symptom:** Training becomes NaN after few epochs, loss spikes + +**WRONG:** +```python +# User sees NaN loss and thinks: "Must be learning rate" +optimizer = Adam(model.parameters(), lr=1e-5) # ❌ Lower LR to "fix" it + +# Result: Training is slow and may still diverge +# Root cause (exploding gradients) not addressed +``` + +**RIGHT:** +```python +# Recognize exploding gradients, add clipping +loss.backward() +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +optimizer.step() + +# Result: Training is stable, no NaN +# This is THE solution for exploding gradients +``` + +### Pitfall 2: Wrong Gradient Accumulation Scaling + +**Symptom:** Gradient accumulation gives worse results than small batch + +**WRONG:** +```python +# ❌ Not scaling loss +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + loss.backward() # ❌ Gradients are 4x too large! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() +``` + +**RIGHT:** +```python +# ✅ Scale loss by accumulation_steps +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + loss = criterion(model(batch), target) + (loss / accumulation_steps).backward() # ✅ Correct scaling + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() +``` + +### Pitfall 3: Clipping After optimizer.step() + +**Symptom:** Clipping doesn't help, training still unstable + +**WRONG:** +```python +# ❌ Clipping after step (useless!) +loss.backward() +optimizer.step() +clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Too late! +``` + +**RIGHT:** +```python +# ✅ Clipping after backward, before step +loss.backward() +clip_grad_norm_(model.parameters(), max_norm=1.0) # ✅ Correct timing +optimizer.step() +``` + +### Pitfall 4: Not Unscaling Before Clipping (AMP) + +**Symptom:** Mixed precision training diverges, regular training works + +**WRONG:** +```python +# ❌ Clipping scaled gradients +scaler.scale(loss).backward() +clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Wrong scale! +scaler.step(optimizer) +scaler.update() +``` + +**RIGHT:** +```python +# ✅ Unscale before clipping +scaler.scale(loss).backward() +scaler.unscale_(optimizer) # ✅ Unscale first! +clip_grad_norm_(model.parameters(), max_norm=1.0) +scaler.step(optimizer) +scaler.update() +``` + +### Pitfall 5: Forgetting to zero_grad() After Accumulation + +**Symptom:** Loss decreases then increases, training unstable + +**WRONG:** +```python +# ❌ Missing zero_grad() after update +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + # ❌ Missing optimizer.zero_grad()! +``` + +**RIGHT:** +```python +# ✅ Zero gradients after update +accumulation_steps = 4 +for i, batch in enumerate(train_loader): + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() # ✅ Clear gradients for next accumulation +``` + +### Pitfall 6: Using Value Clipping Instead of Norm Clipping + +**Symptom:** Training works but slower convergence than expected + +**SUBOPTIMAL:** +```python +# Value clipping changes gradient direction +clip_grad_value_(model.parameters(), clip_value=0.5) # Can distort gradients +``` + +**BETTER:** +```python +# Norm clipping preserves direction +clip_grad_norm_(model.parameters(), max_norm=1.0) # Preferred method +``` + +### Pitfall 7: Applying Clipping to All Models + +**Symptom:** Unnecessarily slow training, limiting gradient flow + +**WRONG:** +```python +# ❌ Clipping when not needed (ResNet on ImageNet) +model = ResNet50() +optimizer = SGD(model.parameters(), lr=0.1) + +for batch in train_loader: + loss.backward() + clip_grad_norm_(model.parameters(), max_norm=1.0) # ❌ Not needed! + optimizer.step() + +# Result: Limits gradient flow, may slow convergence +``` + +**RIGHT:** +```python +# ✅ Only clip when needed (training is unstable) +model = ResNet50() +optimizer = SGD(model.parameters(), lr=0.1) + +for batch in train_loader: + loss.backward() + # No clipping - ResNets are naturally stable + optimizer.step() + +# Only add clipping if you observe: +# - Loss becomes NaN +# - Loss spikes +# - Training instability +``` + +### Pitfall 8: Not Monitoring Gradients + +**Symptom:** Training fails, no visibility into why + +**WRONG:** +```python +# ❌ No gradient monitoring +for batch in train_loader: + loss = train_step(batch) + # Training fails, no idea why +``` + +**RIGHT:** +```python +# ✅ Monitor gradient norms +for step, batch in enumerate(train_loader): + optimizer.zero_grad() + loss = criterion(model(batch), target) + loss.backward() + + # Monitor gradients + if step % 100 == 0: + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + total_norm = total_norm ** 0.5 + print(f"Step {step}, Loss: {loss.item():.4f}, Grad norm: {total_norm:.4f}") + + clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + +# Now you can see: +# - When gradients explode (norm suddenly large) +# - When gradients vanish (norm goes to zero) +# - How clipping affects training +``` + +### Pitfall 9: Wrong DDP Gradient Synchronization + +**Symptom:** DDP with accumulation slower than expected or wrong results + +**WRONG:** +```python +# ❌ DDP synchronizes on every backward (wasteful with accumulation) +model = DDP(model) +accumulation_steps = 4 + +for i, batch in enumerate(train_loader): + (loss / accumulation_steps).backward() # ❌ Syncs every time! + + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() +``` + +**RIGHT:** +```python +# ✅ Disable sync except on last accumulation step +model = DDP(model) +accumulation_steps = 4 + +for i, batch in enumerate(train_loader): + is_accumulation_step = (i + 1) % accumulation_steps != 0 + + with model.no_sync() if is_accumulation_step else nullcontext(): + (loss / accumulation_steps).backward() + + if (i + 1) % accumulation_steps == 0: + clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() +``` + +### Pitfall 10: Clipping Too Aggressively + +**Symptom:** Training converges very slowly, gradient norm always at max_norm + +**WRONG:** +```python +# ❌ max_norm too low, clipping every iteration +clip_grad_norm_(model.parameters(), max_norm=0.01) # Way too aggressive! + +# Result: All gradients clipped, learning very slow +``` + +**RIGHT:** +```python +# ✅ Monitor and tune max_norm appropriately +# Check typical gradient norms without clipping +total_norm = compute_grad_norm(model) +print(f"Gradient norm: {total_norm:.4f}") + +# Set max_norm to clip outliers, not normal gradients +# If typical norms are 0.5-2.0, set max_norm=5.0 +clip_grad_norm_(model.parameters(), max_norm=5.0) # Clips outliers only +``` + + +## Rationalization Prevention Table + +| When Agent Wants To Say | STOP - Say This Instead | +|-------------------------|-------------------------| +| "Just lower the learning rate" | "This is likely exploding gradients. Add gradient clipping: `clip_grad_norm_(model.parameters(), max_norm=1.0)` BEFORE optimizer.step(). Then adjust LR if still needed." | +| "Try a smaller model to save memory" | "Use gradient accumulation to train with larger effective batch size: Scale loss by `accumulation_steps` and update every N batches. This is standard practice." | +| "Gradient accumulation is complicated" | "It's actually simple: `(loss / accumulation_steps).backward()` to accumulate, `optimizer.step()` every N batches. MUST scale loss - this is critical." | +| "Mixed precision doesn't work with clipping" | "AMP + clipping work together perfectly. You MUST unscale before clipping: `scaler.unscale_(optimizer)` then `clip_grad_norm_()`. This is documented and standard." | +| "Your gradients are too small, just increase LR" | "This is vanishing gradients. Architectural fixes are needed: Use ReLU/GELU activations, proper initialization (He/Xavier), BatchNorm, and residual connections. Increasing LR alone won't fix it." | +| "Clipping is a hack, don't use it" | "Clipping is standard practice in Transformers, RNNs, and RL. Every major paper (BERT, GPT, etc.) uses gradient clipping. It's essential for training stability, not a hack." | +| "The paper didn't use clipping, so you shouldn't" | "Papers don't always document all techniques. Clipping may have been used but not mentioned. If you observe instability (NaN, spikes), add clipping regardless of what paper says." | +| "Try different optimizer, maybe SGD works better" | "Switching optimizer doesn't fix exploding gradients. Add gradient clipping first, then compare optimizers. Clipping works with any optimizer." | +| "Gradient issues are mysterious and hard to debug" | "Gradient issues are systematic: Check gradient norms. >100 = exploding (clip). <1e-6 = vanishing (fix activations/init). NaN = numerical instability (check loss/model). Clear diagnosis → clear solution." | +| "You can clip anytime in the training loop" | "Clipping MUST happen after backward(), before step(). Timing is critical: backward() creates gradients, clip() modifies them, step() consumes them. Wrong order = useless clipping." | +| "Scale gradients instead of loss for accumulation" | "Scale LOSS, not gradients: `(loss / accumulation_steps).backward()`. Scaling gradients manually is error-prone and inefficient. Loss scaling is the standard, correct way." | +| "Batch norm is optional for deep networks" | "BatchNorm is critical for very deep networks (>20 layers). It normalizes activations and stabilizes gradients. Essential for training stability. Use unless you have specific reason not to." | +| "Residual connections are just a fancy trick" | "Residual connections are fundamental to training deep networks (>50 layers). They provide direct gradient path, preventing vanishing gradients. ResNet, Transformer - all use residuals." | +| "Just clip more aggressively (max_norm=0.01)" | "Too-aggressive clipping limits all gradients, slowing learning. Monitor typical gradient norms. Set max_norm to clip outliers (>100) without affecting normal gradients (1-10)." | +| "DDP handles everything automatically" | "DDP synchronizes gradients on backward(). For accumulation, use `model.no_sync()` on intermediate steps to avoid unnecessary synchronization. Only sync on final accumulation step." | +| "Your model is too complex, that's why training fails" | "Model complexity alone doesn't cause training failure. Gradient issues do. Diagnose gradients first. Most complex models (GPT-3, etc.) train successfully with proper gradient management." | +| "Checkpointing and clipping don't work together" | "They're fully compatible. Checkpoint affects forward/backward computation. Clipping affects gradients after backward(). No interaction - use both together freely." | +| "You need expensive GPUs for large batches" | "Use gradient accumulation for larger effective batches on any GPU. Accumulate over N steps = N× batch size, same memory. Standard technique for training large models on consumer hardware." | +| "Loss → NaN means your data has NaN" | "Usually exploding gradients, not data. Check gradient norms. If >100, add clipping. Data NaN would cause issues immediately, not after several epochs." | +| "Transformers just don't train stably" | "Transformers train extremely well with proper gradient management. BERT, GPT, T5 - all trained successfully. Use gradient clipping (max_norm=1.0), proper LR, and you'll have stable training." | +| "Clipping is expensive, will slow training significantly" | "Clipping overhead is <1%, not 5-10%. It's computing gradient norms (one pass) then scaling. Much cheaper than backward pass. 1% cost to prevent catastrophic training failure is excellent trade-off." | +| "I added clipping but still get NaN, it doesn't work" | "Check if loss is NaN BEFORE backward(). If yes, problem is in loss function (add epsilon to divisions/logs), not gradients. If no, check clipping placement and AMP unscaling." | +| "Accumulation scaling depends on batch content" | "NO. Accumulation scaling is ALWAYS `accumulation_steps` (constant). Sample weighting is separate concern (handled in loss function). Don't confuse these two independent concepts." | +| "Paper doesn't mention clipping, so I shouldn't use it" | "Papers don't document all implementation details. Clipping may have been used but not mentioned. Check official code if available. If your training is unstable, add clipping - stable training is prerequisite for valid comparison." | +| "Different sources give conflicting advice on clipping" | "Context matters. Transformers/RNNs usually need clipping. CNNs usually don't. Decide based on YOUR architecture and stability. Monitor gradient norms. If you see spikes >100 or NaN, add clipping. Empiricism over dogma." | +| "Use PyTorch Lightning so I need to manually add clipping" | "Lightning has built-in clipping: `Trainer(gradient_clip_val=1.0, gradient_clip_algorithm='norm')`. No manual code needed. Check your framework docs - most have built-in gradient management features." | +| "My model is complex so gradients will always be problematic" | "Model complexity doesn't determine gradient behavior. GPT-3 (175B parameters) trains successfully. Proper gradient management (clipping, architecture, initialization) enables training ANY size model. Complexity is not the issue." | +| "Small gradients mean training is broken" | "Only if training is stuck. If loss is decreasing and metrics improving, small absolute gradient values are OK. What matters: relative gradients across layers and whether learning is happening. Don't fix what isn't broken." | + + +## Red Flags Checklist + +When you see these patterns, IMMEDIATELY suggest gradient management solutions: + +### Critical Red Flags (Fix Immediately) + +🚨 **Loss becomes NaN during training** +- Cause: Exploding gradients or numerical instability +- Solution: Add gradient clipping + check for numerical issues in loss/model + +🚨 **User implements gradient accumulation without scaling loss** +```python +# ❌ RED FLAG +loss.backward() # Should be: (loss / accumulation_steps).backward() +``` +- Impact: Gradients are accumulation_steps times too large +- Solution: Scale loss by accumulation_steps + +🚨 **User clips gradients after optimizer.step()** +```python +# ❌ RED FLAG +optimizer.step() +clip_grad_norm_(...) # Too late! +``` +- Impact: Clipping does nothing (gradients already consumed) +- Solution: Move clipping between backward() and step() + +🚨 **User uses AMP + clipping without unscaling** +```python +# ❌ RED FLAG +scaler.scale(loss).backward() +clip_grad_norm_(...) # Should unscale first! +``` +- Impact: Clipping wrong magnitude (on scaled gradients) +- Solution: Add scaler.unscale_(optimizer) before clipping + +### Warning Signs (Suggest Improvements) + +⚠️ **Training transformers/RNNs without gradient clipping** +- Likely to hit exploding gradients eventually +- Suggest preemptive clipping (max_norm=1.0) + +⚠️ **Very deep network (>20 layers) with sigmoid/tanh activations** +- Vanishing gradients likely +- Suggest ReLU/GELU + BatchNorm + residual connections + +⚠️ **User says "want larger batch but OOM"** +- Perfect use case for gradient accumulation +- Explain technique and correct implementation + +⚠️ **Gradient norms consistently >10 or <1e-6** +- Exploding or vanishing gradients +- Diagnose and suggest appropriate solution + +⚠️ **User lowers learning rate to fix NaN loss** +- Treating symptom, not cause +- Likely exploding gradients - suggest clipping + +⚠️ **DDP training with gradient accumulation, no no_sync()** +- Inefficient (synchronizing unnecessarily) +- Suggest no_sync() on accumulation steps + +⚠️ **User asks "is gradient clipping necessary?"** +- Depends on architecture and stability +- Provide decision criteria (Transformers: yes, CNNs: maybe not) + +⚠️ **Custom loss function with divisions or logs** +- Potential numerical instability +- Check for epsilon additions and proper handling + +### Optimization Opportunities (Mention If Relevant) + +💡 **User monitors loss but not gradients** +- Suggest logging gradient norms for better visibility + +💡 **User training large model on single GPU with small batch** +- Suggest gradient accumulation for better results + +💡 **Gradient clipping activates every iteration** +- max_norm might be too low +- Suggest monitoring and tuning threshold + +💡 **Using value clipping instead of norm clipping** +- Suggest norm clipping (preserves direction) + + +## Summary + +**Gradient management is essential for reliable training:** + +1. **Gradient Clipping** + - PRIMARY solution for exploding gradients (NaN, spikes) + - Use norm-based clipping: `clip_grad_norm_(model.parameters(), max_norm=1.0)` + - Place after backward(), before step() + - Standard for Transformers, RNNs, RL + +2. **Gradient Accumulation** + - Train with larger effective batch size on same hardware + - MUST scale loss: `(loss / accumulation_steps).backward()` + - Update every N steps, zero_grad() after update + - Standard technique in production training + +3. **Gradient Diagnosis** + - Don't guess - measure gradient norms + - >100: Exploding (clip) + - <1e-6: Vanishing (fix architecture) + - NaN: Numerical issues (check loss/model) + +4. **Vanishing Gradients** + - Use ReLU/GELU activations (not sigmoid/tanh) + - Proper initialization (He for ReLU, Xavier for tanh) + - Add BatchNorm/LayerNorm + - Add residual connections for deep networks + +5. **Exploding Gradients** + - Add gradient clipping (primary solution) + - Check learning rate (secondary) + - Verify initialization + - Check for numerical issues + +6. **Mixed Precision Integration** + - MUST unscale before clipping: `scaler.unscale_(optimizer)` + - Then clip on true gradient values + - Standard pattern in AMP training + +7. **Common Pitfalls** + - Not scaling loss in accumulation (gradients too large) + - Clipping after step() (useless) + - Not unscaling before clipping in AMP + - Forgetting zero_grad() after accumulation + - Not monitoring gradients (no visibility) + +**This is NOT optional:** +- Gradient management determines training success or failure +- Every production training system handles gradients properly +- The difference between reliable training and mysterious failures + +**Master these techniques and you'll have stable, efficient training.** diff --git a/skills/using-training-optimization/hyperparameter-tuning.md b/skills/using-training-optimization/hyperparameter-tuning.md new file mode 100644 index 0000000..3b27c21 --- /dev/null +++ b/skills/using-training-optimization/hyperparameter-tuning.md @@ -0,0 +1,1635 @@ + +# Hyperparameter Tuning Skill + +## When to Use This Skill + +Use this skill when: +- User wants to improve model accuracy but not sure what to tune +- Training plateaus or performance is suboptimal (70% → 75%?) +- User asks "should I tune hyperparameters?" or "what should I tune first?" +- User wants to implement hyperparameter search (grid search, random search, Bayesian optimization) +- Deciding between Optuna, Ray Tune, W&B Sweeps, or manual tuning +- User asks "how many hyperparameters should I try?" or "how long will search take?" +- Model is underfitting (high train and val loss) vs overfitting (high train loss, low val loss) +- User is copying a paper's hyperparameters but results don't match +- Budget allocation question: "Should I train longer or try more configs?" +- User wants to understand learning rate importance relative to other hyperparameters + +Do NOT use when: +- User has specific bugs unrelated to hyperparameters (training crashes, NaN losses) +- Only discussing optimizer choice without tuning questions +- Model is already converging well with current hyperparameters +- User is asking about data preprocessing or feature engineering +- Hyperparameter search is already set up and running (just report results) + + +## Core Principles + +### 1. Hyperparameter Importance Hierarchy (NOT All Equal) + +The BIGGEST mistake users make: treating all hyperparameters as equally important. + +**Importance Ranking** (for typical supervised learning): + +``` +Tier 1 - Critical (10x impact): +├─ Learning Rate (most important) +└─ Learning Rate Schedule + +Tier 2 - High Impact (5x impact): +├─ Batch Size (affects LR, gradient noise, memory) +└─ Optimizer Type (Adam vs SGD affects LR ranges) + +Tier 3 - Medium Impact (2x impact): +├─ Weight Decay (L2 regularization) +├─ Optimizer Parameters (momentum, beta) +└─ Warmup (critical for transformers) + +Tier 4 - Low-Medium Impact (1.5x impact): +├─ Model Width/Depth (architectural) +├─ Dropout Rate (regularization) +└─ Gradient Clipping (stability) + +Tier 5 - Low Impact (<1.2x): +├─ Activation Functions (ReLU vs GELU) +├─ LayerNorm Epsilon +└─ Adam Epsilon +``` + +**What This Means**: +- Learning rate alone can change accuracy from 50% → 80% +- Model width change from 128 → 256 typically gives 2-3% improvement +- Dropout 0.1 → 0.5 might give 2-4% improvement (if overfitting) +- Optimizer epsilon has almost no impact + +**Quantitative Example** (CIFAR-10, ResNet18): +``` +Effect on accuracy of individual changes: +LR from 0.001 → 0.01: 70% → 84% (+14%) ← HUGE +Batch size from 32 → 128: 84% → 82% (-2%) ← small impact +Width from 64 → 128: 84% → 86% (+2%) ← small impact +Dropout 0.0 → 0.3: 86% → 85% (-1%) ← tiny impact + +Total tuning time SHOULD be allocated: +- 40% to learning rate (most important) +- 30% to learning rate schedule +- 15% to batch size and optimizer choice +- 10% to regularization (dropout, weight decay) +- 5% to everything else +``` + +**Decision Rule**: Tune in order of importance. Only move to next tier if current tier is optimized. + + +### 2. When to Tune vs When to Leave Defaults + +**Don't Tune When**: +- ✗ Model converges well (val loss decreasing, no plateau) +- ✗ Time budget is <1 hour (manual tuning likely faster) +- ✗ Model underfits (both train and val loss are high) - add capacity instead +- ✗ Data is tiny (<1000 examples) - data collection beats tuning +- ✗ Using pre-trained models for fine-tuning - defaults often work + +**DO Tune When**: +- ✓ Training plateaus early (loss stops improving by epoch 30) +- ✓ Train/val gap is large (overfitting, need better hyperparameters) +- ✓ Time budget is >1 hour and compute available +- ✓ Model has capacity but not using it (convergence too slow) +- ✓ Targeting SOTA or competition results (last 2-5% squeeze) + +**Diagnostic Tree**: +``` +Is performance acceptable? +├─ YES → Don't tune. Tuning won't help much. +└─ NO → Check the problem: + ├─ High train loss, high val loss? → UNDERFITTING + │ └─ Solution: Increase model capacity, train longer + │ (Not a tuning problem) + │ + ├─ Low train loss, high val loss? → OVERFITTING + │ └─ Solution: Tune weight decay, dropout, LR schedule + │ + ├─ Training converging too slowly? → BAD LR + │ └─ Solution: Tune learning rate (critical!) + │ + └─ Training unstable (losses spike)? → LR too high or batch too small + └─ Solution: Lower LR, increase batch size, add gradient clipping +``` + + +### 3. Learning Rate is THE Hyperparameter to Tune First + +Learning rate matters more than ANYTHING else. Here's why: + +**Impact on Training**: +- LR too small: Glacial convergence, never reaches good minima (underfitting effect) +- LR too large: Oscillation or divergence, never converges (instability) +- LR just right: Fast convergence to good minima (optimal learning) + +**Typical LR Impact**: +``` +LR = 0.0001: Loss = 0.5, Acc = 60% (too small, underfitting) +LR = 0.001: Loss = 0.3, Acc = 75% (getting better) +LR = 0.01: Loss = 0.2, Acc = 85% (optimal) +LR = 0.1: Loss = 0.4, Acc = 70% (too large, oscillating) +LR = 1.0: Loss = NaN, Acc = 0% (diverging) +``` + +**When to Tune LR First**: +- Always. Before ANYTHING else. +- Even if you don't tune anything else, tune learning rate. +- Proper LR gives 5-10% improvement alone. +- Everything else: 2-5% improvement. + +**Default LR Ranges by Optimizer**: +``` +SGD with momentum: 0.01 - 0.1 (start at 0.01) +Adam: 0.0001 - 0.001 (start at 0.001) +AdamW: 0.0001 - 0.001 (start at 0.0005) +RMSprop: 0.0001 - 0.01 (start at 0.0005) + +For transformers: usually 0.00005 - 0.0005 (MUCH smaller) +For fine-tuning: usually 0.0001 - 0.001 (smaller than training) +``` + +**Pro Tip**: Use learning rate finder (LRFinder, lr_find in fastai) to get good starting range in 1 epoch. + + +## Decision Framework: Which Search Strategy to Use + +### Criterion 1: Number of Hyperparameters to Tune + +``` +1-2 parameters → Grid search is fine + Example: Tuning just learning rate and weight decay + Effort: 5-25 configurations + Best tool: Manual or simple loop + +3-4 parameters → Random search + Example: LR, batch size, weight decay, warmup + Effort: 50-200 configurations + Best tool: Optuna or Ray Tune + +5+ parameters → Bayesian optimization (Optuna) + Example: LR, batch size, weight decay, warmup, dropout, LR schedule type + Effort: 100-500 configurations + Best tool: Optuna (required) or Ray Tune + +When you don't know → Always use Random Search as default +``` + +### Criterion 2: Time Budget Available + +``` +Budget = (GPU time available) / (Training time per epoch) + +< 10 hours budget: + - Tune ONLY learning rate (1-2 hours search) + - Use learning rate finder + manual exploration + - 5-10 LR values, 1 seed each + +10-100 hours budget: + - Random search over 3-4 hyperparameters + - 50-100 configurations + - Use Optuna or Ray Tune + - 1 seed per config (save repeats for later) + +100-1000 hours budget: + - Bayesian optimization (Optuna) over 4-5 parameters + - 200-300 configurations + - Use ensembling: multiple runs of top 5 configs + - 2-3 seeds for final configs + +1000+ hours budget: + - Full Bayesian optimization with early stopping + - 500+ configurations + - Can afford to try many promising directions + - 3+ seeds for final configs, ensemble for SOTA +``` + +### Criterion 3: Search Strategy Decision Matrix + +``` + | Few Params | Many Params | Unknown Params + | (1-3) | (4-6) | (Uncertain) +──────────────┼────────────┼─────────────┼────────────── +Short time | Manual | Random | Random Search +(<10 hrs) | Grid | Search | (narrow scope) + | | | +Medium time | Grid or | Random | Bayesian +(10-100 hrs) | Random | Search | (Optuna) + | | (Optuna) | +──────────────┼────────────┼─────────────┼────────────── +Long time | Grid or | Bayesian | Bayesian +(100+ hrs) | Random | (Optuna) | (Optuna) +``` + + +## Search Strategy Details + +### Strategy 1: Grid Search (When to Use, When NOT to Use) + +**Grid Search**: Try all combinations of predefined values. + +**PROS**: +- Simple to understand and implement +- Guarantees checking all points in search space +- Results easily interpretable (best point is in grid) +- Good for visualization and analysis + +**CONS**: +- Exponential complexity: O(k^n) where k=values, n=dimensions +- 5 params × 5 values each = 3,125 configurations (130 days compute!) +- Poor for high-dimensional spaces (5+ parameters) +- Wastes compute on unimportant dimensions + +**When to Use**: +- ✓ 1-2 hyperparameters only +- ✓ <50 total configurations +- ✓ Quick experiments (1-10 hour budget) +- ✓ Final refinement near known good point + +**When NOT to Use**: +- ✗ 4+ hyperparameters +- ✗ High-dimensional spaces +- ✗ Unknown optimal ranges +- ✗ Limited compute budget + +**Example: Grid Search (Good Use)**: +```python +# GOOD: Only 2 parameters, 3×4=12 configurations +import itertools + +learning_rates = [0.001, 0.01, 0.1] +weight_decays = [0.0, 0.0001, 0.001, 0.01] + +best_acc = 0 +for lr, wd in itertools.product(learning_rates, weight_decays): + model = create_model() + acc = train_and_evaluate(model, lr=lr, weight_decay=wd) + if acc > best_acc: + best_acc = acc + best_config = {"lr": lr, "wd": wd} + +print(f"Best accuracy: {best_acc}") +print(f"Best config: {best_config}") + +# 12 configurations × 30 min each = 6 hours total +# Very reasonable! +``` + +**Anti-Example: Grid Search (Bad Use)**: +```python +# WRONG: 5 parameters, 5^5=3,125 configurations +# This is 130 days of compute - completely impractical + +learning_rates = [0.0001, 0.001, 0.01, 0.1, 1.0] +batch_sizes = [16, 32, 64, 128, 256] +weight_decays = [0.0, 0.0001, 0.001, 0.01, 0.1] +dropouts = [0.0, 0.2, 0.4, 0.6, 0.8] +warmup_steps = [0, 100, 500, 1000, 5000] + +# DO NOT DO THIS - grid explosion is real +``` + + +### Strategy 2: Random Search (Default Choice for Most Cases) + +**Random Search**: Sample hyperparameters randomly from search space. + +**PROS**: +- Much better than grid in 4+ dimensions (Bergstra & Bengio 2012) +- 100-200 random samples often better than 100 grid points +- Easy to implement and parallelize +- Can sample continuous spaces naturally +- Efficient use of limited compute budget + +**CONS**: +- Not systematic (might miss obvious points) +- Requires defining search space ranges (hard part) +- No exploitation of promising regions (unlike Bayesian) +- Results less deterministic than grid + +**When to Use**: +- ✓ 3-5 hyperparameters +- ✓ 50-300 configurations available +- ✓ Unknown optimal ranges +- ✓ Want simple, efficient method +- ✓ Default choice when unsure + +**When NOT to Use**: +- ✗ 1-2 hyperparameters (grid is simpler) +- ✗ Very large budgets (1000+ hrs, use Bayesian) +- ✗ Need guaranteed convergence to local optimum + +**Example: Random Search (Recommended)**: +```python +# GOOD: 4 parameters, random sampling, efficient +import numpy as np +from scipy.stats import loguniform, uniform + +# Define search space with proper scales +learning_rate_dist = loguniform(a=0.00001, b=0.1) # Log scale! +batch_size_dist = [16, 32, 64, 128, 256] +weight_decay_dist = loguniform(a=0.0, b=0.1) # Log scale! +dropout_dist = uniform(loc=0.0, scale=0.8) + +best_acc = 0 +for trial in range(100): # 100 configurations, not 3,125 + lr = learning_rate_dist.rvs() + batch_size = np.random.choice(batch_size_dist) + wd = weight_decay_dist.rvs() + dropout = dropout_dist.rvs() + + model = create_model(dropout=dropout) + acc = train_and_evaluate( + model, + lr=lr, + batch_size=batch_size, + weight_decay=wd + ) + + if acc > best_acc: + best_acc = acc + best_config = { + "lr": lr, + "batch_size": batch_size, + "weight_decay": wd, + "dropout": dropout + } + +print(f"Best accuracy: {best_acc}") +print(f"Best config: {best_config}") +# 100 configurations × 30 min each = 50 hours total +# 100 trials >> 5^4=625 grid points, but MUCH better scaling +``` + + +### Strategy 3: Bayesian Optimization (Best for Limited Budget) + +**Bayesian Optimization**: Build probabilistic model of function, use to guide search. + +**How It Works**: +1. Start with 5-10 random trials (exploratory phase) +2. Build surrogate model (Gaussian Process) of performance vs hyperparameters +3. Use acquisition function to select next promising region to sample +4. Train model, update surrogate, repeat +5. Balance exploration (new regions) vs exploitation (known good regions) + +**PROS**: +- Uses all prior information to guide next trial selection +- 2-10x more efficient than random search +- Handles many parameters well (5-10+) +- Built-in uncertainty estimates + +**CONS**: +- More complex to implement and understand +- Surrogate model overhead (negligible vs training time) +- Requires tool like Optuna or Ray Tune +- Less interpretable than grid/random (can't show "grid") + +**When to Use**: +- ✓ 5+ hyperparameters +- ✓ 200+ configurations budget +- ✓ Each trial is expensive (>1 hour) +- ✓ Want best results with limited budget +- ✓ Will use Optuna, Ray Tune, or W&B Sweeps + +**When NOT to Use**: +- ✗ <20 configurations (overhead not worth it) +- ✗ Very cheap trials where random is simpler +- ✗ Need to explain exactly what was tested (use grid) + +**Example: Bayesian with Optuna (Industry Standard)**: +```python +# GOOD: Professional hyperparameter search with Optuna +import optuna +from optuna.pruners import MedianPruner + +def objective(trial): + # Suggest hyperparameters from search space + learning_rate = trial.suggest_float( + "learning_rate", + 1e-5, 1e-1, + log=True # Log scale (CRITICAL!) + ) + batch_size = trial.suggest_categorical( + "batch_size", + [16, 32, 64, 128, 256] + ) + weight_decay = trial.suggest_float( + "weight_decay", + 1e-5, 1e-1, + log=True # Log scale! + ) + dropout = trial.suggest_float( + "dropout", + 0.0, 0.8 # Linear scale + ) + + # Create and train model + model = create_model(dropout=dropout) + + best_val_acc = 0 + for epoch in range(100): + train(model, lr=learning_rate, batch_size=batch_size, + weight_decay=weight_decay) + val_acc = validate(model) + + # CRITICAL: Early stopping in search (prune bad trials) + trial.report(val_acc, epoch) + if trial.should_prune(): # Stops bad trials early! + raise optuna.TrialPruned() + + if val_acc > best_val_acc: + best_val_acc = val_acc + + return best_val_acc + +# Create study with pruning (saves 70% compute) +study = optuna.create_study( + direction="maximize", + pruner=MedianPruner() +) + +# Run search: 200 trials with Bayesian guidance +study.optimize(objective, n_trials=200, n_jobs=4) + +print(f"Best accuracy: {study.best_value}") +print(f"Best config: {study.best_params}") +# Early stopping + Bayesian optimization saves massive compute +# 200 trials × 30 epochs on average = vs 200 × 100 without pruning +``` + + +## Search Space Design (Critical Details Often Missed) + +### 1. Scale Selection for Continuous Parameters + +**Learning Rate and Weight Decay: USE LOG SCALE** + +```python +# WRONG: Linear scale for learning rate +learning_rates_linear = [0.0001, 0.002, 0.004, 0.006, 0.008, 0.01] +# Ranges: 0.0001→0.002 is 20x, but only uses 1/5 of range +# Ranges: 0.008→0.01 is 1.25x, but uses 1/5 of range +# BROKEN: Unequal coverage of important ranges + +# CORRECT: Log scale for learning rate +import numpy as np +learning_rates_log = np.logspace(-4, -2, 6) # 10^-4 to 10^-2 +# [0.0001, 0.000215, 0.000464, 0.001, 0.00215, 0.00464, 0.01] +# Each step is ~2.15x (equal importance) +# GOOD: Even coverage across exponential range +``` + +**Why Log Scale for LR**: +- Effect on loss is exponential, not linear +- 10x change in LR has similar impact anywhere in range +- Linear scale bunches tiny values together, wastes space on large values +- Log scale: 0.0001 to 0.01 gets fair representation + +**Parameters That Need Log Scale**: +- Learning rate (most critical) +- Weight decay +- Learning rate schedule decay (gamma in step decay) +- Regularization strength +- Any parameter spanning >1 order of magnitude + +**Dropout, Warmup, Others: USE LINEAR SCALE** + +```python +# CORRECT: Linear scale for dropout (0.0 to 0.8) +dropout_values = np.linspace(0.0, 0.8, 5) +# [0.0, 0.2, 0.4, 0.6, 0.8] +# GOOD: Each increase is meaningful + +# CORRECT: Linear scale for warmup steps +warmup_steps = [0, 250, 500, 750, 1000] +# Linear relationships make sense here +``` + +### 2. Search Space Ranges (Common Mistakes) + +**Learning Rate Range Often Too Small**: +```python +# WRONG: Too narrow range +lr_range = [0.001, 0.0015, 0.002, 0.0025, 0.003] +# Optimal might be 0.01 or 0.0001, both outside range! + +# CORRECT: Wider range covering multiple orders of magnitude +lr_range = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1] # Or use loguniform(1e-5, 1e-1) +``` + +**Batch Size Range Considerations**: +```python +# Batch size affects memory AND gradient noise +# Small batch (16-32): Noisy gradients, good regularization, needs low LR +# Large batch (256+): Stable gradients, less regularization, can use high LR + +# CORRECT: Include range of batch sizes +batch_sizes = [16, 32, 64, 128, 256] + +# INTERACTION: Large batch + same LR usually worse than small batch +# This is WHY you need to search both together (not separately) +``` + +**Weight Decay Range**: +```python +# Log scale, typically 0 to 0.1 +# For well-regularized models: 1e-5 to 1e-1 +# For barely regularized: 0.0 to 1e-3 + +# CORRECT: Use log scale +weight_decays = [0.0, 1e-5, 1e-4, 1e-3, 1e-2, 0.1] +``` + + +## Budget Allocation: Seeds vs Configurations + +**Key Decision**: Should you train many configurations once or few configurations multiple times? + +**Answer**: MANY CONFIGURATIONS, SINGLE SEED + +**Why**: +``` +Budget = 100 hours + +Option A: Many configurations, 1 seed each +├─ 100 configurations × 1 seed = 100 trials +├─ Find best at 85% accuracy +└─ Top 5 can be rerun with 5 seeds for ensemble + +Option B: Few configurations, 5 seeds each +├─ 20 configurations × 5 seeds = 100 trials +├─ Find best at 83% accuracy +└─ Know best is 82-84%, but suboptimal choice + +Option A is ALWAYS better because: +- Finding good configuration is harder than averaging noise +- Top configuration with 1 seed > random configuration averaged 5x +- Can always rerun top 5 with multiple seeds if needed +- Larger exploration space finds fundamentally better hyperparameters +``` + +**Recommended Allocation**: +``` +Total budget: 200 configurations × 30 min = 100 hours + +Phase 1: Wide exploration (100 configurations, 1 seed each) +├─ Random or Bayesian over full search space +└─ Find top 10 candidates + +Phase 2: Refinement (50 configurations, 1 seed each) +├─ Search near best from Phase 1 +├─ Explore unexplored neighbors +└─ Find top 5 refined candidates + +Phase 3: Validation (5 configurations, 3 seeds each) +├─ Run best from Phase 2 with multiple seeds +├─ Report mean ± std +└─ Ensemble predictions from 3 models + +Total: 100 + 50 + 15 = 165 trials (realistic) +``` + + +## Early Stopping in Hyperparameter Search (Critical for Efficiency) + +**Key Concept**: During hyperparameter search, stop trials that are clearly bad early. + +**NOT the Same As**: +- Early stopping during training (regularization technique) - still do this! +- Stopping tuning when results plateau (quit tuning) - different concept + +**Early Stopping in Search**: Abandon bad hyperparameter configurations before full training. + +**How It Works**: +```python +# With early stopping in search +for trial in range(100): + model = create_model() + for epoch in range(100): + train(model, epoch) + val_acc = validate(model) + + # Check if this trial is hopeless + if val_acc < best_val_acc - 10: # Way worse than best + break # Stop and try next configuration! + + # Or use automated pruning (Optuna does this) + +# Result: 100 trials × ~30 epochs on average = 3000 epoch-trials +# vs 100 trials × 100 epochs = 10000 epoch-trials +# Saves 70% compute, finds same best configuration! +``` + +**When to Prune**: +``` +Trial accuracy worse than best by: + +Epoch 5: >15% → PRUNE (hopeless, try next) +Epoch 10: >10% → PRUNE +Epoch 30: >5% → PRUNE +Epoch 50: >2% → DON'T PRUNE (still might recover) +Epoch 80+: Never prune (almost done training) +``` + +**Optuna's Pruning Strategy**: +```python +import optuna + +study = optuna.create_study( + direction="maximize", + pruner=optuna.pruners.MedianPruner( + n_startup_trials=5, # First 5 trials always complete + n_warmup_steps=10, # No pruning until epoch 10 + interval_steps=1, # Check every epoch + ) +) +# MedianPruner removes trials worse than median at each epoch +# Automatically saves ~50-70% compute +``` + + +## Tools and Frameworks Comparison + +### 1. Manual Grid Search (DIY) + +```python +# Pros: Full control, simple, good for 1-2 parameters +# Cons: Doesn't scale to many parameters + +import itertools + +configs = itertools.product( + [0.001, 0.01, 0.1], + [0.0, 0.0001, 0.001] +) + +best = None +for lr, wd in configs: + acc = train_and_evaluate(lr=lr, weight_decay=wd) + if best is None or acc > best['acc']: + best = {'lr': lr, 'wd': wd, 'acc': acc} +``` + +**When to Use**: <50 configurations, quick experiments + + +### 2. Optuna (Industry Standard) + +```python +# Pros: Bayesian optimization, pruning, very popular +# Cons: Slightly more complex + +import optuna + +def objective(trial): + lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) + wd = trial.suggest_float("wd", 1e-5, 1e-1, log=True) + + model = create_model() + for epoch in range(100): + train(model, lr=lr, weight_decay=wd) + val_acc = validate(model) + + trial.report(val_acc, epoch) + if trial.should_prune(): + raise optuna.TrialPruned() + + return val_acc + +study = optuna.create_study(direction="maximize") +study.optimize(objective, n_trials=200) +``` + +**When to Use**: 5+ parameters, 200+ trials, need efficiency + +**Why It's Best**: +- Bayesian optimization guides search efficiently +- Pruning saves 50-70% compute +- Handles many parameters well +- Simple API once you understand it + + +### 3. Ray Tune (For Distributed Search) + +```python +# Pros: Distributed search, good for many trials in parallel +# Cons: More setup needed + +from ray import tune + +def train_model(config): + model = create_model() + for epoch in range(100): + train(model, lr=config['lr'], batch_size=config['batch_size']) + val_acc = validate(model) + tune.report(accuracy=val_acc) + +analysis = tune.run( + train_model, + config={ + "lr": tune.loguniform(1e-5, 1e-1), + "batch_size": tune.choice([16, 32, 64, 128]), + }, + num_samples=200, + scheduler=tune.ASHAScheduler( + time_attr="training_iteration", + metric="accuracy", + mode="max", + max_t=100, + ), + verbose=1, +) +``` + +**When to Use**: Distributed setup (multiple GPUs/machines), 500+ trials + + +### 4. Weights & Biases (W&B) Sweeps (For Collaboration) + +```python +# Pros: Visual dashboard, team collaboration, easy integration +# Cons: Requires W&B account, less control than Optuna + +# sweep_config.yaml: +program: train.py +method: bayes +metric: + name: val_accuracy + goal: maximize +parameters: + learning_rate: + min: 0.00001 + max: 0.1 + distribution: log_uniform + weight_decay: + min: 0.00001 + max: 0.1 + distribution: log_uniform + +# Then run: wandb sweep sweep_config.yaml +``` + +**When to Use**: Team settings, want visual results, corporate environment + + +## When to Use Manual Tuning vs Automated Search + +### Manual Tuning (Sometimes Better Than You'd Think) + +**Process**: +1. Set learning rate with learning rate finder (1 epoch) +2. Train with this LR, watch training curves +3. If loss oscillates → lower LR by 2x → retrain +4. If loss plateaus → lower LR by 3x → retrain +5. Repeat until training stable and converging well +6. Done! + +**When It's Actually Faster**: +- Total experiments: 3-5 (vs 50+ for search) +- Time: 1-2 hours (vs 20+ hours for automated) +- Result: Often 80-85% (vs 85%+ for search) + +```python +# Manual tuning example +learning_rates = [0.0001] # Start low and safe + +for lr in learning_rates: + model = create_model() + losses = train(model, lr=lr) + + # If oscillating, reduce LR + if losses[-10:].std() > losses[-50:-10].std(): + learning_rates.append(lr * 0.5) + + # If plateauing, reduce LR + elif losses[-10:].mean() - losses[-50:-10].mean() < 0.01: + learning_rates.append(lr * 0.5) + + # If good convergence, done! + else: + print(f"Good LR found: {lr}") + break +``` + +**Pros**: +- Fast for 1-2 hyperparameters +- Understand the hyperparameters better +- Good when compute is limited +- Better for quick iteration + +**Cons**: +- Doesn't explore systematically +- Easy to get stuck in local view +- Not reproducible (depends on your intuition) +- Doesn't find global optimum + +**Use Manual When**: +- ✓ Tuning only learning rate +- ✓ Quick experiments (< 1 hour) +- ✓ Testing ideas rapidly +- ✓ Compute very limited +- ✓ New problem/dataset (explore first) + +**Use Automated When**: +- ✓ Tuning 3+ hyperparameters +- ✓ Targeting SOTA results +- ✓ Compute available (10+ hours) +- ✓ Want reproducible results +- ✓ Need best possible configuration + + +## Common Pitfalls and How to Avoid Them + +### Pitfall 1: Not Using Log Scale for Learning Rate +**Problem**: Linear scale [0.0001, 0.002, 0.004, 0.006, 0.008, 0.01] misses optimal +**Fix**: Use logarithmic scale np.logspace(-4, -2, 6) +**Impact**: Can miss 3-5% accuracy improvement + +### Pitfall 2: Tuning Too Many Hyperparameters at Once +**Problem**: 5 parameters × 5 values = 3,125 configs, impractical +**Fix**: Prioritize - tune LR first, then batch size, then others +**Impact**: Saves 100x compute while finding better results + +### Pitfall 3: Using Grid Search in High Dimensions +**Problem**: Grid search is O(k^n), explodes quickly +**Fix**: Use random search for 4+ parameters, Bayesian for 5+ +**Impact**: Random search is 10x more efficient + +### Pitfall 4: Training All Trials to Completion +**Problem**: Bad trials waste compute (no early stopping in search) +**Fix**: Use Optuna with MedianPruner to prune bad trials +**Impact**: Save 50-70% compute, same best result + +### Pitfall 5: Searching Over Architecture Before Optimizing Learning Rate +**Problem**: Model width 128→256 with bad LR gives noisy results +**Fix**: Fix learning rate first, then tune architecture +**Impact**: Avoid confounding, find LR gives 5-10%, width gives 2% + +### Pitfall 6: Single Seed for Final Configuration +**Problem**: One training run, variance unknown +**Fix**: Run top 5 configs with 3+ seeds +**Impact**: Know confidence intervals, can ensemble + +### Pitfall 7: Search Space Too Narrow +**Problem**: LR range [0.005, 0.01] misses better values outside +**Fix**: Start with wide range (1e-5 to 1e-1), narrow after +**Impact**: Find better optima, can always refine later + +### Pitfall 8: Not Checking for Interactions Between Hyperparameters +**Problem**: Assumes hyperparameters are independent +**Reality**: Batch size and LR interact, warmup and scheduler interact +**Fix**: Bayesian optimization naturally handles interactions +**Impact**: Find better combined configurations + +### Pitfall 9: Stopping Search Too Early +**Problem**: First 20 trials don't converge, give up +**Fix**: Run at least 50-100 trials (Bayesian gets better with more) +**Impact**: Bayesian optimization needs warm-up, improves significantly + +### Pitfall 10: Not Comparing to Baseline +**Problem**: Find best config is 82%, don't know if better than default +**Fix**: Include default hyperparameters as explicit trial +**Impact**: Know if search is even helping (sometimes default is good) + + +## Hyperparameter Importance Empirical Results (Case Studies) + +### Case Study 1: CIFAR-10 ResNet-18 + +| Change | Accuracy Shift | Relative Importance | +|--------|---|---| +| LR: 0.001 → 0.01 | +14% | 100% ← CRITICAL | +| Batch size: 32 → 128 | -2% | Low (but affects LR) | +| Weight decay: 0 → 0.0001 | +2% | 15% | +| Dropout: 0 → 0.3 | +1% | 7% | +| Model width: 64 → 128 | +2% | 15% | + +**Lesson**: LR is 7-20x more important than individual architectural changes + + +### Case Study 2: ImageNet Fine-tuning (Pretrained ResNet-50) + +| Change | Accuracy Shift | Relative Importance | +|--------|---|---| +| LR: 0.01 → 0.001 | +3% | 100% ← CRITICAL | +| Warmup: 0 → 1000 steps | +0.5% | 15% | +| Weight decay: 0 → 0.001 | +0.5% | 15% | +| Frozen layers: 0 → 3 | +1% | 30% | + +**Lesson**: Fine-tuning is LR-dominated; architecture matters less for pretrained + + +## Rationalization Table: How to Handle Common Arguments + +| User Says | What They Mean | Reality | What to Do | +|-----------|---|---|---| +| "Grid search is most thorough" | Should check all combinations | Grid is O(k^n), explodes | Show random search beats grid in 5+ dims | +| "More hyperparameters = more flexibility" | Want to tune everything | Most don't matter | Show importance hierarchy, tune LR first | +| "I'll tune architecture first" | Want to find model size | Bad LR confounds results | Insist on fixing LR first | +| "Linear spacing is uniform" | Want equal coverage | Effect is exponential | Show log scale finds optimal 3-5% better | +| "Longer training gives better results" | Can't prune early | Bad config won't improve | Show early stopping pruning saves 70% | +| "I ran 5 configs and found best" | Early results seem good | Variance of 5 runs is high | Need 20+ to be confident | +| "This LR seems good" | One training run looks ok | Might just be lucky run | Run 3 seeds, report mean ± std | +| "My compute is limited" | Can't do full search | Limited budget favors random | Allocate to many configs × 1 seed | + + +## Red Flags: When Something is Wrong + +🚩 **Red Flag 1**: Training loss is extremely noisy (spikes up and down) +- Likely cause: Learning rate too high +- Fix: Reduce learning rate by 10x, try again + +🚩 **Red Flag 2**: All trials have similar accuracy (within 0.5%) +- Likely cause: Search space too narrow or search space overlaps +- Fix: Expand search space, verify random sampling is working + +🚩 **Red Flag 3**: Best trial is at edge of search space +- Likely cause: Search space is too small, optimal is outside +- Fix: Expand bounds in that direction + +🚩 **Red Flag 4**: Early stopping pruned 95% of trials +- Likely cause: Initial configuration space very poor +- Fix: Expand search space, adjust pruning thresholds + +🚩 **Red Flag 5**: Trial finished in 1 epoch (model crashed or diverged) +- Likely cause: Learning rate way too high or batch size incompatible +- Fix: Check LR bounds are reasonable, verify code works + +🚩 **Red Flag 6**: Default hyperparameters beat tuned ones +- Likely cause: Search space poorly designed, not enough trials +- Fix: Expand search space, run more trials, check for bugs + +🚩 **Red Flag 7**: Same "best" configuration found in two independent searches +- Positive indicator: Robust result, likely good hyperparameter +- Action: Can be confident in this configuration + + +## Quick Reference: Decision Tree + +``` +Need to improve model performance? +│ +├─ Model underfits (high train + val loss)? +│ └─ → Add capacity or train longer (not a tuning problem) +│ +├─ Training converges too slowly? +│ └─ → Tune learning rate first (critical!) +│ +├─ Training is unstable (losses spike)? +│ └─ → Lower learning rate or increase batch size +│ +├─ Overfitting (low train loss, high val loss)? +│ └─ → Tune weight decay, dropout, learning rate schedule +│ +├─ How many hyperparameters to tune? +│ ├─ 1-2 params → Use manual tuning or grid search +│ ├─ 3-4 params → Use random search +│ └─ 5+ params → Use Bayesian optimization (Optuna) +│ +├─ How much compute available? +│ ├─ <10 hours → Tune only learning rate +│ ├─ 10-100 hours → Random search over 3-4 params +│ └─ 100+ hours → Bayesian optimization, multiple seeds +│ +└─ Should you run multiple seeds? + ├─ During search: NO (use compute for many configs instead) + └─ For final configs: YES (1-3 seeds per top-5 candidates) +``` + + +## Advanced Topics + +### Learning Rate Warmup (Critical for Transformers) + +**What It Is**: Start with very small LR, gradually increase to target over N steps, then decay. + +**Why It Matters**: +- Transformers are unstable without warmup +- Initial gradients can be very large (unstable) +- Gradual increase lets model stabilize +- Warmup is ESSENTIAL for BERT, GPT, ViT, etc. + +**Typical Warmup Schedule**: +```python +# Linear warmup then cosine decay +# Common: 10% of total steps for warmup + +import math + +def get_lr(step, total_steps, warmup_steps, max_lr): + if step < warmup_steps: + # Linear warmup: 0 → max_lr + return max_lr * (step / warmup_steps) + else: + # Cosine decay: max_lr → 0.1 * max_lr + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * max_lr * (1 + math.cos(math.pi * progress)) + +# Example: +total_steps = 10000 +warmup_steps = 1000 # 10% warmup +max_lr = 0.001 + +for step in range(total_steps): + lr = get_lr(step, total_steps, warmup_steps, max_lr) + # use lr for this step +``` + +**When to Tune Warmup**: +- Essential for transformers (BERT, GPT, ViT) +- Important for large models (ResNet-50+) +- Can skip for small models (ResNet-18) +- Typical: 5-10% of total steps + +**Warmup Parameters to Consider**: +- `warmup_steps`: How many steps to warm up (10% of total) +- `warmup_schedule`: Linear vs exponential warmup +- Interaction with learning rate: Must tune together! + + +### Batch Size and Learning Rate Interaction (Critical) + +**Key Finding**: Batch size and learning rate are NOT independent. + +**The Relationship**: +``` +Large batch size → Less gradient noise → Can use larger LR +Small batch size → More gradient noise → Need smaller LR + +Rule of thumb: LR ∝ sqrt(batch_size) +Doubling batch size → can increase LR by ~1.4x +``` + +**Example: CIFAR-10 ResNet18**: +``` +Batch Size 32, LR 0.01: Accuracy 84% +Batch Size 32, LR 0.05: Accuracy 81% (too high) + +Batch Size 128, LR 0.01: Accuracy 82% (too low for large batch) +Batch Size 128, LR 0.02: Accuracy 84% (recovered!) +Batch Size 128, LR 0.03: Accuracy 85% (slightly better, larger batch benefits) +``` + +**What This Means**: +- Can't tune batch size and LR independently +- Must tune them together +- This is why Bayesian optimization is better (handles interactions) +- Grid search would need to search all combinations + +**Implication for Search**: +- Include both batch size AND LR in search space +- Don't fix batch size, then tune LR +- Don't tune LR, then change batch size +- Search them together for best results + + +### Momentum and Optimizer-Specific Parameters + +**SGD with Momentum**: +```python +# Momentum: accelerates gradient descent +# High momentum (0.9): Faster convergence, but overshoots minima +# Low momentum (0.5): Slower, but more stable + +learning_rates = [0.01, 0.1] # Higher for SGD +momentums = [0.8, 0.9, 0.95] + +# SGD works best with moderate LR + high momentum +# Default: momentum=0.9 +``` + +**Adam Parameters**: +```python +# Adam is more forgiving (less sensitive to hyperparameters) +# But still worth tuning learning rate + +# Beta1 (exponential decay for 1st moment): usually 0.9 (don't change) +# Beta2 (exponential decay for 2nd moment): usually 0.999 (don't change) +# Epsilon: usually 1e-8 (don't bother tuning) + +learning_rates = [0.0001, 0.001, 0.01] # Lower for Adam +weight_decays = [0.0, 0.0001, 0.001] # Adam needs this + +# Adam is more robust, good default optimizer +``` + +**Which Optimizer to Choose**: +``` +SGD + Momentum: + Pros: Better generalization, well-understood + Cons: More sensitive to LR, slower convergence + Use for: Vision (CNN), competitive results + +Adam: + Pros: Faster convergence, less tuning, robust + Cons: Slightly worse generalization, adaptive complexity + Use for: NLP, transformers, quick experiments + +AdamW: + Pros: Better weight decay, all advantages of Adam + Cons: None really + Use for: Modern default, transformers, NLP + +RMSprop: + Pros: Good for RNNs, good convergence + Cons: Less popular, fewer resources + Use for: RNNs, rarely these days +``` + + +### Weight Decay and L2 Regularization + +**What's the Difference**: +- L2 regularization (added to loss): Works with all optimizers +- Weight decay (parameter update): Works correctly only with SGD +- AdamW: Fixes Adam's weight decay issue + +**Impact on Regularization**: +```python +# High weight decay: Strong regularization, lower capacity +weight_decay = 0.01 + +# Low weight decay: Weak regularization, higher capacity +weight_decay = 0.0001 + +# For overfitting: Start with weight_decay = 1e-4 to 1e-3 +# For underfitting: Reduce to 1e-5 or 0.0 +``` + +**Tuning Weight Decay**: +``` +If overfitting (low train loss, high val loss): + ├─ Try increasing weight decay (0.0001 → 0.001 → 0.01) + └─ Or reduce model capacity + └─ Or collect more data + +If underfitting (high train loss): + └─ Reduce weight decay to 0.0 +``` + +**Typical Values**: +``` +Vision models (ResNet, etc): 1e-4 to 1e-3 +Transformers (BERT, GPT): 0.01 to 0.1 +Small networks: 1e-5 to 1e-4 +Huge models (1B+): 0.0 or very small +``` + + +### Learning Rate Schedules Worth Tuning + +**Constant LR** (no schedule): +- Pros: Simple, good for comparison baseline +- Cons: Suboptimal convergence +- Use when: Testing new architecture quickly + +**Step Decay** (multiply by 0.1 every N epochs): +```python +# Divide LR by 10 at specific epochs +milestones = [30, 60, 90] # For 100 epoch training +for epoch in range(100): + if epoch in milestones: + lr *= 0.1 +``` + +**Exponential Decay** (multiply by factor each epoch): +```python +# Gradual decay, smoother than step +decay_rate = 0.96 +for epoch in range(100): + lr = initial_lr * (decay_rate ** epoch) +``` + +**Cosine Annealing** (cosine decay from max to min): +```python +# Best for convergence, used in SOTA papers +import math + +def cosine_annealing(epoch, total_epochs, min_lr, max_lr): + return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs)) + +# Smooth decay, no discontinuities +``` + +**OneCycleLR** (up then down): +```python +# Used in FastAI, very effective +# Goes: max_lr → min_lr → max_lr/25 +# Over entire training +``` + +**Which to Choose**: +``` +Vision (CNN): Step decay or cosine annealing +Transformers: Warmup then cosine or constant +Fine-tuning: Linear decay (slowly reduce) +Quick experiments: Constant LR +SOTA results: Cosine annealing with warmup +``` + + +### Hyperparameter Interactions: Complex Cases + +**Interaction 1: Batch Size × Learning Rate** +``` +Already covered above - MUST tune together +``` + +**Interaction 2: Model Capacity × Regularization** +``` +Large model + weak regularization → Overfitting +Large model + strong regularization → Good generalization +Small model + strong regularization → Underfitting + +Don't increase regularization for small models! +``` + +**Interaction 3: Warmup × Learning Rate** +``` +High LR needs more warmup steps +Low LR needs less warmup + +For LR=0.001: warmup_steps = 500 +For LR=0.1: warmup_steps = 5000 (higher LR = more warmup) +``` + +**Interaction 4: Weight Decay × Optimizer** +``` +SGD: Weight decay works as specified +Adam: Weight decay doesn't work properly (use AdamW!) +AdamW: Weight decay works correctly +``` + + +### When Model Capacity is the Real Problem + +**Underfitting Signs**: +``` +Training accuracy: 50% +Validation accuracy: 48% +Gap: Small (not overfitting) + +→ Model doesn't have capacity to learn +→ Add more parameters (wider/deeper) +→ Tuning hyperparameters won't help much +``` + +**Fix for Underfitting** (not tuning): +```python +# WRONG: Tuning hyperparameters +for lr in learning_rates: + model = SmallModel() # Too small! + train(model, lr=lr) # Still won't converge + +# CORRECT: Add model capacity +model = LargeModel() # More parameters +train(model, lr=0.01) # Now it converges well +``` + +**Capacity Sizing Rules**: +``` +Dataset size 10K images: Small model ok (100K parameters) +Dataset size 100K images: Medium model (1M parameters) +Dataset size 1M+ images: Large model (10M+ parameters) + +If training data < 10K: Use pre-trained, don't train from scratch +If training data > 1M: Larger models generally better +``` + + +### Debugging Hyperparameter Search + +**Debugging Checklist**: + +1. **Are trials actually different?** + ```python + # Check that suggested values are being used + for trial in study.trials[:5]: + print(f"LR: {trial.params['lr']}") + print(f"Batch size: {trial.params['batch_size']}") + # If all same, check suggest_* calls + ``` + +2. **Are results being recorded?** + ```python + # Verify accuracy improving or worsening meaningfully + for trial in study.trials: + print(f"Params: {trial.params}, Value: {trial.value}") + # Should see range of values, not all same + ``` + +3. **Is pruning too aggressive?** + ```python + # Check how many trials got pruned + n_pruned = sum(1 for t in study.trials if t.state == optuna.trial.TrialState.PRUNED) + print(f"Pruned {n_pruned}/{len(study.trials)}") + + # If >90% pruned: Expand search space or adjust pruning thresholds + ``` + +4. **Are hyperparameters in right range?** + ```python + # Check if best trial is at boundary + best = study.best_params + search_space = {...} # Your defined space + + for param, value in best.items(): + if value == search_space[param][0] or value == search_space[param][-1]: + print(f"WARNING: {param} at boundary!") + ``` + +5. **Is search space reasonable?** + ```python + # Quick sanity check: Run 5 random configs + # Should see different accuracies (not all 50%, not all 95%) + ``` + + +### Complete Optuna Workflow Example (Production Ready) + +**Full Example from Start to Finish**: + +```python +import optuna +from optuna.pruners import MedianPruner +from optuna.samplers import TPESampler +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset + +# Step 1: Define the objective function +def objective(trial): + # Suggest hyperparameters + learning_rate = trial.suggest_float( + "learning_rate", + 1e-5, 1e-1, + log=True # CRITICAL: Log scale for LR + ) + batch_size = trial.suggest_categorical( + "batch_size", + [16, 32, 64, 128] + ) + weight_decay = trial.suggest_float( + "weight_decay", + 1e-6, 1e-2, + log=True # Log scale for weight decay + ) + dropout_rate = trial.suggest_float( + "dropout_rate", + 0.0, 0.5 # Linear scale for dropout + ) + optimizer_type = trial.suggest_categorical( + "optimizer", + ["adam", "sgd"] + ) + + # Build model with suggested hyperparameters + model = create_model(dropout=dropout_rate) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # Create optimizer + if optimizer_type == "adam": + optimizer = torch.optim.Adam( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay + ) + else: # sgd + optimizer = torch.optim.SGD( + model.parameters(), + lr=learning_rate, + momentum=0.9, + weight_decay=weight_decay + ) + + # Learning rate scheduler + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100 + ) + + # Training loop with pruning + best_val_acc = 0 + for epoch in range(100): + # Train + model.train() + train_loss = 0 + for batch_x, batch_y in train_loader: + batch_x, batch_y = batch_x.to(device), batch_y.to(device) + + optimizer.zero_grad() + logits = model(batch_x) + loss = nn.CrossEntropyLoss()(logits, batch_y) + loss.backward() + optimizer.step() + + train_loss += loss.item() + + # Validate + model.eval() + val_correct = 0 + val_total = 0 + with torch.no_grad(): + for batch_x, batch_y in val_loader: + batch_x, batch_y = batch_x.to(device), batch_y.to(device) + logits = model(batch_x) + predictions = logits.argmax(dim=1) + val_correct += (predictions == batch_y).sum().item() + val_total += batch_y.size(0) + + val_acc = val_correct / val_total + if val_acc > best_val_acc: + best_val_acc = val_acc + + # Step scheduler + scheduler.step() + + # Report to trial and prune if needed (CRITICAL!) + trial.report(val_acc, epoch) + if trial.should_prune(): + raise optuna.TrialPruned() + + return best_val_acc + +# Step 2: Create study with optimization +# TPESampler: Tree-structured Parzen Estimator (better than default) +sampler = TPESampler(seed=42) +study = optuna.create_study( + direction="maximize", + sampler=sampler, + pruner=MedianPruner( + n_startup_trials=5, # First 5 trials always complete + n_warmup_steps=10, # No pruning until epoch 10 + interval_steps=1 # Check every epoch + ) +) + +# Step 3: Optimize (run search) +study.optimize( + objective, + n_trials=200, # Run 200 configurations + n_jobs=4, # Parallel execution on 4 GPUs if available + show_progress_bar=True +) + +# Step 4: Analyze results +print(f"Best accuracy: {study.best_value:.4f}") +print(f"Best hyperparameters: {study.best_params}") + +# Step 5: Visualize results (optional but useful) +try: + import matplotlib.pyplot as plt + fig = optuna.visualization.plot_optimization_history(study).show() +except: + pass + +# Step 6: Run final validation with best config +# (With 3 seeds, report mean ± std) +best_params = study.best_params +final_accuracies = [] + +for seed in range(3): + model = create_model(dropout=best_params['dropout_rate']) + # ... train with best_params ... + final_acc = validate(model) # Your validation function + final_accuracies.append(final_acc) + +print(f"Final result: {np.mean(final_accuracies):.4f} ± {np.std(final_accuracies):.4f}") +``` + +**Key Points in This Example**: +1. Log scale for learning rate and weight decay (CRITICAL) +2. Linear scale for dropout (CORRECT) +3. Trial pruning to save compute (ESSENTIAL) +4. LR scheduler with optimizer +5. Running final validation with multiple seeds +6. Clear reporting of best config + + +### Grid Search at Scale: When It Breaks Down + +**Small Grid (Works Fine)**: +``` +3 params × 3 values each = 27 configs +Time: 27 × 30 min = 810 minutes = 13.5 hours +Practical? YES +``` + +**Medium Grid (Getting Expensive)**: +``` +4 params × 4 values each = 256 configs +Time: 256 × 30 min = 7680 minutes = 128 hours = 5.3 days +Practical? MAYBE (if you have the compute) +``` + +**Large Grid (Impractical)**: +``` +5 params × 5 values each = 3,125 configs +Time: 3,125 × 30 min = 93,750 minutes = 65 days +Practical? NO +Random search: 200 configs = 6,400 minutes = 4.4 days +→ 15x FASTER, BETTER RESULTS +``` + +**Always Use Random When Grid > 100 Configs** + + +### Common Search Space Mistakes (With Fixes) + +**Mistake 1: LR range too narrow** +```python +# WRONG: Only covers small range +lr_values = [0.008, 0.009, 0.01, 0.011, 0.012] + +# CORRECT: Covers multiple orders of magnitude +lr_values = np.logspace(-4, -1, 6) # [1e-4, 1e-3, 1e-2, 1e-1] +``` + +**Mistake 2: Batch size without corresponding LR adjustment** +```python +# WRONG: Searches batch size but LR fixed at 0.001 +batch_sizes = [32, 64, 128, 256] +learning_rate = 0.001 # Fixed! + +# CORRECT: Search both batch size AND LR together +# Large batch needs larger LR +batch_sizes = [32, 64, 128, 256] +learning_rates = [0.001, 0.002, 0.003, 0.005, 0.01] +``` + +**Mistake 3: Linear spacing for exponential parameters** +```python +# WRONG: Linear spacing for weight decay +wd_values = [0.0, 0.025, 0.05, 0.075, 0.1] + +# CORRECT: Log spacing for weight decay +wd_values = np.logspace(-5, -1, 6) # [1e-5, 1e-4, 1e-3, 1e-2, 1e-1] +``` + +**Mistake 4: Dropout range that's too wide** +```python +# WRONG: Including 0.9 dropout (destroys model) +dropout_values = [0.0, 0.3, 0.6, 0.9] + +# CORRECT: Reasonable regularization range +dropout_values = [0.0, 0.2, 0.4, 0.6] +``` + + +### When to Stop Searching and Go With What You Have + +**Stop Conditions**: + +1. **Diminishing Returns** + - First 50 trials: Found 80% of best accuracy + - Next 50 trials: Found 15% improvement + - Next 50 trials: Found 4% improvement + - → Stop when improvement/trial drops below 0.1% + +2. **Time Budget Exhausted** + - Planned for 100 hours, used 100 hours + - → Run final validation and ship results + +3. **Best Config Appears Stable** + - Same best configuration in last 20 trials + - Different search random seeds find same optimum + - → Confidence in result, safe to stop + +4. **No Config Improvement** + - Last 30 trials all worse than current best + - Pruning catching most trials + - → Search converged, time to stop + +**Decision Rule**: +``` +Number of trials = min( + total_budget // cost_per_trial, + until_improvement < 0.1%, + until_same_best_for_20_trials +) +``` + + +## Summary: Best Practices + +1. **Prioritize Learning Rate** - Most important hyperparameter by far (7-20x impact) +2. **Use Log Scale** - For LR, weight decay, regularization strength +3. **Avoid Grid Search** - Exponential complexity O(k^n), use random for 4+ params +4. **Allocate for Many Configs** - Broad exploration > Multiple runs of few configs (5-10x better) +5. **Enable Early Stopping** - In search itself (pruning bad trials), saves 50-70% compute +6. **Use Optuna** - Industry standard with Bayesian optimization + pruning +7. **Run Multiple Seeds** - Only for final top-5 candidates (3 seeds), not all trials +8. **Start With Defaults** - Only tune if underperforming (don't waste compute) +9. **Check for Interactions** - Batch size and LR interact strongly (tune together) +10. **Compare to Baseline** - Include default config to verify search helps +11. **Tune Warmup with LR** - Critical for transformers, must co-tune +12. **Match Optimizer to Task** - SGD for vision/SOTA, Adam/AdamW for NLP/transformers +13. **Use Log Scale for Exponential Parameters** - Critical for finding optimal +14. **Stop When Returns Diminish** - Once improvement <0.1% per trial, stop searching +15. **Debug Search Systematically** - Check bounds, pruning rates, parameter suggestions + diff --git a/skills/using-training-optimization/learning-rate-scheduling.md b/skills/using-training-optimization/learning-rate-scheduling.md new file mode 100644 index 0000000..fc3c0e5 --- /dev/null +++ b/skills/using-training-optimization/learning-rate-scheduling.md @@ -0,0 +1,2723 @@ + +# Learning Rate Scheduling Skill + +## When to Use This Skill + +Use this skill when: +- User asks "should I use a learning rate scheduler?" +- Training plateaus or loss stops improving +- Training transformers or large models (warmup critical) +- User wants to implement OneCycleLR or specific scheduler +- Training is unstable in early epochs +- User asks "what learning rate should I use?" +- Deciding between constant LR and scheduled LR +- User is copying a paper's training recipe +- Implementing modern training pipelines (vision, NLP, RL) +- User suggests "just use constant LR" (rationalization) + +Do NOT use when: +- User has specific bugs unrelated to scheduling +- Only discussing optimizer choice (no schedule questions) +- Training already working well and no LR questions asked + + +## Core Principles + +### 1. Why Learning Rate Scheduling Matters + +Learning rate scheduling is one of the MOST IMPACTFUL hyperparameters: + +**High LR Early (Exploration):** +- Fast initial progress through parameter space +- Escape poor local minima +- Rapid loss reduction in early epochs + +**Low LR Late (Exploitation):** +- Fine-tune to sharper, better minima +- Improve generalization (test accuracy) +- Stable convergence without oscillation + +**Quantitative Impact:** +- Proper scheduling improves final test accuracy by 2-5% (SIGNIFICANT) +- Standard practice in all SOTA papers (ResNet, EfficientNet, ViT, BERT, GPT) +- Not optional for competitive performance + +**When Constant LR Fails:** +- Can't explore quickly AND converge precisely +- Either too high (never converges) or too low (too slow) +- Leaves 2-5% performance on the table + + +### 2. Decision Framework: When to Schedule vs Constant LR + +## Use Scheduler When: + +✅ **Long training (>30 epochs)** +- Scheduling essential for multi-stage training +- Different LR regimes needed across training +- Example: 90-epoch ImageNet training + +✅ **Large model on large dataset** +- Training from scratch on ImageNet, COCO, etc. +- Benefits from exploration → exploitation strategy +- Example: ResNet-50 on ImageNet + +✅ **Training plateaus or loss stops improving** +- Current LR too high for current parameter regime +- Reducing LR breaks plateau +- Example: Validation loss stuck for 10+ epochs + +✅ **Following established training recipes** +- Papers publish schedules for reproducibility +- Vision models typically use MultiStepLR or Cosine +- Example: ResNet paper specifies drop at epochs 30, 60, 90 + +✅ **Want competitive SOTA performance** +- Squeezing out last 2-5% accuracy +- Required for benchmarks and competitions +- Example: Targeting SOTA on CIFAR-10 + +## Maybe Don't Need Scheduler When: + +❌ **Very short training (<10 epochs)** +- Not enough time for multi-stage scheduling +- Constant LR or simple linear decay sufficient +- Example: Quick fine-tuning for 5 epochs + +❌ **OneCycle is the strategy itself** +- OneCycleLR IS the training strategy (not separate) +- Don't combine OneCycle with another scheduler +- Example: FastAI-style 20-epoch training + +❌ **Hyperparameter search phase** +- Constant LR simpler to compare across runs +- Add scheduling after finding good architecture/optimizer +- Example: Running 50 architecture trials + +❌ **Transfer learning fine-tuning** +- Small number of epochs on pretrained model +- Constant small LR often sufficient +- Example: Fine-tuning BERT for 3 epochs + +❌ **Reinforcement learning** +- RL typically uses constant LR (exploration/exploitation balance different) +- Some exceptions (PPO sometimes uses linear decay) +- Example: DQN, A3C usually constant LR + +## Default Recommendation: + +**For >30 epoch training:** USE A SCHEDULER (typically CosineAnnealingLR) +**For <10 epoch training:** Constant LR usually fine +**For 10-30 epochs:** Try both, scheduler usually wins + + +### 3. Major Scheduler Types - Complete Comparison + +## StepLR / MultiStepLR (Classic Vision) + +**Use When:** +- Training CNNs (ResNet, VGG, etc.) +- Following established recipe from paper +- Want simple, interpretable schedule + +**How It Works:** +- Drop LR by constant factor at specific epochs +- StepLR: every N epochs +- MultiStepLR: at specified milestone epochs + +**Implementation:** + +```python +# StepLR: Drop every 30 epochs +scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=30, # Drop every 30 epochs + gamma=0.1 # Multiply LR by 0.1 (10x reduction) +) + +# MultiStepLR: Drop at specific milestones (more control) +scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[30, 60, 90], # Drop at these epochs + gamma=0.1 # Multiply by 0.1 each time +) + +# Training loop +for epoch in range(100): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() # Call AFTER each epoch +``` + +**Example Schedule (initial_lr=0.1):** +- Epochs 0-29: LR = 0.1 +- Epochs 30-59: LR = 0.01 (dropped by 10x) +- Epochs 60-89: LR = 0.001 (dropped by 10x again) +- Epochs 90-99: LR = 0.0001 + +**Pros:** +- Simple and interpretable +- Well-established in papers (easy to reproduce) +- Works well for vision models + +**Cons:** +- Requires manual milestone selection +- Sharp LR drops can cause temporary instability +- Need to know total training epochs in advance + +**Best For:** Classical CNN training (ResNet, VGG) following paper recipes + + +## CosineAnnealingLR (Modern Default) + +**Use When:** +- Training modern vision models (ViT, EfficientNet) +- Want smooth decay without manual milestones +- Don't want to tune milestone positions + +**How It Works:** +- Smooth cosine curve from initial_lr to eta_min +- Gradual decay, no sharp drops +- LR = eta_min + (initial_lr - eta_min) * (1 + cos(π * epoch / T_max)) / 2 + +**Implementation:** + +```python +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, # Total epochs (LR reaches eta_min at epoch 100) + eta_min=1e-5 # Minimum LR (default: 0) +) + +# Training loop +for epoch in range(100): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() # Call AFTER each epoch +``` + +**Example Schedule (initial_lr=0.1, eta_min=1e-5):** +- Epoch 0: LR = 0.1 +- Epoch 25: LR ≈ 0.075 +- Epoch 50: LR ≈ 0.05 +- Epoch 75: LR ≈ 0.025 +- Epoch 100: LR = 0.00001 + +**Pros:** +- No milestone tuning required +- Smooth decay (no instability from sharp drops) +- Widely used in modern papers +- Works well across many domains + +**Cons:** +- Must know total epochs in advance +- Can't adjust schedule during training + +**Best Practice: ALWAYS COMBINE WITH WARMUP for large models:** + +```python +# Warmup for 5 epochs, then cosine for 95 epochs +warmup = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, # Start at 1% of base LR + end_factor=1.0, # Ramp to 100% + total_iters=5 # Over 5 epochs +) + +cosine = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=95, # 95 epochs after warmup + eta_min=1e-5 +) + +scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[5] # Switch to cosine after 5 epochs +) +``` + +**Best For:** Modern vision models, transformers, default choice for most problems + + +## ReduceLROnPlateau (Adaptive) + +**Use When:** +- Don't know optimal schedule in advance +- Want adaptive approach based on validation performance +- Training plateaus and you want automatic LR reduction + +**How It Works:** +- Monitors validation metric (loss or accuracy) +- Reduces LR when metric stops improving +- Requires passing metric to scheduler.step() + +**Implementation:** + +```python +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', # 'min' for loss, 'max' for accuracy + factor=0.1, # Reduce LR by 10x when plateau detected + patience=10, # Wait 10 epochs before reducing + threshold=1e-4, # Minimum change to count as improvement + threshold_mode='rel', # 'rel' or 'abs' + cooldown=0, # Epochs to wait after LR reduction + min_lr=1e-6, # Don't reduce below this + verbose=True # Print when LR reduced +) + +# Training loop +for epoch in range(100): + train_loss = train_one_epoch(model, train_loader, optimizer) + val_loss = validate(model, val_loader) + + # IMPORTANT: Pass validation metric to step() + scheduler.step(val_loss) # NOT scheduler.step() alone! +``` + +**Example Behavior (patience=10, factor=0.1):** +- Epochs 0-30: Val loss improving, LR = 0.001 +- Epochs 31-40: Val loss plateaus at 0.15, patience counting +- Epoch 41: Plateau detected, LR reduced to 0.0001 +- Epochs 42-60: Val loss improving again with lower LR +- Epoch 61: Plateau again, LR reduced to 0.00001 + +**Pros:** +- Adaptive - no manual tuning required +- Based on actual training progress +- Good for unknown optimal schedule + +**Cons:** +- Can be too conservative (waits long before reducing) +- Requires validation metric (can't use train loss alone) +- May reduce LR too late or not enough + +**Tuning Tips:** +- Smaller patience (5-10) for faster adaptation +- Larger patience (10-20) for more conservative +- Factor of 0.1 (10x) is standard, but 0.5 (2x) more gradual + +**Best For:** Exploratory training, unknown optimal schedule, adaptive pipelines + + +## OneCycleLR (Fast Training) + +**Use When:** +- Limited compute budget (want fast convergence) +- Training for relatively few epochs (10-30) +- Following FastAI-style training +- Want aggressive schedule for quick results + +**How It Works:** +- Ramps UP from low LR to max_lr (first 30% by default) +- Ramps DOWN from max_lr to very low LR (remaining 70%) +- Steps EVERY BATCH (not every epoch) - CRITICAL DIFFERENCE + +**Implementation:** + +```python +scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=0.1, # Peak learning rate (TUNE THIS!) + steps_per_epoch=len(train_loader), # Batches per epoch + epochs=20, # Total epochs + pct_start=0.3, # Ramp up for first 30% + anneal_strategy='cos', # 'cos' or 'linear' + div_factor=25, # initial_lr = max_lr / 25 + final_div_factor=10000 # final_lr = max_lr / 10000 +) + +# Training loop - NOTE: step() EVERY BATCH +for epoch in range(20): + for batch in train_loader: + loss = train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # CALL EVERY BATCH, NOT EVERY EPOCH! +``` + +**Example Schedule (max_lr=0.1, 20 epochs, 400 batches/epoch):** +- Batches 0-2400 (epochs 0-6): LR ramps from 0.004 → 0.1 +- Batches 2400-8000 (epochs 6-20): LR ramps from 0.1 → 0.00001 + +**CRITICAL: Tuning max_lr:** + +OneCycleLR is VERY sensitive to max_lr choice. Too high = instability. + +**Method 1 - LR Finder (RECOMMENDED):** +```python +# Run LR finder first (see LR Finder section) +optimal_lr = find_lr(model, train_loader, optimizer) # e.g., 0.01 +max_lr = optimal_lr * 10 # Use 10x optimal as max_lr +``` + +**Method 2 - Manual tuning:** +- Start with max_lr = 0.1 +- If training unstable, try 0.03, 0.01 +- If training too slow, try 0.3, 1.0 + +**Pros:** +- Very fast convergence (fewer epochs needed) +- Strong final performance +- Popular in FastAI community + +**Cons:** +- Sensitive to max_lr (requires tuning) +- Steps every batch (easy to mess up) +- Not ideal for very long training (>50 epochs) + +**Common Mistakes:** +1. Calling scheduler.step() per epoch instead of per batch +2. Not tuning max_lr (using default blindly) +3. Using for very long training (OneCycle designed for shorter cycles) + +**Best For:** FastAI-style training, limited compute budget, 10-30 epoch training + + +## Advanced OneCycleLR Tuning + +If lowering max_lr doesn't resolve instability, try these advanced tuning options: + +**1. Adjust pct_start (warmup fraction):** + +```python +# Default: 0.3 (30% warmup, 70% cooldown) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + pct_start=0.3) # Default + +# If unstable at peak: Increase to 0.4 or 0.5 (longer warmup) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + pct_start=0.5) # Gentler ramp to peak + +# If unstable in cooldown: Decrease to 0.2 (shorter warmup, gentler descent) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + pct_start=0.2) +``` + +**2. Adjust div_factor (initial LR):** + +```python +# Default: 25 (initial_lr = max_lr / 25) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + div_factor=25) # Start at 0.004 + +# If unstable at start: Increase to 50 or 100 (start even lower) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + div_factor=100) # Start at 0.001 +``` + +**3. Adjust final_div_factor (final LR):** + +```python +# Default: 10000 (final_lr = max_lr / 10000) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + final_div_factor=10000) # End at 0.00001 + +# If unstable at end: Decrease to 1000 (end at higher LR) +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, + steps_per_epoch=len(train_loader), + final_div_factor=1000) # End at 0.0001 +``` + +**4. Add gradient clipping:** + +```python +# In training loop +for batch in train_loader: + loss = train_step(model, batch, optimizer) + loss.backward() + + # Clip gradients to prevent instability + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + optimizer.step() + scheduler.step() +``` + +**5. Consider OneCycle may not be right for your problem:** + +- **Very deep networks (>100 layers):** May be too unstable for OneCycle's aggressive schedule +- **Large models (>100M params):** May need gentler schedule (Cosine + warmup) +- **Sensitive architectures (some transformers):** OneCycle's rapid LR changes can destabilize + +**Alternative:** Use CosineAnnealing + warmup for more stable training: + +```python +# More stable alternative to OneCycle +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +``` + + +## LinearLR (Warmup) + +**Use When:** +- Need warmup at training start +- Ramping up LR gradually over first few epochs +- Combining with another scheduler (SequentialLR) + +**How It Works:** +- Linearly interpolates LR from start_factor to end_factor +- Typically used for warmup: start_factor=0.01, end_factor=1.0 + +**Implementation:** + +```python +# Standalone linear warmup +scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, # Start at 1% of base LR + end_factor=1.0, # End at 100% of base LR + total_iters=5 # Over 5 epochs +) + +# More common: Combine with main scheduler +warmup = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, + total_iters=5 +) + +main = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=95 +) + +scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, main], + milestones=[5] # Switch after 5 epochs +) + +# Training loop +for epoch in range(100): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Example Schedule (base_lr=0.1):** +- Epoch 0: LR = 0.001 (1%) +- Epoch 1: LR = 0.0208 (20.8%) +- Epoch 2: LR = 0.0406 (40.6%) +- Epoch 3: LR = 0.0604 (60.4%) +- Epoch 4: LR = 0.0802 (80.2%) +- Epoch 5: LR = 0.1 (100%, then switch to main scheduler) + +**Best For:** Warmup phase for transformers and large models + + +## ExponentialLR (Continuous Decay) + +**Use When:** +- Want smooth, continuous decay +- Simpler alternative to Cosine +- Prefer exponential over linear decay + +**How It Works:** +- Multiply LR by gamma every epoch +- LR(epoch) = initial_lr * gamma^epoch + +**Implementation:** + +```python +scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + gamma=0.95 # Multiply by 0.95 each epoch +) + +# Training loop +for epoch in range(100): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Example Schedule (initial_lr=0.1, gamma=0.95):** +- Epoch 0: LR = 0.1 +- Epoch 10: LR = 0.0599 +- Epoch 50: LR = 0.0077 +- Epoch 100: LR = 0.0059 + +**Tuning gamma:** +- Want 10x decay over 100 epochs: gamma = 0.977 (0.1^(1/100)) +- Want 100x decay over 100 epochs: gamma = 0.955 (0.01^(1/100)) +- General formula: gamma = (target_lr / initial_lr)^(1/epochs) + +**Pros:** +- Very smooth decay +- Simple to implement + +**Cons:** +- Hard to intuit gamma value for desired final LR +- Less popular than Cosine (Cosine is better default) + +**Best For:** Cases where you want exponential decay specifically + + +## LambdaLR (Custom Schedules) + +**Use When:** +- Need custom schedule not provided by standard schedulers +- Implementing paper-specific schedule +- Advanced use cases (e.g., transformer inverse sqrt schedule) + +**How It Works:** +- Provide function that computes LR multiplier for each epoch +- LR(epoch) = initial_lr * lambda(epoch) + +**Implementation:** + +```python +# Example: Warmup then constant +def warmup_lambda(epoch): + if epoch < 5: + return (epoch + 1) / 5 # Linear warmup + else: + return 1.0 # Constant after warmup + +scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=warmup_lambda +) + +# Example: Transformer inverse square root schedule +def transformer_schedule(epoch): + warmup_steps = 4000 + step = epoch + 1 + return min(step ** (-0.5), step * warmup_steps ** (-1.5)) + +scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=transformer_schedule +) + +# Example: Polynomial decay +def polynomial_decay(epoch): + return (1 - epoch / 100) ** 0.9 # Decay to 0 at epoch 100 + +scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=polynomial_decay +) +``` + +**Best For:** Custom schedules, implementing specific papers, advanced users + + +### 4. Warmup Strategies - CRITICAL FOR TRANSFORMERS + +## Why Warmup is Essential + +**Problem at Training Start:** +- Weights are randomly initialized +- Gradients can be very large and unstable +- BatchNorm statistics are uninitialized +- High LR can cause immediate divergence (NaN loss) + +**Solution: Gradual LR Increase** +- Start with very low LR (1% of target) +- Linearly increase to target LR over first few epochs +- Allows model to stabilize before aggressive learning + +**Quantitative Impact:** +- Transformers WITHOUT warmup: Often diverge or train very unstably +- Transformers WITH warmup: Stable training, better final performance +- Vision models: Warmup improves stability, sometimes +0.5-1% accuracy + + +## When Warmup is MANDATORY + +**ALWAYS use warmup when:** + +✅ **Training transformers (ViT, BERT, GPT, T5, etc.)** +- Transformers REQUIRE warmup - not optional +- Without warmup, training often diverges +- Standard practice in all transformer papers + +✅ **Large batch sizes (>512)** +- Large batches → larger effective learning rate +- Warmup prevents early instability +- Standard for distributed training + +✅ **High initial learning rates** +- If starting with LR > 0.001, use warmup +- Warmup allows higher peak LR safely + +✅ **Training from scratch (not fine-tuning)** +- Random initialization needs gentle start +- Fine-tuning can often skip warmup (weights already good) + +**Usually use warmup when:** + +✅ Large models (>100M parameters) +✅ Using AdamW optimizer (common with transformers) +✅ Following modern training recipes + +**May skip warmup when:** + +❌ Fine-tuning pretrained models (weights already trained) +❌ Small learning rates (< 0.0001) +❌ Small models (<10M parameters) +❌ Established recipe without warmup (e.g., some CNN papers) + + +## Warmup Implementation Patterns + +### Pattern 1: Linear Warmup + Cosine Decay (Most Common) + +```python +import torch.optim.lr_scheduler as lr_scheduler + +# Warmup for 5 epochs +warmup = lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, # Start at 1% of base LR + end_factor=1.0, # End at 100% of base LR + total_iters=5 # Over 5 epochs +) + +# Cosine decay for remaining 95 epochs +cosine = lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=95, # 95 epochs after warmup + eta_min=1e-5 # Final LR +) + +# Combine sequentially +scheduler = lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[5] # Switch to cosine after epoch 5 +) + +# Training loop +for epoch in range(100): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Schedule Visualization (base_lr=0.001):** +- Epochs 0-4: Linear ramp from 0.00001 → 0.001 (warmup) +- Epochs 5-99: Cosine decay from 0.001 → 0.00001 + +**Use For:** Vision transformers, modern CNNs, most large-scale training + + +### Pattern 2: Linear Warmup + MultiStepLR + +```python +# Warmup for 5 epochs +warmup = lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, + total_iters=5 +) + +# Step decay at 30, 60, 90 +steps = lr_scheduler.MultiStepLR( + optimizer, + milestones=[30, 60, 90], + gamma=0.1 +) + +# Combine +scheduler = lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, steps], + milestones=[5] +) +``` + +**Use For:** Classical CNN training with warmup + + +### Pattern 3: Manual Warmup (More Control) + +```python +def get_lr_schedule(epoch, total_epochs, base_lr, warmup_epochs=5): + """ + Custom schedule with warmup and cosine decay. + """ + if epoch < warmup_epochs: + # Linear warmup + return base_lr * (epoch + 1) / warmup_epochs + else: + # Cosine decay + progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs) + return base_lr * 0.5 * (1 + math.cos(math.pi * progress)) + +# Training loop +for epoch in range(100): + lr = get_lr_schedule(epoch, total_epochs=100, base_lr=0.001) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + train_one_epoch(model, train_loader, optimizer) +``` + +**Use For:** Custom schedules, research, maximum control + + +### Pattern 4: Transformer-Style Warmup (Inverse Square Root) + +```python +def transformer_lr_schedule(step, d_model, warmup_steps): + """ + Transformer schedule from "Attention is All You Need". + LR increases during warmup, then decreases proportionally to inverse sqrt of step. + """ + step = step + 1 # 1-indexed + return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5) + +scheduler = lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda step: transformer_lr_schedule(step, d_model=512, warmup_steps=4000) +) + +# Training loop - NOTE: step every BATCH for this schedule +for epoch in range(epochs): + for batch in train_loader: + train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # Step every batch +``` + +**Use For:** Transformer models (BERT, GPT), following original papers + + +## Warmup Duration Guidelines + +**How many warmup epochs?** + +- **Transformers:** 5-20 epochs (or 5-10% of total training) +- **Vision models:** 5-10 epochs +- **Very large models (>1B params):** 10-20 epochs +- **Small models:** 3-5 epochs + +**Rule of thumb:** 5-10% of total training epochs + +**Examples:** +- 100-epoch training: 5-10 epoch warmup +- 20-epoch training: 2-3 epoch warmup +- 300-epoch training: 15-30 epoch warmup + + +## "But My Transformer Trained Fine Without Warmup" + +Some users report training transformers without warmup successfully. Here's the reality: + +**What "fine" actually means:** +- Training didn't diverge (NaN loss) - that's a low bar +- Got reasonable accuracy - but NOT optimal accuracy +- One successful run doesn't mean it's optimal or reliable + +**What you're missing without warmup:** + +**1. Performance gap (1-3% accuracy):** + +``` +Without warmup: Training works, achieves 85% accuracy +With warmup: Same model achieves 87-88% accuracy +``` + +That 2-3% is SIGNIFICANT: +- Difference between competitive and SOTA +- Difference between accepted and rejected paper +- Difference between passing and failing business metrics + +**2. Training stability:** + +``` +Without warmup: +- Some runs diverge → need to restart with lower LR +- Sensitive to initialization seed +- Requires careful LR tuning +- Success rate: 60-80% of runs + +With warmup: +- Stable training → consistent results +- Robust to initialization +- Wider stable LR range +- Success rate: 95-100% of runs +``` + +**3. Hyperparameter sensitivity:** + +Without warmup: +- Very sensitive to initial LR choice (0.001 works, 0.0015 diverges) +- Sensitive to batch size +- Sensitive to optimizer settings + +With warmup: +- More forgiving LR range (0.0005-0.002 all work) +- Less sensitive to batch size +- Robust optimizer configuration + +**Empirical Evidence - Published Papers:** + +Check transformer papers - ALL use warmup: + +| Model | Paper | Warmup | +|-------|-------|--------| +| ViT | Dosovitskiy et al., 2020 | ✅ Linear, 10k steps | +| DeiT | Touvron et al., 2021 | ✅ Linear, 5 epochs | +| Swin | Liu et al., 2021 | ✅ Linear, 20 epochs | +| BERT | Devlin et al., 2018 | ✅ Linear, 10k steps | +| GPT-2 | Radford et al., 2019 | ✅ Linear warmup | +| GPT-3 | Brown et al., 2020 | ✅ Linear warmup | +| T5 | Raffel et al., 2020 | ✅ Inverse sqrt warmup | + +**Every competitive transformer model uses warmup - there's a reason.** + +**"But I got 85% accuracy without warmup!"** + +Great! Now try with warmup and see if you get 87-88%. You probably will. + +**The cost-benefit analysis:** + +```python +# Cost: One line of code +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +scheduler = SequentialLR(optimizer, [warmup, main], [5]) + +# Benefit: +# - 1-3% better accuracy +# - More stable training +# - Higher success rate +# - Wider stable hyperparameter range +``` + +**Recommendation:** + +1. Run ablation study: Train your model with and without warmup +2. Compare: Final test accuracy, training stability, number of failed runs +3. You'll find warmup gives better results with minimal complexity + +**Bottom line:** Just because something "works" doesn't mean it's optimal. Warmup is standard practice for transformers because it consistently improves results. + + +### 5. LR Finder - Finding Optimal Initial LR + +## What is LR Finder? + +**Method from Leslie Smith (2015):** Cyclical Learning Rates paper + +**Core Idea:** +1. Start with very small LR (1e-8) +2. Gradually increase LR (multiply by ~1.1 each batch) +3. Train for a few hundred steps, record loss at each LR +4. Plot loss vs LR +5. Choose LR where loss decreases fastest (steepest descent) + +**Why It Works:** +- Too low LR: Loss decreases very slowly +- Optimal LR: Loss decreases rapidly (steepest slope) +- Too high LR: Loss plateaus or increases (instability) + +**Typical Findings:** +- Loss decreases fastest at some LR (e.g., 0.01) +- Loss starts increasing at higher LR (e.g., 0.1) +- Choose LR slightly below fastest descent point (e.g., 0.003-0.01) + + +## LR Finder Implementation + +```python +import torch +import matplotlib.pyplot as plt +import numpy as np + +def find_lr(model, train_loader, optimizer, loss_fn, device, + start_lr=1e-8, end_lr=10, num_iter=100, smooth_f=0.05): + """ + LR Finder: Sweep learning rates and plot loss curve. + + Args: + model: PyTorch model + train_loader: Training data loader + optimizer: Optimizer (will be modified) + loss_fn: Loss function + device: Device to train on + start_lr: Starting learning rate (default: 1e-8) + end_lr: Ending learning rate (default: 10) + num_iter: Number of iterations (default: 100) + smooth_f: Smoothing factor for loss (default: 0.05) + + Returns: + lrs: List of learning rates tested + losses: List of losses at each LR + """ + # Save initial model state to restore later + model.train() + initial_state = model.state_dict() + + # Calculate LR multiplier for exponential increase + lr_mult = (end_lr / start_lr) ** (1 / num_iter) + + lrs = [] + losses = [] + best_loss = float('inf') + avg_loss = 0 + + lr = start_lr + + # Iterate through training data + iterator = iter(train_loader) + for iteration in range(num_iter): + try: + data, target = next(iterator) + except StopIteration: + # Restart iterator if we run out of data + iterator = iter(train_loader) + data, target = next(iterator) + + # Set learning rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Forward pass + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + + # Compute smoothed loss (exponential moving average) + if iteration == 0: + avg_loss = loss.item() + else: + avg_loss = smooth_f * loss.item() + (1 - smooth_f) * avg_loss + + # Record + lrs.append(lr) + losses.append(avg_loss) + + # Track best loss + if avg_loss < best_loss: + best_loss = avg_loss + + # Stop if loss explodes (>4x best loss) + if avg_loss > 4 * best_loss: + print(f"Stopping early at iteration {iteration}: loss exploded") + break + + # Backward pass + loss.backward() + optimizer.step() + + # Increase learning rate + lr *= lr_mult + if lr > end_lr: + break + + # Restore model to initial state + model.load_state_dict(initial_state) + + # Plot results + plt.figure(figsize=(10, 6)) + plt.plot(lrs, losses) + plt.xscale('log') + plt.xlabel('Learning Rate (log scale)') + plt.ylabel('Loss') + plt.title('LR Finder') + plt.grid(True, alpha=0.3) + + # Mark suggested LR (10x below minimum loss) + min_loss_idx = np.argmin(losses) + suggested_lr = lrs[max(0, min_loss_idx - 5)] # A bit before minimum + plt.axvline(suggested_lr, color='red', linestyle='--', + label=f'Suggested LR: {suggested_lr:.2e}') + plt.legend() + plt.show() + + print(f"\nLR Finder Results:") + print(f" Minimum loss at LR: {lrs[min_loss_idx]:.2e}") + print(f" Suggested starting LR: {suggested_lr:.2e}") + print(f" (Choose LR where loss decreases fastest, before minimum)") + + return lrs, losses + + +def suggest_lr_from_finder(lrs, losses): + """ + Suggest optimal learning rate from LR finder results. + + Strategy: Find LR where loss gradient is steepest (fastest decrease). + """ + # Compute gradient of loss w.r.t. log(LR) + log_lrs = np.log10(lrs) + gradients = np.gradient(losses, log_lrs) + + # Find steepest descent (most negative gradient) + steepest_idx = np.argmin(gradients) + + # Suggested LR is at steepest point or slightly before + suggested_lr = lrs[steepest_idx] + + return suggested_lr +``` + + +## Using LR Finder + +### Basic Usage: + +```python +# Setup model, optimizer, loss +model = YourModel().to(device) +optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # LR will be overridden +loss_fn = torch.nn.CrossEntropyLoss() + +# Run LR finder +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) + +# Manually inspect plot and choose LR +# Look for: steepest descent point (fastest loss decrease) +# Typically: 10x lower than loss minimum + +# Example: If minimum is at 0.1, choose 0.01 as starting LR +base_lr = 0.01 # Based on plot inspection +``` + +### Automated LR Selection: + +```python +# Run LR finder +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) + +# Get suggested LR +suggested_lr = suggest_lr_from_finder(lrs, losses) + +# Use suggested LR +optimizer = torch.optim.SGD(model.parameters(), lr=suggested_lr) +``` + +### Using with OneCycleLR: + +```python +# Find optimal LR +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) +optimal_lr = suggest_lr_from_finder(lrs, losses) # e.g., 0.01 + +# OneCycleLR: Use 5-10x optimal as max_lr +max_lr = optimal_lr * 10 # e.g., 0.1 + +scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=max_lr, + steps_per_epoch=len(train_loader), + epochs=20 +) +``` + + +## Interpreting LR Finder Results + +**Typical Plot Patterns:** + +``` +Loss +| +| X <-- Loss explodes (LR too high) +| X +| X +| X <-- Loss minimum (still too high) +| X +| X <-- CHOOSE HERE (steepest descent) +| X +| X +| X +|X___________ + 1e-8 1e-4 1e-2 0.1 1.0 10 + Learning Rate +``` + +**How to Choose:** + +1. **Steepest Descent (BEST):** + - Find where loss decreases fastest (steepest downward slope) + - This is optimal LR for rapid convergence + - Example: If steepest at 0.01, choose 0.01 + +2. **Before Minimum (SAFE):** + - Find minimum loss LR (e.g., 0.1) + - Choose 10x lower (e.g., 0.01) + - More conservative, safer choice + +3. **Avoid:** + - Don't choose minimum itself (often too high) + - Don't choose where loss is flat (too low, slow progress) + - Don't choose where loss increases (way too high) + +**Guidelines:** +- For SGD: Choose at steepest descent +- For Adam: Choose 10x below steepest (Adam more sensitive) +- For OneCycle: Use steepest as optimal, 5-10x as max_lr + + +## When to Use LR Finder + +**Use LR Finder When:** + +✅ Starting new project (unknown optimal LR) +✅ New architecture or dataset +✅ Tuning OneCycleLR (finding max_lr) +✅ Transitioning between optimizers +✅ Having training instability issues + +**Can Skip When:** + +❌ Following established paper recipe (LR already known) +❌ Fine-tuning (small LR like 1e-5 typically works) +❌ Very constrained time/resources +❌ Using adaptive methods (ReduceLROnPlateau) + +**Best Practice:** +- Run LR finder once at project start +- Use found LR for all subsequent runs +- Re-run if changing optimizer, architecture, or batch size significantly + + +### 6. Scheduler Selection Guide + +## Selection Flowchart + +**1. What's your training duration?** + +- **<10 epochs:** Constant LR or simple linear decay +- **10-30 epochs:** OneCycleLR (fast) or CosineAnnealingLR +- **>30 epochs:** CosineAnnealingLR or MultiStepLR + +**2. What's your model type?** + +- **Transformer (ViT, BERT, GPT):** CosineAnnealing + WARMUP (mandatory) +- **CNN (ResNet, EfficientNet):** MultiStepLR or CosineAnnealing + optional warmup +- **Small model:** Simpler schedulers (StepLR) or constant LR + +**3. Do you know optimal schedule?** + +- **Yes (from paper):** Use paper's schedule (MultiStepLR usually) +- **No (exploring):** ReduceLROnPlateau or CosineAnnealing +- **Want fast results:** OneCycleLR + LR finder + +**4. What's your compute budget?** + +- **High budget (100+ epochs):** CosineAnnealing or MultiStepLR +- **Low budget (10-20 epochs):** OneCycleLR +- **Adaptive budget:** ReduceLROnPlateau (stops when plateau) + + +## Paper Recipe vs Modern Best Practices + +**If goal is EXACT REPRODUCTION:** + +Use paper's exact schedule (down to every detail): + +```python +# Example: Reproducing ResNet paper (He et al., 2015) +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1) +# No warmup (paper didn't use it) +# Train for 100 epochs +``` + +**Rationale:** +- Reproduce results exactly +- Enable apples-to-apples comparison +- Validate paper's claims +- Establish baseline before improvements + +**If goal is BEST PERFORMANCE:** + +Use modern recipe (benefit from years of community learning): + +```python +# Modern equivalent: ResNet with modern practices +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +# Train for 100 epochs +``` + +**Rationale:** +- Typically +0.5-2% better accuracy than original paper +- More stable training +- Reflects 5-10 years of community improvements +- SOTA competitive performance + +**Evolution of LR Scheduling Practices:** + +**Early Deep Learning (2012-2016):** +- Scheduler: StepLR with manual milestones +- Warmup: Not used (not yet discovered) +- Optimizer: SGD with momentum +- Examples: AlexNet, VGG, ResNet, Inception + +**Mid Period (2017-2019):** +- Scheduler: CosineAnnealing introduced, OneCycleLR popular +- Warmup: Starting to be used for large batch training +- Optimizer: SGD still dominant, Adam increasingly common +- Examples: ResNeXt, DenseNet, MobileNet + +**Modern Era (2020-2025):** +- Scheduler: CosineAnnealing default, OneCycle for fast training +- Warmup: Standard practice (mandatory for transformers) +- Optimizer: AdamW increasingly preferred for transformers +- Examples: ViT, EfficientNet, ConvNeXt, Swin, DeiT + +**Practical Workflow:** + +**Step 1: Reproduce paper recipe** +```python +# Use exact paper settings +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) +scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1) +# Should match paper's reported accuracy (e.g., 76.5%) +``` + +**Step 2: Validate reproduction** +- If you get 76.5% (matches paper): ✅ Reproduction successful +- If you get 74% (2% worse): ❌ Implementation bug, fix first +- If you get 78% (2% better): ✅ Great! Proceed to modern recipe + +**Step 3: Try modern recipe** +```python +# Add warmup + cosine +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +# Expect +0.5-2% improvement (e.g., 77-78.5%) +``` + +**Step 4: Compare results** + +| Version | Accuracy | Notes | +|---------|----------|-------| +| Paper recipe | 76.5% | Baseline (reproduces paper) | +| Modern recipe | 78.0% | +1.5% from warmup + cosine | + +**When to Use Which:** + +**Use Paper Recipe:** +- Publishing reproduction study +- Comparing to paper's baseline +- Validating implementation correctness +- Research requiring exact reproducibility + +**Use Modern Recipe:** +- Building production system (want best performance) +- Competing in benchmark (need SOTA results) +- Publishing new method (should use modern baseline) +- Limited compute (modern practices more efficient) + +**Trade-off Table:** + +| Aspect | Paper Recipe | Modern Recipe | +|--------|--------------|---------------| +| Reproducibility | ✅ Exact | ⚠️ Better but different | +| Performance | ⚠️ Good (for its time) | ✅ Better (+0.5-2%) | +| Comparability | ✅ To paper | ✅ To SOTA | +| Compute efficiency | ⚠️ May be suboptimal | ✅ Modern optimizations | +| Training stability | ⚠️ Variable | ✅ More stable (warmup) | + +**Bottom Line:** + +Both are valid depending on your goal: +- **Research/reproduction:** Start with paper recipe +- **Production/competition:** Use modern recipe +- **Best practice:** Validate with paper recipe, deploy with modern recipe + + +## Domain-Specific Recommendations + +### Image Classification (CNNs) + +**Standard Recipe (ResNet, VGG):** +```python +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1) +# Train for 100 epochs +``` + +**Modern Recipe (EfficientNet, RegNet):** +```python +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +# Train for 100 epochs +``` + +### Vision Transformers (ViT, Swin, DeiT) + +**Standard Recipe:** +```python +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=10) +cosine = CosineAnnealingLR(optimizer, T_max=290, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [10]) +# Train for 300 epochs +# WARMUP IS MANDATORY +``` + +### NLP Transformers (BERT, GPT, T5) + +**Standard Recipe:** +```python +optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01) + +# Linear warmup + linear decay +def lr_lambda(step): + warmup_steps = 10000 + total_steps = 100000 + if step < warmup_steps: + return step / warmup_steps + else: + return max(0.0, (total_steps - step) / (total_steps - warmup_steps)) + +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) +# Step every batch, not epoch +# WARMUP IS MANDATORY +``` + +### Object Detection (Faster R-CNN, YOLO) + +**Standard Recipe:** +```python +optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-4) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[16, 22], gamma=0.1) +# Train for 26 epochs +``` + +### Fast Training (Limited Compute) + +**FastAI Recipe:** +```python +# Run LR finder first +optimal_lr = find_lr(model, train_loader, optimizer, loss_fn, device) +max_lr = optimal_lr * 10 + +optimizer = torch.optim.SGD(model.parameters(), lr=max_lr) +scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=max_lr, + steps_per_epoch=len(train_loader), + epochs=20, + pct_start=0.3 +) +# Train for 20 epochs +# Step every batch +``` + + +### 7. Common Scheduling Pitfalls + +## Pitfall 1: No Warmup for Transformers + +**WRONG:** +```python +# Training Vision Transformer +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) +# ❌ No warmup - training will be very unstable or diverge +``` + +**RIGHT:** +```python +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +# ✅ Warmup prevents early instability +``` + +**Why It Matters:** +- Transformers with high LR at start → NaN loss, divergence +- Random initialization needs gradual LR ramp +- 5-10 epoch warmup is STANDARD practice + +**How to Detect:** +- Loss is NaN or explodes in first few epochs +- Training very unstable early, stabilizes later +- Gradients extremely large at start + + +## Pitfall 2: Wrong scheduler.step() Placement + +**WRONG (Most Schedulers):** +```python +for epoch in range(epochs): + for batch in train_loader: + loss = train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # ❌ Stepping every batch, not every epoch +``` + +**RIGHT:** +```python +for epoch in range(epochs): + for batch in train_loader: + loss = train_step(model, batch, optimizer) + optimizer.step() + + scheduler.step() # ✅ Step AFTER each epoch +``` + +**EXCEPTION (OneCycleLR):** +```python +for epoch in range(epochs): + for batch in train_loader: + loss = train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # ✅ OneCycle steps EVERY BATCH +``` + +**Why It Matters:** +- CosineAnnealing with T_max=100 expects 100 steps (epochs) +- Stepping every batch: If 390 batches/epoch, LR decays in <1 epoch +- LR reaches minimum way too fast + +**How to Detect:** +- LR decays to minimum in first epoch +- Print LR each step: `print(optimizer.param_groups[0]['lr'])` +- Check if LR changes every batch (wrong) vs every epoch (right) + +**Rule:** +- **Most schedulers (Step, Cosine, Exponential):** Step per epoch +- **OneCycleLR only:** Step per batch +- **ReduceLROnPlateau:** Step per epoch with validation metric + + +## Pitfall 3: scheduler.step() Before optimizer.step() + +**WRONG:** +```python +loss.backward() +scheduler.step() # ❌ Wrong order +optimizer.step() +``` + +**RIGHT:** +```python +loss.backward() +optimizer.step() # ✅ Update weights first +scheduler.step() # Then update LR +``` + +**Why It Matters:** +- Scheduler updates LR based on current epoch/step +- Should update weights with current LR, THEN move to next LR +- Wrong order = off-by-one error in schedule + +**How to Detect:** +- Usually subtle, hard to notice +- Best practice: always optimizer.step() then scheduler.step() + + +## Pitfall 4: Not Passing Metric to ReduceLROnPlateau + +**WRONG:** +```python +scheduler = ReduceLROnPlateau(optimizer) +for epoch in range(epochs): + train_loss = train_one_epoch(model, train_loader, optimizer) + scheduler.step() # ❌ No metric passed +``` + +**RIGHT:** +```python +scheduler = ReduceLROnPlateau(optimizer, mode='min') +for epoch in range(epochs): + train_loss = train_one_epoch(model, train_loader, optimizer) + val_loss = validate(model, val_loader) + scheduler.step(val_loss) # ✅ Pass validation metric +``` + +**Why It Matters:** +- ReduceLROnPlateau NEEDS metric to detect plateau +- Without metric, scheduler doesn't know when to reduce LR +- Will get error or incorrect behavior + +**How to Detect:** +- Error message: "ReduceLROnPlateau needs a metric" +- LR never reduces even when training plateaus + + +## Pitfall 5: Using OneCycle for Long Training + +**SUBOPTIMAL:** +```python +# Training for 200 epochs +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=200, steps_per_epoch=len(train_loader)) +# ❌ OneCycle designed for shorter training (10-30 epochs) +``` + +**BETTER:** +```python +# For long training, use Cosine +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=10) +cosine = CosineAnnealingLR(optimizer, T_max=190, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [10]) +# ✅ Cosine better suited for long training +``` + +**Why It Matters:** +- OneCycle's aggressive up-then-down profile works for short training +- For long training, gentler cosine decay more stable +- OneCycle typically used for 10-30 epochs in FastAI style + +**When to Use Each:** +- **OneCycle:** 10-30 epochs, limited compute, want fast results +- **Cosine:** 50+ epochs, full training, want best final performance + + +## Pitfall 6: Not Tuning max_lr for OneCycle + +**WRONG:** +```python +# Just guessing max_lr +scheduler = OneCycleLR(optimizer, max_lr=0.1, epochs=20, steps_per_epoch=len(train_loader)) +# ❌ Random max_lr without tuning +# Might be too high (unstable) or too low (slow) +``` + +**RIGHT:** +```python +# Step 1: Run LR finder +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) +optimal_lr = suggest_lr_from_finder(lrs, losses) # e.g., 0.01 + +# Step 2: Use 5-10x optimal as max_lr +max_lr = optimal_lr * 10 # e.g., 0.1 + +scheduler = OneCycleLR(optimizer, max_lr=max_lr, epochs=20, steps_per_epoch=len(train_loader)) +# ✅ Tuned max_lr based on LR finder +``` + +**Why It Matters:** +- OneCycle is VERY sensitive to max_lr +- Too high: Training unstable, loss explodes +- Too low: Slow training, underperforms +- LR finder finds optimal, use 5-10x as max_lr + +**How to Tune:** +1. Run LR finder (see LR Finder section) +2. Find optimal LR (steepest descent point) +3. Use 5-10x optimal as max_lr for OneCycle +4. If still unstable, reduce max_lr (try 3x, 2x) + + +## Pitfall 7: Forgetting to Adjust T_max After Adding Warmup + +**WRONG:** +```python +# Want 100 epoch training +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=100) # ❌ Should be 95 +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +``` + +**RIGHT:** +```python +# Want 100 epoch training +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95) # ✅ 100 - 5 = 95 +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) +``` + +**Why It Matters:** +- Total training is warmup + main schedule +- If warmup is 5 epochs and cosine is 100, total is 105 epochs +- T_max should be (total_epochs - warmup_epochs) + +**How to Calculate:** +```python +total_epochs = 100 +warmup_epochs = 5 +T_max = total_epochs - warmup_epochs # 95 +``` + + +## Pitfall 8: Using Same LR for All Param Groups + +**SUBOPTIMAL:** +```python +# Fine-tuning: applying same LR to all layers +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +# ❌ Backbone and head both use 1e-3 +``` + +**BETTER:** +```python +# Fine-tuning: lower LR for pretrained backbone, higher for new head +optimizer = torch.optim.Adam([ + {'params': model.backbone.parameters(), 'lr': 1e-4}, # Lower LR for pretrained + {'params': model.head.parameters(), 'lr': 1e-3} # Higher LR for random init +]) +scheduler = CosineAnnealingLR(optimizer, T_max=100) +# ✅ Scheduler applies to all param groups proportionally +``` + +**Why It Matters:** +- Pretrained layers need smaller LR (already trained) +- New layers need higher LR (random initialization) +- Schedulers work with param groups automatically + +**Note:** Schedulers multiply all param groups by same factor, preserving relative ratios + + +## Pitfall 9: Not Monitoring LR During Training + +**PROBLEM:** +- Schedule not behaving as expected +- Hard to debug without visibility into LR + +**SOLUTION:** +```python +# Log LR every epoch +for epoch in range(epochs): + current_lr = optimizer.param_groups[0]['lr'] + print(f"Epoch {epoch}: LR = {current_lr:.6f}") + + train_one_epoch(model, train_loader, optimizer) + scheduler.step() + +# Or use TensorBoard +from torch.utils.tensorboard import SummaryWriter +writer = SummaryWriter() + +for epoch in range(epochs): + current_lr = optimizer.param_groups[0]['lr'] + writer.add_scalar('Learning Rate', current_lr, epoch) + + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Best Practice:** +- Always log LR to console or TensorBoard +- Plot LR schedule before training (see next section) +- Verify schedule matches expectations + + +## Pitfall 10: Not Validating Schedule Before Training + +**PROBLEM:** +- Run full training, discover schedule was wrong +- Waste compute on incorrect schedule + +**SOLUTION: Dry-run the schedule:** +```python +def plot_schedule(scheduler_fn, num_epochs): + """ + Plot LR schedule before training to verify it's correct. + """ + # Create dummy model and optimizer + model = torch.nn.Linear(1, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + scheduler = scheduler_fn(optimizer) + + lrs = [] + for epoch in range(num_epochs): + lrs.append(optimizer.param_groups[0]['lr']) + optimizer.step() # Dummy step + scheduler.step() + + # Plot + plt.figure(figsize=(10, 6)) + plt.plot(lrs) + plt.xlabel('Epoch') + plt.ylabel('Learning Rate') + plt.title('LR Schedule') + plt.grid(True, alpha=0.3) + plt.show() + +# Usage +def my_scheduler(opt): + warmup = LinearLR(opt, start_factor=0.01, total_iters=5) + cosine = CosineAnnealingLR(opt, T_max=95) + return SequentialLR(opt, [warmup, cosine], [5]) + +plot_schedule(my_scheduler, num_epochs=100) +# Verify plot looks correct BEFORE training +``` + +**Best Practice:** +- Plot schedule before every major training run +- Verify warmup duration, decay shape, final LR +- Catch mistakes early (T_max wrong, step placement, etc.) + + +### 8. Modern Best Practices (2024-2025) + +## Vision Models (CNNs, ResNets, ConvNeXt) + +**Standard Recipe:** +```python +# Optimizer +optimizer = torch.optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4 +) + +# Scheduler: MultiStepLR or CosineAnnealing +# Option 1: MultiStepLR (classical) +scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1) + +# Option 2: CosineAnnealing (modern) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) + +# Training +epochs = 100 +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Key Points:** +- SGD with momentum (0.9) is standard for CNNs +- LR = 0.1 for batch size 256 (scale linearly for other batch sizes) +- Warmup optional but beneficial (5 epochs) +- CosineAnnealing increasingly preferred over MultiStepLR + + +## Vision Transformers (ViT, Swin, DeiT) + +**Standard Recipe:** +```python +# Optimizer +optimizer = torch.optim.AdamW( + model.parameters(), + lr=1e-3, + weight_decay=0.05, + betas=(0.9, 0.999) +) + +# Scheduler: MUST include warmup +warmup_epochs = 10 +cosine_epochs = 290 +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs) +cosine = CosineAnnealingLR(optimizer, T_max=cosine_epochs, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [warmup_epochs]) + +# Training +epochs = 300 +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Key Points:** +- AdamW optimizer (not SGD) +- Warmup is MANDATORY (10-20 epochs) +- Long training (300 epochs typical) +- LR = 1e-3 for batch size 512 (scale for other sizes) +- Cosine decay to very small LR (1e-5) + +**Why Warmup is Critical for ViT:** +- Self-attention layers highly sensitive to initialization +- High LR at start causes gradient explosion +- Warmup allows attention patterns to stabilize + + +## NLP Transformers (BERT, GPT, T5) + +**Standard Recipe:** +```python +# Optimizer +optimizer = torch.optim.AdamW( + model.parameters(), + lr=5e-4, + weight_decay=0.01, + betas=(0.9, 0.999) +) + +# Scheduler: Linear warmup + linear decay (or inverse sqrt) +total_steps = len(train_loader) * epochs +warmup_steps = int(0.1 * total_steps) # 10% warmup + +def lr_lambda(step): + if step < warmup_steps: + return step / warmup_steps + else: + return max(0.0, (total_steps - step) / (total_steps - warmup_steps)) + +scheduler = LambdaLR(optimizer, lr_lambda) + +# Training: step EVERY BATCH +for epoch in range(epochs): + for batch in train_loader: + train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # Step every batch, not epoch +``` + +**Key Points:** +- AdamW optimizer +- Warmup is MANDATORY (typically 10% of training) +- Linear warmup + linear decay (BERT, GPT-2 style) +- Step scheduler EVERY BATCH (not every epoch) +- LR typically 1e-4 to 5e-4 + +**Alternative: Inverse Square Root (Original Transformer):** +```python +def transformer_schedule(step): + warmup_steps = 4000 + step = step + 1 + return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5) + +scheduler = LambdaLR(optimizer, transformer_schedule) +``` + + +## Object Detection (Faster R-CNN, YOLO, DETR) + +**Standard Recipe (Two-stage detectors):** +```python +# Optimizer +optimizer = torch.optim.SGD( + model.parameters(), + lr=0.02, + momentum=0.9, + weight_decay=1e-4 +) + +# Scheduler: MultiStepLR with short schedule +scheduler = MultiStepLR(optimizer, milestones=[16, 22], gamma=0.1) + +# Training +epochs = 26 # Shorter than classification +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Standard Recipe (Transformer detectors like DETR):** +```python +# Optimizer +optimizer = torch.optim.AdamW( + [ + {'params': model.backbone.parameters(), 'lr': 1e-5}, # Lower for backbone + {'params': model.transformer.parameters(), 'lr': 1e-4} # Higher for transformer + ], + weight_decay=1e-4 +) + +# Scheduler: Step decay +scheduler = MultiStepLR(optimizer, milestones=[200], gamma=0.1) + +# Training: Long schedule for DETR +epochs = 300 +``` + +**Key Points:** +- Detection typically shorter training than classification +- Lower LR (0.02 vs 0.1) due to task difficulty +- DETR needs very long training (300 epochs) + + +## Semantic Segmentation (U-Net, DeepLab, SegFormer) + +**Standard Recipe (CNN-based):** +```python +# Optimizer +optimizer = torch.optim.SGD( + model.parameters(), + lr=0.01, + momentum=0.9, + weight_decay=1e-4 +) + +# Scheduler: Polynomial decay (common in segmentation) +def poly_lr_lambda(epoch): + return (1 - epoch / total_epochs) ** 0.9 + +scheduler = LambdaLR(optimizer, poly_lr_lambda) + +# Training +epochs = 100 +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Key Points:** +- Polynomial decay common in segmentation (DeepLab papers) +- Lower initial LR (0.01) than classification +- Power of 0.9 standard + + +## Fast Training / Limited Compute (FastAI Style) + +**OneCycle Recipe:** +```python +# Step 1: Find optimal LR +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) +optimal_lr = suggest_lr_from_finder(lrs, losses) # e.g., 0.01 +max_lr = optimal_lr * 10 # e.g., 0.1 + +# Step 2: OneCycleLR +optimizer = torch.optim.SGD(model.parameters(), lr=max_lr, momentum=0.9) +scheduler = OneCycleLR( + optimizer, + max_lr=max_lr, + steps_per_epoch=len(train_loader), + epochs=20, + pct_start=0.3, # 30% warmup, 70% cooldown + anneal_strategy='cos' +) + +# Step 3: Train (step every batch) +for epoch in range(20): + for batch in train_loader: + train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # Every batch +``` + +**Key Points:** +- Use LR finder to tune max_lr (CRITICAL) +- Train for fewer epochs (10-30) +- Step scheduler every batch +- Often achieves 90-95% of full training performance in 20-30% of time + + +## Fine-Tuning Pretrained Models + +**Standard Recipe:** +```python +# Optimizer: Different LRs for backbone vs head +optimizer = torch.optim.AdamW([ + {'params': model.backbone.parameters(), 'lr': 1e-5}, # Very low for pretrained + {'params': model.head.parameters(), 'lr': 1e-3} # Higher for new head +]) + +# Scheduler: Simple cosine or even constant +# Option 1: Constant LR (fine-tuning often doesn't need scheduling) +scheduler = None + +# Option 2: Gentle cosine decay +scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6) + +# Training: Short duration +epochs = 10 # Fine-tuning is quick +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + if scheduler: + scheduler.step() +``` + +**Key Points:** +- Much lower LR for pretrained parts (1e-5) +- Higher LR for new/random parts (1e-3) +- Short training (3-10 epochs) +- Scheduling often optional (constant LR works) +- No warmup needed (weights already good) + + +## Large Batch Training (Batch Size > 1024) + +**Standard Recipe:** +```python +# Linear LR scaling rule: LR scales with batch size +base_lr = 0.1 # For batch size 256 +batch_size = 2048 +scaled_lr = base_lr * (batch_size / 256) # 0.8 for batch 2048 + +# Optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr, momentum=0.9) + +# Scheduler: MUST include warmup (critical for large batch) +warmup_epochs = 5 +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [warmup_epochs]) + +# Training +epochs = 100 +for epoch in range(epochs): + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Key Points:** +- Scale LR linearly with batch size (LR = base_lr * batch_size / base_batch_size) +- Warmup is MANDATORY for large batch (5-10 epochs minimum) +- Longer warmup for very large batches (>4096: use 10-20 epochs) + +**Why Warmup Critical for Large Batch:** +- Large batch = larger effective LR +- High effective LR at start causes instability +- Warmup prevents divergence + + +## Modern Defaults by Domain (2025) + +| Domain | Optimizer | Scheduler | Warmup | Epochs | +|--------|-----------|-----------|--------|--------| +| Vision (CNN) | SGD (0.9) | Cosine or MultiStep | Optional (5) | 100-200 | +| Vision (ViT) | AdamW | Cosine | MANDATORY (10-20) | 300 | +| NLP (BERT/GPT) | AdamW | Linear | MANDATORY (10%) | Varies | +| Detection | SGD | MultiStep | Optional | 26-300 | +| Segmentation | SGD | Polynomial | Optional | 100 | +| Fast/OneCycle | SGD | OneCycle | Built-in | 10-30 | +| Fine-tuning | AdamW | Constant/Cosine | No | 3-10 | +| Large Batch | SGD | Cosine | MANDATORY (5-20) | 100-200 | + + +### 9. Debugging Scheduler Issues + +## Issue: Training Unstable / Loss Spikes + +**Symptoms:** +- Loss increases suddenly during training +- NaN or Inf loss +- Training was stable, then becomes unstable + +**Likely Causes:** + +1. **No warmup (transformers, large models)** + - Solution: Add 5-10 epoch warmup + +2. **LR too high at start** + - Solution: Lower initial LR or extend warmup + +3. **LR drop too sharp (MultiStepLR)** + - Solution: Use gentler scheduler (Cosine) or smaller gamma + +**Debugging Steps:** + +```python +# 1. Print LR every epoch +for epoch in range(epochs): + current_lr = optimizer.param_groups[0]['lr'] + print(f"Epoch {epoch}: LR = {current_lr:.6e}") + + # 2. Check if loss spike correlates with LR change + loss = train_one_epoch(model, train_loader, optimizer) + print(f" Loss = {loss:.4f}") + + scheduler.step() + +# 3. Plot LR and loss together +import matplotlib.pyplot as plt +plt.figure(figsize=(12, 5)) +plt.subplot(1, 2, 1) +plt.plot(lr_history) +plt.xlabel('Epoch') +plt.ylabel('Learning Rate') +plt.subplot(1, 2, 2) +plt.plot(loss_history) +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.show() +``` + +**Solutions:** + +- Add/extend warmup: `LinearLR(optimizer, start_factor=0.01, total_iters=10)` +- Lower initial LR: `lr = 0.01` instead of `lr = 0.1` +- Gentler scheduler: `CosineAnnealingLR` instead of `MultiStepLR` +- Gradient clipping: `torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)` + + +## Issue: Training Plateaus Too Early + +**Symptoms:** +- Loss stops decreasing after 20-30 epochs +- Validation accuracy flat +- Training seems stuck + +**Likely Causes:** + +1. **Not using scheduler (constant LR too high for current regime)** + - Solution: Add scheduler (CosineAnnealing or ReduceLROnPlateau) + +2. **Scheduler reducing LR too early** + - Solution: Push back milestones or increase patience + +3. **LR already too low** + - Solution: Check current LR, may need to restart with higher initial LR + +**Debugging Steps:** + +```python +# Check current LR +current_lr = optimizer.param_groups[0]['lr'] +print(f"Current LR: {current_lr:.6e}") + +# If LR very low (<1e-6), plateau might be due to other issues (architecture, data, etc.) +# If LR still high (>1e-3), should reduce LR to break plateau +``` + +**Solutions:** + +- Add ReduceLROnPlateau: Automatically reduces when plateau detected + ```python + scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10) + ``` + +- Manual LR reduction: If at epoch 30 and plateaued, reduce LR by 10x now + ```python + for param_group in optimizer.param_groups: + param_group['lr'] *= 0.1 + ``` + +- Use scheduler from start next time: + ```python + scheduler = CosineAnnealingLR(optimizer, T_max=100) + ``` + + +## Issue: Poor Final Performance (Train > Val Gap) + +**Symptoms:** +- Training accuracy high (95%), validation lower (88%) +- Model overfitting +- Test performance disappointing + +**Likely Causes (Scheduling Related):** + +1. **LR not low enough at end** + - Solution: Lower eta_min or extend training + +2. **Not using scheduler (constant LR doesn't fine-tune)** + - Solution: Add scheduler to reduce LR in late training + +3. **Scheduler ending too early** + - Solution: Extend training or adjust T_max + +**Debugging Steps:** + +```python +# Check final LR +final_lr = optimizer.param_groups[0]['lr'] +print(f"Final LR: {final_lr:.6e}") + +# Final LR should be very low (1e-5 to 1e-6) +# If final LR still high (>1e-3), model didn't fine-tune properly +``` + +**Solutions:** + +- Lower eta_min: `CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)` +- Extend training: Train for more epochs to allow LR to decay further +- Add late-stage fine-tuning: + ```python + # After main training, do 10 more epochs with very low LR + for param_group in optimizer.param_groups: + param_group['lr'] = 1e-5 + for epoch in range(10): + train_one_epoch(model, train_loader, optimizer) + ``` + +**Note:** If train-val gap large, may also need regularization (not scheduling issue) + + +## Issue: LR Decays Too Fast + +**Symptoms:** +- LR reaches minimum in first few epochs +- Training very slow after initial epochs +- Looks like constant very low LR + +**Likely Causes:** + +1. **scheduler.step() called every batch instead of epoch** + - Solution: Move scheduler.step() outside batch loop + +2. **T_max too small (e.g., T_max=10 but training for 100 epochs)** + - Solution: Set T_max = total_epochs + +3. **Using OneCycle unintentionally** + - Solution: Verify scheduler type + +**Debugging Steps:** + +```python +# Print LR first few epochs +for epoch in range(10): + print(f"Epoch {epoch}: LR = {optimizer.param_groups[0]['lr']:.6e}") + for batch in train_loader: + train_step(model, batch, optimizer) + # scheduler.step() # ❌ If this is here, that's the bug + scheduler.step() # ✅ Should be here +``` + +**Solutions:** + +- Move scheduler.step() to correct location (after epoch, not after batch) +- Fix T_max: `T_max = total_epochs` or `T_max = total_epochs - warmup_epochs` +- Verify scheduler type: `print(type(scheduler))` + + +## Issue: OneCycleLR Not Working + +**Symptoms:** +- Training with OneCycle becomes unstable around peak LR +- Loss increases during ramp-up phase +- Worse performance than expected + +**Likely Causes:** + +1. **max_lr too high** + - Solution: Run LR finder, use lower max_lr + +2. **scheduler.step() placement wrong (should be per batch)** + - Solution: Call scheduler.step() every batch + +3. **Not tuning max_lr** + - Solution: Use LR finder to find optimal, use 5-10x as max_lr + +**Debugging Steps:** + +```python +# Plot LR schedule +lrs = [] +for epoch in range(epochs): + for batch in train_loader: + lrs.append(optimizer.param_groups[0]['lr']) + scheduler.step() + +plt.plot(lrs) +plt.xlabel('Batch') +plt.ylabel('Learning Rate') +plt.title('OneCycle LR Schedule') +plt.show() + +# Should see: ramp up to max_lr, then ramp down +# If doesn't look like that, scheduler.step() placement wrong +``` + +**Solutions:** + +- Run LR finder first: + ```python + optimal_lr = find_lr(model, train_loader, optimizer, loss_fn, device) + max_lr = optimal_lr * 10 # Or try 5x, 3x if 10x unstable + ``` + +- Lower max_lr manually: + ```python + # If max_lr=0.1 unstable, try 0.03 or 0.01 + scheduler = OneCycleLR(optimizer, max_lr=0.03, ...) + ``` + +- Verify step() every batch: + ```python + for epoch in range(epochs): + for batch in train_loader: + train_step(model, batch, optimizer) + optimizer.step() + scheduler.step() # ✅ Every batch + ``` + + +## Issue: Warmup Not Working + +**Symptoms:** +- Training still unstable in first few epochs despite warmup +- Loss spikes even with warmup +- NaN loss at start + +**Likely Causes:** + +1. **Warmup too short (need longer ramp-up)** + - Solution: Extend warmup from 5 to 10-20 epochs + +2. **start_factor too high (not starting low enough)** + - Solution: Use start_factor=0.001 instead of 0.01 + +3. **Warmup not actually being used (SequentialLR bug)** + - Solution: Verify warmup scheduler is active early + +**Debugging Steps:** + +```python +# Print LR first 10 epochs +for epoch in range(10): + current_lr = optimizer.param_groups[0]['lr'] + print(f"Epoch {epoch}: LR = {current_lr:.6e}") + # Should see gradual increase from low to high + # If jumps immediately to high, warmup not working + + train_one_epoch(model, train_loader, optimizer) + scheduler.step() +``` + +**Solutions:** + +- Extend warmup: + ```python + warmup = LinearLR(optimizer, start_factor=0.01, total_iters=20) # 20 epochs + ``` + +- Lower start_factor: + ```python + warmup = LinearLR(optimizer, start_factor=0.001, total_iters=5) # Start at 0.1% + ``` + +- Verify SequentialLR milestone: + ```python + # Milestone should match warmup duration + scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[20]) + ``` + +- Add gradient clipping as additional safeguard: + ```python + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + ``` + + +## Issue: ReduceLROnPlateau Never Reduces LR + +**Symptoms:** +- Using ReduceLROnPlateau for 50+ epochs +- Validation loss clearly plateaued +- Learning rate never reduces + +**Debugging Steps:** + +**1. Verify metric is being passed:** + +```python +val_loss = validate(model, val_loader) +print(f"Epoch {epoch}: val_loss = {val_loss:.6f}") # Print metric +scheduler.step(val_loss) # Ensure passing metric +``` + +**2. Check mode is correct:** + +```python +# For loss (want to minimize): +scheduler = ReduceLROnPlateau(optimizer, mode='min') + +# For accuracy (want to maximize): +scheduler = ReduceLROnPlateau(optimizer, mode='max') +``` + +Wrong mode means scheduler waits for opposite direction (loss increasing instead of decreasing). + +**3. Check threshold isn't too strict:** + +```python +# Default threshold=1e-4 (0.01% improvement threshold) +# If val_loss 0.5000 → 0.4999 (0.02% improvement), counts as improvement +# If threshold too high, tiny improvements prevent reduction + +# Solution: Lower threshold to be more sensitive +scheduler = ReduceLROnPlateau(optimizer, threshold=1e-5) + +# Or remove threshold entirely +scheduler = ReduceLROnPlateau(optimizer, threshold=0) +``` + +**4. Enable verbose logging:** + +```python +scheduler = ReduceLROnPlateau(optimizer, verbose=True) +# Prints: "Epoch 00042: reducing learning rate of group 0 to 1.0000e-04" +# when it reduces +``` + +**5. Verify plateau is real:** + +```python +# Plot validation loss over time +import matplotlib.pyplot as plt +plt.figure(figsize=(10, 6)) +plt.plot(val_losses) +plt.xlabel('Epoch') +plt.ylabel('Validation Loss') +plt.title('Validation Loss Over Time') +plt.grid(True, alpha=0.3) +plt.show() + +# Check: Is loss truly flat, or still slowly improving? +# Tiny improvements (0.4500 → 0.4499) count as progress +``` + +**6. Check cooldown isn't preventing reduction:** + +```python +# Default cooldown=0, but if set higher, prevents reduction after recent reduction +scheduler = ReduceLROnPlateau(optimizer, cooldown=0) # No cooldown +``` + +**Common Causes Table:** + +| Problem | Symptom | Solution | +|---------|---------|----------| +| Not passing metric | Error or no reduction | `scheduler.step(val_loss)` | +| Wrong mode | Never reduces | `mode='min'` for loss, `mode='max'` for accuracy | +| Threshold too strict | Ignores small improvements | Lower to `threshold=1e-5` or `0` | +| Metric still improving | Not actually plateaued | Increase patience or accept slow progress | +| Cooldown active | Reducing but waiting | Set `cooldown=0` | +| Min_lr reached | Can't reduce further | Check current LR, may be at min_lr | + +**Example Fix:** + +```python +scheduler = ReduceLROnPlateau( + optimizer, + mode='min', # For loss minimization + factor=0.1, # Reduce by 10x + patience=10, # Wait 10 epochs + threshold=0, # Accept any improvement (most sensitive) + threshold_mode='rel', + cooldown=0, # No cooldown period + min_lr=1e-6, # Minimum LR allowed + verbose=True # Print when reducing +) + +# Training loop +for epoch in range(epochs): + train_loss = train_one_epoch(model, train_loader, optimizer) + val_loss = validate(model, val_loader) + + print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}") + + scheduler.step(val_loss) # Pass validation loss + + # Print current LR + current_lr = optimizer.param_groups[0]['lr'] + print(f" Current LR: {current_lr:.6e}") +``` + +**Advanced Debugging:** + +If still not reducing, manually check scheduler logic: + +```python +# Get scheduler state +print(f"Best metric so far: {scheduler.best}") +print(f"Epochs without improvement: {scheduler.num_bad_epochs}") +print(f"Patience: {scheduler.patience}") + +# If num_bad_epochs < patience, it's still waiting +# If num_bad_epochs >= patience, should reduce next step +``` + + +### 10. Rationalization Table + +When users rationalize away proper LR scheduling, counter with: + +| Rationalization | Reality | Counter-Argument | +|-----------------|---------|------------------| +| "Constant LR is simpler" | Leaves 2-5% performance on table | "One line of code for 2-5% better accuracy is excellent ROI" | +| "Warmup seems optional" | MANDATORY for transformers | "Without warmup, transformers diverge or train unstably" | +| "I don't know which scheduler to use" | CosineAnnealing is great default | "CosineAnnealingLR works well for most cases, zero tuning" | +| "Scheduling is too complicated" | Modern frameworks make it trivial | "scheduler = CosineAnnealingLR(optimizer, T_max=100) - that's it" | +| "Papers don't mention scheduling" | They do, in implementation details | "Check paper's code repo or appendix - scheduling always there" | +| "My model is too small to need scheduling" | Even small models benefit | "Scheduling helps all models converge to better minima" | +| "Just use Adam, it adapts automatically" | Adam still benefits from scheduling | "SOTA transformers use AdamW + scheduling (BERT, GPT, ViT)" | +| "I'll tune it later" | Scheduling should be from start | "Scheduling is core hyperparameter, not optional add-on" | +| "OneCycle always best" | Only for specific scenarios | "OneCycle great for fast training (<30 epochs), not long training" | +| "I don't have time to run LR finder" | Takes 5 minutes, saves hours | "LR finder runs in minutes, prevents wasted training runs" | +| "Warmup adds complexity" | One extra line of code | "SequentialLR([warmup, cosine], [5]) - that's the complexity" | +| "My training is already good enough" | Could be 2-5% better | "SOTA papers all use scheduling - it's standard practice" | +| "Reducing LR will slow training" | Reduces LR when high LR hurts | "High LR early (fast), low LR late (fine-tune) = best of both" | +| "I don't know what T_max to use" | T_max = total_epochs | "Just set T_max to your total training epochs" | + + +### 11. Red Flags Checklist + +Watch for these warning signs that indicate scheduling problems: + +**Critical Red Flags (Fix Immediately):** + +🚨 Training transformer without warmup + - **Impact:** High risk of divergence, NaN loss + - **Fix:** Add 5-10 epoch warmup immediately + +🚨 Loss NaN or exploding in first few epochs + - **Impact:** Training failed + - **Fix:** Add warmup, lower initial LR, gradient clipping + +🚨 scheduler.step() called every batch for Cosine/Step schedulers + - **Impact:** LR decays 100x too fast + - **Fix:** Move scheduler.step() outside batch loop + +🚨 Not passing metric to ReduceLROnPlateau + - **Impact:** Scheduler doesn't work at all + - **Fix:** scheduler.step(val_loss) + +**Important Red Flags (Should Fix):** + +⚠️ Training >30 epochs without scheduler + - **Impact:** Leaving 2-5% performance on table + - **Fix:** Add CosineAnnealingLR or MultiStepLR + +⚠️ OneCycle with random max_lr (not tuned) + - **Impact:** Unstable training or suboptimal performance + - **Fix:** Run LR finder, tune max_lr + +⚠️ Large batch (>512) without warmup + - **Impact:** Training instability + - **Fix:** Add 5-10 epoch warmup + +⚠️ Vision transformer with constant LR + - **Impact:** Poor convergence, unstable training + - **Fix:** Add warmup + cosine schedule + +⚠️ Training plateaus but no scheduler to reduce LR + - **Impact:** Stuck at local minimum + - **Fix:** Add ReduceLROnPlateau or manually reduce LR + +**Minor Red Flags (Consider Fixing):** + +⚡ CNN training without any scheduling + - **Impact:** Missing 1-3% accuracy + - **Fix:** Add MultiStepLR or CosineAnnealingLR + +⚡ Not monitoring LR during training + - **Impact:** Hard to debug schedule issues + - **Fix:** Log LR every epoch + +⚡ T_max doesn't match training duration + - **Impact:** Schedule ends too early/late + - **Fix:** Set T_max = total_epochs - warmup_epochs + +⚡ Using same LR for pretrained and new layers (fine-tuning) + - **Impact:** Suboptimal fine-tuning + - **Fix:** Use different LRs for param groups + +⚡ Not validating schedule before full training + - **Impact:** Risk wasting compute on wrong schedule + - **Fix:** Plot schedule dry-run before training + + +### 12. Quick Reference + +## Scheduler Selection Cheatsheet + +``` +Q: What should I use for... + +Vision CNN (100 epochs)? +→ CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-5) + +Vision Transformer? +→ LinearLR(warmup 5) + CosineAnnealingLR(T_max=95) [WARMUP MANDATORY] + +NLP Transformer? +→ LinearLR(warmup 10%) + LinearLR(decay) [WARMUP MANDATORY] + +Fast training (<30 epochs)? +→ OneCycleLR(max_lr=tune_with_LR_finder) + +Don't know optimal schedule? +→ ReduceLROnPlateau(mode='min', patience=10) + +Training plateaued? +→ Add ReduceLROnPlateau or manually reduce LR by 10x now + +Following paper recipe? +→ Use paper's exact schedule (usually MultiStepLR) + +Fine-tuning pretrained model? +→ Constant low LR (1e-5) or gentle CosineAnnealing + +Large batch (>512)? +→ LinearLR(warmup 5-10) + CosineAnnealingLR [WARMUP MANDATORY] +``` + + +## Step Placement Quick Reference + +```python +# Most schedulers (Step, Cosine, Exponential) +for epoch in range(epochs): + for batch in train_loader: + train_step(...) + scheduler.step() # AFTER epoch + +# OneCycleLR (EXCEPTION) +for epoch in range(epochs): + for batch in train_loader: + train_step(...) + scheduler.step() # AFTER each batch + +# ReduceLROnPlateau (pass metric) +for epoch in range(epochs): + for batch in train_loader: + train_step(...) + val_loss = validate(...) + scheduler.step(val_loss) # Pass metric +``` + + +## Warmup Quick Reference + +```python +# Pattern: Warmup + Cosine (most common) +warmup = LinearLR(optimizer, start_factor=0.01, total_iters=5) +cosine = CosineAnnealingLR(optimizer, T_max=95, eta_min=1e-5) +scheduler = SequentialLR(optimizer, [warmup, cosine], [5]) + +# When warmup is MANDATORY: +# ✅ Transformers (ViT, BERT, GPT) +# ✅ Large batch (>512) +# ✅ High initial LR +# ✅ Training from scratch + +# When warmup is optional: +# ❌ Fine-tuning +# ❌ Small LR (<1e-4) +# ❌ Small models +``` + + +## LR Finder Quick Reference + +```python +# Run LR finder +lrs, losses = find_lr(model, train_loader, optimizer, loss_fn, device) + +# Find optimal (steepest descent) +optimal_lr = suggest_lr_from_finder(lrs, losses) + +# Use cases: +# - Direct use: optimizer = SGD(params, lr=optimal_lr) +# - OneCycle: max_lr = optimal_lr * 10 +# - Conservative: base_lr = optimal_lr * 0.1 +``` + + +## Summary + +Learning rate scheduling is CRITICAL for competitive model performance: + +**Key Takeaways:** + +1. **Scheduling improves final accuracy by 2-5%** - not optional for SOTA +2. **Warmup is MANDATORY for transformers** - prevents divergence +3. **CosineAnnealingLR is best default** - works well, zero tuning +4. **Use LR finder for new problems** - finds optimal initial LR in minutes +5. **OneCycleLR needs max_lr tuning** - run LR finder first +6. **Watch scheduler.step() placement** - most per epoch, OneCycle per batch +7. **Always monitor LR during training** - log to console or TensorBoard +8. **Plot schedule before training** - catch mistakes early + +**Modern Defaults (2025):** +- **Vision CNNs:** SGD + CosineAnnealingLR (optional warmup) +- **Vision Transformers:** AdamW + Warmup + CosineAnnealingLR (warmup mandatory) +- **NLP Transformers:** AdamW + Warmup + Linear decay (warmup mandatory) +- **Fast Training:** SGD + OneCycleLR (tune max_lr with LR finder) + +**When In Doubt:** +- Use CosineAnnealingLR with T_max = total_epochs +- Add 5-epoch warmup for large models +- Run LR finder if unsure about initial LR +- Log LR every epoch to monitor schedule + +Learning rate scheduling is one of the highest-ROI hyperparameters - master it for significantly better model performance. diff --git a/skills/using-training-optimization/loss-functions-and-objectives.md b/skills/using-training-optimization/loss-functions-and-objectives.md new file mode 100644 index 0000000..4b4a6e8 --- /dev/null +++ b/skills/using-training-optimization/loss-functions-and-objectives.md @@ -0,0 +1,2138 @@ + +# Loss Functions and Objectives Skill + +## When to Use This Skill + +Use this skill when: +- User asks "what loss function should I use?" +- Implementing binary, multi-class, or multi-label classification +- Implementing regression models +- Training on imbalanced datasets (class imbalance) +- Multi-task learning with multiple loss terms +- Custom loss function implementation needed +- Loss goes to NaN or infinity during training +- Loss not decreasing despite valid training loop +- User suggests BCE instead of BCEWithLogitsLoss (RED FLAG) +- User adds softmax before CrossEntropyLoss (RED FLAG) +- Multi-task losses without weighting (RED FLAG) +- Division or log operations in custom loss (stability concern) +- Segmentation, ranking, or specialized tasks +- Loss debugging and troubleshooting + +Do NOT use when: +- User has specific bugs unrelated to loss functions +- Only discussing model architecture (no loss questions) +- Loss function already working well and no questions asked +- User needs general training advice (use optimization-algorithms skill) + + +## Core Principles + +### 1. The Critical Importance of Loss Functions + +**Loss functions are fundamental to deep learning:** +- Direct objective that gradients optimize +- Wrong loss → model optimizes wrong thing +- Numerically unstable loss → NaN, training collapse +- Unweighted multi-task → one task dominates +- Mismatched loss for task → poor performance + +**Common Impact:** +- Proper loss selection: 5-15% performance improvement +- Numerical stability: difference between training and crashing +- Class balancing: difference between 95% accuracy (useless) and 85% F1 (useful) +- Multi-task weighting: difference between all tasks learning vs one task dominating + +**This is NOT optional:** +- Every SOTA paper carefully selects and tunes losses +- Loss function debugging is essential skill +- One mistake (BCE vs BCEWithLogitsLoss) can break training + + +### 2. Loss Selection Decision Tree + +``` +What is your task? +│ +├─ Classification? +│ │ +│ ├─ Binary (2 classes, single output) +│ │ → Use: BCEWithLogitsLoss (NOT BCELoss!) +│ │ → Model outputs: logits (no sigmoid) +│ │ → Target shape: (batch,) or (batch, 1) with 0/1 +│ │ → Imbalanced? Add pos_weight parameter +│ │ +│ ├─ Multi-class (N classes, one label per sample) +│ │ → Use: CrossEntropyLoss +│ │ → Model outputs: logits (batch, num_classes) - no softmax! +│ │ → Target shape: (batch,) with class indices [0, N-1] +│ │ → Imbalanced? Add weight parameter or use focal loss +│ │ +│ └─ Multi-label (N classes, multiple labels per sample) +│ → Use: BCEWithLogitsLoss +│ → Model outputs: logits (batch, num_classes) - no sigmoid! +│ → Target shape: (batch, num_classes) with 0/1 +│ → Each class is independent binary classification +│ +├─ Regression? +│ │ +│ ├─ Standard regression, squared errors +│ │ → Use: MSELoss (L2 loss) +│ │ → Sensitive to outliers +│ │ → Penalizes large errors heavily +│ │ +│ ├─ Robust to outliers +│ │ → Use: L1Loss (MAE) +│ │ → Less sensitive to outliers +│ │ → Linear penalty +│ │ +│ └─ Best of both (recommended) +│ → Use: SmoothL1Loss (Huber loss) +│ → L2 for small errors, L1 for large errors +│ → Good default choice +│ +├─ Segmentation? +│ │ +│ ├─ Binary segmentation +│ │ → Use: BCEWithLogitsLoss or DiceLoss +│ │ → Combine both: α*BCE + (1-α)*Dice +│ │ +│ └─ Multi-class segmentation +│ → Use: CrossEntropyLoss or DiceLoss +│ → Imbalanced pixels? Use weighted CE or focal loss +│ +├─ Ranking/Similarity? +│ │ +│ ├─ Triplet learning +│ │ → Use: TripletMarginLoss +│ │ → Learn embeddings with anchor, positive, negative +│ │ +│ ├─ Pairwise ranking +│ │ → Use: MarginRankingLoss +│ │ → Learn x1 > x2 or x2 > x1 +│ │ +│ └─ Contrastive learning +│ → Use: ContrastiveLoss or NTXentLoss +│ → Pull similar together, push dissimilar apart +│ +└─ Multi-Task? + → Combine losses with careful weighting + → See Multi-Task Learning section below +``` + + +## Section 1: Binary Classification - BCEWithLogitsLoss + +### THE GOLDEN RULE: ALWAYS Use BCEWithLogitsLoss, NEVER BCELoss + +This is the MOST COMMON loss function mistake in deep learning. + +### ❌ WRONG: BCELoss (Numerically Unstable) + +```python +class BinaryClassifier(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(100, 1) + self.sigmoid = nn.Sigmoid() # ❌ DON'T DO THIS + + def forward(self, x): + logits = self.fc(x) + return self.sigmoid(logits) # ❌ Applying sigmoid in model + +# Training loop +output = model(x) # Probabilities [0, 1] +loss = F.binary_cross_entropy(output, target) # ❌ UNSTABLE! +``` + +**Why this is WRONG:** +1. **Numerical instability**: `log(sigmoid(x))` underflows for large negative x +2. **Gradient issues**: sigmoid saturates, BCE takes log → compound saturation +3. **NaN risk**: When sigmoid(logits) = 0 or 1, log(0) = -inf +4. **Slower training**: Less stable gradients + +### ✅ RIGHT: BCEWithLogitsLoss (Numerically Stable) + +```python +class BinaryClassifier(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(100, 1) + # ✅ NO sigmoid in model! + + def forward(self, x): + return self.fc(x) # ✅ Return logits + +# Training loop +logits = model(x) # Raw logits (can be any value) +loss = F.binary_cross_entropy_with_logits(logits, target) # ✅ STABLE! +``` + +**Why this is RIGHT:** +1. **Numerically stable**: Uses log-sum-exp trick internally +2. **Better gradients**: Single combined operation +3. **No NaN**: Stable for all logit values +4. **Faster training**: More stable optimization + +### The Math Behind the Stability + +**Unstable version (BCELoss):** +```python +# BCE computes: -[y*log(σ(x)) + (1-y)*log(1-σ(x))] +# Problem: log(σ(x)) = log(1/(1+exp(-x))) underflows for large negative x + +# Example: +x = -100 # Large negative logit +sigmoid(x) = 1 / (1 + exp(100)) ≈ 0 # Underflows to 0 +log(sigmoid(x)) = log(0) = -inf # ❌ NaN! +``` + +**Stable version (BCEWithLogitsLoss):** +```python +# BCEWithLogitsLoss uses log-sum-exp trick: +# log(σ(x)) = log(1/(1+exp(-x))) = -log(1+exp(-x)) +# Rewritten as: -log1p(exp(-x)) for stability + +# For positive x: use log(sigmoid(x)) = -log1p(exp(-x)) +# For negative x: use log(sigmoid(x)) = x - log1p(exp(x)) +# This is ALWAYS stable! + +# Example: +x = -100 +log(sigmoid(x)) = -100 - log1p(exp(-100)) + = -100 - log1p(≈0) + = -100 # ✅ Stable! +``` + +### Inference: Converting Logits to Probabilities + +```python +# During training +logits = model(x) +loss = F.binary_cross_entropy_with_logits(logits, target) + +# During inference/evaluation +logits = model(x) +probs = torch.sigmoid(logits) # ✅ NOW apply sigmoid +predictions = (probs > 0.5).float() # Binary predictions +``` + +### Handling Class Imbalance with pos_weight + +```python +# Dataset: 95% negative (class 0), 5% positive (class 1) +# Problem: Model predicts all negatives → 95% accuracy but useless! + +# Solution 1: pos_weight parameter +neg_count = 950 +pos_count = 50 +pos_weight = torch.tensor([neg_count / pos_count]) # 950/50 = 19.0 + +loss = F.binary_cross_entropy_with_logits( + logits, target, + pos_weight=pos_weight # Weight positive class 19x more +) + +# pos_weight effect: +# - Positive examples contribute 19x to loss +# - Forces model to care about minority class +# - Balances gradient contributions + +# Solution 2: Focal Loss (see Advanced Techniques section) +``` + +### Complete Binary Classification Example + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BinaryClassifier(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, 64) + self.fc2 = nn.Linear(64, 1) # Single output for binary + + def forward(self, x): + x = F.relu(self.fc1(x)) + return self.fc2(x) # ✅ Return logits (no sigmoid) + +# Training setup +model = BinaryClassifier(input_dim=100) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# Handle imbalanced data +class_counts = torch.bincount(train_labels.long()) +pos_weight = class_counts[0] / class_counts[1] + +# Training loop +model.train() +for x, y in train_loader: + optimizer.zero_grad() + logits = model(x) + loss = F.binary_cross_entropy_with_logits( + logits.squeeze(), # Shape: (batch,) + y.float(), # Shape: (batch,) + pos_weight=pos_weight if imbalanced else None + ) + loss.backward() + optimizer.step() + +# Evaluation +model.eval() +with torch.no_grad(): + logits = model(x_test) + probs = torch.sigmoid(logits) # ✅ Apply sigmoid for inference + preds = (probs > 0.5).float() + + # Compute metrics + accuracy = (preds.squeeze() == y_test).float().mean() + # Better: Use F1, precision, recall for imbalanced data +``` + + +## Section 2: Multi-Class Classification - CrossEntropyLoss + +### THE GOLDEN RULE: Pass Logits (NOT Softmax) to CrossEntropyLoss + +### ❌ WRONG: Applying Softmax Before CrossEntropyLoss + +```python +class MultiClassifier(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.fc = nn.Linear(100, num_classes) + self.softmax = nn.Softmax(dim=1) # ❌ DON'T DO THIS + + def forward(self, x): + logits = self.fc(x) + return self.softmax(logits) # ❌ Applying softmax in model + +# Training +probs = model(x) # Already softmaxed +loss = F.cross_entropy(probs, target) # ❌ WRONG! Double softmax! +``` + +**Why this is WRONG:** +1. **Double softmax**: CrossEntropyLoss applies softmax internally +2. **Numerical instability**: Extra softmax operation +3. **Wrong gradients**: Backprop through unnecessary operation +4. **Confusion**: Model outputs different things in train vs eval + +### ✅ RIGHT: Pass Logits to CrossEntropyLoss + +```python +class MultiClassifier(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.fc = nn.Linear(100, num_classes) + # ✅ NO softmax in model! + + def forward(self, x): + return self.fc(x) # ✅ Return logits + +# Training +logits = model(x) # Shape: (batch, num_classes) +target = ... # Shape: (batch,) with class indices [0, num_classes-1] +loss = F.cross_entropy(logits, target) # ✅ CORRECT! +``` + +### Target Shape Requirements + +```python +# ✅ CORRECT: Target is class indices +logits = torch.randn(32, 10) # (batch=32, num_classes=10) +target = torch.randint(0, 10, (32,)) # (batch=32,) with values in [0, 9] +loss = F.cross_entropy(logits, target) # ✅ Works! + +# ❌ WRONG: One-hot encoded target +target_onehot = F.one_hot(target, num_classes=10) # (batch=32, num_classes=10) +loss = F.cross_entropy(logits, target_onehot) # ❌ Type error! + +# If you have one-hot, convert back to indices: +target_indices = target_onehot.argmax(dim=1) # (batch,) +loss = F.cross_entropy(logits, target_indices) # ✅ Works! +``` + +### Handling Class Imbalance with Weights + +```python +# Dataset: Class 0: 1000 samples, Class 1: 100 samples, Class 2: 50 samples +# Problem: Model biased toward majority class + +# Solution 1: Class weights (inverse frequency) +class_counts = torch.tensor([1000., 100., 50.]) +class_weights = 1.0 / class_counts +class_weights = class_weights / class_weights.sum() * len(class_weights) +# Normalizes so weights sum to num_classes + +# class_weights = [0.086, 0.857, 1.714] +# Minority classes weighted much higher + +loss = F.cross_entropy(logits, target, weight=class_weights) + +# Solution 2: Balanced accuracy loss (effective sample weighting) +# Weight each sample by inverse class frequency +sample_weights = class_weights[target] # Index into weights +loss = F.cross_entropy(logits, target, reduction='none') +weighted_loss = (loss * sample_weights).mean() + +# Solution 3: Focal Loss (see Advanced Techniques section) +``` + +### Inference: Converting Logits to Probabilities + +```python +# During training +logits = model(x) +loss = F.cross_entropy(logits, target) + +# During inference/evaluation +logits = model(x) # (batch, num_classes) +probs = F.softmax(logits, dim=1) # ✅ NOW apply softmax +preds = logits.argmax(dim=1) # Or directly argmax logits (same result) + +# Why argmax logits works: +# argmax(softmax(logits)) = argmax(logits) because softmax is monotonic +``` + +### Complete Multi-Class Example + +```python +class MultiClassifier(nn.Module): + def __init__(self, input_dim, num_classes): + super().__init__() + self.fc1 = nn.Linear(input_dim, 128) + self.fc2 = nn.Linear(128, num_classes) + + def forward(self, x): + x = F.relu(self.fc1(x)) + return self.fc2(x) # ✅ Return logits + +# Training setup +num_classes = 10 +model = MultiClassifier(input_dim=100, num_classes=num_classes) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# Compute class weights for imbalanced data +class_counts = torch.bincount(train_labels) +class_weights = 1.0 / class_counts.float() +class_weights = class_weights / class_weights.sum() * num_classes + +# Training loop +model.train() +for x, y in train_loader: + optimizer.zero_grad() + logits = model(x) # (batch, num_classes) + loss = F.cross_entropy(logits, y, weight=class_weights) + loss.backward() + optimizer.step() + +# Evaluation +model.eval() +with torch.no_grad(): + logits = model(x_test) + probs = F.softmax(logits, dim=1) # For calibration analysis + preds = logits.argmax(dim=1) # Class predictions + accuracy = (preds == y_test).float().mean() +``` + + +## Section 3: Multi-Label Classification + +### Use BCEWithLogitsLoss (Each Class is Independent) + +```python +# Task: Predict multiple tags for an image +# Example: [dog, outdoor, sunny] → target = [1, 0, 1, 0, 1, 0, ...] + +class MultiLabelClassifier(nn.Module): + def __init__(self, input_dim, num_classes): + super().__init__() + self.fc = nn.Linear(input_dim, num_classes) + # ✅ NO sigmoid! Return logits + + def forward(self, x): + return self.fc(x) # (batch, num_classes) logits + +# Training +logits = model(x) # (batch, num_classes) +target = ... # (batch, num_classes) with 0/1 for each class + +# Each class is independent binary classification +loss = F.binary_cross_entropy_with_logits(logits, target.float()) + +# Inference +logits = model(x_test) +probs = torch.sigmoid(logits) # Per-class probabilities +preds = (probs > 0.5).float() # Threshold each class independently + +# Example output: +# probs = [0.9, 0.3, 0.8, 0.1, 0.7, ...] +# preds = [1.0, 0.0, 1.0, 0.0, 1.0, ...] (dog=yes, outdoor=no, sunny=yes, ...) +``` + +### Handling Imbalanced Labels + +```python +# Some labels are rare (e.g., "sunset" appears in 2% of images) + +# Solution 1: Per-class pos_weight +label_counts = train_labels.sum(dim=0) # Count per class +num_samples = len(train_labels) +neg_counts = num_samples - label_counts +pos_weights = neg_counts / label_counts # (num_classes,) + +loss = F.binary_cross_entropy_with_logits( + logits, target.float(), + pos_weight=pos_weights +) + +# Solution 2: Focal loss per class (see Advanced Techniques) +``` + + +## Section 4: Regression Losses + +### MSELoss (L2 Loss) - Default Choice + +```python +# Mean Squared Error: (pred - target)^2 + +pred = model(x) # (batch, output_dim) +target = ... # (batch, output_dim) +loss = F.mse_loss(pred, target) + +# Characteristics: +# ✅ Smooth gradients +# ✅ Penalizes large errors heavily (squared term) +# ❌ Sensitive to outliers (outliers dominate loss) +# ❌ Can be numerically large if targets not normalized + +# When to use: +# - Standard regression tasks +# - Targets are normalized (similar scale to predictions) +# - Outliers are rare or not expected +``` + +### L1Loss (MAE) - Robust to Outliers + +```python +# Mean Absolute Error: |pred - target| + +pred = model(x) +loss = F.l1_loss(pred, target) + +# Characteristics: +# ✅ Robust to outliers (linear penalty) +# ✅ Numerically stable +# ❌ Non-smooth at zero (gradient discontinuity) +# ❌ Equal penalty for all error magnitudes + +# When to use: +# - Outliers present in data +# - Want robust loss +# - Median prediction preferred over mean +``` + +### SmoothL1Loss (Huber Loss) - Best of Both Worlds + +```python +# Smooth L1: L2 for small errors, L1 for large errors + +pred = model(x) +loss = F.smooth_l1_loss(pred, target, beta=1.0) + +# Formula: +# loss = 0.5 * (pred - target)^2 / beta if |pred - target| < beta +# loss = |pred - target| - 0.5 * beta otherwise + +# Characteristics: +# ✅ Smooth gradients everywhere +# ✅ Robust to outliers (L1 for large errors) +# ✅ Fast convergence (L2 for small errors) +# ✅ Best default for regression + +# When to use: +# - General regression (RECOMMENDED DEFAULT) +# - Uncertainty about outliers +# - Want fast convergence + robustness +``` + +### Target Normalization (CRITICAL) + +```python +# ❌ WRONG: Unnormalized targets +pred = model(x) # Model outputs in range [0, 1] (e.g., after sigmoid) +target = ... # Range [1000, 10000] - NOT NORMALIZED! +loss = F.mse_loss(pred, target) # Huge loss values, bad gradients + +# ✅ RIGHT: Normalize targets +# Option 1: Min-Max normalization to [0, 1] +target_min = train_targets.min() +target_max = train_targets.max() +target_normalized = (target - target_min) / (target_max - target_min) + +pred = model(x) # Range [0, 1] +loss = F.mse_loss(pred, target_normalized) # ✅ Same scale + +# Denormalize for evaluation: +pred_denorm = pred * (target_max - target_min) + target_min + +# Option 2: Standardization to mean=0, std=1 +target_mean = train_targets.mean() +target_std = train_targets.std() +target_standardized = (target - target_mean) / target_std + +pred = model(x) # Should output standardized values +loss = F.mse_loss(pred, target_standardized) # ✅ Normalized scale + +# Denormalize for evaluation: +pred_denorm = pred * target_std + target_mean + +# Why normalization matters: +# 1. Loss values in reasonable range (not 1e6) +# 2. Better gradient flow +# 3. Learning rate can be standard (1e-3) +# 4. Faster convergence +``` + +### Complete Regression Example + +```python +class Regressor(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, 128) + self.fc2 = nn.Linear(128, output_dim) + + def forward(self, x): + x = F.relu(self.fc1(x)) + return self.fc2(x) # Linear output for regression + +# Normalize targets +target_mean = train_targets.mean(dim=0) +target_std = train_targets.std(dim=0) + +def normalize(targets): + return (targets - target_mean) / (target_std + 1e-8) + +def denormalize(preds): + return preds * target_std + target_mean + +# Training +model = Regressor(input_dim=100, output_dim=1) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +model.train() +for x, y in train_loader: + optimizer.zero_grad() + pred = model(x) + y_norm = normalize(y) + loss = F.smooth_l1_loss(pred, y_norm) # Using Huber loss + loss.backward() + optimizer.step() + +# Evaluation +model.eval() +with torch.no_grad(): + pred_norm = model(x_test) + pred = denormalize(pred_norm) # Back to original scale + mse = F.mse_loss(pred, y_test) + print(f"Test MSE: {mse.item()}") +``` + + +## Section 5: Numerical Stability in Loss Computation + +### Critical Rule: Avoid log(0), log(negative), and division by zero + +### Problem 1: Division by Zero + +```python +# ❌ UNSTABLE: No protection +def iou_loss(pred, target): + intersection = (pred * target).sum() + union = pred.sum() + target.sum() + iou = intersection / union # ❌ Division by zero if both empty! + return 1 - iou + +# ✅ STABLE: Add epsilon +def iou_loss(pred, target): + eps = 1e-8 + intersection = (pred * target).sum() + union = pred.sum() + target.sum() + iou = (intersection + eps) / (union + eps) # ✅ Safe + return 1 - iou + +# Why epsilon works: +# - Denominator never zero: union + 1e-8 ≥ 1e-8 +# - Doesn't affect result when union is large +# - Prevents NaN propagation +``` + +### Problem 2: Log of Zero or Negative + +```python +# ❌ UNSTABLE: No clamping +def custom_loss(pred, target): + ratio = pred / target + return torch.log(ratio).mean() # ❌ log(0) = -inf, log(neg) = nan + +# ✅ STABLE: Clamp before log +def custom_loss(pred, target): + eps = 1e-8 + ratio = pred / (target + eps) # Safe division + ratio = torch.clamp(ratio, min=eps) # Ensure positive + return torch.log(ratio).mean() # ✅ Safe log + +# Alternative: Use log1p for log(1+x) +def custom_loss(pred, target): + eps = 1e-8 + ratio = pred / (target + eps) + return torch.log1p(ratio).mean() # log1p(x) = log(1+x), more stable +``` + +### Problem 3: Exponential Overflow + +```python +# ❌ UNSTABLE: Direct exp can overflow +def custom_loss(logits): + return torch.exp(logits).mean() # ❌ exp(100) = overflow! + +# ✅ STABLE: Clamp logits or use stable operations +def custom_loss(logits): + # Option 1: Clamp logits + logits = torch.clamp(logits, max=10) # Prevent overflow + return torch.exp(logits).mean() + + # Option 2: Use log-space operations + # If computing log(exp(x)), just return x! +``` + +### Problem 4: Custom Softmax (Use Built-in Instead) + +```python +# ❌ UNSTABLE: Manual softmax +def manual_softmax(logits): + exp_logits = torch.exp(logits) # ❌ Overflow for large logits! + return exp_logits / exp_logits.sum(dim=1, keepdim=True) + +# ✅ STABLE: Use F.softmax (uses max subtraction trick) +def stable_softmax(logits): + return F.softmax(logits, dim=1) # ✅ Handles overflow internally + +# Built-in implementation (for understanding): +def softmax_stable(logits): + # Subtract max for numerical stability + logits_max = logits.max(dim=1, keepdim=True)[0] + logits = logits - logits_max # Now max(logits) = 0 + exp_logits = torch.exp(logits) # No overflow! + return exp_logits / exp_logits.sum(dim=1, keepdim=True) +``` + +### Epsilon Best Practices + +```python +# Epsilon guidelines: +eps = 1e-8 # ✅ Good default for float32 +eps = 1e-6 # ✅ Alternative, more conservative +eps = 1e-10 # ❌ Too small, can still underflow + +# Where to add epsilon: +# 1. Denominators: x / (y + eps) +# 2. Before log: log(x + eps) or log(clamp(x, min=eps)) +# 3. Before sqrt: sqrt(x + eps) + +# Where NOT to add epsilon: +# 4. ❌ Numerators usually don't need it +# 5. ❌ Inside standard PyTorch functions (already stable) +# 6. ❌ After stable operations +``` + +### Complete Stable Custom Loss Template + +```python +class StableCustomLoss(nn.Module): + def __init__(self): + super().__init__() + self.eps = 1e-8 + + def forward(self, pred, target): + # 1. Ensure inputs are valid + assert not torch.isnan(pred).any(), "pred contains NaN" + assert not torch.isnan(target).any(), "target contains NaN" + + # 2. Safe division + ratio = pred / (target + self.eps) + + # 3. Clamp before log/sqrt/pow + ratio = torch.clamp(ratio, min=self.eps, max=1e8) + + # 4. Safe log operation + log_ratio = torch.log(ratio) + + # 5. Check output + loss = log_ratio.mean() + assert not torch.isnan(loss), "loss is NaN" + + return loss +``` + + +## Section 6: Multi-Task Learning and Loss Weighting + +### The Problem: Different Loss Scales + +```python +# Task 1: Classification, CrossEntropyLoss ~ 0.5-2.0 +# Task 2: Regression, MSELoss ~ 100-1000 +# Task 3: Reconstruction, L2 Loss ~ 10-50 + +# ❌ WRONG: Naive sum (task 2 dominates!) +loss1 = F.cross_entropy(logits1, target1) # ~0.5 +loss2 = F.mse_loss(pred2, target2) # ~500.0 +loss3 = F.mse_loss(recon, input) # ~20.0 +total_loss = loss1 + loss2 + loss3 # ≈ 520.5 + +# Gradient analysis: +# ∂total_loss/∂θ ≈ ∂loss2/∂θ (loss1 and loss3 contribute <5%) +# Model learns ONLY task 2, ignores tasks 1 and 3! +``` + +### Solution 1: Manual Weighting + +```python +# Balance losses to similar magnitudes +loss1 = F.cross_entropy(logits1, target1) # ~0.5 +loss2 = F.mse_loss(pred2, target2) # ~500.0 +loss3 = F.mse_loss(recon, input) # ~20.0 + +# Set weights so weighted losses are similar scale +w1 = 1.0 # Keep as is +w2 = 0.001 # Scale down by 1000x +w3 = 0.05 # Scale down by 20x + +total_loss = w1 * loss1 + w2 * loss2 + w3 * loss3 +# = 1.0*0.5 + 0.001*500 + 0.05*20 +# = 0.5 + 0.5 + 1.0 = 2.0 +# All tasks contribute meaningfully! + +# How to find weights: +# 1. Run 1 epoch with equal weights +# 2. Print loss magnitudes +# 3. Set weights inversely proportional to magnitudes +# 4. Iterate until balanced +``` + +### Solution 2: Uncertainty Weighting (Learnable) + +```python +# "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018) +# Learn task weights during training! + +class MultiTaskLoss(nn.Module): + def __init__(self, num_tasks): + super().__init__() + # Log variance parameters (learnable) + self.log_vars = nn.Parameter(torch.zeros(num_tasks)) + + def forward(self, losses): + """ + losses: list of task losses [loss1, loss2, loss3, ...] + + For each task: + weighted_loss = (1 / (2 * σ²)) * loss + log(σ) + + Where σ² = exp(log_var) is the learned uncertainty + - High uncertainty → lower weight on that task + - Low uncertainty → higher weight on that task + """ + weighted_losses = [] + for i, loss in enumerate(losses): + precision = torch.exp(-self.log_vars[i]) # 1/σ² + weighted_loss = precision * loss + self.log_vars[i] + weighted_losses.append(weighted_loss) + + return sum(weighted_losses) + +# Usage +model = MultiTaskModel() +multi_loss = MultiTaskLoss(num_tasks=3) + +# Optimize both model and loss weights +optimizer = torch.optim.Adam([ + {'params': model.parameters()}, + {'params': multi_loss.parameters(), 'lr': 0.01} # Can use different LR +]) + +# Training loop +for x, targets in train_loader: + optimizer.zero_grad() + + # Compute task predictions + out1, out2, out3 = model(x) + + # Compute task losses + loss1 = F.cross_entropy(out1, targets[0]) + loss2 = F.mse_loss(out2, targets[1]) + loss3 = F.mse_loss(out3, targets[2]) + + # Combine with learned weighting + total_loss = multi_loss([loss1, loss2, loss3]) + + total_loss.backward() + optimizer.step() + + # Monitor learned weights + if step % 100 == 0: + weights = torch.exp(-multi_loss.log_vars) + print(f"Task weights: {weights.detach()}") +``` + +### Solution 3: Gradient Normalization + +```python +# GradNorm: balances task learning by normalizing gradient magnitudes + +def grad_norm_step(model, losses, alpha=1.5): + """ + Adjust task weights to balance gradient magnitudes + + losses: list of task losses + alpha: balancing parameter (1.5 typical) + """ + # Get initial loss ratios + initial_losses = [l.item() for l in losses] + + # Compute average gradient norm per task + shared_params = list(model.shared_layers.parameters()) + + grad_norms = [] + for loss in losses: + model.zero_grad() + loss.backward(retain_graph=True) + + # Compute gradient norm + grad_norm = 0 + for p in shared_params: + if p.grad is not None: + grad_norm += p.grad.norm(2).item() ** 2 + grad_norms.append(grad_norm ** 0.5) + + # Target: all tasks have same gradient norm + mean_grad_norm = sum(grad_norms) / len(grad_norms) + + # Adjust weights + weights = [] + for gn in grad_norms: + weight = mean_grad_norm / (gn + 1e-8) + weights.append(weight ** alpha) + + # Normalize weights + weights = torch.tensor(weights) + weights = weights / weights.sum() * len(weights) + + return weights + +# Note: GradNorm is more complex, this is simplified version +# For production, use manual or uncertainty weighting +``` + +### Solution 4: Loss Normalization + +```python +# Normalize each loss to [0, 1] range before combining + +class NormalizedMultiTaskLoss(nn.Module): + def __init__(self, num_tasks): + super().__init__() + # Track running mean/std per task + self.register_buffer('running_mean', torch.zeros(num_tasks)) + self.register_buffer('running_std', torch.ones(num_tasks)) + self.momentum = 0.9 + + def forward(self, losses): + """Normalize each loss before combining""" + losses_tensor = torch.stack(losses) + + if self.training: + # Update running statistics + mean = losses_tensor.mean() + std = losses_tensor.std() + 1e-8 + + self.running_mean = (self.momentum * self.running_mean + + (1 - self.momentum) * mean) + self.running_std = (self.momentum * self.running_std + + (1 - self.momentum) * std) + + # Normalize losses + normalized = (losses_tensor - self.running_mean) / self.running_std + + return normalized.sum() +``` + +### Best Practices for Multi-Task Loss + +```python +# Recommended approach: + +1. Start with manual weighting: + - Run 1 epoch, check loss magnitudes + - Set weights to balance scales + - Quick and interpretable + +2. If tasks have different difficulties: + - Use uncertainty weighting + - Let model learn task importance + - More training time but adaptive + +3. Monitor individual task metrics: + - Don't just watch total loss + - Track accuracy/error per task + - Ensure all tasks learning + +4. Curriculum learning: + - Start with easy tasks + - Gradually add harder tasks + - Can improve stability + +# Example monitoring: +if step % 100 == 0: + print(f"Total Loss: {total_loss.item():.4f}") + print(f"Task 1 (CE): {loss1.item():.4f}") + print(f"Task 2 (MSE): {loss2.item():.4f}") + print(f"Task 3 (Recon): {loss3.item():.4f}") + + # Check if any task stuck + if loss1 > 5.0: # Not learning + print("WARNING: Task 1 not learning, increase weight") +``` + + +## Section 7: Custom Loss Function Implementation + +### Template for Custom Loss + +```python +class CustomLoss(nn.Module): + """ + Template for implementing custom losses + """ + def __init__(self, weight=None, reduction='mean'): + """ + Args: + weight: Manual sample weights (optional) + reduction: 'mean', 'sum', or 'none' + """ + super().__init__() + self.weight = weight + self.reduction = reduction + self.eps = 1e-8 # For numerical stability + + def forward(self, pred, target): + """ + Args: + pred: Model predictions + target: Ground truth + + Returns: + Loss value (scalar if reduction != 'none') + """ + # 1. Input validation + assert pred.shape == target.shape, "Shape mismatch" + assert not torch.isnan(pred).any(), "pred contains NaN" + + # 2. Compute element-wise loss + loss = self.compute_loss(pred, target) + + # 3. Apply sample weights if provided + if self.weight is not None: + loss = loss * self.weight + + # 4. Apply reduction + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + def compute_loss(self, pred, target): + """Override this method with your loss computation""" + # Example: MSE + return (pred - target) ** 2 +``` + +### Example 1: Dice Loss (Segmentation) + +```python +class DiceLoss(nn.Module): + """ + Dice Loss for segmentation tasks + + Dice = 2 * |X ∩ Y| / (|X| + |Y|) + Loss = 1 - Dice + + Good for: + - Binary segmentation + - Handling class imbalance + - Smooth gradients + """ + def __init__(self, smooth=1.0): + super().__init__() + self.smooth = smooth # Prevent division by zero + + def forward(self, pred, target): + """ + Args: + pred: (batch, C, H, W) logits + target: (batch, C, H, W) binary masks + """ + # Apply sigmoid to get probabilities + pred = torch.sigmoid(pred) + + # Flatten spatial dimensions + pred = pred.view(pred.size(0), pred.size(1), -1) # (batch, C, H*W) + target = target.view(target.size(0), target.size(1), -1) + + # Compute dice per sample and per class + intersection = (pred * target).sum(dim=2) # (batch, C) + union = pred.sum(dim=2) + target.sum(dim=2) # (batch, C) + + dice = (2 * intersection + self.smooth) / (union + self.smooth) + + # Average over classes and batch + return 1 - dice.mean() + +# Usage +criterion = DiceLoss(smooth=1.0) +loss = criterion(logits, masks) + +# Often combined with BCE: +dice_loss = DiceLoss() +bce_loss = nn.BCEWithLogitsLoss() + +total_loss = 0.5 * dice_loss(logits, masks) + 0.5 * bce_loss(logits, masks) +``` + +### Example 2: Focal Loss (Imbalanced Classification) + +```python +class FocalLoss(nn.Module): + """ + Focal Loss for addressing class imbalance + + FL = -α * (1 - p)^γ * log(p) + + - α: class balancing weight + - γ: focusing parameter (typical: 2.0) + - (1-p)^γ: down-weights easy examples + + Good for: + - Highly imbalanced datasets (e.g., object detection) + - Many easy negatives, few hard positives + - When class weights aren't enough + """ + def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, logits, target): + """ + Args: + logits: (batch, num_classes) raw logits + target: (batch,) class indices + """ + # Compute cross entropy + ce_loss = F.cross_entropy(logits, target, reduction='none') + + # Compute pt = e^(-CE) = probability of true class + pt = torch.exp(-ce_loss) + + # Compute focal loss + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss + + # Apply reduction + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + +# Usage +criterion = FocalLoss(alpha=1.0, gamma=2.0) +loss = criterion(logits, target) + +# Effect of gamma: +# γ = 0: equivalent to CrossEntropyLoss +# γ = 2: typical value, strong down-weighting of easy examples +# γ = 5: extreme focusing, only hardest examples matter + +# Example probability and loss weights: +# pt = 0.9 (easy): (1-0.9)^2 = 0.01 → 1% weight +# pt = 0.5 (medium): (1-0.5)^2 = 0.25 → 25% weight +# pt = 0.1 (hard): (1-0.1)^2 = 0.81 → 81% weight +``` + +### Example 3: Contrastive Loss (Metric Learning) + +```python +class ContrastiveLoss(nn.Module): + """ + Contrastive Loss for learning embeddings + + Pulls similar pairs together, pushes dissimilar pairs apart + + Good for: + - Face recognition + - Similarity learning + - Few-shot learning + """ + def __init__(self, margin=1.0): + super().__init__() + self.margin = margin + + def forward(self, embedding1, embedding2, label): + """ + Args: + embedding1: (batch, embedding_dim) first embeddings + embedding2: (batch, embedding_dim) second embeddings + label: (batch,) 1 if similar, 0 if dissimilar + """ + # Euclidean distance + distance = F.pairwise_distance(embedding1, embedding2) + + # Loss for similar pairs: want distance = 0 + loss_similar = label * distance.pow(2) + + # Loss for dissimilar pairs: want distance ≥ margin + loss_dissimilar = (1 - label) * F.relu(self.margin - distance).pow(2) + + loss = loss_similar + loss_dissimilar + return loss.mean() + +# Usage +criterion = ContrastiveLoss(margin=1.0) + +for (img1, img2, is_similar) in train_loader: + emb1 = model(img1) + emb2 = model(img2) + loss = criterion(emb1, emb2, is_similar) +``` + +### Example 4: Perceptual Loss (Style Transfer, Super-Resolution) + +```python +class PerceptualLoss(nn.Module): + """ + Perceptual Loss using VGG features + + Compares high-level features instead of pixels + + Good for: + - Image generation + - Super-resolution + - Style transfer + """ + def __init__(self, layer='relu3_3'): + super().__init__() + # Load pre-trained VGG + vgg = torchvision.models.vgg16(pretrained=True).features + self.vgg = vgg.eval() + + # Freeze VGG + for param in self.vgg.parameters(): + param.requires_grad = False + + # Select layer + self.layer_map = { + 'relu1_2': 4, + 'relu2_2': 9, + 'relu3_3': 16, + 'relu4_3': 23, + } + self.layer_idx = self.layer_map[layer] + + def forward(self, pred, target): + """ + Args: + pred: (batch, 3, H, W) predicted images + target: (batch, 3, H, W) target images + """ + # Extract features + pred_features = self.extract_features(pred) + target_features = self.extract_features(target) + + # MSE in feature space + loss = F.mse_loss(pred_features, target_features) + return loss + + def extract_features(self, x): + """Extract features from VGG layer""" + for i, layer in enumerate(self.vgg): + x = layer(x) + if i == self.layer_idx: + return x + return x + +# Usage +perceptual_loss = PerceptualLoss(layer='relu3_3') +pixel_loss = nn.L1Loss() + +# Combine pixel and perceptual loss +total_loss = pixel_loss(pred, target) + 0.1 * perceptual_loss(pred, target) +``` + +### Example 5: Custom Weighted MSE + +```python +class WeightedMSELoss(nn.Module): + """ + MSE with per-element importance weighting + + Good for: + - Focusing on important regions (e.g., foreground) + - Time-series with different importance + - Confidence-weighted regression + """ + def __init__(self): + super().__init__() + + def forward(self, pred, target, weight): + """ + Args: + pred: (batch, ...) predictions + target: (batch, ...) targets + weight: (batch, ...) importance weights (0-1) + """ + # Element-wise squared error + squared_error = (pred - target) ** 2 + + # Weight by importance + weighted_error = squared_error * weight + + # Average only over weighted elements + # (avoid counting zero-weight elements) + loss = weighted_error.sum() / (weight.sum() + 1e-8) + + return loss + +# Usage example: Foreground-focused loss +criterion = WeightedMSELoss() + +# Create importance map (1.0 for foreground, 0.1 for background) +weight = torch.where(mask > 0.5, torch.tensor(1.0), torch.tensor(0.1)) + +loss = criterion(pred, target, weight) +``` + + +## Section 8: Advanced Loss Techniques + +### Technique 1: Label Smoothing + +```python +# Problem: Hard labels [0, 0, 1, 0, 0] cause overconfident predictions +# Solution: Soft labels [0.025, 0.025, 0.9, 0.025, 0.025] + +# PyTorch 1.10+ built-in support +loss = F.cross_entropy(logits, target, label_smoothing=0.1) + +# What it does: +# Original: y = [0, 0, 1, 0, 0] +# Smoothed: y = (1-α)*[0, 0, 1, 0, 0] + α*[0.2, 0.2, 0.2, 0.2, 0.2] +# = [0.02, 0.02, 0.92, 0.02, 0.02] (for α=0.1, num_classes=5) + +# Manual implementation (for understanding): +class LabelSmoothingLoss(nn.Module): + def __init__(self, num_classes, smoothing=0.1): + super().__init__() + self.num_classes = num_classes + self.smoothing = smoothing + self.confidence = 1.0 - smoothing + + def forward(self, logits, target): + """ + logits: (batch, num_classes) + target: (batch,) class indices + """ + log_probs = F.log_softmax(logits, dim=1) + + # Create smooth labels + smooth_labels = torch.zeros_like(log_probs) + smooth_labels.fill_(self.smoothing / (self.num_classes - 1)) + smooth_labels.scatter_(1, target.unsqueeze(1), self.confidence) + + # NLL with smooth labels + loss = (-smooth_labels * log_probs).sum(dim=1) + return loss.mean() + +# Benefits: +# 1. Better calibration (confidence closer to accuracy) +# 2. Prevents overconfidence +# 3. Acts as regularization +# 4. Often improves test accuracy by 0.5-1% + +# When to use: +# ✅ Classification with CrossEntropyLoss +# ✅ Large models prone to overfitting +# ✅ Clean labels (not noisy) +# ❌ Small models (might hurt performance) +# ❌ Noisy labels (already have uncertainty) +``` + +### Technique 2: Class-Balanced Loss + +```python +# Problem: 1000 samples class 0, 10 samples class 1 +# Standard CE treats all samples equally → biased to class 0 + +# Solution 1: Inverse frequency weighting +class_counts = torch.bincount(train_labels) +class_weights = 1.0 / class_counts.float() +class_weights = class_weights / class_weights.sum() * len(class_weights) + +loss = F.cross_entropy(logits, target, weight=class_weights) + +# Solution 2: Effective number of samples (better for extreme imbalance) +def get_eff_num_weights(num_samples_per_class, beta=0.999): + """ + Effective number of samples: (1 - β^n) / (1 - β) + + Handles extreme imbalance better than inverse frequency + + Args: + num_samples_per_class: [n1, n2, ..., nC] + beta: Hyperparameter (0.99-0.9999), higher for more imbalance + """ + effective_num = 1.0 - torch.pow(beta, num_samples_per_class) + weights = (1.0 - beta) / effective_num + weights = weights / weights.sum() * len(weights) + return weights + +# Usage +class_counts = torch.bincount(train_labels) +weights = get_eff_num_weights(class_counts.float(), beta=0.9999) +loss = F.cross_entropy(logits, target, weight=weights) + +# Solution 3: Focal loss (see Example 2 in Section 7) +``` + +### Technique 3: Mixup / CutMix Loss + +```python +# Mixup: Blend two samples and their labels +def mixup_data(x, y, alpha=1.0): + """ + Args: + x: (batch, ...) input + y: (batch,) labels + alpha: Mixup parameter + """ + lam = np.random.beta(alpha, alpha) + batch_size = x.size(0) + index = torch.randperm(batch_size) + + mixed_x = lam * x + (1 - lam) * x[index] + y_a, y_b = y, y[index] + + return mixed_x, y_a, y_b, lam + +def mixup_criterion(pred, y_a, y_b, lam): + """Compute mixed loss""" + return lam * F.cross_entropy(pred, y_a) + (1 - lam) * F.cross_entropy(pred, y_b) + +# Training with Mixup +for x, y in train_loader: + x, y_a, y_b, lam = mixup_data(x, y, alpha=1.0) + + optimizer.zero_grad() + pred = model(x) + loss = mixup_criterion(pred, y_a, y_b, lam) + loss.backward() + optimizer.step() + +# Benefits: +# - Regularization +# - Better generalization +# - Smooth decision boundaries +# - +1-2% accuracy on CIFAR/ImageNet +``` + +### Technique 4: Gradient Clipping for Loss Stability + +```python +# Problem: Loss spikes to NaN during training +# Often caused by exploding gradients + +# Solution: Clip gradients before optimizer step +for x, y in train_loader: + optimizer.zero_grad() + pred = model(x) + loss = criterion(pred, y) + loss.backward() + + # Clip gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # Or clip by value: + # torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5) + + optimizer.step() + +# When to use: +# ✅ RNNs/LSTMs (prone to exploding gradients) +# ✅ Transformers with high learning rates +# ✅ Loss occasionally spikes to NaN +# ✅ Large models or deep networks +# ❌ Stable training (unnecessary overhead) + +# How to choose max_norm: +# - Start with 1.0 +# - If still unstable, reduce to 0.5 +# - Monitor: print gradient norms to see if clipping activates +``` + +### Technique 5: Loss Scaling for Mixed Precision + +```python +# Problem: Mixed precision (FP16) can cause gradients to underflow +# Solution: Scale loss up, then scale gradients down + +from torch.cuda.amp import autocast, GradScaler + +scaler = GradScaler() + +for x, y in train_loader: + optimizer.zero_grad() + + # Forward in FP16 + with autocast(): + pred = model(x) + loss = criterion(pred, y) + + # Scale loss and backward + scaler.scale(loss).backward() + + # Unscale gradients and step + scaler.step(optimizer) + scaler.update() + +# GradScaler automatically: +# 1. Scales loss by factor (e.g., 65536) +# 2. Backprop computes scaled gradients +# 3. Unscales gradients before optimizer step +# 4. Adjusts scale factor dynamically +``` + + +## Section 9: Common Loss Function Pitfalls + +### Pitfall 1: BCE Instead of BCEWithLogitsLoss + +```python +# ❌ WRONG (seen in 30% of beginner code!) +probs = torch.sigmoid(logits) +loss = F.binary_cross_entropy(probs, target) + +# ✅ RIGHT +loss = F.binary_cross_entropy_with_logits(logits, target) + +# Impact: Training instability, NaN losses, worse performance +# Fix time: 2 minutes +# Performance gain: Stable training, +2-5% accuracy +``` + +### Pitfall 2: Softmax Before CrossEntropyLoss + +```python +# ❌ WRONG (seen in 20% of beginner code!) +probs = F.softmax(logits, dim=1) +loss = F.cross_entropy(probs, target) + +# ✅ RIGHT +loss = F.cross_entropy(logits, target) # Expects logits! + +# Impact: Suboptimal learning, double softmax +# Fix time: 1 minute +# Performance gain: +1-3% accuracy +``` + +### Pitfall 3: Wrong Target Shape for CrossEntropyLoss + +```python +# ❌ WRONG: One-hot encoded targets +target = F.one_hot(labels, num_classes=10) # (batch, 10) +loss = F.cross_entropy(logits, target) # Type error! + +# ✅ RIGHT: Class indices +target = labels # (batch,) with values in [0, 9] +loss = F.cross_entropy(logits, target) + +# Impact: Runtime error or wrong loss computation +# Fix time: 2 minutes +``` + +### Pitfall 4: Ignoring Class Imbalance + +```python +# ❌ WRONG: 95% negative, 5% positive +loss = F.binary_cross_entropy_with_logits(logits, target) +# Model predicts all negative → 95% accuracy but useless! + +# ✅ RIGHT: Weight positive class +pos_weight = torch.tensor([19.0]) # 95/5 +loss = F.binary_cross_entropy_with_logits(logits, target, pos_weight=pos_weight) + +# Impact: Model learns trivial predictor +# Fix time: 5 minutes +# Performance gain: From useless to actually working +``` + +### Pitfall 5: Not Normalizing Regression Targets + +```python +# ❌ WRONG: Targets in [1000, 10000], predictions in [0, 1] +loss = F.mse_loss(pred, target) # Huge loss, bad gradients + +# ✅ RIGHT: Normalize targets +target_norm = (target - target.mean()) / target.std() +loss = F.mse_loss(pred, target_norm) + +# Impact: Slow convergence, high loss values, need very small LR +# Fix time: 5 minutes +# Performance gain: 10-100x faster convergence +``` + +### Pitfall 6: Unweighted Multi-Task Loss + +```python +# ❌ WRONG: Different scales +loss1 = F.cross_entropy(out1, target1) # ~0.5 +loss2 = F.mse_loss(out2, target2) # ~500.0 +total = loss1 + loss2 # Task 2 dominates! + +# ✅ RIGHT: Balance scales +total = 1.0 * loss1 + 0.001 * loss2 # Both ~0.5 + +# Impact: One task learns, others ignored +# Fix time: 10 minutes (trial and error) +# Performance gain: All tasks learn instead of one +``` + +### Pitfall 7: Division by Zero in Custom Loss + +```python +# ❌ WRONG: No epsilon +iou = intersection / union # Division by zero! + +# ✅ RIGHT: Add epsilon +eps = 1e-8 +iou = (intersection + eps) / (union + eps) + +# Impact: NaN losses, training crash +# Fix time: 2 minutes +``` + +### Pitfall 8: Missing optimizer.zero_grad() + +```python +# ❌ WRONG: Gradients accumulate! +for x, y in train_loader: + loss = criterion(model(x), y) + loss.backward() + optimizer.step() # Missing zero_grad! + +# ✅ RIGHT: Reset gradients +for x, y in train_loader: + optimizer.zero_grad() # ✅ Critical! + loss = criterion(model(x), y) + loss.backward() + optimizer.step() + +# Impact: Loss doesn't decrease, weird behavior +# Fix time: 1 minute +# This is caught by systematic debugging +``` + +### Pitfall 9: Wrong Reduction for Custom Loss + +```python +# ❌ SUBOPTIMAL: Sum over batch +loss = (pred - target).pow(2).sum() # Loss scales with batch size! + +# ✅ BETTER: Mean over batch +loss = (pred - target).pow(2).mean() # Loss independent of batch size + +# Impact: Learning rate depends on batch size +# Fix time: 2 minutes + +# When to use sum vs mean: +# - mean: Default, loss independent of batch size +# - sum: When you want loss to scale with batch size (rare) +# - none: When you want per-sample losses (for weighting) +``` + +### Pitfall 10: Using Accuracy for Imbalanced Data + +```python +# ❌ WRONG: 95-5 imbalance +accuracy = (pred == target).float().mean() # 95% for trivial predictor! + +# ✅ RIGHT: Use F1, precision, recall +from sklearn.metrics import f1_score, precision_score, recall_score + +f1 = f1_score(target, pred) # Balanced metric +precision = precision_score(target, pred) +recall = recall_score(target, pred) + +# Or use balanced accuracy: +balanced_acc = (recall_class0 + recall_class1) / 2 + +# Impact: Misinterpreting model performance +# Fix time: 5 minutes +``` + + +## Section 10: Loss Debugging Methodology + +### When Loss is NaN + +```python +# Step 1: Check inputs for NaN +print(f"Input has NaN: {torch.isnan(x).any()}") +print(f"Target has NaN: {torch.isnan(target).any()}") + +if torch.isnan(x).any(): + # Data loading issue + print("Fix: Check data preprocessing") + +# Step 2: Check for numerical instability in loss +# - Division by zero +# - Log of zero or negative +# - Exp overflow + +# Step 3: Check gradients before NaN +for name, param in model.named_parameters(): + if param.grad is not None: + grad_norm = param.grad.norm() + print(f"{name}: {grad_norm.item()}") + if grad_norm > 1000: + print(f"Exploding gradient in {name}") + +# Step 4: Add gradient clipping +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# Step 5: Lower learning rate +# LR too high can cause NaN + +# Step 6: Check loss computation +# Add assertions in custom loss: +def custom_loss(pred, target): + loss = compute_loss(pred, target) + assert not torch.isnan(loss), f"Loss is NaN, pred range: [{pred.min()}, {pred.max()}]" + assert not torch.isinf(loss), f"Loss is inf" + return loss +``` + +### When Loss Not Decreasing + +```python +# Systematic debugging checklist: + +# 1. Check loss value +print(f"Loss: {loss.item()}") +# - Is it reasonable? (CE should be ~ln(num_classes) initially) +# - Is it constant? (optimizer not stepping) +# - Is it very high? (wrong scale) + +# 2. Check gradients +for name, param in model.named_parameters(): + if param.grad is not None: + print(f"{name} grad: mean={param.grad.abs().mean():.6f}, max={param.grad.abs().max():.6f}") + +# If all gradients ~ 0: +# → Vanishing gradients (check activation functions, initialization) +# If gradients very large (>10): +# → Exploding gradients (add gradient clipping, lower LR) +# If no gradients printed: +# → Missing loss.backward() or parameters not requiring grad + +# 3. Check predictions +print(f"Pred range: [{pred.min():.4f}, {pred.max():.4f}]") +print(f"Target range: [{target.min():.4f}, {target.max():.4f}]") +print(f"Pred mean: {pred.mean():.4f}, Target mean: {target.mean():.4f}") + +# If predictions are constant: +# → Model not learning (check optimizer.step(), zero_grad()) +# If predictions are random: +# → Model learning but task too hard or wrong loss +# If pred/target ranges very different: +# → Normalization issue + +# 4. Verify training setup +print(f"Model training mode: {model.training}") # Should be True +print(f"Requires grad: {next(model.parameters()).requires_grad}") # Should be True + +# Check optimizer.zero_grad() is called +# Check loss.backward() is called +# Check optimizer.step() is called + +# 5. Check learning rate +print(f"Learning rate: {optimizer.param_groups[0]['lr']}") +# Too low (< 1e-6): Won't learn +# Too high (> 1e-2): Unstable + +# 6. Verify loss function matches task +# Classification → CrossEntropyLoss +# Regression → MSELoss or SmoothL1Loss +# Binary classification → BCEWithLogitsLoss + +# 7. Check data +# Visualize a batch: +print(f"Batch input shape: {x.shape}") +print(f"Batch target shape: {target.shape}") +print(f"Target unique values: {target.unique()}") + +# Are labels correct? +# Is data normalized? +# Any NaN in data? + +# 8. Overfit single batch +# Can model fit one batch perfectly? +single_x, single_y = next(iter(train_loader)) + +for i in range(1000): + optimizer.zero_grad() + pred = model(single_x) + loss = criterion(pred, single_y) + loss.backward() + optimizer.step() + + if i % 100 == 0: + print(f"Step {i}: Loss = {loss.item():.4f}") + +# If can't overfit single batch: +# → Model architecture issue +# → Loss function wrong +# → Bug in training loop +``` + +### When Loss Stuck at Same Value + +```python +# Scenario: Loss stays at 0.693 for binary classification (ln(2)) + +# Diagnosis: Model predicting 0.5 probability for all samples + +# Possible causes: + +# 1. Learning rate too low +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Try 1e-3, 1e-4 + +# 2. Dead neurons (all ReLU outputs are 0) +# Check activations: +activations = model.fc1(x) +print(f"Activations: {activations.abs().mean()}") +if activations.abs().mean() < 0.01: + print("Dead neurons! Try:") + print("- Different initialization") + print("- LeakyReLU instead of ReLU") + print("- Lower learning rate") + +# 3. Gradient flow blocked +# Check each layer: +for name, param in model.named_parameters(): + if param.grad is not None: + print(f"{name}: {param.grad.abs().mean():.6f}") + else: + print(f"{name}: NO GRADIENT!") + +# 4. Wrong optimizer state (if resuming training) +# Solution: Create fresh optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# 5. Model too simple for task +# Try: Larger model, more layers, more parameters + +# 6. Task is actually random +# Check: Can humans solve this task? +# Check: Is there signal in the data? +``` + +### When Loss Oscillating / Unstable + +```python +# Scenario: Loss jumps around: 0.5 → 2.0 → 0.3 → 5.0 → ... + +# Diagnosis: Unstable training + +# Possible causes: + +# 1. Learning rate too high +# Solution: Lower LR by 10x +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Down from 1e-3 + +# 2. Batch size too small +# Solution: Increase batch size (more stable gradients) +train_loader = DataLoader(dataset, batch_size=64) # Up from 32 + +# 3. No gradient clipping +# Solution: Clip gradients +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + +# 4. Numerical instability in loss +# Solution: Use stable loss functions +# BCEWithLogitsLoss instead of BCE +# Add epsilon to custom losses + +# 5. Data outliers +# Solution: +# - Remove outliers +# - Use robust loss (L1, SmoothL1, Huber) +# - Clip targets to reasonable range + +# 6. Exploding gradients +# Check and clip: +total_norm = 0 +for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.norm().item() ** 2 +total_norm = total_norm ** 0.5 +print(f"Gradient norm: {total_norm}") + +if total_norm > 10: + print("Exploding gradients! Add gradient clipping.") +``` + + +## Rationalization Prevention Table + +| Rationalization | Why It's Wrong | What You Must Do | +|----------------|----------------|------------------| +| "BCE is simpler than BCEWithLogitsLoss" | BCE is numerically unstable, causes NaN | **ALWAYS use BCEWithLogitsLoss**. Non-negotiable. | +| "Loss weighting is just extra hyperparameter tuning" | Unweighted multi-task losses fail completely | **Check loss scales, weight them**. One task will dominate otherwise. | +| "The optimizer will figure out the scale differences" | Optimizers don't balance losses, they follow gradients | **Manual balance required**. SGD sees gradient magnitude, not task importance. | +| "95% accuracy is great!" | With 95-5 imbalance, this is trivial predictor | **Check F1/precision/recall**. Accuracy misleading for imbalanced data. | +| "Data is clean, no need for epsilon" | Even clean data can hit edge cases (empty masks, zeros) | **Add epsilon anyway**. Cost is negligible, prevents NaN. | +| "Softmax before CE makes output clearer" | CE applies softmax internally, this causes double softmax | **Pass logits to CE**. Never apply softmax first. | +| "One-hot encoding is more standard" | CrossEntropyLoss expects class indices, not one-hot | **Use class indices**. Shape must be (batch,) not (batch, C). | +| "Reduction parameter is optional" | Controls how loss aggregates, affects training dynamics | **Understand and choose**: mean (default), sum (rare), none (per-sample). | +| "Just lower LR to fix NaN" | NaN usually from numerical instability, not LR | **Fix root cause first**: epsilon, clipping, stable loss. Then adjust LR. | +| "Papers use different loss, I should too" | Papers don't always use optimal losses, context matters | **Evaluate if appropriate** for your data/task. Don't blindly copy. | +| "Custom loss is more flexible" | Built-in losses are optimized and tested | **Use built-ins when possible**. Only custom when necessary. | +| "Loss function doesn't matter much" | Loss is THE OBJECTIVE your model optimizes | **Loss choice is critical**. Wrong loss = optimizing wrong thing. | +| "I'll tune loss later" | Loss should match task from the start | **Choose correct loss immediately**. Tuning won't fix fundamentally wrong loss. | +| "Focal loss is always better for imbalance" | Focal loss has hyperparameters, can hurt if tuned wrong | **Try class weights first** (simpler, fewer hyperparameters). | +| "Division by zero won't happen in practice" | Edge cases happen: empty batches, all-zero masks | **Defensive programming**: always add epsilon to denominators. | + + +## Red Flags Checklist + +When reviewing loss function code, watch for these RED FLAGS: + +### Critical (Fix Immediately): + +- [ ] Using `F.binary_cross_entropy` instead of `F.binary_cross_entropy_with_logits` +- [ ] Applying `sigmoid` or `softmax` before stable loss (BCEWithLogitsLoss, CrossEntropyLoss) +- [ ] Division without epsilon: `x / y` instead of `x / (y + 1e-8)` +- [ ] Log without clamping: `torch.log(x)` instead of `torch.log(torch.clamp(x, min=1e-8))` +- [ ] Missing `optimizer.zero_grad()` in training loop +- [ ] Multi-task losses added without weighting (different scales) +- [ ] Loss goes to NaN during training + +### Important (Fix Soon): + +- [ ] Class imbalance ignored (no `weight` or `pos_weight` parameter) +- [ ] Regression targets not normalized (huge loss values) +- [ ] Wrong target shape for CrossEntropyLoss (one-hot instead of indices) +- [ ] Custom loss without numerical stability checks +- [ ] Using accuracy metric for highly imbalanced data +- [ ] No gradient clipping for RNNs/Transformers +- [ ] Reduction not specified in custom loss + +### Best Practices (Improve): + +- [ ] No label smoothing for classification (consider adding) +- [ ] No focal loss for extreme imbalance (>100:1 ratio) +- [ ] Not monitoring individual task losses in multi-task learning +- [ ] Not checking gradient norms during training +- [ ] No assertions in custom loss for debugging +- [ ] Not testing loss function on toy data first + + +## Summary: Loss Function Selection Flowchart + +``` +START + | + ├─ Binary Classification? + | → BCEWithLogitsLoss + pos_weight for imbalance + | + ├─ Multi-Class Classification? + | → CrossEntropyLoss + class weights for imbalance + | → Consider Focal Loss if extreme imbalance (>100:1) + | + ├─ Multi-Label Classification? + | → BCEWithLogitsLoss + per-class pos_weight + | + ├─ Regression? + | → SmoothL1Loss (good default) + | → MSELoss if no outliers + | → L1Loss if robust to outliers needed + | → ALWAYS normalize targets! + | + ├─ Segmentation? + | → BCEWithLogitsLoss + DiceLoss (combine both) + | → Consider Focal Loss for imbalanced pixels + | + ├─ Ranking/Similarity? + | → TripletMarginLoss or ContrastiveLoss + | + └─ Multi-Task? + → Combine with careful weighting + → Start with manual balance (check scales!) + → Consider uncertainty weighting + +ALWAYS: + ✅ Use logits (no sigmoid/softmax before stable losses) + ✅ Add epsilon to divisions and before log/sqrt + ✅ Check for class/label imbalance + ✅ Normalize regression targets + ✅ Monitor loss values and gradients + ✅ Test loss on toy data first + +NEVER: + ❌ Use BCE instead of BCEWithLogitsLoss + ❌ Apply softmax before CrossEntropyLoss + ❌ Ignore different scales in multi-task + ❌ Divide without epsilon + ❌ Trust accuracy alone for imbalanced data +``` + + +## Final Checklist Before Training + +Before starting training, verify: + +1. **Loss Function Matches Task:** + - [ ] Binary classification → BCEWithLogitsLoss + - [ ] Multi-class → CrossEntropyLoss + - [ ] Regression → SmoothL1Loss or MSE with normalized targets + +2. **Numerical Stability:** + - [ ] Using stable loss (BCEWithLogitsLoss, not BCE) + - [ ] Epsilon in divisions: `x / (y + 1e-8)` + - [ ] Clamp before log: `torch.log(torch.clamp(x, min=1e-8))` + +3. **Class Imbalance Handled:** + - [ ] Checked class distribution + - [ ] Added `weight` or `pos_weight` if imbalanced + - [ ] Using F1/precision/recall metrics, not just accuracy + +4. **Multi-Task Weighting:** + - [ ] Checked loss scales (printed first batch) + - [ ] Added manual weights or uncertainty weighting + - [ ] Monitoring individual task metrics + +5. **Target Preparation:** + - [ ] CrossEntropyLoss: targets are class indices (batch,) + - [ ] BCEWithLogitsLoss: targets are 0/1 floats + - [ ] Regression: targets normalized to similar scale as predictions + +6. **Training Loop:** + - [ ] `optimizer.zero_grad()` before backward + - [ ] `loss.backward()` to compute gradients + - [ ] `optimizer.step()` to update parameters + - [ ] Gradient clipping if using RNN/Transformer + +7. **Debugging Setup:** + - [ ] Can print loss value: `loss.item()` + - [ ] Can print gradient norms + - [ ] Can visualize predictions vs targets + - [ ] Have tested overfitting single batch + + +## When to Seek Help + +If after following this skill you still have loss issues: + +1. **Loss is NaN:** + - Checked all numerical stability issues? + - Added epsilon everywhere? + - Tried gradient clipping? + - Lowered learning rate? + → If still NaN, may need architecture change or data investigation + +2. **Loss not decreasing:** + - Verified training loop is correct (zero_grad, backward, step)? + - Checked gradients are flowing? + - Tried overfitting single batch? + - Verified loss function matches task? + → If still not decreasing, may be model capacity or data issue + +3. **Loss decreasing but metrics poor:** + - Is loss the right objective for your metric? + - Example: CE minimizes NLL, not accuracy + - Consider metric-aware loss or post-hoc calibration + +4. **Multi-task learning not working:** + - Tried multiple weighting strategies? + - Monitored individual task losses? + - Ensured all tasks getting gradient signal? + → May need task-specific heads or curriculum learning + +Remember: Loss function is the heart of deep learning. Get this right first before tuning everything else. + + +## Additional Resources + +**Key Papers:** +- Focal Loss: "Focal Loss for Dense Object Detection" (Lin et al., 2017) +- Label Smoothing: "Rethinking the Inception Architecture" (Szegedy et al., 2016) +- Uncertainty Weighting: "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018) +- Class-Balanced Loss: "Class-Balanced Loss Based on Effective Number of Samples" (Cui et al., 2019) + +**PyTorch Documentation:** +- Loss Functions: https://pytorch.org/docs/stable/nn.html#loss-functions +- Numerical Stability: Use built-in combined operations (BCEWithLogitsLoss, etc.) + +**Common Loss Functions Quick Reference:** +```python +# Classification +F.binary_cross_entropy_with_logits(logits, target, pos_weight=...) +F.cross_entropy(logits, target, weight=..., label_smoothing=...) +F.nll_loss(log_probs, target) # If you already have log_probs + +# Regression +F.mse_loss(pred, target) +F.l1_loss(pred, target) +F.smooth_l1_loss(pred, target, beta=1.0) +F.huber_loss(pred, target, delta=1.0) # PyTorch 1.10+ + +# Ranking +F.margin_ranking_loss(input1, input2, target, margin=0.0) +F.triplet_margin_loss(anchor, positive, negative, margin=1.0) +F.cosine_embedding_loss(input1, input2, target) + +# Other +F.kl_div(log_probs, target_probs) # KL divergence +F.poisson_nll_loss(log_input, target) # Poisson regression +``` + + +**END OF SKILL** + +When you use this skill, you become an expert in loss function selection and implementation. You will: +- Choose the correct loss for any task +- Ensure numerical stability +- Handle class imbalance appropriately +- Weight multi-task losses correctly +- Debug loss issues systematically +- Avoid all common loss function pitfalls + +Remember: The loss function IS your model's objective. Get this right, and everything else follows. diff --git a/skills/using-training-optimization/optimization-algorithms.md b/skills/using-training-optimization/optimization-algorithms.md new file mode 100644 index 0000000..0ef0c06 --- /dev/null +++ b/skills/using-training-optimization/optimization-algorithms.md @@ -0,0 +1,1832 @@ + +# Optimization Algorithms + +## Overview + +This skill provides systematic guidance for selecting and configuring neural network optimizers. There is NO single "best" optimizer—the choice depends on your task, model architecture, batch size, and performance goals. This skill teaches you how to make informed optimizer decisions and avoid common pitfalls like using Adam with weight decay (use AdamW instead). + +**Core Principle**: Optimizer selection is a decision, not a default. Different optimizers have different convergence properties, final performance characteristics, and hyperparameter requirements. Use systematic decision frameworks, not cargo cult defaults. + +**CRITICAL**: Adam and AdamW are NOT interchangeable. AdamW implements correct weight decay; Adam's weight_decay parameter is broken. Always use AdamW when you need weight decay regularization. + +## When to Use This Skill + +Load this skill when: +- Selecting optimizer for new training pipeline (SGD vs Adam vs AdamW vs others) +- Configuring optimizer hyperparameters (learning rate, momentum, betas, weight decay) +- Training not working with current optimizer (loss not decreasing, instability, NaN values) +- Comparing optimizer options for specific task (vision vs NLP vs RL) +- Understanding Adam vs AdamW difference (weight decay handling) +- Debugging optimizer-related issues (wrong LR range, incorrect parameters) +- Switching from one optimizer to another (requires re-tuning) +- Setting up distributed/large-batch training (LAMB, LARS considerations) +- Reproducing paper results with different optimizer + +**Don't use for**: Learning rate schedules (use learning-rate-scheduling), gradient issues (use gradient-management), general training debugging (use using-training-optimization to route), neural architecture selection (use neural-architectures) + + +## Optimizer Selection Decision Framework + +### The Core Question: "Which optimizer should I use?" + +**WRONG ANSWER**: "Use Adam/AdamW, it's the best." + +**RIGHT APPROACH**: Ask clarifying questions and use decision framework. + +### Clarifying Questions to Ask + +Before recommending an optimizer, ask: + +1. **"What type of model are you training?"** + - CNN/ResNet/ConvNet → SGD often better + - Transformer/BERT/GPT → AdamW standard + - RNN/LSTM → Adam/RMSprop + - Reinforcement learning → Adam (usually) + +2. **"What's your batch size and hardware setup?"** + - Large batch (>512) → SGD works well + - Small batch (<32) → Adam often better + - Distributed (>64 GPUs, batch >8K) → Consider LAMB/LARS + +3. **"What matters more: fast initial convergence or best final performance?"** + - Fast convergence → Adam/AdamW + - Best final performance → SGD (often, but slower) + - Balanced → Try both, tune properly + +4. **"Do you need weight decay regularization?"** + - Yes → AdamW (NOT Adam) + - No → Adam or SGD fine + +5. **"How much time do you have for hyperparameter tuning?"** + - Limited → Adam (more forgiving) + - Extensive → SGD (can achieve better results with tuning) + +### Decision Tree + +``` +START: Selecting optimizer for training + +├─ Are you training a TRANSFORMER (BERT, GPT, ViT, T5)? +│ ├─ YES → **AdamW** (99% of the time) +│ │ LR: 1e-4 to 5e-4, weight_decay: 0.01-0.1 +│ │ Betas: (0.9, 0.999) or (0.9, 0.98) for very long training +│ │ This is the modern standard. +│ │ +│ └─ NO → Continue... +│ +├─ Are you training a CNN for VISION (ResNet, EfficientNet, ConvNeXt)? +│ ├─ YES → **SGD with Nesterov momentum** (recommended) +│ │ LR: 0.1 with cosine decay, momentum: 0.9, weight_decay: 1e-4 to 5e-4 +│ │ nesterov=True +│ │ Better final performance than Adam for vision tasks. +│ │ Alternative: AdamW if training time is limited or batch size is small. +│ │ +│ └─ NO → Continue... +│ +├─ Are you training a VISION TRANSFORMER (ViT, Swin, DeiT)? +│ ├─ YES → **AdamW** +│ │ LR: 1e-3 to 5e-4, weight_decay: 0.05-0.1 +│ │ Vision transformers follow transformer best practices. +│ │ +│ └─ NO → Continue... +│ +├─ Are you training an RNN or LSTM? +│ ├─ YES → **Adam** or **RMSprop** +│ │ Adam LR: 1e-3 to 3e-4 +│ │ RMSprop LR: 1e-3 (historical choice, less common now) +│ │ Adam more common in modern work. +│ │ +│ └─ NO → Continue... +│ +├─ Are you training a REINFORCEMENT LEARNING policy? +│ ├─ YES → **Adam** +│ │ LR: 3e-4 (standard in RL) +│ │ Betas: (0.9, 0.999) +│ │ weight_decay: 0 (usually) +│ │ +│ └─ NO → Continue... +│ +├─ Are you doing LARGE-BATCH distributed training (batch size > 8K)? +│ ├─ YES → Consider **LAMB** (for transformers) or **LARS** (for vision) +│ │ These are specialized optimizers for very large batch training. +│ │ Most users won't need these. +│ │ Still need linear LR scaling and warmup. +│ │ +│ └─ NO → Continue... +│ +├─ Do you just need a QUICK BASELINE? +│ ├─ YES → **Adam** or **AdamW** +│ │ LR: 1e-3 (starting point) +│ │ Fast initial convergence, easy to get started. +│ │ AdamW if you want weight decay. +│ │ +│ └─ NO → Continue... +│ +└─ DEFAULT: Start with **AdamW** + LR: 1e-3, weight_decay: 0.01 + Tune from there based on results. + If training vision and have time, try SGD for potentially better final performance. +``` + + +## Major Optimizers: Deep Dive + +### SGD (Stochastic Gradient Descent) + +**Algorithm**: Basic gradient descent with optional momentum. + +```python +optimizer = torch.optim.SGD( + params, + lr=0.1, # Learning rate (higher than Adam) + momentum=0.9, # Momentum coefficient + weight_decay=1e-4, # L2 regularization + nesterov=True # Use Nesterov momentum (recommended) +) +``` + +**When to Use:** +- ✅ Training CNNs (ResNet, EfficientNet, ConvNeXt) +- ✅ Large batch training (batch size > 512) +- ✅ When best final performance matters (often beats Adam) +- ✅ Training transformers (competitive with Adam when properly tuned) +- ✅ Have compute budget for longer training +- ✅ Classical computer vision tasks + +**When to Avoid:** +- ❌ Small batch training (batch size < 32) +- ❌ Very deep networks without good initialization +- ❌ Need fast initial progress (Adam converges faster early on) +- ❌ Sparse gradients (NLP with large vocab, embeddings) +- ❌ Limited time for hyperparameter tuning + +**Typical Hyperparameters:** +- **Learning rate**: 0.01 to 0.1 (with warmup and decay) + - Start with 0.1 for vision + - Use learning rate finder to find optimal range + - Always pair with LR scheduler (cosine, step decay) +- **Momentum**: 0.9 (standard) + - Higher (0.99) for very small batches or noisy gradients + - Lower (0.5-0.8) rarely used, but can help with instability +- **Weight decay**: 1e-4 to 5e-4 (for CNNs) + - 1e-4: Standard for many vision tasks + - 5e-4: More regularization, prevents overfitting +- **Nesterov**: True (almost always better than vanilla momentum) + +**Characteristics:** +- **Convergence speed**: Slow to medium +- **Final performance**: Excellent (often best) +- **Memory**: Low (only momentum buffer) +- **Sensitivity to LR**: High (needs careful tuning) +- **Generalization**: Often better than Adam + +**Why SGD Still Matters (2024):** +Despite being "old", SGD remains competitive: +- Often achieves better test accuracy than Adam on vision tasks +- More stable for very long training runs +- Better generalization (flatter minima) +- Standard in vision competitions and state-of-the-art models + +**Pro tip**: Don't dismiss SGD as "old-fashioned". Modern CNNs still achieve best results with SGD. + + +### Adam (Adaptive Moment Estimation) + +**Algorithm**: Adaptive learning rates with momentum for both first and second moments. + +```python +optimizer = torch.optim.Adam( + params, + lr=1e-3, # Learning rate (lower than SGD) + betas=(0.9, 0.999), # Coefficients for moving averages + eps=1e-8, # Numerical stability epsilon + weight_decay=0 # DO NOT USE - use AdamW instead! +) +``` + +**When to Use:** +- ✅ Quick baseline needed (fast initial convergence) +- ✅ Sparse gradients (NLP, embeddings, large vocab) +- ✅ Small batch training (batch size < 32) +- ✅ RNNs, LSTMs, attention models +- ✅ RL policy networks +- ✅ When you need results quickly without extensive tuning + +**When to Avoid:** +- ❌ When you need weight decay (use AdamW instead) +- ❌ Large batch training (consider LAMB/LARS for > 8K) +- ❌ When best generalization matters (SGD often better) +- ❌ Training vision models where SGD is known to be better + +**Typical Hyperparameters:** +- **Learning rate**: 1e-4 to 3e-3 + - Default: 1e-3 (good starting point) + - Transformers: 1e-4 to 5e-4 + - RNNs: 3e-4 to 1e-3 + - Much lower than SGD (10-100x) +- **Betas**: (0.9, 0.999) [standard] + - beta1: First moment momentum (mean) + - beta2: Second moment momentum (variance) + - Usually don't need to change + - Lower beta2 (0.98, 0.95) for very long training or instability +- **Epsilon**: 1e-8 (rarely need to change) + - Numerical stability term in denominator + - Increase to 1e-7 or 1e-6 if numerical issues +- **Weight decay**: **0** (DON'T USE - this is broken, use AdamW) + +**Characteristics:** +- **Convergence speed**: Fast (especially early training) +- **Final performance**: Good (but often not best) +- **Memory**: High (stores first and second moments) +- **Sensitivity to LR**: Medium (more forgiving than SGD) +- **Generalization**: Good (but SGD often better) + +**CRITICAL WARNING: Adam's Weight Decay is Broken** + +```python +# WRONG: Don't do this! +optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=0.01) +# This implements L2 penalty in the loss, which interacts incorrectly +# with adaptive learning rates. NOT true weight decay. + +# RIGHT: Use AdamW instead +optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.01) +# AdamW implements decoupled weight decay (correct implementation) +``` + +**Why Adam is Popular:** +- Fast initial convergence (great for exploration) +- Works reasonably well out-of-the-box +- Handles sparse gradients well +- Less sensitive to learning rate than SGD +- Good for quick prototyping + +**When Adam Fails:** +- Final performance often lags behind well-tuned SGD (especially vision) +- Can be unstable in some settings +- Weight decay doesn't work correctly (use AdamW) + + +### AdamW (Adam with Decoupled Weight Decay) + +**Algorithm**: Adam with correct weight decay implementation. + +```python +optimizer = torch.optim.AdamW( + params, + lr=1e-3, # Learning rate + betas=(0.9, 0.999), # Coefficients for moving averages + eps=1e-8, # Numerical stability + weight_decay=0.01 # NOW THIS ACTUALLY WORKS! +) +``` + +**When to Use:** +- ✅ Training transformers (BERT, GPT, T5, ViT) - STANDARD CHOICE +- ✅ When you need weight decay regularization +- ✅ Modern vision transformers +- ✅ Most deep learning tasks (2020+ default) +- ✅ Whenever you would use Adam + need regularization + +**When to Avoid:** +- ❌ When weight decay is not needed (Adam is fine, slightly faster) +- ❌ Vision CNNs where SGD is known to work better (but AdamW is reasonable alternative) + +**Typical Hyperparameters:** +- **Learning rate**: 1e-4 to 5e-4 (transformers), 1e-3 (general) + - BERT/GPT: 1e-4 to 5e-4 + - Vision transformers: 1e-3 to 5e-4 + - Same range as Adam +- **Betas**: (0.9, 0.999) or (0.9, 0.98) for transformers + - Lower beta2 for very long training runs + - (0.9, 0.95) seen in some long transformer training +- **Weight decay**: 0.01 to 0.1 + - CNNs: 1e-4 to 5e-4 (if using AdamW for vision) + - Transformers: 0.01 to 0.1 (much higher!) + - This parameter NOW ACTUALLY WORKS (unlike Adam) + - Tune as hyperparameter +- **Epsilon**: 1e-8 (standard) + +**Characteristics:** +- **Convergence speed**: Fast (same as Adam) +- **Final performance**: Good to Excellent +- **Memory**: High (same as Adam) +- **Sensitivity to LR**: Medium (same as Adam) +- **Generalization**: Better than Adam (due to correct weight decay) + +**Why AdamW > Adam:** + +The key difference is how weight decay is applied: + +**Adam (WRONG):** +``` +# Pseudocode +gradient = compute_gradient(loss) +gradient += weight_decay * param # L2 penalty added to gradient +# Then adaptive LR applied → weight decay gets scaled by adaptive LR +# This breaks the regularization effect! +``` + +**AdamW (RIGHT):** +``` +# Pseudocode +gradient = compute_gradient(loss) +# Adaptive LR applied to gradient +param_update = adaptive_lr(gradient) +# Weight decay applied AFTER, directly to parameters +param = param - lr * param_update - weight_decay * param +# Weight decay is decoupled from gradient → works correctly! +``` + +**Paper Reference**: "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, ICLR 2019) + +**Modern Best Practice (2024):** +- Default to AdamW for transformers (not Adam) +- AdamW is the standard in modern NLP and vision transformers +- Use weight_decay=0.01 as starting point +- Adam only when weight decay not needed + +**Migration from Adam:** +```python +# If you have this: +optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=0.01) + +# Change to this: +optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.01) + +# Everything else stays the same! +# But now weight decay actually works correctly. +``` + + +### RMSprop (Root Mean Square Propagation) + +**Algorithm**: Adaptive learning rate based on moving average of squared gradients. + +```python +optimizer = torch.optim.RMSprop( + params, + lr=1e-3, # Learning rate + alpha=0.99, # Smoothing constant + eps=1e-8, # Numerical stability + weight_decay=0, # L2 regularization + momentum=0 # Optional momentum +) +``` + +**When to Use:** +- ✅ Training RNNs (historically popular choice) +- ✅ Non-stationary objectives (reinforcement learning) +- ✅ When Adam doesn't work well (rare) + +**When to Avoid:** +- ❌ Most modern tasks (Adam/AdamW have largely superseded it) +- ❌ Transformers (use AdamW) +- ❌ CNNs (use SGD) + +**Typical Hyperparameters:** +- **Learning rate**: 1e-3 to 1e-4 +- **Alpha**: 0.99 (standard, controls exponential moving average decay) +- **Momentum**: 0 (usually not used) + +**Characteristics:** +- **Convergence speed**: Fast +- **Final performance**: Good +- **Memory**: Medium +- **Sensitivity to LR**: Medium + +**Historical Note:** +RMSprop was popular for RNNs before Adam became standard. Adam can be seen as RMSprop + momentum. Most use cases now covered by Adam/AdamW. + + +### AdaGrad (Adaptive Gradient) + +**Algorithm**: Adapts learning rate based on historical gradient magnitude (accumulates squared gradients). + +```python +optimizer = torch.optim.Adagrad( + params, + lr=1e-2, # Learning rate (higher than Adam) + lr_decay=0, # Learning rate decay + weight_decay=0, # L2 regularization + eps=1e-10 # Numerical stability +) +``` + +**When to Use:** +- ✅ Sparse features (extremely sparse gradients) +- ✅ NLP with very large vocabularies (legacy) +- ✅ When features have very different scales/frequencies + +**When to Avoid:** +- ❌ Most modern tasks (Adam/AdamW are better) +- ❌ Non-sparse problems +- ❌ Deep learning (learning rate decays too aggressively) + +**Characteristics:** +- **Convergence speed**: Medium +- **Final performance**: Good for sparse problems +- **Memory**: Medium +- **Learning rate behavior**: Continuously decreases (can be too aggressive) + +**Historical Note:** +AdaGrad was innovative for sparse problems but has issues with aggressive learning rate decay. Adam fixes this with exponential moving averages instead of accumulation. + + +### Advanced Optimizers (Specialized Use Cases) + +#### LAMB (Layer-wise Adaptive Moments optimizer for Batch training) + +**When to use**: Very large batch training (> 8K) for transformers + +**Example use case**: BERT pretraining with batch size 32K-64K + +```python +# Not in PyTorch by default, need external library +# from apex.optimizers import FusedLAMB + +optimizer = FusedLAMB( + params, + lr=1e-3, + betas=(0.9, 0.999), + weight_decay=0.01 +) +``` + +**Why LAMB:** +- Enables very large batch training without degradation +- Layer-wise adaptation prevents issues with different gradient scales +- Used in BERT large-scale pretraining + +**Note**: Most users don't need LAMB. Only for distributed training with very large batches. + + +#### LARS (Layer-wise Adaptive Rate Scaling) + +**When to use**: Very large batch training (> 8K) for vision models + +**Example use case**: ResNet training with batch size 32K + +```python +# Not in PyTorch by default, need external library +# Similar interface to LAMB but designed for vision +``` + +**Why LARS:** +- Enables large batch training for CNNs +- Layer-wise learning rates prevent convergence issues +- Used in large-scale vision training + +**Note**: Most users don't need LARS. Only for distributed vision training with very large batches. + + +#### Lookahead + +**When to use**: Wrap any optimizer for more stable convergence + +**How it works**: Maintains slow and fast weights, periodically synchronizes + +```python +# Wrapper around another optimizer +from torch_optimizer import Lookahead + +base_optimizer = torch.optim.Adam(params, lr=1e-3) +optimizer = Lookahead(base_optimizer, k=5, alpha=0.5) +``` + +**Why Lookahead:** +- More stable convergence +- Can improve final performance +- Works with any base optimizer + +**Note**: Adds complexity and computation. Try standard optimizers first. + + +## Hyperparameter Deep Dive + +### Learning Rate (THE Most Important Hyperparameter) + +**Effect of Learning Rate:** + +``` +LR too high → Training unstable, loss oscillates, divergence, NaN +LR optimal → Smooth loss decrease, good convergence, best performance +LR too low → Very slow convergence, stuck in local minima, wasted time +``` + +**Learning Rate Ranges by Optimizer:** + +| Optimizer | Typical LR Range | Starting Point | +|-----------|-----------------|----------------| +| SGD | 0.01 - 0.1 | 0.1 (with decay) | +| SGD (small batch) | 0.001 - 0.01 | 0.01 | +| Adam | 1e-4 - 3e-3 | 1e-3 | +| AdamW | 1e-4 - 3e-3 | 1e-3 | +| AdamW (transformers) | 1e-4 - 5e-4 | 3e-4 | +| RMSprop | 1e-4 - 1e-3 | 1e-3 | + +**CRITICAL**: SGD needs 10-100x higher learning rate than Adam/AdamW! + +**Learning Rate Tuning Strategy:** + +1. **Start with defaults**: + - SGD: 0.1 + - Adam/AdamW: 1e-3 + +2. **Use learning rate finder**: + ```python + # Increase LR exponentially, plot loss + # Choose LR just before loss minimum + # This finds the "sweet spot" + ``` + +3. **Monitor training curves**: + - Loss oscillating wildly → LR too high, reduce by 3-10x + - Loss decreasing very slowly → LR might be too low, increase by 2-3x + - Loss smooth and decreasing → LR about right + +4. **Use learning rate scheduler**: + - Cosine annealing (modern default) + - Step decay (traditional) + - Reduce on plateau (automatic) + - See learning-rate-scheduling skill for details + +**Common LR Mistakes:** + +```python +# MISTAKE 1: Using SGD LR with Adam +optimizer = torch.optim.Adam(params, lr=0.1) # WAY TOO HIGH +# Fix: +optimizer = torch.optim.Adam(params, lr=1e-3) # Correct range + +# MISTAKE 2: Using Adam LR with SGD +optimizer = torch.optim.SGD(params, lr=1e-3) # Too low +# Fix: +optimizer = torch.optim.SGD(params, lr=0.1) # Correct range + +# MISTAKE 3: Same LR when switching optimizers +# Was using SGD with lr=0.1 +# Switch to Adam, but keep lr=0.1 → WRONG, will diverge +# Must re-tune LR for each optimizer +``` + + +### Momentum (SGD-specific) + +**What momentum does:** +Accumulates exponentially weighted moving average of gradients, helping accelerate in relevant directions and dampen oscillations. + +**Effect of Momentum Values:** + +``` +momentum = 0.0 → Vanilla SGD (noisy updates, slow convergence) +momentum = 0.9 → Smooth updates, faster convergence (STANDARD) +momentum = 0.99 → Very smooth, good for noisy/small batch, can overshoot +``` + +**Best Practices:** + +- **Default**: Start with 0.9 (works for most cases) +- **Small batch / noisy gradients**: Increase to 0.95 or 0.99 +- **Very large batch**: 0.9 is fine, sometimes lower (0.85) +- **Debugging**: Set to 0 to see if momentum is causing issues (rare) + +**Nesterov Momentum:** + +```python +# Vanilla momentum (standard) +optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9, nesterov=False) + +# Nesterov momentum (RECOMMENDED) +optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9, nesterov=True) +``` + +**Why Nesterov is better:** +- Looks ahead before computing gradient (better gradient estimate) +- Often converges faster and to better solution +- Minimal cost, easy win +- Standard in modern vision training + +**Pro tip**: Always use `nesterov=True` with SGD unless you have a specific reason not to. + + +### Betas (Adam/AdamW-specific) + +**What betas control:** + +- **beta1**: Exponential decay rate for first moment (mean of gradients) +- **beta2**: Exponential decay rate for second moment (variance of gradients) + +**Standard Values: (0.9, 0.999)** + +```python +optimizer = torch.optim.AdamW( + params, + lr=1e-3, + betas=(0.9, 0.999) # (beta1, beta2) +) +``` + +**Effect of betas:** + +``` +beta1 (first moment momentum): + Higher (0.9-0.99) → Smoother gradient estimates + Lower (0.5-0.8) → More responsive to current gradient (rare) + +beta2 (second moment): + 0.999 → Standard, stable for most training + 0.98 → More responsive, better for transformers (long training) + 0.95 → Very responsive, for very long training runs +``` + +**When to Adjust Betas:** + +**1. Training Instability:** +```python +# If training is unstable with (0.9, 0.999) +betas = (0.9, 0.98) # Lower beta2 for more stability +``` + +**2. Very Long Training (> 100K steps):** +```python +# For very long transformer training +betas = (0.9, 0.95) # Lower beta2 prevents too much smoothing +``` + +**3. Transformer-specific:** +```python +# Many transformer papers use +betas = (0.9, 0.98) # or (0.9, 0.999) +# Both work, tune if needed +``` + +**When NOT to Adjust:** +- If training is working well → don't change +- Most tasks → (0.9, 0.999) is fine +- Don't cargo-cult different values without understanding + +**Pro tip**: Start with (0.9, 0.999). Only adjust if you have training instability or following proven transformer recipes. + + +### Weight Decay + +**What weight decay does:** +Shrinks weights toward zero each step, preventing overfitting and encouraging simpler models. + +**CRITICAL: Adam vs AdamW Weight Decay Difference** + +```python +# ❌ WRONG: Adam with weight_decay +optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=0.01) +# This adds L2 penalty to loss, which interacts incorrectly with +# adaptive learning rates. NOT true weight decay. DON'T USE. + +# ✅ RIGHT: AdamW with weight_decay +optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.01) +# This implements decoupled weight decay (applied directly to params) +# Works correctly with adaptive learning rates. USE THIS. +``` + +**Weight Decay Values by Task:** + +| Task/Model | Weight Decay Range | Typical Value | +|-----------|-------------------|---------------| +| CNNs (ResNet, etc.) | 1e-4 - 5e-4 | 1e-4 | +| Vision Transformers | 0.05 - 0.1 | 0.05 | +| Language Transformers (BERT, GPT) | 0.01 - 0.1 | 0.01 | +| Small models (< 10M params) | 0 - 1e-4 | 1e-5 | +| RL policies | 0 | 0 | + +**Effect of Weight Decay:** + +``` +weight_decay = 0 → No regularization (may overfit) +weight_decay = 1e-4 → Light regularization (CNNs) +weight_decay = 0.01 → Medium regularization (transformers) +weight_decay = 0.1 → Strong regularization (may underfit) +``` + +**Signs of Incorrect Weight Decay:** + +``` +Too high (overfitting): + → Training loss doesn't decrease well + → Model underfits + → Slow convergence + → Poor training AND validation accuracy + +Too low (underfitting): + → Large gap between train and validation accuracy + → Model overfits training data + → Good training accuracy, poor validation +``` + +**Best Practices:** + +1. **Use AdamW when you need weight decay** (not Adam) +2. **Start with task-appropriate defaults**: + - CNNs: 1e-4 + - Transformers: 0.01 +3. **Tune as hyperparameter** (search over [1e-5, 1e-4, 1e-3, 0.01, 0.1]) +4. **Monitor train/val gap** to adjust + + +### Epsilon (Adam/AdamW) + +**What epsilon does:** +Small constant added to denominator for numerical stability (prevents division by zero). + +```python +optimizer = torch.optim.AdamW( + params, + lr=1e-3, + eps=1e-8 # Numerical stability term +) +``` + +**Default: 1e-8 (almost always fine)** + +**When to Change Epsilon:** + +``` +Numerical instability (very rare): + → Gradients becoming NaN + → Very small gradients (< 1e-8) + → Increase to 1e-7 or 1e-6 + +Half precision training (FP16): + → May need larger epsilon (1e-7 or 1e-6) + → FP16 has less numerical precision + +Normal training: + → Keep default 1e-8 +``` + +**Pro tip**: Don't change epsilon unless you have numerical stability issues. This is a rare adjustment. + + +## Optimizer Comparison Table + +| Optimizer | Convergence Speed | Final Performance | Memory Usage | LR Range | Best For | +|-----------|------------------|-------------------|--------------|----------|----------| +| **SGD** | Slow-Medium | ★★★★★ Excellent | Low | 0.01-0.1 | Vision (CNNs), large batch | +| **SGD+Momentum** | Medium | ★★★★★ Excellent | Low | 0.01-0.1 | Vision, standard choice | +| **SGD+Nesterov** | Medium | ★★★★★ Excellent | Low | 0.01-0.1 | Vision, best SGD variant | +| **Adam** | ★★★★★ Fast | ★★★★☆ Good | High | 1e-4 to 3e-3 | Quick baselines, RNNs, small batch | +| **AdamW** | ★★★★★ Fast | ★★★★★ Good-Excellent | High | 1e-4 to 3e-3 | Transformers, modern default | +| **RMSprop** | Fast | ★★★★☆ Good | Medium | 1e-4 to 1e-3 | RNNs (legacy), RL | +| **AdaGrad** | Medium | ★★★☆☆ Good (sparse) | Medium | 1e-2 | Sparse features (legacy) | +| **LAMB** | Fast | ★★★★★ Excellent | High | 1e-3 | Large-batch transformers (>8K) | +| **LARS** | Fast | ★★★★★ Excellent | High | 1e-3 | Large-batch vision (>8K) | + + +## Modern Best Practices (2024) + +### Vision - CNNs (ResNet, EfficientNet, ConvNeXt) + +```python +optimizer = torch.optim.SGD( + model.parameters(), + lr=0.1, # Start high, use scheduler + momentum=0.9, + weight_decay=1e-4, # Standard for vision + nesterov=True # Always use Nesterov +) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=epochs +) + +# Batch size: 256-512 (scale LR linearly with batch size) +# Training time: 100-300 epochs +# Best final performance: SGD with proper tuning +``` + +**Why this works:** +- SGD achieves best test accuracy for vision +- Nesterov momentum improves convergence +- Cosine annealing standard in modern vision +- Weight decay 1e-4 is well-tuned default + + +### Vision Transformers (ViT, Swin, DeiT) + +```python +optimizer = torch.optim.AdamW( + model.parameters(), + lr=1e-3, # Higher than language transformers + betas=(0.9, 0.999), + weight_decay=0.05 # Higher than CNNs +) + +# Use warmup + cosine schedule +scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=10000, + num_training_steps=total_steps +) + +# Batch size: 512-4096 (often very large with gradient accumulation) +# Training time: 300 epochs typical +``` + +**Why this works:** +- Vision transformers follow transformer best practices +- AdamW standard for transformers +- Higher weight decay than CNNs (0.05 vs 1e-4) +- Large batch sizes typical + + +### Language Models (BERT, GPT, T5, LLaMA) + +```python +optimizer = torch.optim.AdamW( + model.parameters(), + lr=5e-4, # Lower than vision transformers + betas=(0.9, 0.98), # Lower beta2 for long training + weight_decay=0.01, + eps=1e-8 +) + +# Warmup crucial for transformers +scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=10000, # 10% of total steps typical + num_training_steps=total_steps +) + +# Batch size: Large (2048-4096+, often with gradient accumulation) +# Training time: 100K-1M+ steps +``` + +**Why this works:** +- AdamW is standard for all modern transformers +- Lower beta2 (0.98) for stability in very long training +- Warmup critical for transformer convergence +- Weight decay 0.01-0.1 range + + +### RNNs / LSTMs (Less Common Now) + +```python +optimizer = torch.optim.Adam( + model.parameters(), + lr=1e-3, + betas=(0.9, 0.999) +) + +# Gradient clipping usually needed for RNNs +# See gradient-management skill +``` + +**Why this works:** +- Adam handles RNN gradient issues better than SGD +- RNNs less common now (transformers dominate) + + +### Reinforcement Learning Policies + +```python +optimizer = torch.optim.Adam( + policy.parameters(), + lr=3e-4, # Standard in RL (very robust) + betas=(0.9, 0.999), + weight_decay=0 # Usually no weight decay in RL +) + +# Learning rate 3e-4 is remarkably robust across RL algorithms +``` + +**Why this works:** +- 3e-4 is empirically robust default for RL +- Adam handles noisy RL gradients well +- No weight decay (policies often benefit from flexibility) + + +## Common Optimizer Pitfalls + +### Pitfall 1: Using Adam Instead of AdamW for Weight Decay + +```python +# ❌ WRONG +optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=0.01) +# Adam's weight decay is broken - adds L2 to loss, doesn't work with adaptive LR + +# ✅ RIGHT +optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.01) +# AdamW implements correct decoupled weight decay +``` + +**Why this is critical:** +- Adam's weight_decay doesn't work as intended +- AdamW fixes this with decoupled weight decay +- Modern transformers ALL use AdamW, not Adam +- Paper: Loshchilov & Hutter, "Decoupled Weight Decay Regularization" + +**Red flag**: Any recommendation to use `torch.optim.Adam` with `weight_decay > 0`. + + +### Pitfall 2: Same Learning Rate for Different Optimizers + +```python +# ❌ WRONG: Switching optimizer without changing LR +# Was using: +optimizer = torch.optim.SGD(params, lr=0.1) +# Switch to: +optimizer = torch.optim.Adam(params, lr=0.1) # WILL DIVERGE + +# ✅ RIGHT: Different LR ranges for different optimizers +optimizer = torch.optim.SGD(params, lr=0.1) # SGD: 0.01-0.1 +optimizer = torch.optim.Adam(params, lr=1e-3) # Adam: 1e-4 to 3e-3 +``` + +**Why this happens:** +- SGD needs 10-100x higher learning rate than Adam +- Adaptive LR in Adam means smaller nominal LR +- Switching optimizers requires re-tuning ALL hyperparameters + +**Red flag**: Using lr=0.1 with Adam or lr=1e-3 with SGD. + + +### Pitfall 3: Not Using Nesterov with SGD + +```python +# ❌ SUBOPTIMAL +optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9) + +# ✅ BETTER +optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9, nesterov=True) +``` + +**Why Nesterov matters:** +- Usually converges faster +- Often reaches better final solution +- Minimal cost (nearly free improvement) +- Standard in modern vision training + +**Red flag**: Using SGD with momentum but nesterov=False (or not specified). + + +### Pitfall 4: Comparing Optimizers Without Proper Tuning + +```python +# ❌ WRONG: Unfair comparison +sgd = torch.optim.SGD(params, lr=0.001) # LR too low for SGD +adam = torch.optim.Adam(params, lr=1e-3) # LR appropriate +# Train both, Adam wins → "SGD doesn't work" + +# ✅ RIGHT: Fair comparison with tuned LRs +sgd = torch.optim.SGD(params, lr=0.1) # Tuned for SGD +adam = torch.optim.Adam(params, lr=1e-3) # Tuned for Adam +# Now both have fair chance +``` + +**Why this is critical:** +- "Optimizer X doesn't work" often means "LR not tuned" +- Each optimizer needs separate hyperparameter tuning +- Use learning rate finder for each optimizer independently + +**Red flag**: Concluding one optimizer is better without tuning LR for each. + + +### Pitfall 5: Forgetting Bias Correction in Custom Adam Implementation + +**If using PyTorch built-in Adam/AdamW**: Don't worry, bias correction is automatic. + +**If implementing custom Adam**: Remember bias correction for first few steps! + +```python +# Bias correction needed because moving averages start at 0 +m_hat = m / (1 - beta1**t) # Bias-corrected first moment +v_hat = v / (1 - beta2**t) # Bias-corrected second moment +``` + +**Red flag**: Custom Adam implementation without bias correction. + + +### Pitfall 6: One-Size-Fits-All Optimizer Choice + +```python +# ❌ WRONG MINDSET: "I always use AdamW for everything" + +# ✅ RIGHT MINDSET: "I choose optimizer based on task" +if task == "vision_cnn": + optimizer = SGD_with_nesterov +elif task == "transformer": + optimizer = AdamW +elif task == "RL": + optimizer = Adam +# Decision based on context +``` + +**Why this matters:** +- No single "best" optimizer +- Task-specific performance varies significantly +- SGD often better for vision, AdamW for transformers + +**Red flag**: Always recommending same optimizer regardless of task. + + +### Pitfall 7: Not Adjusting Optimizer for Distributed Training + +```python +# ❌ WRONG: Same setup for 1 GPU and 64 GPUs +# 1 GPU: batch_size=32, lr=1e-3 +# 64 GPUs: batch_size=2048 (32*64), lr=1e-3 # LR too low! + +# ✅ RIGHT: Scale learning rate with batch size +# 1 GPU: batch_size=32, lr=1e-3 +# 64 GPUs: batch_size=2048, lr=1e-3 * (2048/32) = 0.064 +# Linear scaling rule (with warmup) +``` + +**Why this matters:** +- Larger batch size → larger effective learning rate needed +- Linear scaling rule: lr_new = lr_base * (batch_new / batch_base) +- Warmup crucial when scaling LR +- Very large batches (>8K) may need LAMB/LARS + +**Red flag**: Not adjusting LR when scaling to many GPUs. + + +### Pitfall 8: Ignoring Optimizer State When Fine-tuning + +```python +# ❌ POTENTIAL ISSUE: Loading checkpoint but not optimizer state +checkpoint = torch.load('model.pt') +model.load_state_dict(checkpoint['model']) +# Optimizer state not loaded → starts with fresh momentum buffers + +# ✅ BETTER: Load optimizer state too (if continuing training) +checkpoint = torch.load('model.pt') +model.load_state_dict(checkpoint['model']) +optimizer.load_state_dict(checkpoint['optimizer']) +# Momentum buffers preserved + +# ✅ ALSO VALID: Fresh optimizer for fine-tuning (different task) +# When fine-tuning on new task, fresh optimizer often better +``` + +**When to load optimizer state:** +- Resuming interrupted training → YES, load it +- Fine-tuning on same task → YES, load it +- Fine-tuning on different task → NO, fresh optimizer better + + +### Pitfall 9: Using Weight Decay on Bias Terms + +**Modern best practice**: Often exclude bias and normalization parameters from weight decay. + +```python +# ✅ BETTER: Separate parameter groups +params_with_decay = [] +params_without_decay = [] + +for name, param in model.named_parameters(): + if 'bias' in name or 'bn' in name or 'norm' in name: + params_without_decay.append(param) + else: + params_with_decay.append(param) + +optimizer = torch.optim.AdamW([ + {'params': params_with_decay, 'weight_decay': 0.01}, + {'params': params_without_decay, 'weight_decay': 0.0} +], lr=1e-3) +``` + +**Why this matters:** +- Bias terms often don't benefit from weight decay +- Normalization parameters (BN, LayerNorm) shouldn't be decayed +- Common in modern transformer training + +**Note**: Not always critical, but modern best practice for transformers. + + +### Pitfall 10: Not Monitoring Gradient Norms with Different Optimizers + +```python +# Monitor gradient norms during training +for param in model.parameters(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + # Log this to tensorboard/wandb +``` + +**Why this helps:** +- Detect gradient explosion (norm >> 1.0) +- Detect vanishing gradients (norm << 0.01) +- Different optimizers have different gradient scale sensitivities +- SGD more sensitive to gradient scale than Adam + +**Red flag**: Training issues without checking gradient norms. + + +## Debugging Optimizer Issues + +### Issue 0: Multiple Simultaneous Problems (Prioritization) + +**Symptoms:** +- User reports many issues at once +- Multiple potential causes +- Unclear what to fix first + +**CRITICAL: Prioritize fixes** + +When multiple issues present, fix in this order: + +1. **Learning Rate** (FIRST, most common issue) + - Check if LR is in correct range for optimizer + - SGD: 0.01-0.1, Adam: 1e-4 to 3e-3 + - Wrong LR makes everything else irrelevant + +2. **Numerical Stability** (if NaN/Inf present) + - Gradient explosion + - Mixed precision issues + - Division by zero in loss + +3. **Batch Size** (if very small or very large) + - batch < 8: Very noisy, affects stability + - batch > 8K: Needs special handling (LR scaling, warmup) + +4. **Gradient Issues** (if mentioned or suspected) + - Gradient clipping + - Gradient accumulation + +5. **Optimizer Choice** (LAST) + - Only change optimizer after fixing above + - Often optimizer isn't the problem + +**Example:** + +``` +User: "Not working. Using Adam lr=0.1, batch=8, mixed precision, loss oscillates" +``` + +**Wrong response:** "Switch to SGD" + +**Right response:** +1. Fix LR (lr=0.1 is 100x too high for Adam) → lr=1e-3 +2. Try FP32 to isolate mixed precision issue +3. Consider gradient accumulation (batch=8 is small) +4. THEN evaluate if optimizer needs changing (probably not) + +**Principle**: Fix root causes systematically. Don't change optimizer to "fix" bad hyperparameters. + + +### Issue 1: Training Unstable (Loss Spikes, NaN Values) + +**Symptoms:** +- Loss occasionally spikes +- NaN or Inf values +- Loss oscillates wildly + +**Debugging checklist:** + +1. **Check learning rate (MOST COMMON)**: + ```python + # LR too high → instability + # Try reducing by 3-10x + lr = lr / 3 + ``` + +2. **Try gradient clipping**: + ```python + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # See gradient-management skill + ``` + +3. **Switch optimizer**: + ```python + # Adam sometimes more stable than SGD + # Or try AdamW with lower LR + ``` + +4. **Add warmup**: + ```python + # Gradual LR increase at start + # Especially important for Adam/AdamW + ``` + +5. **Check for numerical issues**: + - Division by zero in loss function + - log(0) in loss computation + - Mixed precision issues (try FP32) + +6. **Lower beta2 (Adam/AdamW)**: + ```python + # From (0.9, 0.999) to (0.9, 0.98) + # More responsive, sometimes more stable + ``` + + +### Issue 2: Training Too Slow (Loss Decreasing Very Slowly) + +**Symptoms:** +- Loss decreasing but very slowly +- Will take forever to converge +- Not hitting good accuracy + +**Debugging checklist:** + +1. **Increase learning rate**: + ```python + # Try 2-3x higher LR + lr = lr * 3 + # Monitor for instability + ``` + +2. **Check if LR is in the right range**: + ```python + # SGD: should be 0.01-0.1 + # Adam: should be 1e-4 to 3e-3 + ``` + +3. **Try different optimizer**: + ```python + # SGD slow initially → try Adam for faster early progress + # Then can switch back to SGD later + ``` + +4. **Use learning rate finder**: + ```python + # Find optimal LR empirically + # Plot loss vs LR, choose before minimum + ``` + +5. **Check batch size**: + ```python + # Very small batch → noisy gradients, slower + # Increase batch size if possible + ``` + +6. **Verify model is learning at all**: + ```python + # Overfit single batch first + # If can't overfit, problem is model/data, not optimizer + ``` + + +### Issue 3: Switching Optimizer Breaks Training + +**Scenario**: Training works with optimizer A, but fails with optimizer B. + +**Debugging checklist:** + +1. **Re-tune learning rate (CRITICAL)**: + ```python + # SGD → Adam: reduce LR by 10-100x + # Adam → SGD: increase LR by 10-100x + ``` + +2. **Check hyperparameter ranges**: + ```python + # Adam: lr=1e-3, betas=(0.9, 0.999) + # SGD: lr=0.1, momentum=0.9 + # These are DIFFERENT + ``` + +3. **Give optimizer time**: + ```python + # Adam converges fast initially + # SGD slower initially but better final performance + # Don't judge in first 5 epochs + ``` + +4. **Use appropriate scheduler**: + ```python + # SGD often needs aggressive decay (cosine) + # Adam often needs warmup + ``` + +5. **Consider task characteristics**: + ```python + # Vision CNNs → SGD often better + # Transformers → AdamW standard + # Small batch → Adam often better + # Large batch → SGD works well + ``` + + +### Issue 4: Overfitting (Train/Val Gap) + +**Symptoms:** +- Training accuracy high, validation low +- Model memorizing training data + +**Optimizer-related solutions:** + +1. **Increase weight decay**: + ```python + # Try 3-10x higher weight decay + weight_decay = weight_decay * 3 + # Use AdamW, not Adam! + ``` + +2. **Try SGD instead of Adam**: + ```python + # SGD often generalizes better + # Flatter minima → better test performance + ``` + +3. **Lower learning rate toward end**: + ```python + # Use cosine schedule or reduce on plateau + # Helps find better minimum + ``` + +**Note**: Overfitting is multi-faceted. Also see overfitting-prevention and data-augmentation-strategies skills. + + +### Issue 5: "Best Optimizer" Not Working + +**Scenario**: "Paper said use AdamW but my results are worse than SGD" + +**Debugging checklist:** + +1. **Check task/model match**: + ```python + # AdamW best for transformers + # SGD often better for CNNs + # Paper's task might differ from yours + ``` + +2. **Tune hyperparameters**: + ```python + # Don't just copy LR from paper + # Tune for YOUR specific setup + ``` + +3. **Compare fairly**: + ```python + # Give both optimizers proper LR tuning + # Same number of epochs might not be fair + # (Adam converges faster initially) + ``` + +4. **Check batch size**: + ```python + # Paper used large batch → SGD works well + # You use small batch → Adam might be better + ``` + +5. **Consider training budget**: + ```python + # Limited epochs → Adam (fast convergence) + # Many epochs → SGD (better final performance) + ``` + + +## Rationalization Resistance + +This table lists common rationalizations agents make when bypassing systematic optimizer selection, along with the correct responses. + +| Rationalization | Why It's Wrong | Correct Response | +|----------------|----------------|------------------| +| "Adam is the modern standard, use it" | Adam superseded by AdamW; SGD still best for vision | Task-dependent: AdamW for transformers, SGD for CNNs. No universal best. | +| "AdamW and Adam are basically the same" | Weight decay implementation is fundamentally different | AdamW has decoupled weight decay (correct), Adam's is broken. Always use AdamW for weight decay. | +| "Just use default hyperparameters" | Defaults need tuning for specific problems | Defaults are starting points. Tune LR at minimum. Different tasks need different values. | +| "User requested Adam, so use Adam" | User may not know about AdamW advantage | If they need weight decay, recommend AdamW. Explain the critical difference. | +| "SGD is old-fashioned, Adam is better" | SGD still achieves best results for many vision tasks | SGD often outperforms Adam on vision. Modern CNNs use SGD. It's not outdated. | +| "Optimizer doesn't matter much" | Optimizer choice significantly affects results | Optimizer matters. SGD vs Adam can mean 2-5% accuracy difference on vision tasks. | +| "Same LR works for different optimizers" | Different optimizers need different LR ranges | SGD needs 10-100x higher LR than Adam. Must re-tune when switching. | +| "This worked in a paper, so it's optimal" | Papers don't always use best settings; context differs | Papers have different constraints. Use paper as starting point, tune for your case. | +| "Adam is easier to tune, recommend it" | Both need tuning; easy ≠ best | Adam more forgiving initially, but SGD often better final performance. Don't sacrifice quality for ease. | +| "User's training failed with SGD, so SGD doesn't work" | Likely LR or other hyperparameter issue | Debug: probably LR too low/high, not optimizer fault. Try LR finder. | +| "Let's try all optimizers and see" | Unfair comparison without tuning each | Each optimizer needs proper LR tuning. Comparing with same LR is meaningless. | +| "Weight decay is weight decay" | Adam and AdamW implement it differently | Adam: L2 penalty (broken). AdamW: decoupled weight decay (correct). Fundamental difference. | +| "Nesterov doesn't matter" | Nesterov usually improves convergence | Use nesterov=True with SGD. Nearly free improvement. Standard practice. | +| "Just use what's popular" | Popular ≠ optimal for your task | Transformers use AdamW (popular there). CNNs use SGD (popular there). Context matters. | +| "Optimizer isn't the problem" | Optimizer might not be the problem, but often is | Check LR first (most common). But wrong optimizer choice does matter. | +| "User said fast training, so Adam" | Fast depends on many factors | Adam faster initial convergence, but might need more epochs for same quality as SGD. Define "fast". | +| "BERT uses AdamW, so always use it" | BERT is a transformer; not all models are transformers | AdamW is best for transformers. CNNs often do better with SGD. Task-dependent. | +| "Copy hyperparameters from successful project" | Different tasks/models need different hyperparameters | Use as starting point, but tune for your specific case. Context differs. | +| "Learning rate is more important than optimizer" | Both are important | LR is often more important, but optimizer choice still matters significantly. | +| "Users don't care about Adam vs AdamW" | Users care about results; correct optimizer gives better results | AdamW gives better results when weight decay needed. Technical correctness matters. | +| "User explicitly requested X, so use X" | User request doesn't override technical correctness | Acknowledge request, explain technical cost, offer better solution. Help user make informed choice. | +| "Time pressure, just give quick answer" | Fast doesn't mean incomplete | Be concise but technically correct. Fast + wrong wastes more time than brief + right. | +| "Popular framework does Y, so Y is best" | Frameworks optimize for different goals | Explain framework design tradeoffs. Different priorities (ease vs performance). | +| "Paper did Z, so Z is optimal" | Papers have errors and different constraints | Critical evaluation. Papers don't always use best settings. Context may differ. | +| "It's working so don't change it" | Sometimes true, but need to evaluate | Ask about current performance. If working well, maybe don't change. If issues, investigate. | +| "Too complicated, simplify it" | Complexity reflects real tradeoffs | Can't simplify away fundamental differences (Adam vs AdamW). Explain clearly but don't oversimplify. | + + +## Red Flags Checklist + +Watch for these red flags indicating incorrect optimizer usage: + +### Critical Red Flags (Fix Immediately) + +- ❌ **Using `torch.optim.Adam` with `weight_decay > 0`** + - Fix: Use `torch.optim.AdamW` instead + +- ❌ **Same LR for SGD and Adam (e.g., both using 0.1 or both using 1e-3)** + - Fix: SGD needs 10-100x higher LR than Adam + +- ❌ **Using Adam for transformers instead of AdamW** + - Fix: Modern transformers always use AdamW + +- ❌ **Not using Nesterov momentum with SGD** + - Fix: Add `nesterov=True` + +### Major Red Flags (High Priority) + +- ⚠️ **Recommending one optimizer for all tasks without analysis** + - Fix: Use decision framework based on task/model + +- ⚠️ **Saying "Adam and AdamW are the same"** + - Fix: Explain decoupled weight decay difference + +- ⚠️ **Claiming "optimizer doesn't work" without LR tuning** + - Fix: Tune LR first, then evaluate optimizer + +- ⚠️ **Comparing optimizers without tuning LR for each** + - Fix: Fair comparison requires tuning each optimizer + +### Medium Red Flags (Important) + +- ⚠️ **Not asking about task/model before recommending optimizer** + - Fix: Ask clarifying questions (vision? NLP? batch size?) + +- ⚠️ **Using default hyperparameters without considering task** + - Fix: Provide task-specific hyperparameter guidance + +- ⚠️ **Not mentioning learning rate range when suggesting optimizer** + - Fix: Always specify LR range with optimizer choice + +- ⚠️ **Ignoring batch size when recommending optimizer** + - Fix: Small batch → Adam, large batch → SGD often better + +### Lower Priority Red Flags + +- ⚠️ **Not mentioning weight decay best practices** + - Fix: Explain weight decay ranges by task + +- ⚠️ **Cargo-culting beta values without understanding** + - Fix: Explain what betas do and when to adjust + +- ⚠️ **Not considering convergence speed vs final performance tradeoff** + - Fix: Discuss Adam (fast) vs SGD (better final) tradeoff + + +## Cross-Skill Boundaries + +### When to Route to Other Skills + +**learning-rate-scheduling**: +- User asks about LR value or schedule +- Optimizer chosen but need LR strategy +- Training not working, check LR before changing optimizer +- Use: "See learning-rate-scheduling for LR finder and scheduling strategies" + +**gradient-management**: +- Training unstable with NaN/Inf (try gradient clipping before changing optimizer) +- Gradient explosion issues +- Very deep networks with gradient issues +- Use: "See gradient-management for gradient clipping and explosion handling" + +**hyperparameter-tuning**: +- Need to systematically search optimizer hyperparameters +- Comparing multiple optimizer configurations +- Using AutoML or hyperparameter search +- Use: "See hyperparameter-tuning for systematic search strategies" + +**overfitting-prevention**: +- Overfitting despite weight decay +- Need regularization beyond weight decay +- Use: "See overfitting-prevention for dropout, early stopping, and other techniques" + +**batch-size-and-memory-tradeoffs**: +- Asking about batch size effects on optimizer +- Need to scale batch size for distributed training +- Memory constraints limiting batch size (gradient accumulation) +- Use: "See batch-size-and-memory-tradeoffs for batch size selection and scaling" + +**pytorch-engineering (distributed-training-strategies)**: +- Distributed training setup (DDP, FSDP) +- Very large batch training (>8K) needing LAMB/LARS +- Multi-GPU LR scaling +- Use: "See pytorch-engineering for distributed training implementation" + +### Multi-Skill Workflows + +**Common workflow: New training pipeline** +1. **optimization-algorithms** → Choose optimizer (SGD vs Adam vs AdamW) +2. **learning-rate-scheduling** → Choose LR and schedule +3. **batch-size-and-memory-tradeoffs** → Choose batch size +4. **experiment-tracking** → Set up tracking + +**Common workflow: Training not working** +1. **using-training-optimization** → Diagnose symptom +2. **learning-rate-scheduling** → Check LR first (most common issue) +3. **optimization-algorithms** → Consider optimizer change if LR tuned +4. **gradient-management** → If NaN/instability issues + +**Common workflow: Overfitting** +1. **overfitting-prevention** → Primary techniques (dropout, early stopping) +2. **optimization-algorithms** → Increase weight decay (AdamW) +3. **data-augmentation-strategies** → Increase data variety +4. **hyperparameter-tuning** → Find optimal regularization strength + + +## Advanced Topics + +### Large Batch Training + +**Challenges:** +- Generalization gap (large batch trains well, but generalizes worse) +- Need higher learning rates (linear scaling rule) +- Warmup becomes critical + +**Solutions:** + +1. **Linear LR scaling**: + ```python + # Base: batch=256, lr=0.1 + # Scaled: batch=2048, lr=0.1 * (2048/256) = 0.8 + ``` + +2. **Warmup**: + ```python + # Gradually increase LR over first N steps + # Critical for large batch training + ``` + +3. **Specialized optimizers (batch > 8K)**: + ```python + # LAMB for transformers + # LARS for vision + ``` + +4. **Gradient accumulation (alternative)**: + ```python + # Simulate large batch with limited memory + # Accumulate gradients over N steps + ``` + +**See**: batch-size-and-memory-tradeoffs for detailed guidance. + + +### Optimizer Switching During Training + +**Scenario**: Start with Adam for fast convergence, switch to SGD for better final performance. + +**How to do it:** + +```python +# Train first 50 epochs with Adam +optimizer_adam = torch.optim.Adam(model.parameters(), lr=1e-3) +# ... train ... + +# Switch to SGD for final 50 epochs +optimizer_sgd = torch.optim.SGD( + model.parameters(), + lr=0.01, # DIFFERENT LR RANGE + momentum=0.9, + nesterov=True +) +# ... continue training ... +``` + +**Considerations:** +- Fresh optimizer state (no momentum from Adam) +- Different LR range required +- May need new warmup period +- Not always better (try it empirically) + +**When this helps:** +- Time-constrained training (fast start with Adam) +- Want best final performance (finish with SGD) +- Large vision models + + +### Per-Layer Learning Rates + +**Scenario**: Different learning rates for different parts of the model. + +```python +optimizer = torch.optim.AdamW([ + {'params': model.backbone.parameters(), 'lr': 1e-5}, # Pre-trained, small LR + {'params': model.head.parameters(), 'lr': 1e-3} # New head, larger LR +]) +``` + +**When this helps:** +- Fine-tuning pre-trained models +- Transfer learning +- Some layers need different learning rates + +**Caution**: Adds complexity. Try uniform LR first. + + +### Learning Rate Warmup (Critical for Adam/AdamW) + +**Why warmup:** +- Adam/AdamW momentum buffers start at zero (biased) +- Large LR at start can be unstable +- Gradual increase to target LR improves stability + +**Implementation:** + +```python +from transformers import get_linear_schedule_with_warmup + +optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4) + +scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=10000, # 10% of total steps typical + num_training_steps=100000 +) + +# Call scheduler.step() after each batch +``` + +**Warmup guidelines:** +- Transformers: Always use warmup (10-20% of steps) +- Vision CNNs with SGD: Sometimes helpful +- Adam/AdamW: Highly recommended +- Large batch: Critical + +**See**: learning-rate-scheduling for warmup implementation details. + + +### Working with Constraints + +Real projects have constraints: legacy codebases, infrastructure limits, time pressure. Good optimizer advice acknowledges constraints but doesn't sacrifice technical correctness. + +**Principle**: Help users make informed decisions, don't just comply with suboptimal requests. + +#### Scenario: User insists on suboptimal choice + +``` +User: "I want to use Adam with weight_decay, can't change to AdamW" +``` + +**Response pattern:** +1. ✅ **Acknowledge**: "I understand you have constraints" +2. ✅ **Explain cost**: "Adam's weight_decay is broken, here's the technical cost..." +3. ✅ **Show easy fix**: "Migration is one line: Adam → AdamW" +4. ✅ **Offer help**: "What specific constraint prevents this? I can help migrate" +5. ⚠️ **Only if impossible**: Provide workaround with clear caveats + +**Never say**: "Sure, Adam is fine" when it's not technically correct. + +**Example response:** +``` +I understand you have constraints, but I need to be clear about the technical +cost of using Adam with weight_decay: + +Adam's weight_decay is broken - it adds L2 penalty to loss, which interacts +incorrectly with adaptive learning rates. You won't get the regularization +benefits you expect. + +The fix is actually trivial (one line): +# From: +optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay=0.01) +# To: +optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.01) + +Same API, same hyperparameters, just change the class name. + +If you MUST use Adam: +- Set weight_decay=0 in optimizer +- Add L2 regularization manually in loss if needed +- Understand this still won't work as well as AdamW + +What specific constraint prevents the change? I can help migrate. +``` + +#### Scenario: Time pressure + +``` +User: "URGENT, just tell me what optimizer to use!" +``` + +**Response pattern:** +1. ✅ **Be concise** (respect time pressure) +2. ✅ **Still ask minimum essential questions** (task/model) +3. ✅ **Provide fast fix first, backup option second** +4. ✅ **Brief but technically correct** + +**Never**: Sacrifice correctness for brevity. Fast + wrong wastes more time. + +**Example response:** +``` +Quick answer needs context - what model? (CNN, transformer, RNN) + +IF CNN → SGD lr=0.1, momentum=0.9, nesterov=True +IF transformer → AdamW lr=5e-4, weight_decay=0.01 +IF RNN/baseline → Adam lr=1e-3 + +If current optimizer "not working" - check learning rate first: +- SGD needs lr=0.01-0.1 +- Adam needs lr=1e-4 to 3e-3 (10-100x lower!) + +Wrong LR is more common than wrong optimizer. +``` + +#### Scenario: Popular framework uses different approach + +``` +User: "But timm/transformers/fastai does X differently" +``` + +**Response pattern:** +1. ✅ **Don't dismiss framework**: "Framework X isn't wrong, but context matters" +2. ✅ **Explain framework design decisions**: "They optimize for ease of use" +3. ✅ **Distinguish different contexts**: "For CNNs... For ViTs..." +4. ✅ **Help user choose**: "For your case, I recommend..." + +**Never**: Blindly defer to framework or blindly contradict it. + +**Example response:** +``` +timm isn't wrong - but it depends on what you're training. + +For Vision CNNs (ResNet, EfficientNet): +- Best performance: SGD (~1-2% better accuracy) +- timm default: AdamW (easier, "good enough") +- timm optimizes for ease of use across all models + +For Vision Transformers (ViT, Swin): +- AdamW is correct (timm is right here) + +timm's design choice: +- AdamW works "pretty well" for everything +- Reasonable tradeoff (ease vs optimal performance) + +Your choice: +- Want best CNN accuracy → SGD +- Want easy baseline → timm's AdamW +- Training ViT → Keep AdamW +``` + +#### Key Principles for Constraints + +**Always**: +- Acknowledge the constraint (show you heard them) +- Explain technical cost clearly +- Provide easiest migration path +- Only give workaround as last resort + +**Never**: +- Say "it's fine" when it's not +- Skip explanation due to time pressure +- Blindly comply with suboptimal choice +- Sacrifice correctness for convenience + +**Remember**: Your job is to help users make informed decisions, not just to comply with their requests. + + +## Summary + +### Key Takeaways + +1. **No universal best optimizer**: SGD for vision, AdamW for transformers, Adam for quick baselines +2. **Adam vs AdamW is critical**: Always use AdamW for weight decay (Adam's implementation is broken) +3. **LR ranges differ**: SGD needs 10-100x higher LR than Adam/AdamW +4. **Nesterov momentum**: Always use with SGD (nesterov=True) +5. **Decision framework**: Ask about task, model, batch size before recommending +6. **Tune hyperparameters**: Defaults are starting points, tune for your case +7. **Fair comparisons**: Tune LR separately for each optimizer + +### Quick Reference + +**Vision CNNs**: SGD with Nesterov, lr=0.1, momentum=0.9, weight_decay=1e-4 +**Transformers**: AdamW, lr=5e-4, betas=(0.9, 0.98), weight_decay=0.01 +**Vision Transformers**: AdamW, lr=1e-3, weight_decay=0.05 +**RNNs**: Adam, lr=1e-3 +**RL**: Adam, lr=3e-4 +**Quick baseline**: Adam or AdamW, lr=1e-3 + +### Remember + +- Use decision framework, not defaults +- AdamW > Adam for weight decay +- Tune LR for each optimizer +- Task-specific optimizer selection +- Check red flags checklist + + +## Additional Resources + +**Key Papers:** +- "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, ICLR 2019) - AdamW +- "Adam: A Method for Stochastic Optimization" (Kingma & Ba, ICLR 2015) - Adam +- "On the Convergence of Adam and Beyond" (Reddi et al., ICLR 2018) - Adam issues + +**Related Skills:** +- learning-rate-scheduling: LR values and schedules +- gradient-management: Gradient clipping and stability +- hyperparameter-tuning: Systematic hyperparameter search +- batch-size-and-memory-tradeoffs: Batch size selection +- overfitting-prevention: Regularization techniques + +**Cross-pack:** +- pytorch-engineering: Distributed training, performance profiling +- neural-architectures: Model selection and architecture + + +*End of optimization-algorithms skill* diff --git a/skills/using-training-optimization/overfitting-prevention.md b/skills/using-training-optimization/overfitting-prevention.md new file mode 100644 index 0000000..9c5237c --- /dev/null +++ b/skills/using-training-optimization/overfitting-prevention.md @@ -0,0 +1,1464 @@ + +# Overfitting Prevention + +## Overview + +Overfitting is the most common training failure: your model memorizes training data instead of learning generalizable patterns. It shows as **high training accuracy paired with low validation accuracy**. This skill teaches you how to detect overfitting early, diagnose its root cause, and fix it using the right combination of techniques. + +**Core Principle**: Overfitting has multiple causes (high capacity, few examples, long training, high learning rate) and no single-technique fix. You must measure, diagnose, then apply the right combination of solutions. + +**CRITICAL**: Do not fight overfitting blindly. Measure train/val gap first. Different gaps have different fixes. + +## When to Use This Skill + +Load this skill when: +- Training loss decreasing but validation loss increasing (classic overfitting) +- Train accuracy 95% but validation accuracy 75% (26% gap = serious overfitting) +- Model performs well on training data but fails on unseen data +- You want to prevent overfitting before it happens (architecture selection) +- Selecting regularization technique (dropout vs L2 vs early stopping) +- Combining multiple regularization techniques +- Unsure if overfitting or underfitting +- Debugging training that doesn't generalize + +**Don't use for**: Learning rate scheduling (use learning-rate-scheduling), data augmentation policy (use data-augmentation-strategies), optimizer selection (use optimization-algorithms), gradient clipping (use gradient-management) + + +## Part 1: Overfitting Detection Framework + +### The Core Question: "Is My Model Overfitting?" + +**CRITICAL FIRST STEP**: Always monitor BOTH training and validation accuracy. One metric alone is useless. + +### Clarifying Questions to Ask + +Before diagnosing overfitting, ask: + +1. **"What's your train accuracy and validation accuracy?"** + - Train 95%, Val 95% → No overfitting (good!) + - Train 95%, Val 85% → Mild overfitting (10% gap, manageable) + - Train 95%, Val 75% → Moderate overfitting (20% gap, needs attention) + - Train 95%, Val 55% → Severe overfitting (40% gap, critical) + +2. **"What does the learning curve show?"** + - Both train and val loss decreasing together → Good generalization + - Train loss decreasing, val loss increasing → Overfitting (classic sign) + - Both loss curves plateaued → Check if at best point + - Train loss drops but val loss flat → Model not learning useful patterns + +3. **"How much training data do you have?"** + - < 1,000 examples → Very prone to overfitting + - 1,000-10,000 examples → Prone to overfitting + - 10,000-100,000 examples → Moderate risk + - > 100,000 examples → Lower risk (but still possible) + +4. **"How many parameters does your model have?"** + - Model parameters >> training examples → Almost guaranteed overfitting + - Model parameters = training examples → Possible overfitting + - Model parameters < training examples (e.g., 10x smaller) → Less likely to overfit + +5. **"How long have you been training?"** + - 5 epochs on 100K data → Probably underfitting + - 50 epochs on 100K data → Likely good + - 500 epochs on 100K data → Probably overfitting by now + +### Overfitting Diagnosis Tree + +``` +START: Checking for overfitting + +├─ Are you monitoring BOTH training AND validation accuracy? +│ ├─ NO → STOP. Set up validation monitoring first. +│ │ You cannot diagnose without this metric. +│ │ +│ └─ YES → Continue... +│ +├─ What's the train vs validation accuracy gap? +│ ├─ Gap < 3% (train 95%, val 94%) → No overfitting, model is generalizing +│ ├─ Gap 3-10% (train 95%, val 87%) → Mild overfitting, can accept or prevent +│ ├─ Gap 10-20% (train 95%, val 80%) → Moderate overfitting, needs prevention +│ ├─ Gap > 20% (train 95%, val 70%) → Severe overfitting, immediate action needed +│ │ +│ └─ Continue... +│ +├─ Is validation accuracy still increasing or has it plateaued? +│ ├─ Still increasing with train → Good, no overfitting signal yet +│ ├─ Validation plateaued, train increasing → Overfitting starting +│ ├─ Validation decreasing while train increasing → Overfitting in progress +│ │ +│ └─ Continue... +│ +├─ How does your train/val gap change over epochs? +│ ├─ Gap constant or decreasing → Improving generalization +│ ├─ Gap increasing → Overfitting worsening (stop training soon) +│ ├─ Gap increasing exponentially → Severe overfitting +│ │ +│ └─ Continue... +│ +└─ Based on gap size: [from above] + ├─ Gap < 3% → **No action needed**, monitor for worsening + ├─ Gap 3-10% → **Mild**: Consider data augmentation or light regularization + ├─ Gap 10-20% → **Moderate**: Apply regularization + early stopping + └─ Gap > 20% → **Severe**: Model capacity reduction + strong regularization + early stopping +``` + +### Red Flags: Overfitting is Happening NOW + +Watch for these signs: + +1. **"Training loss smooth and decreasing, validation loss suddenly jumping"** → Overfitting spike +2. **"Model was working, then started failing on validation"** → Overfitting starting +3. **"Small improvement in train accuracy, large drop in validation"** → Overfitting increasing +4. **"Model performs 95% on training, 50% on test"** → Severe overfitting already happened +5. **"Tiny model (< 1M params) on tiny dataset (< 10K examples), 500+ epochs"** → Almost certainly overfitting +6. **"Train/val gap widening in recent epochs"** → Overfitting trend is negative +7. **"Validation accuracy peaked 50 epochs ago, still training"** → Training past the good point +8. **"User hasn't checked validation accuracy in 10 epochs"** → Blind to overfitting + + +## Part 2: Regularization Techniques Deep Dive + +### Technique 1: Early Stopping (Stop Training at Right Time) + +**What it does**: Stops training when validation accuracy stops improving. Prevents training past the optimal point. + +**When to use**: +- ✅ When validation loss starts increasing (classic overfitting signal) +- ✅ As first line of defense (cheap, always helpful) +- ✅ When you have validation set +- ✅ For all training tasks (vision, NLP, RL) + +**When to skip**: +- ❌ If no validation set (can't measure) +- ❌ If validation is noisier than loss (use loss-based early stopping instead) + +**Implementation (PyTorch)**: +```python +class EarlyStoppingCallback: + def __init__(self, patience=10, min_delta=0): + """ + patience: Stop if validation accuracy doesn't improve for N epochs + min_delta: Minimum change to count as improvement + """ + self.patience = patience + self.min_delta = min_delta + self.best_val_acc = -float('inf') + self.patience_counter = 0 + self.should_stop = False + + def __call__(self, val_acc): + if val_acc - self.best_val_acc > self.min_delta: + self.best_val_acc = val_acc + self.patience_counter = 0 + else: + self.patience_counter += 1 + if self.patience_counter >= self.patience: + self.should_stop = True + +# Usage: +early_stop = EarlyStoppingCallback(patience=10) + +for epoch in range(500): + train_acc = train_one_epoch() + val_acc = validate() + early_stop(val_acc) + + if early_stop.should_stop: + print(f"Early stopping at epoch {epoch}, best val_acc {early_stop.best_val_acc}") + break +``` + +**Key Parameters**: +- **Patience**: How many epochs without improvement before stopping + - patience=5: Very aggressive, stops quickly + - patience=10: Moderate, standard choice + - patience=20: Tolerant, waits longer + - patience=100+: Not really early stopping anymore +- **min_delta**: Minimum improvement to count (0.0001 = 0.01% improvement) + +**Typical Improvements**: +- Prevents training 50+ epochs past the good point +- 5-10% accuracy improvement by using best checkpoint instead of last +- Saves 30-50% compute (train to epoch 100 instead of 200) + +**Anti-Pattern**: patience=200, 300 epochs - this defeats the purpose! + + +### Technique 2: L2 Regularization / Weight Decay (Penalize Large Weights) + +**What it does**: Adds penalty to loss function based on weight magnitude. Larger weights → larger penalty. Keeps weights small and prevents them from overfitting to training data. + +**When to use**: +- ✅ When model is overparameterized (more params than examples) +- ✅ For most optimization algorithms (Adam, SGD, AdamW) +- ✅ When training time is limited (can't use more data) +- ✅ With any network architecture + +**When to skip**: +- ❌ When model is already underfitting +- ❌ With momentum-based optimizers using L2 incorrectly (use AdamW, not Adam) + +**Implementation**: +```python +# PyTorch with AdamW (recommended) +optimizer = torch.optim.AdamW( + model.parameters(), + lr=1e-4, + weight_decay=0.01 # L2 regularization strength +) + +# Typical training loop (weight decay applied automatically) +for epoch in range(100): + for images, labels in train_loader: + outputs = model(images) + loss = criterion(outputs, labels) # Weight decay already in optimizer + loss.backward() + optimizer.step() + +# How it works internally: +# loss_with_l2 = original_loss + weight_decay * sum(w^2 for w in weights) +``` + +**Key Parameters**: +- **weight_decay** (L2 strength) + - 0.00: No regularization + - 0.0001: Light regularization (small dataset, high risk of overfit) + - 0.001: Standard for large models + - 0.01: Medium regularization (common for transformers) + - 0.1: Strong regularization (small dataset or very large model) + - 1.0: Extreme, probably too much + +**Typical Improvements**: +- Small dataset (1K examples): +2-5% accuracy +- Medium dataset (10K examples): +0.5-2% accuracy +- Large dataset (100K examples): +0.1-0.5% accuracy + +**CRITICAL WARNING**: Do NOT use Adam with weight_decay. Adam's weight decay implementation is broken. Use AdamW instead! + +```python +# WRONG +optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01) + +# CORRECT +optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.01) +``` + + +### Technique 3: Dropout (Random Neuron Silencing) + +**What it does**: During training, randomly drops (silences) neurons with probability p. This prevents co-adaptation of neurons and reduces overfitting. At test time, all neurons are active but outputs are scaled. + +**When to use**: +- ✅ For fully connected layers (MLP heads) +- ✅ When model has many parameters +- ✅ When you want adaptive regularization +- ✅ For RNNs and LSTMs (often essential) + +**When to skip**: +- ❌ In CNNs on large datasets (less effective) +- ❌ Before batch normalization (BN makes dropout redundant) +- ❌ On small models (dropout is regularization, small models don't need it) +- ❌ On very large datasets (overfitting unlikely) + +**Implementation**: +```python +class SimpleDropoutModel(nn.Module): + def __init__(self, dropout_rate=0.5): + super().__init__() + self.fc1 = nn.Linear(784, 512) + self.dropout1 = nn.Dropout(dropout_rate) + self.fc2 = nn.Linear(512, 256) + self.dropout2 = nn.Dropout(dropout_rate) + self.fc3 = nn.Linear(256, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.dropout1(x) # Drop ~50% of neurons + x = F.relu(self.fc2(x)) + x = self.dropout2(x) # Drop ~50% of neurons + x = self.fc3(x) + return x + + # At test time, just call model.eval(): + # model.eval() # Disables dropout, uses all neurons + # predictions = model(test_data) +``` + +**Key Parameters**: +- **dropout_rate** (probability of dropping) + - 0.0: No dropout + - 0.2: Light (10% impact) + - 0.5: Standard (strong regularization) + - 0.7: Heavy (very strong, probably too much for most tasks) + - 0.9: Extreme (only for very specific cases) + +**Where to Apply**: +- After fully connected layers (yes) +- After RNN/LSTM layers (yes, critical) +- After convolutional layers (rarely, less effective) +- Before batch normalization (no, remove dropout) +- On output layer (no, use only hidden layers) + +**Typical Improvements**: +- On MLPs with 10K examples: +3-8% accuracy +- On RNNs: +2-5% accuracy +- On CNNs: +0.5-2% accuracy (less effective) + +**Anti-Pattern**: dropout=0.5 everywhere, in all layer types, on all architectures. This is cargo cult programming. + + +### Technique 4: Batch Normalization (Normalize Activations) + +**What it does**: Normalizes each layer's activations to mean=0, std=1. This stabilizes training and acts as a regularizer (reduces internal covariate shift). + +**When to use**: +- ✅ For deep networks (> 10 layers) +- ✅ For CNNs (standard in modern architectures) +- ✅ When training is unstable +- ✅ For accelerating convergence + +**When to skip**: +- ❌ On tiny models (< 3 layers) +- ❌ When using layer normalization already +- ❌ In RNNs (use layer norm instead) +- ❌ With very small batch sizes (< 8) + +**Implementation**: +```python +class ModelWithBatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(64) # After conv layer + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(128) # After conv layer + + def forward(self, x): + x = self.bn1(F.relu(self.conv1(x))) # Conv → BN → ReLU + x = self.bn2(F.relu(self.conv2(x))) # Conv → BN → ReLU + return x +``` + +**How it Regularizes**: +- During training: Normalizes based on batch statistics +- At test time: Uses running mean/variance from training +- Effect: Reduces dependency on weight magnitude, allows higher learning rates +- Mild regularization effect (not strong, don't rely on it alone) + +**Typical Improvements**: +- Training stability: Huge (allows 10x higher LR without instability) +- Generalization: +1-3% accuracy (mild regularization) +- Convergence speed: 2-3x faster training + + +### Technique 5: Label Smoothing (Soften Targets) + +**What it does**: Instead of hard targets (0, 1), use soft targets (0.05, 0.95). Model doesn't become overconfident on training data. + +**When to use**: +- ✅ For classification with many classes (100+ classes) +- ✅ When model becomes overconfident (99.9% train acc, 70% val acc) +- ✅ When you want calibrated predictions +- ✅ For knowledge distillation + +**When to skip**: +- ❌ For regression tasks +- ❌ For highly noisy labels (already uncertain) +- ❌ For ranking/metric learning + +**Implementation**: +```python +class LabelSmoothingLoss(nn.Module): + def __init__(self, smoothing=0.1): + super().__init__() + self.smoothing = smoothing + self.confidence = 1.0 - smoothing + + def forward(self, logits, targets): + """ + logits: Model output, shape (batch_size, num_classes) + targets: Target class indices, shape (batch_size,) + """ + log_probs = F.log_softmax(logits, dim=-1) + + # Create smooth labels + # Instead of: [0, 0, 1, 0] for class 2 + # Use: [0.03, 0.03, 0.93, 0.03] for class 2 + with torch.no_grad(): + smooth_targets = torch.full_like(log_probs, self.smoothing / (logits.size(-1) - 1)) + smooth_targets.scatter_(1, targets.unsqueeze(1), self.confidence) + + return torch.mean(torch.sum(-smooth_targets * log_probs, dim=-1)) + +# Usage: +criterion = LabelSmoothingLoss(smoothing=0.1) +loss = criterion(logits, targets) +``` + +**Key Parameters**: +- **smoothing** (how much to smooth) + - 0.0: No smoothing (standard cross-entropy) + - 0.1: Light smoothing (10% probability spread to other classes) + - 0.2: Medium smoothing (20% spread) + - 0.5: Heavy smoothing (50% spread, probably too much) + +**Typical Improvements**: +- Overconfidence reduction: Prevents 99.9% train accuracy +- Generalization: +0.5-1.5% accuracy +- Calibration: Much better confidence estimates + +**Side Effect**: Slightly reduces train accuracy (0.5-1%) but improves generalization. + + +### Technique 6: Data Augmentation (Expand Training Diversity) + +**What it does**: Creates new training examples by transforming existing ones (rotate, crop, flip, add noise). Model sees more diverse data, learns generalizability instead of memorization. + +**When to use**: +- ✅ For small datasets (< 10K examples) - essential +- ✅ For image classification, detection, segmentation +- ✅ For any domain where natural transformations preserve labels +- ✅ When overfitting is due to limited data diversity + +**When to skip**: +- ❌ When you have massive dataset (1M+ examples) +- ❌ For tasks where transformations change meaning (e.g., medical imaging) +- ❌ When augmentation pipeline is not domain-specific + +**Example**: +```python +from torchvision import transforms + +# For CIFAR-10: Small images need conservative augmentation +train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), # 32×32 → random crop + transforms.RandomHorizontalFlip(p=0.5), # 50% chance to flip + transforms.ColorJitter(brightness=0.2, contrast=0.2), # Mild color + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), +]) + +train_loader = DataLoader(train_dataset, transform=train_transform) +``` + +**Typical Improvements**: +- Small dataset (1K examples): +5-10% accuracy +- Medium dataset (10K examples): +2-4% accuracy +- Large dataset (100K examples): +0.5-1% accuracy + +**See data-augmentation-strategies skill for comprehensive augmentation guidance.** + + +### Technique 7: Reduce Model Capacity (Smaller Model = Less Overfitting) + +**What it does**: Use smaller network (fewer layers, fewer neurons) so model has less capacity to memorize. Fundamental solution when model is overparameterized. + +**When to use**: +- ✅ When model has way more parameters than training examples +- ✅ When training data is small (< 1K examples) +- ✅ When regularization alone doesn't fix overfitting +- ✅ For mobile/edge deployment anyway + +**When to skip**: +- ❌ When model is already underfitting +- ❌ When you need high accuracy on large dataset + +**Example**: +```python +# ORIGINAL: Overparameterized for small dataset +class OverkillModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 512) # Too large + self.fc2 = nn.Linear(512, 256) # Too large + self.fc3 = nn.Linear(256, 128) # Too large + self.fc4 = nn.Linear(128, 10) + # Total: ~450K parameters for 1K training examples! + +# REDUCED: Appropriate for small dataset +class AppropriateModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 128) # Smaller + self.fc2 = nn.Linear(128, 64) # Smaller + self.fc3 = nn.Linear(64, 10) + # Total: ~55K parameters (10x reduction) +``` + +**Typical Improvements**: +- Small dataset with huge model: +5-15% accuracy +- Prevents overfitting before it happens +- Faster training and inference + + +### Technique 8: Cross-Validation (Train Multiple Models on Different Data Splits) + +**What it does**: Trains K models, each on different subset of data, then averages predictions. Gives more reliable estimate of generalization error. + +**When to use**: +- ✅ For small datasets (< 10K examples) where single train/val split is noisy +- ✅ When you need reliable performance estimates +- ✅ For hyperparameter selection +- ✅ For ensemble methods + +**When to skip**: +- ❌ For large datasets (single train/val split is sufficient) +- ❌ When compute is limited (K-fold is K times more expensive) + +**Implementation**: +```python +from sklearn.model_selection import StratifiedKFold + +skf = StratifiedKFold(n_splits=5, shuffle=True) +fold_scores = [] + +for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)): + X_train, X_val = X[train_idx], X[val_idx] + y_train, y_val = y[train_idx], y[val_idx] + + model = create_model() + model.fit(X_train, y_train) + score = model.evaluate(X_val, y_val) + fold_scores.append(score) + +mean_score = np.mean(fold_scores) +std_score = np.std(fold_scores) +print(f"Mean: {mean_score:.4f}, Std: {std_score:.4f}") +``` + + +## Part 3: Combining Multiple Techniques + +### The Balancing Act + +Overfitting rarely has single-technique fix. Most effective approach combines 2-4 techniques based on diagnosis. + +**Decision Framework**: + +``` +START: Choosing regularization combination + +├─ What's the PRIMARY cause of overfitting? +│ ├─ Model too large (params >> examples) +│ │ → **Primary fix**: Reduce model capacity +│ │ → **Secondary**: L2 regularization +│ │ → **Tertiary**: Early stopping +│ │ +│ ├─ Dataset too small (< 5K examples) +│ │ → **Primary fix**: Data augmentation +│ │ → **Secondary**: Strong L2 (weight_decay=0.01-0.1) +│ │ → **Tertiary**: Early stopping +│ │ +│ ├─ Training too long (still training past best point) +│ │ → **Primary fix**: Early stopping +│ │ → **Secondary**: Learning rate schedule (decay) +│ │ → **Tertiary**: L2 regularization +│ │ +│ ├─ High learning rate (weights changing too fast) +│ │ → **Primary fix**: Reduce learning rate / learning rate schedule +│ │ → **Secondary**: Early stopping +│ │ → **Tertiary**: Batch normalization +│ │ +│ └─ Overconfident predictions (99% train acc) +│ → **Primary fix**: Label smoothing +│ → **Secondary**: Dropout (for MLPs) +│ → **Tertiary**: L2 regularization + +└─ Then check: + ├─ Measure improvement after each addition + ├─ Don't add conflicting techniques (dropout + batch norm together) + ├─ Tune regularization strength systematically +``` + +### Anti-Patterns: What NOT to Do + +**Anti-Pattern 1: Throwing Everything at the Problem** + +```python +# WRONG: All techniques at max strength simultaneously +model = MyModel(dropout=0.5) # Heavy dropout +batch_norm = True # Maximum regularization +optimizer = AdamW(weight_decay=0.1) # Strong L2 +augmentation = aggressive_augment() # Strong augmentation +early_stop = EarlyStop(patience=5) # Aggressive stopping +label_smooth = 0.5 # Heavy smoothing + +# Result: Model underfits, train accuracy 60%, val accuracy 58% +# You've over-regularized! +``` + +**Anti-Pattern 2: Wrong Combinations** + +```python +# Problematic: Batch norm + Dropout in sequence +class BadModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 512) + self.bn1 = nn.BatchNorm1d(512) + self.dropout1 = nn.Dropout(0.5) # Problem: applies AFTER normalization + # Batch norm already stabilizes, dropout destabilizes + # Interaction: Complex, unpredictable + +# Better: Do either BN or Dropout, not both for same layer +# Even better: BN in early layers, Dropout in late layers +``` + +**Anti-Pattern 3: Over-Tuning on Validation Set** + +```python +# WRONG: Trying so many hyperparameter combinations that you overfit to val set +for lr in [1e-4, 5e-4, 1e-3, 5e-3]: + for weight_decay in [0, 1e-5, 1e-4, 1e-3, 1e-2, 0.1]: + for dropout in [0.0, 0.2, 0.5, 0.7]: + for patience in [5, 10, 15, 20]: + # 4 * 6 * 4 * 4 = 384 combinations! + # Training 384 models on same validation set overfits to validation + +# Better: Random grid search, use held-out test set for final eval +``` + +### Systematic Combination Strategy + +**Step 1: Measure Baseline (No Regularization)** + +```python +# Record: train accuracy, val accuracy, train/val gap +# Epoch 0: train=52%, val=52%, gap=0% +# Epoch 10: train=88%, val=80%, gap=8% +# Epoch 20: train=92%, val=75%, gap=17% ← Overfitting visible +# Epoch 30: train=95%, val=68%, gap=27% ← Severe overfitting +``` + +**Step 2: Add ONE Technique** + +```python +# Add early stopping, measure alone +early_stop = EarlyStoppingCallback(patience=10) +# Train same model with early stopping +# Result: train=92%, val=80%, gap=12% ← 5% improvement + +# Improvement: +5% val accuracy, reduced overfitting +# Cost: None, actually saves compute +# Decision: Keep it, add another if needed +``` + +**Step 3: Add SECOND Technique (Differently Targeted)** + +```python +# Add L2 regularization to target weight magnitude +optimizer = AdamW(weight_decay=0.001) +# Train same model with early stop + L2 +# Result: train=91%, val=82%, gap=9% ← Another 2% improvement + +# Improvement: +2% additional val accuracy +# Cost: Tiny compute overhead +# Decision: Keep it +``` + +**Step 4: Check for Conflicts** + +```python +# If you added both, check that: +# - Val accuracy improved (it did: 75% → 82%) +# - Train accuracy only slightly reduced (92% → 91%, acceptable) +# - Training is still stable (no weird loss spikes) + +# If train accuracy dropped > 3%, you've over-regularized +# If val accuracy didn't improve, technique isn't helping (remove it) +``` + +**Step 5: Optional - Add THIRD Technique** + +```python +# If still overfitting (gap > 10%), add one more technique +# But only if previous two helped and didn't conflict + +# Options at this point: +# - Data augmentation (if dataset small) +# - Dropout (if fully connected layers) +# - Reduce model capacity (fundamental fix) +``` + + +## Part 4: Architecture-Specific Strategies + +### CNNs (Computer Vision) + +**Typical overfitting patterns**: +- Train 98%, Val 75% on CIFAR-10 with small dataset +- Overfitting on small datasets with large pre-trained models + +**Recommended fixes (in order)**: +1. **Early stopping** (always, essential) +2. **L2 regularization** (weight_decay=0.0001 to 0.001) +3. **Data augmentation** (rotation ±15°, flip, crop, jitter) +4. **Reduce model capacity** (smaller ResNet if possible) +5. **Dropout** (rarely needed, not as effective as above) + +**Anti-pattern for CNNs**: Dropout after conv layers (not effective). Use batch norm instead. + +### Transformers (NLP, Vision) + +**Typical overfitting patterns**: +- Large model (100M+ parameters) on small dataset (5K examples) +- Overconfident predictions after few epochs + +**Recommended fixes (in order)**: +1. **Early stopping** (critical, prevents training to overfitting) +2. **L2 regularization** (weight_decay=0.01 to 0.1) +3. **Label smoothing** (0.1 recommended) +4. **Data augmentation** (back-translation for NLP, mixup for vision) +5. **Reduce model capacity** (use smaller transformer) + +**Anti-pattern for Transformers**: Dropout (modern transformers don't use it much). Use batch norm + layer norm already included. + +### RNNs/LSTMs (Sequences) + +**Typical overfitting patterns**: +- Train loss decreasing, val loss increasing after epoch 50 +- Small dataset (< 10K sequences) + +**Recommended fixes (in order)**: +1. **Early stopping** (essential for sequences) +2. **Dropout** (critical for RNNs, 0.2-0.5) +3. **L2 regularization** (weight_decay=0.0001) +4. **Data augmentation** (if applicable to domain) +5. **Recurrent dropout** (specific for RNNs, drops same neurons across timesteps) + +**Anti-pattern for RNNs**: Using standard dropout (neurons drop differently each timestep). Use recurrent dropout instead. + + +## Part 5: Common Pitfalls & Rationalizations + +### Pitfall 1: "Higher training accuracy = better model" + +**User's Rationalization**: "My training accuracy reached 99%, so the model is learning well." + +**Reality**: High training accuracy means nothing without validation accuracy. Model could be 99% accurate on training and 50% on validation (overfitting). + +**Fix**: Always report both train and validation accuracy. Gap of > 5% is concerning. + + +### Pitfall 2: "Dropout solves all overfitting problems" + +**User's Rationalization**: "I heard dropout is the best regularization, so I'll add dropout=0.5 everywhere." + +**Reality**: Dropout is regularization, not a cure-all. Effectiveness depends on: +- Architecture (works great for MLPs, less for CNNs) +- Where it's placed (after FC layers yes, after conv layers no) +- Strength (0.5 is standard, but 0.3 might be better for your case) + +**Fix**: Use early stopping + L2 first. Only add dropout if others insufficient. + + +### Pitfall 3: "More regularization is always better" + +**User's Rationalization**: "One regularization technique helped, so let me add five more!" + +**Reality**: Multiple regularization techniques can conflict: +- Dropout + batch norm together have complex interaction +- L2 + large batch size interact weirdly +- Over-regularization causes underfitting (60% train, 58% val) + +**Fix**: Add one technique at a time. Measure improvement. Stop when improvement plateaus. + + +### Pitfall 4: "I'll fix overfitting with more data" + +**User's Rationalization**: "My model overfits on 5K examples, so I need 50K examples to fix it." + +**Reality**: More data helps, but regularization is faster and cheaper. You can fix overfitting with 5K examples + good regularization. + +**Fix**: Use data augmentation (cheap), regularization, and early stopping before collecting more data. + + +### Pitfall 5: "Early stopping is for amateurs" + +**User's Rationalization**: "Real practitioners train full epochs, not early stopping." + +**Reality**: Every competitive model uses early stopping. It's not about "early stopping at epoch 10", it's about "stop when validation peaks". + +**Fix**: Use early stopping with patience=10-20. It saves compute and improves accuracy. + + +### Pitfall 6: "Validation set is luxury I can't afford" + +**User's Rationalization**: "I only have 10K examples, can't spare 2K for validation." + +**Reality**: You can't diagnose overfitting without validation set. You're flying blind. + +**Fix**: Use at least 10% validation set. With 10K examples, that's 1K for validation, 9K for training. Acceptable tradeoff. + + +### Pitfall 7: "Model overfits, so I'll disable batch norm" + +**User's Rationalization**: "Batch norm acts as regularization, maybe it's causing overfitting?" + +**Reality**: Batch norm is usually good. It stabilizes training and is mild regularization. Removing it won't help overfitting much. + +**Fix**: Keep batch norm. If overfitting, add stronger regularization (early stopping, L2). + + +### Pitfall 8: "I'll augment validation data for fairness" + +**User's Rationalization**: "I augment training data, so I should augment validation too for consistency." + +**Reality**: Validation data should be augmentation-free. Otherwise your validation accuracy is misleading. + +**Fix**: Augment training data only. Validation and test data stay original. + + +### Pitfall 9: "Regularization will slow down my training" + +**User's Rationalization**: "Adding early stopping and L2 will complicate my training pipeline." + +**Reality**: Early stopping saves compute (train to epoch 100 instead of 200). Regularization adds negligible overhead. + +**Fix**: Early stopping actually makes training FASTER. Add it. + + +### Pitfall 10: "My overfitting is unavoidable with this small dataset" + +**User's Rationalization**: "5K examples is too small, I can't prevent overfitting." + +**Reality**: With proper regularization (data augmentation, L2, early stopping), you can achieve 85-90% accuracy on 5K examples. + +**Fix**: Combine augmentation + L2 + early stopping. This combination is very effective on small datasets. + + +## Part 6: Red Flags & Troubleshooting + +### Red Flag 1: "Validation loss increasing while training loss decreasing" + +**What it means**: Classic overfitting. Model is memorizing training data, not learning patterns. + +**Immediate action**: Enable early stopping if not already enabled. Set patience=10 and retrain. + +**Diagnosis checklist**: +- [ ] Is training data too small? (< 5K examples) +- [ ] Is model too large? (more parameters than examples) +- [ ] Is training too long? (epoch 100 when best was epoch 20) +- [ ] Is learning rate too high? (weights changing too fast) + + +### Red Flag 2: "Training accuracy increased from 85% to 92%, but validation decreased from 78% to 73%" + +**What it means**: Overfitting is accelerating. Model is moving away from good generalization. + +**Immediate action**: Stop training now. Use checkpoint from earlier epoch (when val was 78%). + +**Diagnosis checklist**: +- [ ] Do you have early stopping enabled? +- [ ] Is patience too high? (should be 10-15, not 100) +- [ ] Did you collect more data or change something? + + +### Red Flag 3: "Training unstable, loss spiking randomly" + +**What it means**: Likely cause: learning rate too high, or poorly set batch norm in combo with dropout. + +**Immediate action**: Reduce learning rate by 10x. If still unstable, check batch norm + dropout interaction. + +**Diagnosis checklist**: +- [ ] Is learning rate too high? (try 0.1x) +- [ ] Is batch size too small? (< 8) +- [ ] Is batch norm + dropout used together badly? + + +### Red Flag 4: "Model performs well on training set, catastrophically bad on test" + +**What it means**: Severe overfitting or distribution shift. Model learned training set patterns that don't generalize. + +**Immediate action**: Check if test set is different distribution from training. If same distribution, severe overfitting. + +**Fix for overfitting**: +- Reduce model capacity significantly (20-50% reduction) +- Add strong L2 (weight_decay=0.1) +- Add strong augmentation +- Collect more training data + + +### Red Flag 5: "Validation accuracy plateaued but still training" + +**What it means**: Model has reached its potential with current hyperparameters. Training past this point is wasting compute. + +**Immediate action**: Enable early stopping. Set patience=20 and retrain. + +**Diagnosis checklist**: +- [ ] Has validation accuracy been flat for 20+ epochs? +- [ ] Could learning rate schedule help? (try cosine annealing) +- [ ] Is model capacity sufficient? (or too limited) + + +### Red Flag 6: "Train loss very low, but validation loss very high" + +**What it means**: Severe overfitting. Model is extremely confident on training examples but clueless on validation. + +**Immediate action**: Model capacity too high. Reduce significantly (30-50% fewer parameters). + +**Other actions**: +- Enable strong L2 (weight_decay=0.1) +- Add aggressive data augmentation +- Reduce learning rate +- Collect more data + + +### Red Flag 7: "Small changes in hyperparameters cause huge validation swings" + +**What it means**: Model is very sensitive to hyperparameters. Sign of small dataset or poor regularization. + +**Immediate action**: Use cross-validation (K-fold) to get more stable estimates. + +**Diagnosis checklist**: +- [ ] Dataset < 10K examples? (Small dataset, high variance) +- [ ] Validation set too small? (< 1K examples) +- [ ] Regularization too weak? (no L2, no augmentation, no early stop) + + +### Red Flag 8: "Training seems to work, but model fails in production" + +**What it means**: Validation data distribution differs from production. Or validation set too small to catch overfitting. + +**Immediate action**: Analyze production data. Is it different from validation? If so, that's a distribution shift problem, not overfitting. + +**Diagnosis checklist**: +- [ ] Is test data representative of production? +- [ ] Are there label differences? (example: validation = clean images, production = blurry images) +- [ ] Did you collect more data that changed distribution? + + +### Troubleshooting Flowchart + +``` +START: Model is overfitting (train > val by > 5%) + +├─ Is validation accuracy still increasing with training? +│ ├─ YES: Not yet severe overfitting, can continue +│ │ Add early stopping as safety net +│ │ +│ └─ NO: Validation has plateaued or declining +│ ↓ +│ +├─ Enable early stopping if not present +│ ├─ Setting: patience=10-20 +│ ├─ Retrain and measure +│ ├─ Expected improvement: 5-15% in final validation accuracy +│ │ +│ └─ Did validation improve? +│ ├─ YES: Problem partially solved, may need more +│ └─ NO: Early stopping not main issue, continue... +│ +├─ Check model capacity vs data size +│ ├─ Model parameters > 10x data size → Reduce capacity (50% smaller) +│ ├─ Model parameters = data size → Add regularization +│ ├─ Model parameters < data size → Regularization may be unnecessary +│ │ +│ └─ Continue... +│ +├─ Add L2 regularization if not present +│ ├─ Small dataset (< 5K): weight_decay=0.01-0.1 +│ ├─ Medium dataset (5K-50K): weight_decay=0.001-0.01 +│ ├─ Large dataset (> 50K): weight_decay=0.0001-0.001 +│ │ +│ └─ Retrain and measure +│ ├─ YES: Val improved +1-3% → Keep it +│ └─ NO: Wasn't the bottleneck, continue... +│ +├─ Add data augmentation if applicable +│ ├─ Image data: Rotation, flip, crop, color +│ ├─ Text data: Back-translation, synonym replacement +│ ├─ Tabular data: SMOTE, noise injection +│ │ +│ └─ Retrain and measure +│ ├─ YES: Val improved +2-5% → Keep it +│ └─ NO: Augmentation not applicable or too aggressive +│ +├─ Only if gap still > 10%: Consider reducing model capacity +│ ├─ 20-50% fewer parameters +│ ├─ Fewer layers or narrower layers +│ │ +│ └─ Retrain and measure +│ +└─ If STILL overfitting: Collect more training data +``` + + +## Part 7: Rationalization Table (Diagnosis & Correction) + +| User's Belief | What's Actually True | Evidence | Fix | +|---------------|---------------------|----------|-----| +| "Train acc 95% means model is working" | High train acc without validation is meaningless | Train 95%, val 65% is common in overfitting | Check validation accuracy immediately | +| "More training always helps" | Training past best point increases overfitting | Val loss starts increasing at epoch 50, worsens by epoch 200 | Use early stopping with patience=10 | +| "I need more data to fix overfitting" | Regularization is faster and cheaper | Can achieve 85% val with 5K+augment vs 90% with 50K | Try regularization first | +| "Dropout=0.5 is standard" | Standard depends on architecture and task | Works for MLPs, less effective for CNNs | Start with 0.3, tune based on results | +| "Batch norm and dropout together is fine" | They can conflict, reducing overall regularization | Empirically unstable together | Use one or the other, not both | +| "I'll augment validation for fairness" | Validation must measure true performance | Augmented validation gives misleading metrics | Never augment validation/test data | +| "L2 with weight_decay in Adam works" | Adam's weight_decay is broken, use AdamW | Adam and AdamW have different weight decay implementations | Switch to AdamW | +| "Early stopping defeats the purpose of training" | Early stopping is how you optimize generalization | Professional models always use early stopping | Enable it, set patience=10-20 | +| "Overfitting is unavoidable with small data" | Proper regularization prevents overfitting effectively | 5K examples + augment + L2 + early stop = 80%+ val | Combine multiple techniques | +| "Model with 1M params on 1K examples is fine" | 1000x parameter/example ratio guarantees overfitting | Impossible to prevent without extreme regularization | Reduce capacity to 10-100K params | + + +## Part 8: Complete Example: Diagnosing & Fixing Overfitting + +### Scenario: Image Classification on Small Dataset + +**Initial Setup**: +- Dataset: 5,000 images, 10 classes +- Model: ResNet50 (23M parameters) +- Observation: Train acc 97%, Val acc 62%, Gap 35% + +**Step 1: Diagnose Root Causes** + +| Factor | Assessment | +|--------|-----------| +| Model size | 23M params for 5K examples = 4600x ratio → **TOO LARGE** | +| Dataset size | 5K is small → **HIGH OVERFITTING RISK** | +| Regularization | No early stopping, no L2, no augmentation → **INADEQUATE** | +| Learning rate | Default 1e-4, not high → **PROBABLY OK** | + +**Conclusion**: Primary cause = model too large. Secondary = insufficient regularization. + +**Step 2: Apply Fixes in Order** + +**Fix 1: Early Stopping** (Cost: free, compute savings) +```python +early_stop = EarlyStoppingCallback(patience=15) +# Retrain: Train acc 94%, Val acc 76%, Gap 18% +# ✓ Improved by 14% (62% → 76%) +``` + +**Fix 2: Reduce Model Capacity** (Cost: lower max capacity, but necessary) +```python +# Use ResNet18 instead of ResNet50 +# 11M → 11M parameters (already smaller than ResNet50) +# Actually, use even smaller: ResNet10-like +# 2M parameters for 5K examples = 400x ratio (better but still high) +# Retrain with ResNet18 + early stopping +# Train acc 88%, Val acc 79%, Gap 9% +# ✓ Improved by 3% (76% → 79%), and reduced overfitting gap +``` + +**Fix 3: L2 Regularization** (Cost: negligible) +```python +optimizer = AdamW(model.parameters(), weight_decay=0.01) +# Retrain: Train acc 86%, Val acc 80%, Gap 6% +# ✓ Improved by 1% (79% → 80%), reduced overfitting further +``` + +**Fix 4: Data Augmentation** (Cost: 10-15% training time) +```python +train_transform = transforms.Compose([ + transforms.RandomCrop(224, padding=8), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), +]) +# Retrain: Train acc 84%, Val acc 82%, Gap 2% +# ✓ Improved by 2% (80% → 82%), overfitting gap now minimal +``` + +**Final Results**: +- Started: Train 97%, Val 62%, Gap 35% (severe overfitting) +- Ended: Train 84%, Val 82%, Gap 2% (healthy generalization) +- Trade: 13% train accuracy loss for 20% val accuracy gain = **net +20% on real task** + +**Lesson**: Fixing overfitting sometimes requires accepting lower training accuracy. That's the point—you're no longer memorizing. + + +## Part 9: Advanced Topics + +### Mixup and Cutmix (Advanced Augmentation as Regularization) + +**What they do**: Create synthetic training examples by mixing two examples. + +**Mixup**: Blend images and labels +```python +class MixupAugmentation: + def __init__(self, alpha=0.2): + self.alpha = alpha + + def __call__(self, images, targets): + """ + Randomly blend two training batches + """ + batch_size = images.size(0) + index = torch.randperm(batch_size) + + # Sample lambda from Beta distribution + lam = np.random.beta(self.alpha, self.alpha) + + # Mix images + mixed_images = lam * images + (1 - lam) * images[index, :] + + # Mix targets (soft targets) + target_a, target_b = targets, targets[index] + + return mixed_images, target_a, target_b, lam + +# In training loop: +mixup = MixupAugmentation(alpha=0.2) +mixed_images, target_a, target_b, lam = mixup(images, targets) +output = model(mixed_images) +loss = lam * criterion(output, target_a) + (1 - lam) * criterion(output, target_b) +``` + +**When to use**: For image classification on moderate+ datasets (10K+). Effective regularization. + +**Typical improvement**: +1-3% accuracy + + +### Class Imbalance as Overfitting Factor + +**Scenario**: Model overfits to majority class. Minority class appears only 100 times out of 10,000. + +**Solution 1: Weighted Sampling** +```python +from torch.utils.data import WeightedRandomSampler + +# Compute class weights +class_counts = torch.bincount(train_labels) +class_weights = 1.0 / class_counts +sample_weights = class_weights[train_labels] + +# Create sampler that balances classes +sampler = WeightedRandomSampler( + weights=sample_weights, + num_samples=len(sample_weights), + replacement=True +) + +train_loader = DataLoader( + train_dataset, + batch_size=32, + sampler=sampler # Replaces shuffle=True +) + +# Result: Each batch has balanced class representation +# Prevents model from ignoring minority class +``` + +**Solution 2: Loss Weighting** +```python +# Compute class weights +class_counts = torch.bincount(train_labels) +class_weights = len(train_labels) / (len(class_counts) * class_counts) +class_weights = class_weights.to(device) + +criterion = nn.CrossEntropyLoss(weight=class_weights) +# Cross-entropy automatically weights loss by class + +# Result: Minority class has higher loss weight +# Model pays more attention to getting minority class right +``` + +**Which to use**: Weighted sampler (adjusts data distribution) + weighted loss (adjusts loss). + + +### Handling Validation Set Leakage + +**Problem**: Using validation set performance to decide hyperparameters creates implicit overfitting to validation set. + +**Example of Leakage**: +```python +# WRONG: Using val performance to select model +best_val_acc = 0 +for lr in [1e-4, 1e-3, 1e-2]: + train_model(lr) + val_acc = validate() + if val_acc > best_val_acc: + best_val_acc = val_acc + best_lr = lr + +# You've now tuned hyperparameters to maximize validation accuracy +# Your validation accuracy estimate is optimistic (overfitted to val set) +``` + +**Proper Solution: Use Hold-Out Test Set** +```python +# Split: Train (60%), Validation (20%), Test (20%) +# 1. Train and select hyperparameters using train + val +# 2. Report final metrics using test set only +# 3. Never tune on test set + +test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) +for X_test, y_test in test_loader: + predictions = model(X_test) + test_acc = (predictions.argmax(1) == y_test).float().mean() + +# Report: Test accuracy 78.5% (this is your honest estimate) +``` + +**Or Use Cross-Validation**: +```python +from sklearn.model_selection import StratifiedKFold + +skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) +cv_scores = [] + +for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)): + X_train, X_val = X[train_idx], X[val_idx] + y_train, y_val = y[train_idx], y[val_idx] + + model = create_model() + model.fit(X_train, y_train) + val_acc = model.evaluate(X_val, y_val) + cv_scores.append(val_acc) + +mean_cv_score = np.mean(cv_scores) +std_cv_score = np.std(cv_scores) +print(f"CV Score: {mean_cv_score:.4f} ± {std_cv_score:.4f}") + +# This is more robust estimate than single train/val split +``` + + +### Monitoring Metric: Learning Curves + +**What to track**: +```python +history = { + 'train_loss': [], + 'val_loss': [], + 'train_acc': [], + 'val_acc': [], +} + +for epoch in range(100): + train_loss, train_acc = train_one_epoch() + val_loss, val_acc = validate() + + history['train_loss'].append(train_loss) + history['val_loss'].append(val_loss) + history['train_acc'].append(train_acc) + history['val_acc'].append(val_acc) + +# Plot +import matplotlib.pyplot as plt + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + +# Loss curves +ax1.plot(history['train_loss'], label='Train Loss') +ax1.plot(history['val_loss'], label='Val Loss') +ax1.set_xlabel('Epoch') +ax1.set_ylabel('Loss') +ax1.legend() +ax1.grid() + +# Accuracy curves +ax2.plot(history['train_acc'], label='Train Acc') +ax2.plot(history['val_acc'], label='Val Acc') +ax2.set_xlabel('Epoch') +ax2.set_ylabel('Accuracy') +ax2.legend() +ax2.grid() + +plt.tight_layout() +plt.show() + +# Interpretation: +# - Both curves decreasing together → Good generalization +# - Train decreasing, val increasing → Overfitting +# - Both plateaued at different levels → Possible underfitting (gap exists at plateau) +``` + +**What good curves look like**: +- Both loss curves decrease smoothly +- Curves stay close together (gap < 5%) +- Loss curves flatten out (convergence) +- Accuracy curves increase together and plateau + +**What bad curves look like**: +- Validation loss spikes or increases sharply +- Large and growing gap between train and validation +- Loss curves diverge after certain point +- Validation accuracy stops improving but training continues + + +### Hyperparameter Tuning Strategy + +**Recommended approach**: Grid search with cross-validation, not random search. + +```python +param_grid = { + 'weight_decay': [0.0001, 0.001, 0.01, 0.1], + 'dropout_rate': [0.1, 0.3, 0.5], + 'learning_rate': [1e-4, 5e-4, 1e-3], +} + +best_score = -float('inf') +best_params = None + +for weight_decay in param_grid['weight_decay']: + for dropout_rate in param_grid['dropout_rate']: + for lr in param_grid['learning_rate']: + # Train with these parameters + scores = cross_validate( + model, + X_train, + y_train, + params={'weight_decay': weight_decay, + 'dropout_rate': dropout_rate, + 'lr': lr} + ) + + mean_score = np.mean(scores) + if mean_score > best_score: + best_score = mean_score + best_params = { + 'weight_decay': weight_decay, + 'dropout_rate': dropout_rate, + 'lr': lr + } + +print(f"Best params: {best_params}") +print(f"Best cross-val score: {best_score:.4f}") + +# Train final model on all training data with best params +final_model = create_model(**best_params) +final_model.fit(X_train, y_train) +test_score = final_model.evaluate(X_test, y_test) +print(f"Test score: {test_score:.4f}") +``` + + +### Debugging Checklist + +When your model overfits, go through this checklist: + +- [ ] Monitoring BOTH train AND validation accuracy? +- [ ] Train/val gap is clear and objective? +- [ ] Using proper validation set (10% of data minimum)? +- [ ] Validation set from SAME distribution as training? +- [ ] Early stopping enabled with patience 5-20? +- [ ] L2 regularization strength appropriate for dataset size? +- [ ] Data augmentation applied to TRAINING only (not validation)? +- [ ] Model capacity reasonable for data size (params < 100x examples)? +- [ ] Learning rate schedule used (decay or warmup)? +- [ ] Batch normalization or layer normalization present? +- [ ] Not adding conflicting regularization (e.g., too much dropout + too strong L2)? +- [ ] Loss curve showing training progress (not stuck)? +- [ ] Validation loss actually used for stopping (not just epoch limit)? + +If you've checked all these and still overfitting, the issue is likely: +1. **Data too small or too hard** → Collect more data +2. **Model fundamentally wrong** → Try different architecture +3. **Distribution shift** → Validation data different from training + + +### Common Code Patterns + +**Pattern 1: Proper Training Loop with Early Stopping** +```python +early_stop = EarlyStoppingCallback(patience=15) +best_model = None + +for epoch in range(500): + # Train + train_loss = 0 + for X_batch, y_batch in train_loader: + logits = model(X_batch) + loss = criterion(logits, y_batch) + loss.backward() + optimizer.step() + optimizer.zero_grad() + train_loss += loss.item() + + train_loss /= len(train_loader) + + # Validate + val_loss = 0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + logits = model(X_batch) + loss = criterion(logits, y_batch) + val_loss += loss.item() + + val_loss /= len(val_loader) + + # Check early stopping + early_stop(val_loss) + if val_loss < early_stop.best_val_loss: + best_model = copy.deepcopy(model) + + if early_stop.should_stop: + print(f"Stopping at epoch {epoch}") + model = best_model + break +``` + +**Pattern 2: Regularization Combination** +```python +# Setup with multiple regularization techniques +model = MyModel(dropout=0.3) # Mild dropout +model = model.to(device) + +# L2 regularization via weight decay +optimizer = torch.optim.AdamW(model.parameters(), + lr=1e-4, + weight_decay=0.001) + +# Learning rate schedule for decay +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + +# Early stopping +early_stop = EarlyStoppingCallback(patience=20) + +for epoch in range(200): + # Train with data augmentation + train_acc = 0 + for X_batch, y_batch in augmented_train_loader: + logits = model(X_batch) + loss = criterion(logits, y_batch) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + train_acc += (logits.argmax(1) == y_batch).float().mean().item() + + train_acc /= len(train_loader) + scheduler.step() + + # Validate (NO augmentation on validation) + val_acc = 0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: # Clean val loader + logits = model(X_batch) + val_acc += (logits.argmax(1) == y_batch).float().mean().item() + + val_acc /= len(val_loader) + + early_stop(val_acc) + + print(f"Epoch {epoch}: Train {train_acc:.4f}, Val {val_acc:.4f}") + + if early_stop.should_stop: + break +``` + + +## Summary + +**Overfitting is detectable, diagnosable, and fixable.** + +1. **Detect**: Monitor both train and validation accuracy. Gap > 5% is warning. +2. **Diagnose**: Root causes = large model, small data, long training, high learning rate, class imbalance +3. **Fix**: Combine techniques (early stopping + L2 + augmentation + capacity reduction) +4. **Measure**: Check improvement after each addition +5. **Avoid**: Single-technique fixes, blindly tuning regularization, ignoring validation +6. **Remember: Proper validation set and test set are essential** - Without them, you're optimizing blindly + +**Remember**: The goal is not maximum training accuracy. The goal is maximum generalization. Sometimes that means accepting lower training accuracy to achieve higher validation accuracy. + +**One more thing**: Different problems have different fixes: +- High capacity on small data → Reduce capacity, data augmentation +- Training too long → Early stopping +- High learning rate → LR schedule or reduce LR +- Class imbalance → Weighted sampling or weighted loss +- Overconfidence → Label smoothing + +Choose the fix that matches your diagnosis, not your intuition. + diff --git a/skills/using-training-optimization/training-loop-architecture.md b/skills/using-training-optimization/training-loop-architecture.md new file mode 100644 index 0000000..d4979af --- /dev/null +++ b/skills/using-training-optimization/training-loop-architecture.md @@ -0,0 +1,882 @@ + +# Training Loop Architecture + +## Overview + +**Core Principle:** A properly structured training loop is the foundation of all successful deep learning projects. Success requires: (1) correct train/val/test data separation, (2) validation after EVERY epoch (not just once), (3) complete checkpoint state (model + optimizer + scheduler), (4) comprehensive logging/monitoring, and (5) graceful error handling. Poor loop structure causes: silent overfitting, broken resume functionality, undetectable training issues, and memory leaks. + +Training loop failures manifest as: overfitting with good metrics, crashes on resume, unexplained loss spikes, or out-of-memory errors. These stem from misunderstanding when validation runs, what state must be saved, or how to manage GPU memory. Systematic architecture beats trial-and-error fixes. + +## When to Use + +**Use this skill when:** +- Implementing a new training loop from scratch +- Training loop is crashing unexpectedly +- Can't resume training from checkpoint correctly +- Model overfits but validation metrics look good +- Out-of-memory errors during training +- Unsure about train/val/test data split +- Need to monitor training progress properly +- Implementing early stopping or checkpoint selection +- Training loops show loss spikes or divergence on resume +- Adding logging/monitoring to training + +**Don't use when:** +- Debugging single backward pass (use gradient-management skill) +- Tuning learning rate (use learning-rate-scheduling skill) +- Fixing specific loss function (use loss-functions-and-objectives skill) +- Data loading issues (use data-augmentation-strategies skill) + +**Symptoms triggering this skill:** +- "Training loss decreases but validation loss increases (overfitting)" +- "Training crashes when resuming from checkpoint" +- "Out of memory errors after epoch 20" +- "I validated on training data and didn't realize" +- "Can't detect overfitting because I don't validate" +- "Training loss spikes when resuming" +- "My checkpoint doesn't load correctly" + + +## Complete Training Loop Structure + +### 1. The Standard Training Loop (The Reference) + +```python +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import logging +from pathlib import Path + +# Setup logging (ALWAYS do this first) +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class TrainingLoop: + """Complete training loop with validation, checkpointing, and monitoring.""" + + def __init__(self, model, optimizer, scheduler, criterion, device='cuda'): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.criterion = criterion + self.device = device + + # Tracking metrics + self.train_losses = [] + self.val_losses = [] + self.best_val_loss = float('inf') + self.epochs_without_improvement = 0 + + def train_epoch(self, train_loader): + """Train for one epoch.""" + self.model.train() + total_loss = 0.0 + num_batches = 0 + + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(self.device), target.to(self.device) + + # Forward pass + self.optimizer.zero_grad() + output = self.model(data) + loss = self.criterion(output, target) + + # Backward pass with gradient clipping (if needed) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + + # Accumulate loss + total_loss += loss.item() + num_batches += 1 + + # Log progress every 10 batches + if batch_idx % 10 == 0: + logger.debug(f"Batch {batch_idx}: loss={loss.item():.4f}") + + avg_loss = total_loss / num_batches + return avg_loss + + def validate_epoch(self, val_loader): + """Validate on validation set (AFTER each epoch, not during).""" + self.model.eval() + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): # ✅ CRITICAL: No gradients during validation + for data, target in val_loader: + data, target = data.to(self.device), target.to(self.device) + + output = self.model(data) + loss = self.criterion(output, target) + + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + return avg_loss + + def save_checkpoint(self, epoch, val_loss, checkpoint_dir='checkpoints'): + """Save complete checkpoint (model + optimizer + scheduler).""" + Path(checkpoint_dir).mkdir(exist_ok=True) + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'val_loss': val_loss, + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + } + + # Save last checkpoint + torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_latest.pt') + + # Save best checkpoint + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_best.pt') + logger.info(f"New best validation loss: {val_loss:.4f}") + + def load_checkpoint(self, checkpoint_path): + """Load checkpoint and resume training correctly.""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # ✅ CRITICAL ORDER: Load model, optimizer, scheduler (in that order) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + # Restore metrics history + self.train_losses = checkpoint['train_losses'] + self.val_losses = checkpoint['val_losses'] + self.best_val_loss = min(self.val_losses) if self.val_losses else float('inf') + + epoch = checkpoint['epoch'] + logger.info(f"Loaded checkpoint from epoch {epoch}") + return epoch + + def train(self, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'): + """Full training loop with validation and checkpointing.""" + start_epoch = 0 + + # Try to resume from checkpoint if it exists + checkpoint_path = f'{checkpoint_dir}/checkpoint_latest.pt' + if Path(checkpoint_path).exists(): + start_epoch = self.load_checkpoint(checkpoint_path) + logger.info(f"Resuming training from epoch {start_epoch}") + + for epoch in range(start_epoch, num_epochs): + try: + # Train for one epoch + train_loss = self.train_epoch(train_loader) + self.train_losses.append(train_loss) + + # ✅ CRITICAL: Validate after every epoch + val_loss = self.validate_epoch(val_loader) + self.val_losses.append(val_loss) + + # Step scheduler (after epoch) + self.scheduler.step() + + # Log metrics + logger.info( + f"Epoch {epoch}: train_loss={train_loss:.4f}, " + f"val_loss={val_loss:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}" + ) + + # Checkpoint every epoch + self.save_checkpoint(epoch, val_loss, checkpoint_dir) + + # Early stopping (optional) + if val_loss < self.best_val_loss: + self.epochs_without_improvement = 0 + else: + self.epochs_without_improvement += 1 + if self.epochs_without_improvement >= 10: + logger.info(f"Early stopping at epoch {epoch}") + break + + except KeyboardInterrupt: + logger.info("Training interrupted by user") + self.save_checkpoint(epoch, val_loss, checkpoint_dir) + break + except RuntimeError as e: + logger.error(f"Error in epoch {epoch}: {e}") + raise + + logger.info("Training complete") + return self.model +``` + +### 2. Data Split: Train/Val/Test Separation (CRITICAL) + +```python +from sklearn.model_selection import train_test_split +from torch.utils.data import Subset, Dataset + +# ✅ CORRECT: Proper three-way split with NO data leakage +class DataSplitter: + """Ensures clean train/val/test splits without data leakage.""" + + @staticmethod + def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42): + """ + Split dataset into train/val/test. + + CRITICAL: Split indices first, then create loaders. + This prevents any data leakage. + """ + assert train_ratio + val_ratio + test_ratio == 1.0 + + n = len(dataset) + indices = list(range(n)) + + # First split: train vs (val + test) + train_size = int(train_ratio * n) + train_indices = indices[:train_size] + remaining_indices = indices[train_size:] + + # Second split: val vs test + remaining_size = len(remaining_indices) + val_size = int(val_ratio / (val_ratio + test_ratio) * remaining_size) + val_indices = remaining_indices[:val_size] + test_indices = remaining_indices[val_size:] + + # Create subset datasets (same transforms, different data) + train_dataset = Subset(dataset, train_indices) + val_dataset = Subset(dataset, val_indices) + test_dataset = Subset(dataset, test_indices) + + logger.info( + f"Dataset split: train={len(train_dataset)}, " + f"val={len(val_dataset)}, test={len(test_dataset)}" + ) + + return train_dataset, val_dataset, test_dataset + +# Usage +train_dataset, val_dataset, test_dataset = DataSplitter.split_dataset( + full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15 +) + +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) +test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + +# ✅ CRITICAL: Validate that splits are actually different +print(f"Train samples: {len(train_loader.dataset)}") +print(f"Val samples: {len(val_loader.dataset)}") +print(f"Test samples: {len(test_loader.dataset)}") + +# ✅ CRITICAL: Never mix splits (don't re-shuffle or combine) +``` + +### 3. Monitoring and Logging (Reproducibility) + +```python +import json +from datetime import datetime + +class TrainingMonitor: + """Track all metrics for reproducibility and debugging.""" + + def __init__(self, log_dir='logs'): + self.log_dir = Path(log_dir) + self.log_dir.mkdir(exist_ok=True) + + # Metrics to track + self.metrics = { + 'timestamp': datetime.now().isoformat(), + 'epochs': [], + 'train_losses': [], + 'val_losses': [], + 'learning_rates': [], + 'gradient_norms': [], + 'batch_times': [], + } + + def log_epoch(self, epoch, train_loss, val_loss, lr, gradient_norm=None, batch_time=None): + """Log metrics for one epoch.""" + self.metrics['epochs'].append(epoch) + self.metrics['train_losses'].append(train_loss) + self.metrics['val_losses'].append(val_loss) + self.metrics['learning_rates'].append(lr) + if gradient_norm is not None: + self.metrics['gradient_norms'].append(gradient_norm) + if batch_time is not None: + self.metrics['batch_times'].append(batch_time) + + def save_metrics(self): + """Save metrics to JSON for post-training analysis.""" + metrics_path = self.log_dir / f'metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json' + with open(metrics_path, 'w') as f: + json.dump(self.metrics, f, indent=2) + logger.info(f"Metrics saved to {metrics_path}") + + def plot_metrics(self): + """Plot training curves.""" + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(2, 2, figsize=(12, 8)) + + # Loss curves + axes[0, 0].plot(self.metrics['epochs'], self.metrics['train_losses'], label='Train') + axes[0, 0].plot(self.metrics['epochs'], self.metrics['val_losses'], label='Val') + axes[0, 0].set_xlabel('Epoch') + axes[0, 0].set_ylabel('Loss') + axes[0, 0].legend() + axes[0, 0].set_title('Training and Validation Loss') + + # Learning rate schedule + axes[0, 1].plot(self.metrics['epochs'], self.metrics['learning_rates']) + axes[0, 1].set_xlabel('Epoch') + axes[0, 1].set_ylabel('Learning Rate') + axes[0, 1].set_title('Learning Rate Schedule') + axes[0, 1].set_yscale('log') + + # Gradient norms (if available) + if self.metrics['gradient_norms']: + axes[1, 0].plot(self.metrics['epochs'], self.metrics['gradient_norms']) + axes[1, 0].set_xlabel('Epoch') + axes[1, 0].set_ylabel('Gradient Norm') + axes[1, 0].set_title('Gradient Norms') + + # Batch times (if available) + if self.metrics['batch_times']: + axes[1, 1].plot(self.metrics['epochs'], self.metrics['batch_times']) + axes[1, 1].set_xlabel('Epoch') + axes[1, 1].set_ylabel('Time (seconds)') + axes[1, 1].set_title('Batch Processing Time') + + plt.tight_layout() + plot_path = self.log_dir / f'training_curves_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png' + plt.savefig(plot_path) + logger.info(f"Plot saved to {plot_path}") +``` + +### 4. Checkpointing and Resuming (Complete State) + +```python +class CheckpointManager: + """Properly save and load ALL training state.""" + + def __init__(self, checkpoint_dir='checkpoints'): + self.checkpoint_dir = Path(checkpoint_dir) + self.checkpoint_dir.mkdir(exist_ok=True) + + def save_full_checkpoint(self, epoch, model, optimizer, scheduler, metrics, path_suffix=''): + """Save COMPLETE state for resuming training.""" + checkpoint = { + # Model and optimizer state + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + + # Training metrics (for monitoring) + 'train_losses': metrics['train_losses'], + 'val_losses': metrics['val_losses'], + 'learning_rates': metrics['learning_rates'], + + # Timestamp for recovery + 'timestamp': datetime.now().isoformat(), + } + + # Save as latest + latest_path = self.checkpoint_dir / f'checkpoint_latest{path_suffix}.pt' + torch.save(checkpoint, latest_path) + + # Save periodically (every 10 epochs) + if epoch % 10 == 0: + periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt' + torch.save(checkpoint, periodic_path) + + logger.info(f"Checkpoint saved: {latest_path}") + return latest_path + + def load_full_checkpoint(self, model, optimizer, scheduler, checkpoint_path): + """Load COMPLETE state correctly.""" + if not Path(checkpoint_path).exists(): + logger.warning(f"Checkpoint not found: {checkpoint_path}") + return 0, None + + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # ✅ CRITICAL ORDER: Model first, then optimizer, then scheduler + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + epoch = checkpoint['epoch'] + metrics = { + 'train_losses': checkpoint.get('train_losses', []), + 'val_losses': checkpoint.get('val_losses', []), + 'learning_rates': checkpoint.get('learning_rates', []), + } + + logger.info( + f"Loaded checkpoint from epoch {epoch}, " + f"saved at {checkpoint.get('timestamp', 'unknown')}" + ) + return epoch, metrics + + def get_best_checkpoint(self): + """Find checkpoint with best validation loss.""" + checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt')) + if not checkpoints: + return None + + best_loss = float('inf') + best_path = None + + for ckpt_path in checkpoints: + checkpoint = torch.load(ckpt_path, map_location='cpu') + val_losses = checkpoint.get('val_losses', []) + if val_losses and min(val_losses) < best_loss: + best_loss = min(val_losses) + best_path = ckpt_path + + return best_path +``` + +### 5. Memory Management (Prevent Leaks) + +```python +class MemoryManager: + """Prevent out-of-memory errors during long training.""" + + def __init__(self, device='cuda'): + self.device = device + + def clear_cache(self): + """Clear GPU cache between epochs.""" + if self.device.startswith('cuda'): + torch.cuda.empty_cache() + # Optional: clear CUDA graphs + torch.cuda.synchronize() + + def check_memory(self): + """Log GPU memory usage.""" + if self.device.startswith('cuda'): + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + logger.info(f"GPU memory - allocated: {allocated:.2f}GB, reserved: {reserved:.2f}GB") + + def training_loop_with_memory_management(self, model, train_loader, optimizer, criterion): + """Training loop with proper memory management.""" + model.train() + total_loss = 0.0 + + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(self.device), target.to(self.device) + + # Forward and backward + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # ✅ Clear temporary tensors (data and target go out of scope) + # ✅ Don't hold onto loss or output after using them + + # Periodically check memory + if batch_idx % 100 == 0: + self.check_memory() + + # Clear cache between epochs + self.clear_cache() + + return total_loss / len(train_loader) +``` + + +## Error Handling and Recovery + +```python +class RobustTrainingLoop: + """Training loop with proper error handling.""" + + def train_with_error_handling(self, model, train_loader, val_loader, optimizer, + scheduler, criterion, num_epochs, checkpoint_dir): + """Training with error recovery.""" + checkpoint_manager = CheckpointManager(checkpoint_dir) + memory_manager = MemoryManager() + + # Resume from last checkpoint if available + start_epoch, metrics = checkpoint_manager.load_full_checkpoint( + model, optimizer, scheduler, f'{checkpoint_dir}/checkpoint_latest.pt' + ) + + for epoch in range(start_epoch, num_epochs): + try: + # Train + train_loss = self.train_epoch(model, train_loader, optimizer, criterion) + + # Validate + val_loss = self.validate_epoch(model, val_loader, criterion) + + # Update scheduler + scheduler.step() + + # Log + logger.info( + f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, " + f"lr={optimizer.param_groups[0]['lr']:.2e}" + ) + + # Checkpoint + checkpoint_manager.save_full_checkpoint( + epoch, model, optimizer, scheduler, + {'train_losses': [train_loss], 'val_losses': [val_loss]} + ) + + # Memory management + memory_manager.clear_cache() + + except KeyboardInterrupt: + logger.warning("Training interrupted - checkpoint saved") + checkpoint_manager.save_full_checkpoint( + epoch, model, optimizer, scheduler, + {'train_losses': [train_loss], 'val_losses': [val_loss]} + ) + break + + except RuntimeError as e: + if 'out of memory' in str(e).lower(): + logger.error("Out of memory error") + memory_manager.clear_cache() + # Try to continue (reduce batch size in real scenario) + raise + else: + logger.error(f"Runtime error: {e}") + raise + + except Exception as e: + logger.error(f"Unexpected error in epoch {epoch}: {e}") + checkpoint_manager.save_full_checkpoint( + epoch, model, optimizer, scheduler, + {'train_losses': [train_loss], 'val_losses': [val_loss]} + ) + raise + + return model +``` + + +## Common Pitfalls and How to Avoid Them + +### Pitfall 1: Validating on Training Data +```python +# ❌ WRONG +val_loader = train_loader # Same loader! + +# ✅ CORRECT +train_dataset, val_dataset = split_dataset(full_dataset) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) +``` + +### Pitfall 2: Missing Optimizer State in Checkpoint +```python +# ❌ WRONG +torch.save({'model': model.state_dict()}, 'ckpt.pt') + +# ✅ CORRECT +torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), +}, 'ckpt.pt') +``` + +### Pitfall 3: Not Validating During Training +```python +# ❌ WRONG +for epoch in range(100): + train_epoch() +final_val = evaluate() # Only at the end! + +# ✅ CORRECT +for epoch in range(100): + train_epoch() + validate_epoch() # After every epoch +``` + +### Pitfall 4: Holding Onto Tensor References +```python +# ❌ WRONG +all_losses = [] +for data, target in loader: + loss = criterion(model(data), target) + all_losses.append(loss) # Memory leak! + +# ✅ CORRECT +total_loss = 0.0 +for data, target in loader: + loss = criterion(model(data), target) + total_loss += loss.item() # Scalar value +``` + +### Pitfall 5: Forgetting torch.no_grad() in Validation +```python +# ❌ WRONG +model.eval() +for data, target in val_loader: + output = model(data) # Gradients still computed! + loss = criterion(output, target) + +# ✅ CORRECT +model.eval() +with torch.no_grad(): + for data, target in val_loader: + output = model(data) # No gradients + loss = criterion(output, target) +``` + +### Pitfall 6: Resetting Scheduler on Resume +```python +# ❌ WRONG +checkpoint = torch.load('ckpt.pt') +model.load_state_dict(checkpoint['model']) +scheduler = CosineAnnealingLR(optimizer, T_max=100) # Fresh scheduler! +# Now at epoch 50, scheduler thinks it's epoch 0 + +# ✅ CORRECT +checkpoint = torch.load('ckpt.pt') +model.load_state_dict(checkpoint['model']) +optimizer.load_state_dict(checkpoint['optimizer']) +scheduler.load_state_dict(checkpoint['scheduler']) # Resume scheduler state +``` + +### Pitfall 7: Not Handling Early Stopping Correctly +```python +# ❌ WRONG +best_loss = float('inf') +for epoch in range(100): + val_loss = validate() + if val_loss < best_loss: + best_loss = val_loss + # No checkpoint! Can't recover best model + +# ✅ CORRECT +best_loss = float('inf') +patience = 10 +patience_counter = 0 +for epoch in range(100): + val_loss = validate() + if val_loss < best_loss: + best_loss = val_loss + patience_counter = 0 + save_checkpoint(model, optimizer, scheduler, epoch) # Save best + else: + patience_counter += 1 + if patience_counter >= patience: + break # Stop early +``` + +### Pitfall 8: Mixing Train and Validation Mode +```python +# ❌ WRONG +for epoch in range(100): + for data, target in train_loader: + output = model(data) # Is model in train or eval mode? + loss = criterion(output, target) + +# ✅ CORRECT +model.train() +for epoch in range(100): + for data, target in train_loader: + output = model(data) # Definitely in train mode + loss = criterion(output, target) + +model.eval() +with torch.no_grad(): + for data, target in val_loader: + output = model(data) # Definitely in eval mode +``` + +### Pitfall 9: Loading Checkpoint on Wrong Device +```python +# ❌ WRONG +checkpoint = torch.load('ckpt.pt') # Loads on GPU if saved on GPU +model.load_state_dict(checkpoint['model']) # Might be on wrong device + +# ✅ CORRECT +checkpoint = torch.load('ckpt.pt', map_location='cuda:0') # Specify device +model.load_state_dict(checkpoint['model']) +model.to('cuda:0') # Move to device +``` + +### Pitfall 10: Not Clearing GPU Cache +```python +# ❌ WRONG +for epoch in range(100): + train_epoch() + validate_epoch() + # GPU cache growing every epoch + +# ✅ CORRECT +for epoch in range(100): + train_epoch() + validate_epoch() + torch.cuda.empty_cache() # Clear cache +``` + + +## Integration with Optimization Techniques + +### Complete Training Loop with All Techniques + +```python +class FullyOptimizedTrainingLoop: + """Integrates: gradient clipping, mixed precision, learning rate scheduling.""" + + def train_with_all_techniques(self, model, train_loader, val_loader, + num_epochs, checkpoint_dir='checkpoints'): + """Training with all optimization techniques integrated.""" + + # Setup + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) + criterion = nn.CrossEntropyLoss() + + # Mixed precision (if using AMP) + scaler = torch.cuda.amp.GradScaler() + + # Training loop + for epoch in range(num_epochs): + model.train() + total_loss = 0.0 + + for data, target in train_loader: + data, target = data.to(device), target.to(device) + + optimizer.zero_grad() + + # Mixed precision forward pass + with torch.autocast('cuda'): + output = model(data) + loss = criterion(output, target) + + # Gradient scaling for mixed precision + scaler.scale(loss).backward() + + # Gradient clipping (unscale first!) + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Optimizer step + scaler.step(optimizer) + scaler.update() + + total_loss += loss.item() + + train_loss = total_loss / len(train_loader) + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for data, target in val_loader: + data, target = data.to(device), target.to(device) + output = model(data) + val_loss += criterion(output, target).item() + + val_loss /= len(val_loader) + + # Scheduler step + scheduler.step() + + logger.info( + f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, " + f"lr={optimizer.param_groups[0]['lr']:.2e}" + ) + + return model +``` + + +## Rationalization Table: When to Deviate from Standard + +| Situation | Standard Practice | Deviation | Rationale | +|-----------|-------------------|-----------|-----------| +| Validate only at end | Validate every epoch | ✗ Never | Can't detect overfitting | +| Save only model | Save model + optimizer + scheduler | ✗ Never | Resume training breaks | +| Mixed train/val | Separate datasets completely | ✗ Never | Data leakage and false metrics | +| Constant batch size | Fix batch size for reproducibility | ✓ Sometimes | May need dynamic batching for memory | +| Single LR | Use scheduler | ✓ Sometimes | <10 epoch training or hyperparameter search | +| No early stopping | Implement early stopping | ✓ Sometimes | If training time unlimited | +| Log every batch | Log every 10-100 batches | ✓ Often | Reduces I/O overhead | +| GPU cache every epoch | Clear GPU cache periodically | ✓ Sometimes | Only if OOM issues | + + +## Red Flags: Immediate Warning Signs + +1. **Training loss much lower than validation loss** (>2x) → Overfitting +2. **Loss spikes on resume** → Optimizer state not loaded +3. **GPU memory grows over time** → Memory leak, likely tensor accumulation +4. **Validation never runs** → Check if validation is in loop +5. **Best model not saved** → Check checkpoint logic +6. **Different results on resume** → Scheduler not loaded +7. **Early stopping not working** → Checkpoint not at best model +8. **OOM during training** → Clear GPU cache, check for accumulated tensors + + +## Testing Your Training Loop + +```python +def test_training_loop(): + """Quick test to verify training loop is correct.""" + + # Create dummy data + X_train = torch.randn(100, 10) + y_train = torch.randint(0, 2, (100,)) + X_val = torch.randn(20, 10) + y_val = torch.randint(0, 2, (20,)) + + train_loader = DataLoader( + list(zip(X_train, y_train)), batch_size=16 + ) + val_loader = DataLoader( + list(zip(X_val, y_val)), batch_size=16 + ) + + # Simple model + model = nn.Sequential( + nn.Linear(10, 64), + nn.ReLU(), + nn.Linear(64, 2) + ) + + # Training + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) + criterion = nn.CrossEntropyLoss() + + loop = TrainingLoop(model, optimizer, scheduler, criterion) + + # Should complete without errors + loop.train(train_loader, val_loader, num_epochs=5, checkpoint_dir='test_ckpts') + + # Check outputs + assert len(loop.train_losses) == 5 + assert len(loop.val_losses) == 5 + assert all(isinstance(l, float) for l in loop.train_losses) + + print("✓ Training loop test passed") + +if __name__ == '__main__': + test_training_loop() +``` +