Initial commit
This commit is contained in:
493
skills/using-training-optimization/SKILL.md
Normal file
493
skills/using-training-optimization/SKILL.md
Normal file
@@ -0,0 +1,493 @@
|
||||
---
|
||||
name: using-training-optimization
|
||||
description: Router to training optimization skills based on symptoms and training problems
|
||||
mode: true
|
||||
---
|
||||
|
||||
# Using Training Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
This meta-skill routes you to the right training optimization specialist based on symptoms. Training issues often have multiple potential causes—this skill helps diagnose symptoms and route to the appropriate specialist. Load this skill when you encounter training problems but aren't sure which specific technique to apply.
|
||||
|
||||
**Core Principle**: Diagnose before routing. Training issues often have multiple causes. Ask clarifying questions to understand symptoms before routing to specific skills. Wrong diagnosis wastes time—systematic routing saves it.
|
||||
|
||||
## When to Use
|
||||
|
||||
Load this skill when:
|
||||
- Model not learning (loss stuck, not decreasing, poor accuracy)
|
||||
- Training instability (loss spikes, NaN values, divergence)
|
||||
- Overfitting (large train/val gap, poor generalization)
|
||||
- Training too slow (throughput issues, time constraints)
|
||||
- Hyperparameter selection (optimizer, learning rate, batch size, regularization)
|
||||
- Experiment management (tracking runs, comparing configurations)
|
||||
- Convergence issues (slow learning, plateaus, local minima)
|
||||
- Setting up new training pipeline
|
||||
|
||||
**Don't use for**: PyTorch implementation bugs (use pytorch-engineering), model architecture selection (use neural-architectures), production deployment (use ml-production), RL-specific training (use deep-rl), LLM fine-tuning specifics (use llm-specialist)
|
||||
|
||||
---
|
||||
|
||||
## Routing by Primary Symptom
|
||||
|
||||
### Symptom: "Model Not Learning" / "Loss Not Decreasing"
|
||||
|
||||
**Keywords**: stuck, flat loss, not improving, not learning, accuracy not increasing
|
||||
|
||||
**Diagnostic questions (ask BEFORE routing):**
|
||||
1. "Is loss completely flat from the start, decreasing very slowly, or was it learning then stopped?"
|
||||
2. "Any NaN or Inf values in your loss?"
|
||||
3. "What optimizer and learning rate are you using?"
|
||||
4. "What does your loss curve look like over time?"
|
||||
|
||||
**Route based on answers:**
|
||||
|
||||
| Loss Behavior | Likely Cause | Route To | Why |
|
||||
|---------------|--------------|----------|-----|
|
||||
| Flat from epoch 0 | LR too low OR wrong optimizer OR inappropriate loss | **learning-rate-scheduling** + **optimization-algorithms** | Need to diagnose starting conditions |
|
||||
| Was learning, then plateaued | Local minima OR LR too high OR overfitting | **learning-rate-scheduling** (scheduler) + check validation loss for overfitting | Adaptation needed during training |
|
||||
| Oscillating wildly | LR too high OR gradient instability | **learning-rate-scheduling** + **gradient-management** | Instability issues |
|
||||
| NaN or Inf | Gradient explosion OR numerical instability | **gradient-management** (PRIMARY) + **loss-functions-and-objectives** | Stability critical |
|
||||
| Loss function doesn't match task | Wrong objective | **loss-functions-and-objectives** | Fundamental mismatch |
|
||||
|
||||
**Multi-skill routing**: Often need **optimization-algorithms** (choose optimizer) + **learning-rate-scheduling** (choose LR/schedule) together for "not learning" issues.
|
||||
|
||||
---
|
||||
|
||||
### Symptom: "Training Unstable" / "Loss Spikes" / "NaN Values"
|
||||
|
||||
**Keywords**: NaN, Inf, exploding, diverging, unstable, spikes
|
||||
|
||||
**Diagnostic questions:**
|
||||
1. "When do NaN values appear - immediately at start, or after N epochs?"
|
||||
2. "What's your learning rate and schedule?"
|
||||
3. "Are you using mixed precision training?"
|
||||
4. "What loss function are you using?"
|
||||
|
||||
**Route to (in priority order):**
|
||||
|
||||
1. **gradient-management** (PRIMARY)
|
||||
- When: Gradient explosion, NaN gradients
|
||||
- Why: Gradient issues cause instability. Must stabilize gradients first.
|
||||
- Techniques: Gradient clipping, gradient scaling, NaN debugging
|
||||
|
||||
2. **learning-rate-scheduling** (SECONDARY)
|
||||
- When: LR too high causes instability
|
||||
- Why: High LR can cause divergence even with clipped gradients
|
||||
- Check: If NaN appears later in training, LR schedule might be increasing too much
|
||||
|
||||
3. **loss-functions-and-objectives** (if numerical issues)
|
||||
- When: Loss computation has numerical instability (log(0), division by zero)
|
||||
- Why: Numerical instability in loss propagates to gradients
|
||||
- Check: Custom loss functions especially prone to this
|
||||
|
||||
**Cross-pack note**: If using mixed precision (AMP), also check pytorch-engineering (mixed-precision-and-optimization) for gradient scaling issues.
|
||||
|
||||
---
|
||||
|
||||
### Symptom: "Model Overfits" / "Train/Val Gap Large"
|
||||
|
||||
**Keywords**: overfitting, train/val gap, poor generalization, memorizing training data
|
||||
|
||||
**Diagnostic questions:**
|
||||
1. "How large is your dataset (number of training examples)?"
|
||||
2. "What's the train accuracy vs validation accuracy?"
|
||||
3. "What regularization are you currently using?"
|
||||
4. "What's your model size (parameters) relative to dataset size?"
|
||||
|
||||
**Route to (multi-skill approach):**
|
||||
|
||||
1. **overfitting-prevention** (PRIMARY)
|
||||
- Techniques: Dropout, weight decay, early stopping, L1/L2 regularization
|
||||
- When: Always the first stop for overfitting
|
||||
- Why: Comprehensive regularization strategy
|
||||
|
||||
2. **data-augmentation-strategies** (HIGHLY RECOMMENDED)
|
||||
- When: Dataset is small (< 10K examples) or moderate (10K-100K)
|
||||
- Why: Increases effective dataset size, teaches invariances
|
||||
- Priority: Higher priority for smaller datasets
|
||||
|
||||
3. **hyperparameter-tuning**
|
||||
- When: Need to find optimal regularization strength
|
||||
- Why: Balance between underfitting and overfitting
|
||||
- Use: After implementing regularization techniques
|
||||
|
||||
**Decision factors:**
|
||||
- Small dataset (< 1K): data-augmentation is CRITICAL + overfitting-prevention
|
||||
- Medium dataset (1K-10K): overfitting-prevention + data-augmentation
|
||||
- Large dataset (> 10K) but still overfitting: overfitting-prevention + check model capacity
|
||||
- Model too large: Consider neural-architectures for model size discussion
|
||||
|
||||
---
|
||||
|
||||
### Symptom: "Training Too Slow" / "Low Throughput"
|
||||
|
||||
**Keywords**: slow training, low GPU utilization, takes too long, time per epoch
|
||||
|
||||
**CRITICAL: Diagnose bottleneck before routing**
|
||||
|
||||
**Diagnostic questions:**
|
||||
1. "Is it slow per-step (low throughput) or just need many steps?"
|
||||
2. "What's your GPU utilization percentage?"
|
||||
3. "Are you using data augmentation? How heavy?"
|
||||
4. "What's your current batch size?"
|
||||
|
||||
**Route based on bottleneck:**
|
||||
|
||||
| GPU Utilization | Likely Cause | Route To | Why |
|
||||
|-----------------|--------------|----------|-----|
|
||||
| < 50% consistently | Data loading bottleneck OR CPU preprocessing | **pytorch-engineering** (data loading, profiling) | Not compute-bound, infrastructure issue |
|
||||
| High (> 80%) but still slow | Batch size too small OR need distributed training | **batch-size-and-memory-tradeoffs** + possibly pytorch-engineering (distributed) | Compute-bound, need scaling |
|
||||
| High + heavy augmentation | Augmentation overhead | **data-augmentation-strategies** (optimization) + pytorch-engineering (profiling) | Augmentation CPU cost |
|
||||
| Memory-limited batch size | Can't increase batch size due to OOM | **batch-size-and-memory-tradeoffs** (gradient accumulation) + pytorch-engineering (memory) | Memory constraints limiting throughput |
|
||||
|
||||
**Cross-pack boundaries:**
|
||||
- Data loading issues → **pytorch-engineering** (DataLoader, prefetching, num_workers)
|
||||
- Distributed training setup → **pytorch-engineering** (distributed-training-strategies)
|
||||
- Batch size optimization for speed/memory → **training-optimization** (batch-size-and-memory-tradeoffs)
|
||||
- Profiling to identify bottleneck → **pytorch-engineering** (performance-profiling)
|
||||
|
||||
**Key principle**: Profile FIRST before optimizing. Low GPU utilization = wrong optimization target.
|
||||
|
||||
---
|
||||
|
||||
### Symptom: "Which X Should I Use?" (Direct Questions)
|
||||
|
||||
**Direct hyperparameter questions route to specific skills:**
|
||||
|
||||
| Question | Route To | Examples |
|
||||
|----------|----------|----------|
|
||||
| "Which optimizer?" | **optimization-algorithms** | SGD vs Adam vs AdamW, momentum, betas |
|
||||
| "Which learning rate?" | **learning-rate-scheduling** | Initial LR, warmup, schedule type |
|
||||
| "Which batch size?" | **batch-size-and-memory-tradeoffs** | Batch size effects on convergence and speed |
|
||||
| "Which loss function?" | **loss-functions-and-objectives** | Cross-entropy vs focal vs custom |
|
||||
| "How to prevent overfitting?" | **overfitting-prevention** | Dropout, weight decay, early stopping |
|
||||
| "Which augmentation?" | **data-augmentation-strategies** | Type and strength of augmentation |
|
||||
| "How to tune hyperparameters?" | **hyperparameter-tuning** | Search strategies, AutoML |
|
||||
| "How to track experiments?" | **experiment-tracking** | MLflow, W&B, TensorBoard |
|
||||
|
||||
**For new project setup, route to MULTIPLE in sequence:**
|
||||
1. **optimization-algorithms** - Choose optimizer
|
||||
2. **learning-rate-scheduling** - Choose initial LR and schedule
|
||||
3. **batch-size-and-memory-tradeoffs** - Choose batch size
|
||||
4. **experiment-tracking** - Set up tracking
|
||||
5. **training-loop-architecture** - Design training loop
|
||||
|
||||
---
|
||||
|
||||
### Symptom: "Need to Track Experiments" / "Compare Configurations"
|
||||
|
||||
**Keywords**: experiment tracking, MLflow, wandb, tensorboard, compare runs, log metrics
|
||||
|
||||
**Route to:**
|
||||
|
||||
1. **experiment-tracking** (PRIMARY)
|
||||
- Tools: MLflow, Weights & Biases, TensorBoard, Neptune
|
||||
- When: Setting up tracking, comparing runs, organizing experiments
|
||||
- Why: Systematic experiment management
|
||||
|
||||
2. **hyperparameter-tuning** (if systematic search)
|
||||
- When: Running many configurations systematically
|
||||
- Why: Automated hyperparameter search integrates with tracking
|
||||
- Combined: Hyperparameter search + experiment tracking workflow
|
||||
|
||||
3. **training-loop-architecture** (for integration)
|
||||
- When: Need to integrate tracking into training loop
|
||||
- Why: Proper callback and logging design
|
||||
- Combined: Training loop + tracking integration
|
||||
|
||||
---
|
||||
|
||||
## Cross-Cutting Multi-Skill Scenarios
|
||||
|
||||
### Scenario: New Training Setup (First Time)
|
||||
|
||||
**Route to (in order):**
|
||||
1. **optimization-algorithms** - Select optimizer for your task
|
||||
2. **learning-rate-scheduling** - Choose LR and warmup strategy
|
||||
3. **batch-size-and-memory-tradeoffs** - Determine batch size
|
||||
4. **loss-functions-and-objectives** - Verify loss function appropriate for task
|
||||
5. **experiment-tracking** - Set up experiment logging
|
||||
6. **training-loop-architecture** - Design training loop with checkpointing
|
||||
|
||||
**Why this order**: Foundation (optimizer/LR/batch) → Objective (loss) → Infrastructure (tracking/loop)
|
||||
|
||||
---
|
||||
|
||||
### Scenario: Convergence Issues
|
||||
|
||||
**Route to (diagnose first, then in order):**
|
||||
1. **gradient-management** - Check gradients not vanishing/exploding (use gradient checking)
|
||||
2. **learning-rate-scheduling** - Adjust LR schedule (might be too high/low/wrong schedule)
|
||||
3. **optimization-algorithms** - Consider different optimizer if current one unsuitable
|
||||
|
||||
**Why this order**: Stability (gradients) → Adaptation (LR) → Algorithm (optimizer)
|
||||
|
||||
**Common mistakes:**
|
||||
- Jumping to change optimizer without checking gradients
|
||||
- Blaming optimizer when LR is the issue
|
||||
- Not using gradient monitoring to diagnose
|
||||
|
||||
---
|
||||
|
||||
### Scenario: Overfitting Issues
|
||||
|
||||
**Route to (multi-pronged approach):**
|
||||
1. **overfitting-prevention** - Implement regularization (dropout, weight decay, early stopping)
|
||||
2. **data-augmentation-strategies** - Increase effective dataset size
|
||||
3. **hyperparameter-tuning** - Find optimal regularization strength
|
||||
|
||||
**Why all three**: Overfitting needs comprehensive strategy, not single technique.
|
||||
|
||||
**Prioritization:**
|
||||
- Small dataset (< 1K): data-augmentation is MOST critical
|
||||
- Medium dataset (1K-10K): overfitting-prevention + augmentation balanced
|
||||
- Large dataset but still overfitting: overfitting-prevention + check model size
|
||||
|
||||
---
|
||||
|
||||
### Scenario: Training Speed + Memory Constraints
|
||||
|
||||
**Route to:**
|
||||
1. **batch-size-and-memory-tradeoffs** (PRIMARY) - Gradient accumulation to simulate larger batch
|
||||
2. **pytorch-engineering/tensor-operations-and-memory** - Memory profiling and optimization
|
||||
3. **data-augmentation-strategies** - Reduce augmentation overhead if bottleneck
|
||||
|
||||
**Why**: Speed and memory often coupled—need to balance batch size, memory, and throughput.
|
||||
|
||||
---
|
||||
|
||||
### Scenario: Multi-Task Learning or Custom Loss
|
||||
|
||||
**Route to:**
|
||||
1. **loss-functions-and-objectives** (PRIMARY) - Multi-task loss design, uncertainty weighting
|
||||
2. **gradient-management** - Check gradients per task, gradient balancing
|
||||
3. **hyperparameter-tuning** - Tune task weights, loss coefficients
|
||||
|
||||
**Why**: Custom losses need careful design + gradient analysis + tuning.
|
||||
|
||||
---
|
||||
|
||||
## Ambiguous Queries - Clarification Protocol
|
||||
|
||||
When symptom unclear, ASK ONE diagnostic question before routing:
|
||||
|
||||
| Vague Query | Clarifying Question | Why |
|
||||
|-------------|---------------------|-----|
|
||||
| "Fix my training" | "What specific issue? Not learning? Unstable? Overfitting? Too slow?" | 4+ different routing paths |
|
||||
| "Improve model" | "Improve what? Training speed? Accuracy? Generalization?" | Different optimization targets |
|
||||
| "Training not working well" | "What's 'not working'? Loss behavior? Accuracy? Convergence speed?" | Need specific symptoms |
|
||||
| "Optimize hyperparameters" | "Which hyperparameters? All of them? Specific ones like LR?" | Specific vs broad search |
|
||||
| "Model performs poorly" | "Training accuracy poor or validation accuracy poor or both?" | Underfitting vs overfitting |
|
||||
|
||||
**Never guess when ambiguous. Ask once, route accurately.**
|
||||
|
||||
---
|
||||
|
||||
## Common Routing Mistakes
|
||||
|
||||
| Symptom | Wrong Route | Correct Route | Why |
|
||||
|---------|-------------|---------------|-----|
|
||||
| "Training slow" | batch-size-and-memory | ASK: Check GPU utilization first | Might be data loading, not compute |
|
||||
| "Not learning" | optimization-algorithms | ASK: Diagnose loss behavior | Could be LR, gradients, loss function |
|
||||
| "Loss NaN" | learning-rate-scheduling | gradient-management FIRST | Gradient explosion most common cause |
|
||||
| "Overfitting" | overfitting-prevention only | overfitting-prevention + data-augmentation | Need multi-pronged approach |
|
||||
| "Need to speed up training" | optimization-algorithms | Profile first (pytorch-engineering) | Don't optimize without measuring |
|
||||
| "Which optimizer for transformer" | neural-architectures | optimization-algorithms | Optimizer choice, not architecture |
|
||||
|
||||
**Key principle**: Diagnosis before solutions, clarification before routing, multi-skill for complex issues.
|
||||
|
||||
---
|
||||
|
||||
## When NOT to Use Training-Optimization Pack
|
||||
|
||||
**Skip training-optimization when:**
|
||||
|
||||
| Symptom | Wrong Pack | Correct Pack | Why |
|
||||
|---------|------------|--------------|-----|
|
||||
| "CUDA out of memory" | training-optimization | pytorch-engineering | Infrastructure issue, not training algorithm |
|
||||
| "DDP not working" | training-optimization | pytorch-engineering | Distributed setup, not hyperparameters |
|
||||
| "Which architecture to use" | training-optimization | neural-architectures | Architecture choice precedes training |
|
||||
| "Model won't load" | training-optimization | pytorch-engineering | Checkpointing/serialization issue |
|
||||
| "Inference too slow" | training-optimization | ml-production | Production optimization, not training |
|
||||
| "How to deploy model" | training-optimization | ml-production | Deployment concern |
|
||||
| "RL exploration issues" | training-optimization | deep-rl | RL-specific training concern |
|
||||
| "RLHF for LLM" | training-optimization | llm-specialist | LLM-specific technique |
|
||||
|
||||
**Training-optimization pack is for**: Framework-agnostic training algorithms, hyperparameters, optimization techniques, and training strategies that apply across architectures.
|
||||
|
||||
**Boundaries:**
|
||||
- PyTorch implementation/infrastructure → **pytorch-engineering**
|
||||
- Architecture selection → **neural-architectures**
|
||||
- Production/inference → **ml-production**
|
||||
- RL-specific algorithms → **deep-rl**
|
||||
- LLM-specific techniques → **llm-specialist**
|
||||
|
||||
---
|
||||
|
||||
## Red Flags - Stop and Reconsider
|
||||
|
||||
If you catch yourself about to:
|
||||
- ❌ Suggest specific optimizer without routing → Route to **optimization-algorithms**
|
||||
- ❌ Suggest learning rate value without diagnosis → Route to **learning-rate-scheduling**
|
||||
- ❌ Say "add dropout" without comprehensive strategy → Route to **overfitting-prevention**
|
||||
- ❌ Suggest "reduce batch size" without profiling → Check if memory or speed issue, route appropriately
|
||||
- ❌ Give trial-and-error fixes for NaN → Route to **gradient-management** for systematic debugging
|
||||
- ❌ Provide generic training advice → Identify specific symptom and route to specialist
|
||||
- ❌ Accept user's self-diagnosis without verification → Ask diagnostic questions first
|
||||
|
||||
**All of these mean: You're about to give incomplete advice. Route to specialist instead.**
|
||||
|
||||
---
|
||||
|
||||
## Common Rationalizations (Don't Do These)
|
||||
|
||||
| Rationalization | Reality | What To Do Instead |
|
||||
|-----------------|---------|-------------------|
|
||||
| "User is rushed, skip diagnostic questions" | Diagnosis takes 30 seconds, wrong route wastes 10+ minutes | Ask ONE quick diagnostic question: "Is loss flat, oscillating, or NaN?" |
|
||||
| "Symptoms are obvious, route immediately" | Symptoms often have multiple causes | Ask clarifying question to eliminate ambiguity |
|
||||
| "User suggested optimizer change" | User self-diagnosis can be wrong | "What loss behavior are you seeing?" to verify root cause |
|
||||
| "Expert user doesn't need routing" | Expert users benefit from specialist skills too | Route based on symptoms, not user sophistication |
|
||||
| "Just a quick question" | Quick questions deserve correct answers | Route to specialist—they have quick diagnostics too |
|
||||
| "Single solution will fix it" | Training issues often multi-causal | Consider multi-skill routing for complex symptoms |
|
||||
| "Time pressure means guess quickly" | Wrong guess wastes MORE time | Fast systematic diagnosis faster than trial-and-error |
|
||||
| "They already tried X" | Maybe tried X wrong or X wasn't the issue | Route to specialist to verify X was done correctly |
|
||||
| "Too complex to route" | Complex issues need specialists MORE | Use multi-skill routing for complex scenarios |
|
||||
| "Direct answer is helpful" | Wrong direct answer wastes time | Routing IS the helpful answer |
|
||||
|
||||
**If you catch yourself thinking ANY of these, STOP and route to specialist or ask diagnostic question.**
|
||||
|
||||
---
|
||||
|
||||
## Pressure Resistance - Critical Discipline
|
||||
|
||||
### Time/Emergency Pressure
|
||||
|
||||
| Pressure | Wrong Response | Correct Response |
|
||||
|----------|----------------|------------------|
|
||||
| "Demo tomorrow, need quick fix" | Give untested suggestions | "Fast systematic diagnosis ensures demo success: [question]" |
|
||||
| "Production training failing" | Panic and guess | "Quick clarification prevents longer outage: [question]" |
|
||||
| "Just tell me which optimizer" | "Use Adam" | "30-second clarification ensures right choice: [question]" |
|
||||
|
||||
**Emergency protocol**: Fast clarification (30 sec) → Correct routing (60 sec) → Specialist handles efficiently
|
||||
|
||||
---
|
||||
|
||||
### Authority/Hierarchy Pressure
|
||||
|
||||
| Pressure | Wrong Response | Correct Response |
|
||||
|----------|----------------|------------------|
|
||||
| "Senior said use SGD" | Accept without verification | "To apply SGD effectively, let me verify the symptoms: [question]" |
|
||||
| "PM wants optimizer change" | Change without diagnosis | "Let's diagnose to ensure optimizer is the issue: [question]" |
|
||||
|
||||
**Authority protocol**: Acknowledge → Verify symptoms → Route based on evidence, not opinion
|
||||
|
||||
---
|
||||
|
||||
### User Self-Diagnosis Pressure
|
||||
|
||||
| Pressure | Wrong Response | Correct Response |
|
||||
|----------|----------------|------------------|
|
||||
| "I think it's the optimizer" | Discuss optimizer choice | "What loss behavior makes you think optimizer? [diagnostic]" |
|
||||
| "Obviously need to clip gradients" | Implement clipping | "What symptoms suggest gradient issues? [verify]" |
|
||||
|
||||
**Verification protocol**: User attribution is hypothesis, not diagnosis. Verify with symptoms.
|
||||
|
||||
---
|
||||
|
||||
## Red Flags Checklist - Self-Check Before Routing
|
||||
|
||||
Before giving ANY training advice or routing, ask yourself:
|
||||
|
||||
1. ❓ **Did I identify specific symptoms?**
|
||||
- If no → Ask clarifying question
|
||||
- If yes → Proceed
|
||||
|
||||
2. ❓ **Is this symptom in my routing table?**
|
||||
- If yes → Route to specialist skill
|
||||
- If no → Ask diagnostic question
|
||||
|
||||
3. ❓ **Am I about to give direct advice?**
|
||||
- If yes → STOP. Why am I not routing?
|
||||
- Check rationalization table—am I making excuses?
|
||||
|
||||
4. ❓ **Could this symptom have multiple causes?**
|
||||
- If yes → Ask diagnostic question to narrow down
|
||||
- If no → Route confidently
|
||||
|
||||
5. ❓ **Is this training-optimization or another pack?**
|
||||
- PyTorch errors → pytorch-engineering
|
||||
- Architecture choice → neural-architectures
|
||||
- Deployment → ml-production
|
||||
|
||||
6. ❓ **Am I feeling pressure to skip routing?**
|
||||
- Time pressure → Route anyway (faster overall)
|
||||
- Authority pressure → Verify symptoms anyway
|
||||
- User self-diagnosis → Confirm with questions anyway
|
||||
- Expert user → Route anyway (specialists help experts too)
|
||||
|
||||
**If you failed ANY check, do NOT give direct advice. Route to specialist or ask clarifying question.**
|
||||
|
||||
---
|
||||
|
||||
## Training Optimization Specialist Skills
|
||||
|
||||
After routing, load the appropriate specialist skill for detailed guidance:
|
||||
|
||||
1. [optimization-algorithms.md](optimization-algorithms.md) - Optimizer selection (SGD, Adam, AdamW, momentum), hyperparameter tuning, optimizer comparison
|
||||
2. [learning-rate-scheduling.md](learning-rate-scheduling.md) - LR schedulers (step, cosine, exponential), warmup strategies, cyclical learning rates
|
||||
3. [loss-functions-and-objectives.md](loss-functions-and-objectives.md) - Custom losses, multi-task learning, weighted objectives, numerical stability
|
||||
4. [gradient-management.md](gradient-management.md) - Gradient clipping, accumulation, scaling, vanishing/exploding gradient diagnosis
|
||||
5. [batch-size-and-memory-tradeoffs.md](batch-size-and-memory-tradeoffs.md) - Batch size effects on convergence, gradient accumulation, memory optimization
|
||||
6. [data-augmentation-strategies.md](data-augmentation-strategies.md) - Augmentation techniques (geometric, color, mixing), policy design, AutoAugment
|
||||
7. [overfitting-prevention.md](overfitting-prevention.md) - Regularization (L1/L2, dropout, weight decay), early stopping, generalization techniques
|
||||
8. [training-loop-architecture.md](training-loop-architecture.md) - Training loop design, monitoring, logging, checkpointing integration, callbacks
|
||||
9. [hyperparameter-tuning.md](hyperparameter-tuning.md) - Search strategies (grid, random, Bayesian), AutoML, Optuna, Ray Tune
|
||||
10. [experiment-tracking.md](experiment-tracking.md) - MLflow, Weights & Biases, TensorBoard, Neptune, run comparison
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference: Symptom → Skills
|
||||
|
||||
| Symptom | Primary Skill | Secondary Skills | Diagnostic Question |
|
||||
|---------|---------------|------------------|---------------------|
|
||||
| Loss flat/stuck | learning-rate-scheduling | optimization-algorithms | "Flat from start or plateaued later?" |
|
||||
| Loss NaN/Inf | gradient-management | learning-rate-scheduling, loss-functions | "When does NaN appear?" |
|
||||
| Overfitting | overfitting-prevention | data-augmentation, hyperparameter-tuning | "Dataset size? Current regularization?" |
|
||||
| Training slow | batch-size-and-memory OR pytorch-engineering | data-augmentation | "GPU utilization percentage?" |
|
||||
| Oscillating loss | learning-rate-scheduling | gradient-management | "What's your current LR?" |
|
||||
| Which optimizer | optimization-algorithms | learning-rate-scheduling | Task type, architecture, dataset |
|
||||
| Which LR | learning-rate-scheduling | optimization-algorithms | Optimizer, task, current symptoms |
|
||||
| Track experiments | experiment-tracking | hyperparameter-tuning, training-loop | Tools preference, scale of experiments |
|
||||
| Poor generalization | overfitting-prevention | data-augmentation | Train vs val accuracy gap |
|
||||
| Convergence issues | gradient-management | learning-rate-scheduling, optimization-algorithms | Gradient norms, loss curve |
|
||||
|
||||
---
|
||||
|
||||
## Integration Notes
|
||||
|
||||
**Phase 1 - Standalone**: Training-optimization skills are self-contained and framework-agnostic.
|
||||
|
||||
**Cross-references with other packs:**
|
||||
- **pytorch-engineering**: Infrastructure, implementation, profiling, distributed training
|
||||
- **neural-architectures**: Architecture selection (precedes training optimization)
|
||||
- **deep-rl**: RL-specific training algorithms (policy gradients, Q-learning)
|
||||
- **llm-specialist**: LLM-specific techniques (RLHF, LoRA)
|
||||
- **ml-production**: Production/inference optimization (post-training)
|
||||
|
||||
**Current focus**: Route within training-optimization for framework-agnostic training concerns. Other packs handle implementation, architecture, and domain-specific issues.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
**This meta-skill's job**:
|
||||
1. Identify symptoms through diagnostic questions
|
||||
2. Map symptoms to appropriate specialist skills
|
||||
3. Route to one or multiple skills based on diagnosis
|
||||
4. Resist pressure to give direct advice
|
||||
5. Maintain clear boundaries with other packs
|
||||
|
||||
**Remember**: Diagnose before routing. Training issues often have multiple causes. Clarify symptoms, ask diagnostic questions, then route to specialist skills. Wrong routing wastes more time than asking one clarifying question.
|
||||
|
||||
**Route based on symptoms, not guesses. Let specialists do their job.**
|
||||
File diff suppressed because it is too large
Load Diff
1483
skills/using-training-optimization/data-augmentation-strategies.md
Normal file
1483
skills/using-training-optimization/data-augmentation-strategies.md
Normal file
File diff suppressed because it is too large
Load Diff
1942
skills/using-training-optimization/experiment-tracking.md
Normal file
1942
skills/using-training-optimization/experiment-tracking.md
Normal file
File diff suppressed because it is too large
Load Diff
2442
skills/using-training-optimization/gradient-management.md
Normal file
2442
skills/using-training-optimization/gradient-management.md
Normal file
File diff suppressed because it is too large
Load Diff
1635
skills/using-training-optimization/hyperparameter-tuning.md
Normal file
1635
skills/using-training-optimization/hyperparameter-tuning.md
Normal file
File diff suppressed because it is too large
Load Diff
2723
skills/using-training-optimization/learning-rate-scheduling.md
Normal file
2723
skills/using-training-optimization/learning-rate-scheduling.md
Normal file
File diff suppressed because it is too large
Load Diff
2138
skills/using-training-optimization/loss-functions-and-objectives.md
Normal file
2138
skills/using-training-optimization/loss-functions-and-objectives.md
Normal file
File diff suppressed because it is too large
Load Diff
1832
skills/using-training-optimization/optimization-algorithms.md
Normal file
1832
skills/using-training-optimization/optimization-algorithms.md
Normal file
File diff suppressed because it is too large
Load Diff
1464
skills/using-training-optimization/overfitting-prevention.md
Normal file
1464
skills/using-training-optimization/overfitting-prevention.md
Normal file
File diff suppressed because it is too large
Load Diff
882
skills/using-training-optimization/training-loop-architecture.md
Normal file
882
skills/using-training-optimization/training-loop-architecture.md
Normal file
@@ -0,0 +1,882 @@
|
||||
|
||||
# Training Loop Architecture
|
||||
|
||||
## Overview
|
||||
|
||||
**Core Principle:** A properly structured training loop is the foundation of all successful deep learning projects. Success requires: (1) correct train/val/test data separation, (2) validation after EVERY epoch (not just once), (3) complete checkpoint state (model + optimizer + scheduler), (4) comprehensive logging/monitoring, and (5) graceful error handling. Poor loop structure causes: silent overfitting, broken resume functionality, undetectable training issues, and memory leaks.
|
||||
|
||||
Training loop failures manifest as: overfitting with good metrics, crashes on resume, unexplained loss spikes, or out-of-memory errors. These stem from misunderstanding when validation runs, what state must be saved, or how to manage GPU memory. Systematic architecture beats trial-and-error fixes.
|
||||
|
||||
## When to Use
|
||||
|
||||
**Use this skill when:**
|
||||
- Implementing a new training loop from scratch
|
||||
- Training loop is crashing unexpectedly
|
||||
- Can't resume training from checkpoint correctly
|
||||
- Model overfits but validation metrics look good
|
||||
- Out-of-memory errors during training
|
||||
- Unsure about train/val/test data split
|
||||
- Need to monitor training progress properly
|
||||
- Implementing early stopping or checkpoint selection
|
||||
- Training loops show loss spikes or divergence on resume
|
||||
- Adding logging/monitoring to training
|
||||
|
||||
**Don't use when:**
|
||||
- Debugging single backward pass (use gradient-management skill)
|
||||
- Tuning learning rate (use learning-rate-scheduling skill)
|
||||
- Fixing specific loss function (use loss-functions-and-objectives skill)
|
||||
- Data loading issues (use data-augmentation-strategies skill)
|
||||
|
||||
**Symptoms triggering this skill:**
|
||||
- "Training loss decreases but validation loss increases (overfitting)"
|
||||
- "Training crashes when resuming from checkpoint"
|
||||
- "Out of memory errors after epoch 20"
|
||||
- "I validated on training data and didn't realize"
|
||||
- "Can't detect overfitting because I don't validate"
|
||||
- "Training loss spikes when resuming"
|
||||
- "My checkpoint doesn't load correctly"
|
||||
|
||||
|
||||
## Complete Training Loop Structure
|
||||
|
||||
### 1. The Standard Training Loop (The Reference)
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Setup logging (ALWAYS do this first)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingLoop:
|
||||
"""Complete training loop with validation, checkpointing, and monitoring."""
|
||||
|
||||
def __init__(self, model, optimizer, scheduler, criterion, device='cuda'):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
self.criterion = criterion
|
||||
self.device = device
|
||||
|
||||
# Tracking metrics
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
self.best_val_loss = float('inf')
|
||||
self.epochs_without_improvement = 0
|
||||
|
||||
def train_epoch(self, train_loader):
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# Backward pass with gradient clipping (if needed)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Accumulate loss
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress every 10 batches
|
||||
if batch_idx % 10 == 0:
|
||||
logger.debug(f"Batch {batch_idx}: loss={loss.item():.4f}")
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
return avg_loss
|
||||
|
||||
def validate_epoch(self, val_loader):
|
||||
"""Validate on validation set (AFTER each epoch, not during)."""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad(): # ✅ CRITICAL: No gradients during validation
|
||||
for data, target in val_loader:
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
return avg_loss
|
||||
|
||||
def save_checkpoint(self, epoch, val_loss, checkpoint_dir='checkpoints'):
|
||||
"""Save complete checkpoint (model + optimizer + scheduler)."""
|
||||
Path(checkpoint_dir).mkdir(exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'train_losses': self.train_losses,
|
||||
'val_losses': self.val_losses,
|
||||
}
|
||||
|
||||
# Save last checkpoint
|
||||
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_latest.pt')
|
||||
|
||||
# Save best checkpoint
|
||||
if val_loss < self.best_val_loss:
|
||||
self.best_val_loss = val_loss
|
||||
torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_best.pt')
|
||||
logger.info(f"New best validation loss: {val_loss:.4f}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
"""Load checkpoint and resume training correctly."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
# ✅ CRITICAL ORDER: Load model, optimizer, scheduler (in that order)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
# Restore metrics history
|
||||
self.train_losses = checkpoint['train_losses']
|
||||
self.val_losses = checkpoint['val_losses']
|
||||
self.best_val_loss = min(self.val_losses) if self.val_losses else float('inf')
|
||||
|
||||
epoch = checkpoint['epoch']
|
||||
logger.info(f"Loaded checkpoint from epoch {epoch}")
|
||||
return epoch
|
||||
|
||||
def train(self, train_loader, val_loader, num_epochs, checkpoint_dir='checkpoints'):
|
||||
"""Full training loop with validation and checkpointing."""
|
||||
start_epoch = 0
|
||||
|
||||
# Try to resume from checkpoint if it exists
|
||||
checkpoint_path = f'{checkpoint_dir}/checkpoint_latest.pt'
|
||||
if Path(checkpoint_path).exists():
|
||||
start_epoch = self.load_checkpoint(checkpoint_path)
|
||||
logger.info(f"Resuming training from epoch {start_epoch}")
|
||||
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
try:
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch(train_loader)
|
||||
self.train_losses.append(train_loss)
|
||||
|
||||
# ✅ CRITICAL: Validate after every epoch
|
||||
val_loss = self.validate_epoch(val_loader)
|
||||
self.val_losses.append(val_loss)
|
||||
|
||||
# Step scheduler (after epoch)
|
||||
self.scheduler.step()
|
||||
|
||||
# Log metrics
|
||||
logger.info(
|
||||
f"Epoch {epoch}: train_loss={train_loss:.4f}, "
|
||||
f"val_loss={val_loss:.4f}, lr={self.optimizer.param_groups[0]['lr']:.2e}"
|
||||
)
|
||||
|
||||
# Checkpoint every epoch
|
||||
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
|
||||
|
||||
# Early stopping (optional)
|
||||
if val_loss < self.best_val_loss:
|
||||
self.epochs_without_improvement = 0
|
||||
else:
|
||||
self.epochs_without_improvement += 1
|
||||
if self.epochs_without_improvement >= 10:
|
||||
logger.info(f"Early stopping at epoch {epoch}")
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
self.save_checkpoint(epoch, val_loss, checkpoint_dir)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Error in epoch {epoch}: {e}")
|
||||
raise
|
||||
|
||||
logger.info("Training complete")
|
||||
return self.model
|
||||
```
|
||||
|
||||
### 2. Data Split: Train/Val/Test Separation (CRITICAL)
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset, Dataset
|
||||
|
||||
# ✅ CORRECT: Proper three-way split with NO data leakage
|
||||
class DataSplitter:
|
||||
"""Ensures clean train/val/test splits without data leakage."""
|
||||
|
||||
@staticmethod
|
||||
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
|
||||
"""
|
||||
Split dataset into train/val/test.
|
||||
|
||||
CRITICAL: Split indices first, then create loaders.
|
||||
This prevents any data leakage.
|
||||
"""
|
||||
assert train_ratio + val_ratio + test_ratio == 1.0
|
||||
|
||||
n = len(dataset)
|
||||
indices = list(range(n))
|
||||
|
||||
# First split: train vs (val + test)
|
||||
train_size = int(train_ratio * n)
|
||||
train_indices = indices[:train_size]
|
||||
remaining_indices = indices[train_size:]
|
||||
|
||||
# Second split: val vs test
|
||||
remaining_size = len(remaining_indices)
|
||||
val_size = int(val_ratio / (val_ratio + test_ratio) * remaining_size)
|
||||
val_indices = remaining_indices[:val_size]
|
||||
test_indices = remaining_indices[val_size:]
|
||||
|
||||
# Create subset datasets (same transforms, different data)
|
||||
train_dataset = Subset(dataset, train_indices)
|
||||
val_dataset = Subset(dataset, val_indices)
|
||||
test_dataset = Subset(dataset, test_indices)
|
||||
|
||||
logger.info(
|
||||
f"Dataset split: train={len(train_dataset)}, "
|
||||
f"val={len(val_dataset)}, test={len(test_dataset)}"
|
||||
)
|
||||
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
# Usage
|
||||
train_dataset, val_dataset, test_dataset = DataSplitter.split_dataset(
|
||||
full_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
# ✅ CRITICAL: Validate that splits are actually different
|
||||
print(f"Train samples: {len(train_loader.dataset)}")
|
||||
print(f"Val samples: {len(val_loader.dataset)}")
|
||||
print(f"Test samples: {len(test_loader.dataset)}")
|
||||
|
||||
# ✅ CRITICAL: Never mix splits (don't re-shuffle or combine)
|
||||
```
|
||||
|
||||
### 3. Monitoring and Logging (Reproducibility)
|
||||
|
||||
```python
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
class TrainingMonitor:
|
||||
"""Track all metrics for reproducibility and debugging."""
|
||||
|
||||
def __init__(self, log_dir='logs'):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Metrics to track
|
||||
self.metrics = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'epochs': [],
|
||||
'train_losses': [],
|
||||
'val_losses': [],
|
||||
'learning_rates': [],
|
||||
'gradient_norms': [],
|
||||
'batch_times': [],
|
||||
}
|
||||
|
||||
def log_epoch(self, epoch, train_loss, val_loss, lr, gradient_norm=None, batch_time=None):
|
||||
"""Log metrics for one epoch."""
|
||||
self.metrics['epochs'].append(epoch)
|
||||
self.metrics['train_losses'].append(train_loss)
|
||||
self.metrics['val_losses'].append(val_loss)
|
||||
self.metrics['learning_rates'].append(lr)
|
||||
if gradient_norm is not None:
|
||||
self.metrics['gradient_norms'].append(gradient_norm)
|
||||
if batch_time is not None:
|
||||
self.metrics['batch_times'].append(batch_time)
|
||||
|
||||
def save_metrics(self):
|
||||
"""Save metrics to JSON for post-training analysis."""
|
||||
metrics_path = self.log_dir / f'metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
|
||||
with open(metrics_path, 'w') as f:
|
||||
json.dump(self.metrics, f, indent=2)
|
||||
logger.info(f"Metrics saved to {metrics_path}")
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot training curves."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
||||
|
||||
# Loss curves
|
||||
axes[0, 0].plot(self.metrics['epochs'], self.metrics['train_losses'], label='Train')
|
||||
axes[0, 0].plot(self.metrics['epochs'], self.metrics['val_losses'], label='Val')
|
||||
axes[0, 0].set_xlabel('Epoch')
|
||||
axes[0, 0].set_ylabel('Loss')
|
||||
axes[0, 0].legend()
|
||||
axes[0, 0].set_title('Training and Validation Loss')
|
||||
|
||||
# Learning rate schedule
|
||||
axes[0, 1].plot(self.metrics['epochs'], self.metrics['learning_rates'])
|
||||
axes[0, 1].set_xlabel('Epoch')
|
||||
axes[0, 1].set_ylabel('Learning Rate')
|
||||
axes[0, 1].set_title('Learning Rate Schedule')
|
||||
axes[0, 1].set_yscale('log')
|
||||
|
||||
# Gradient norms (if available)
|
||||
if self.metrics['gradient_norms']:
|
||||
axes[1, 0].plot(self.metrics['epochs'], self.metrics['gradient_norms'])
|
||||
axes[1, 0].set_xlabel('Epoch')
|
||||
axes[1, 0].set_ylabel('Gradient Norm')
|
||||
axes[1, 0].set_title('Gradient Norms')
|
||||
|
||||
# Batch times (if available)
|
||||
if self.metrics['batch_times']:
|
||||
axes[1, 1].plot(self.metrics['epochs'], self.metrics['batch_times'])
|
||||
axes[1, 1].set_xlabel('Epoch')
|
||||
axes[1, 1].set_ylabel('Time (seconds)')
|
||||
axes[1, 1].set_title('Batch Processing Time')
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = self.log_dir / f'training_curves_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'
|
||||
plt.savefig(plot_path)
|
||||
logger.info(f"Plot saved to {plot_path}")
|
||||
```
|
||||
|
||||
### 4. Checkpointing and Resuming (Complete State)
|
||||
|
||||
```python
|
||||
class CheckpointManager:
|
||||
"""Properly save and load ALL training state."""
|
||||
|
||||
def __init__(self, checkpoint_dir='checkpoints'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
def save_full_checkpoint(self, epoch, model, optimizer, scheduler, metrics, path_suffix=''):
|
||||
"""Save COMPLETE state for resuming training."""
|
||||
checkpoint = {
|
||||
# Model and optimizer state
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
|
||||
# Training metrics (for monitoring)
|
||||
'train_losses': metrics['train_losses'],
|
||||
'val_losses': metrics['val_losses'],
|
||||
'learning_rates': metrics['learning_rates'],
|
||||
|
||||
# Timestamp for recovery
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Save as latest
|
||||
latest_path = self.checkpoint_dir / f'checkpoint_latest{path_suffix}.pt'
|
||||
torch.save(checkpoint, latest_path)
|
||||
|
||||
# Save periodically (every 10 epochs)
|
||||
if epoch % 10 == 0:
|
||||
periodic_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
|
||||
torch.save(checkpoint, periodic_path)
|
||||
|
||||
logger.info(f"Checkpoint saved: {latest_path}")
|
||||
return latest_path
|
||||
|
||||
def load_full_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
|
||||
"""Load COMPLETE state correctly."""
|
||||
if not Path(checkpoint_path).exists():
|
||||
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return 0, None
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
# ✅ CRITICAL ORDER: Model first, then optimizer, then scheduler
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
epoch = checkpoint['epoch']
|
||||
metrics = {
|
||||
'train_losses': checkpoint.get('train_losses', []),
|
||||
'val_losses': checkpoint.get('val_losses', []),
|
||||
'learning_rates': checkpoint.get('learning_rates', []),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Loaded checkpoint from epoch {epoch}, "
|
||||
f"saved at {checkpoint.get('timestamp', 'unknown')}"
|
||||
)
|
||||
return epoch, metrics
|
||||
|
||||
def get_best_checkpoint(self):
|
||||
"""Find checkpoint with best validation loss."""
|
||||
checkpoints = list(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
best_loss = float('inf')
|
||||
best_path = None
|
||||
|
||||
for ckpt_path in checkpoints:
|
||||
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
||||
val_losses = checkpoint.get('val_losses', [])
|
||||
if val_losses and min(val_losses) < best_loss:
|
||||
best_loss = min(val_losses)
|
||||
best_path = ckpt_path
|
||||
|
||||
return best_path
|
||||
```
|
||||
|
||||
### 5. Memory Management (Prevent Leaks)
|
||||
|
||||
```python
|
||||
class MemoryManager:
|
||||
"""Prevent out-of-memory errors during long training."""
|
||||
|
||||
def __init__(self, device='cuda'):
|
||||
self.device = device
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear GPU cache between epochs."""
|
||||
if self.device.startswith('cuda'):
|
||||
torch.cuda.empty_cache()
|
||||
# Optional: clear CUDA graphs
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_memory(self):
|
||||
"""Log GPU memory usage."""
|
||||
if self.device.startswith('cuda'):
|
||||
allocated = torch.cuda.memory_allocated() / 1e9
|
||||
reserved = torch.cuda.memory_reserved() / 1e9
|
||||
logger.info(f"GPU memory - allocated: {allocated:.2f}GB, reserved: {reserved:.2f}GB")
|
||||
|
||||
def training_loop_with_memory_management(self, model, train_loader, optimizer, criterion):
|
||||
"""Training loop with proper memory management."""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
|
||||
# Forward and backward
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# ✅ Clear temporary tensors (data and target go out of scope)
|
||||
# ✅ Don't hold onto loss or output after using them
|
||||
|
||||
# Periodically check memory
|
||||
if batch_idx % 100 == 0:
|
||||
self.check_memory()
|
||||
|
||||
# Clear cache between epochs
|
||||
self.clear_cache()
|
||||
|
||||
return total_loss / len(train_loader)
|
||||
```
|
||||
|
||||
|
||||
## Error Handling and Recovery
|
||||
|
||||
```python
|
||||
class RobustTrainingLoop:
|
||||
"""Training loop with proper error handling."""
|
||||
|
||||
def train_with_error_handling(self, model, train_loader, val_loader, optimizer,
|
||||
scheduler, criterion, num_epochs, checkpoint_dir):
|
||||
"""Training with error recovery."""
|
||||
checkpoint_manager = CheckpointManager(checkpoint_dir)
|
||||
memory_manager = MemoryManager()
|
||||
|
||||
# Resume from last checkpoint if available
|
||||
start_epoch, metrics = checkpoint_manager.load_full_checkpoint(
|
||||
model, optimizer, scheduler, f'{checkpoint_dir}/checkpoint_latest.pt'
|
||||
)
|
||||
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
try:
|
||||
# Train
|
||||
train_loss = self.train_epoch(model, train_loader, optimizer, criterion)
|
||||
|
||||
# Validate
|
||||
val_loss = self.validate_epoch(model, val_loader, criterion)
|
||||
|
||||
# Update scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Log
|
||||
logger.info(
|
||||
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
|
||||
f"lr={optimizer.param_groups[0]['lr']:.2e}"
|
||||
)
|
||||
|
||||
# Checkpoint
|
||||
checkpoint_manager.save_full_checkpoint(
|
||||
epoch, model, optimizer, scheduler,
|
||||
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
||||
)
|
||||
|
||||
# Memory management
|
||||
memory_manager.clear_cache()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Training interrupted - checkpoint saved")
|
||||
checkpoint_manager.save_full_checkpoint(
|
||||
epoch, model, optimizer, scheduler,
|
||||
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
||||
)
|
||||
break
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e).lower():
|
||||
logger.error("Out of memory error")
|
||||
memory_manager.clear_cache()
|
||||
# Try to continue (reduce batch size in real scenario)
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Runtime error: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in epoch {epoch}: {e}")
|
||||
checkpoint_manager.save_full_checkpoint(
|
||||
epoch, model, optimizer, scheduler,
|
||||
{'train_losses': [train_loss], 'val_losses': [val_loss]}
|
||||
)
|
||||
raise
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
|
||||
## Common Pitfalls and How to Avoid Them
|
||||
|
||||
### Pitfall 1: Validating on Training Data
|
||||
```python
|
||||
# ❌ WRONG
|
||||
val_loader = train_loader # Same loader!
|
||||
|
||||
# ✅ CORRECT
|
||||
train_dataset, val_dataset = split_dataset(full_dataset)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||
```
|
||||
|
||||
### Pitfall 2: Missing Optimizer State in Checkpoint
|
||||
```python
|
||||
# ❌ WRONG
|
||||
torch.save({'model': model.state_dict()}, 'ckpt.pt')
|
||||
|
||||
# ✅ CORRECT
|
||||
torch.save({
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
}, 'ckpt.pt')
|
||||
```
|
||||
|
||||
### Pitfall 3: Not Validating During Training
|
||||
```python
|
||||
# ❌ WRONG
|
||||
for epoch in range(100):
|
||||
train_epoch()
|
||||
final_val = evaluate() # Only at the end!
|
||||
|
||||
# ✅ CORRECT
|
||||
for epoch in range(100):
|
||||
train_epoch()
|
||||
validate_epoch() # After every epoch
|
||||
```
|
||||
|
||||
### Pitfall 4: Holding Onto Tensor References
|
||||
```python
|
||||
# ❌ WRONG
|
||||
all_losses = []
|
||||
for data, target in loader:
|
||||
loss = criterion(model(data), target)
|
||||
all_losses.append(loss) # Memory leak!
|
||||
|
||||
# ✅ CORRECT
|
||||
total_loss = 0.0
|
||||
for data, target in loader:
|
||||
loss = criterion(model(data), target)
|
||||
total_loss += loss.item() # Scalar value
|
||||
```
|
||||
|
||||
### Pitfall 5: Forgetting torch.no_grad() in Validation
|
||||
```python
|
||||
# ❌ WRONG
|
||||
model.eval()
|
||||
for data, target in val_loader:
|
||||
output = model(data) # Gradients still computed!
|
||||
loss = criterion(output, target)
|
||||
|
||||
# ✅ CORRECT
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
output = model(data) # No gradients
|
||||
loss = criterion(output, target)
|
||||
```
|
||||
|
||||
### Pitfall 6: Resetting Scheduler on Resume
|
||||
```python
|
||||
# ❌ WRONG
|
||||
checkpoint = torch.load('ckpt.pt')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=100) # Fresh scheduler!
|
||||
# Now at epoch 50, scheduler thinks it's epoch 0
|
||||
|
||||
# ✅ CORRECT
|
||||
checkpoint = torch.load('ckpt.pt')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler']) # Resume scheduler state
|
||||
```
|
||||
|
||||
### Pitfall 7: Not Handling Early Stopping Correctly
|
||||
```python
|
||||
# ❌ WRONG
|
||||
best_loss = float('inf')
|
||||
for epoch in range(100):
|
||||
val_loss = validate()
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
# No checkpoint! Can't recover best model
|
||||
|
||||
# ✅ CORRECT
|
||||
best_loss = float('inf')
|
||||
patience = 10
|
||||
patience_counter = 0
|
||||
for epoch in range(100):
|
||||
val_loss = validate()
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
patience_counter = 0
|
||||
save_checkpoint(model, optimizer, scheduler, epoch) # Save best
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
break # Stop early
|
||||
```
|
||||
|
||||
### Pitfall 8: Mixing Train and Validation Mode
|
||||
```python
|
||||
# ❌ WRONG
|
||||
for epoch in range(100):
|
||||
for data, target in train_loader:
|
||||
output = model(data) # Is model in train or eval mode?
|
||||
loss = criterion(output, target)
|
||||
|
||||
# ✅ CORRECT
|
||||
model.train()
|
||||
for epoch in range(100):
|
||||
for data, target in train_loader:
|
||||
output = model(data) # Definitely in train mode
|
||||
loss = criterion(output, target)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
output = model(data) # Definitely in eval mode
|
||||
```
|
||||
|
||||
### Pitfall 9: Loading Checkpoint on Wrong Device
|
||||
```python
|
||||
# ❌ WRONG
|
||||
checkpoint = torch.load('ckpt.pt') # Loads on GPU if saved on GPU
|
||||
model.load_state_dict(checkpoint['model']) # Might be on wrong device
|
||||
|
||||
# ✅ CORRECT
|
||||
checkpoint = torch.load('ckpt.pt', map_location='cuda:0') # Specify device
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.to('cuda:0') # Move to device
|
||||
```
|
||||
|
||||
### Pitfall 10: Not Clearing GPU Cache
|
||||
```python
|
||||
# ❌ WRONG
|
||||
for epoch in range(100):
|
||||
train_epoch()
|
||||
validate_epoch()
|
||||
# GPU cache growing every epoch
|
||||
|
||||
# ✅ CORRECT
|
||||
for epoch in range(100):
|
||||
train_epoch()
|
||||
validate_epoch()
|
||||
torch.cuda.empty_cache() # Clear cache
|
||||
```
|
||||
|
||||
|
||||
## Integration with Optimization Techniques
|
||||
|
||||
### Complete Training Loop with All Techniques
|
||||
|
||||
```python
|
||||
class FullyOptimizedTrainingLoop:
|
||||
"""Integrates: gradient clipping, mixed precision, learning rate scheduling."""
|
||||
|
||||
def train_with_all_techniques(self, model, train_loader, val_loader,
|
||||
num_epochs, checkpoint_dir='checkpoints'):
|
||||
"""Training with all optimization techniques integrated."""
|
||||
|
||||
# Setup
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Mixed precision (if using AMP)
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for data, target in train_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Mixed precision forward pass
|
||||
with torch.autocast('cuda'):
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Gradient scaling for mixed precision
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Gradient clipping (unscale first!)
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
||||
# Optimizer step
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
train_loss = total_loss / len(train_loader)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
val_loss += criterion(output, target).item()
|
||||
|
||||
val_loss /= len(val_loader)
|
||||
|
||||
# Scheduler step
|
||||
scheduler.step()
|
||||
|
||||
logger.info(
|
||||
f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}, "
|
||||
f"lr={optimizer.param_groups[0]['lr']:.2e}"
|
||||
)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
|
||||
## Rationalization Table: When to Deviate from Standard
|
||||
|
||||
| Situation | Standard Practice | Deviation | Rationale |
|
||||
|-----------|-------------------|-----------|-----------|
|
||||
| Validate only at end | Validate every epoch | ✗ Never | Can't detect overfitting |
|
||||
| Save only model | Save model + optimizer + scheduler | ✗ Never | Resume training breaks |
|
||||
| Mixed train/val | Separate datasets completely | ✗ Never | Data leakage and false metrics |
|
||||
| Constant batch size | Fix batch size for reproducibility | ✓ Sometimes | May need dynamic batching for memory |
|
||||
| Single LR | Use scheduler | ✓ Sometimes | <10 epoch training or hyperparameter search |
|
||||
| No early stopping | Implement early stopping | ✓ Sometimes | If training time unlimited |
|
||||
| Log every batch | Log every 10-100 batches | ✓ Often | Reduces I/O overhead |
|
||||
| GPU cache every epoch | Clear GPU cache periodically | ✓ Sometimes | Only if OOM issues |
|
||||
|
||||
|
||||
## Red Flags: Immediate Warning Signs
|
||||
|
||||
1. **Training loss much lower than validation loss** (>2x) → Overfitting
|
||||
2. **Loss spikes on resume** → Optimizer state not loaded
|
||||
3. **GPU memory grows over time** → Memory leak, likely tensor accumulation
|
||||
4. **Validation never runs** → Check if validation is in loop
|
||||
5. **Best model not saved** → Check checkpoint logic
|
||||
6. **Different results on resume** → Scheduler not loaded
|
||||
7. **Early stopping not working** → Checkpoint not at best model
|
||||
8. **OOM during training** → Clear GPU cache, check for accumulated tensors
|
||||
|
||||
|
||||
## Testing Your Training Loop
|
||||
|
||||
```python
|
||||
def test_training_loop():
|
||||
"""Quick test to verify training loop is correct."""
|
||||
|
||||
# Create dummy data
|
||||
X_train = torch.randn(100, 10)
|
||||
y_train = torch.randint(0, 2, (100,))
|
||||
X_val = torch.randn(20, 10)
|
||||
y_val = torch.randint(0, 2, (20,))
|
||||
|
||||
train_loader = DataLoader(
|
||||
list(zip(X_train, y_train)), batch_size=16
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
list(zip(X_val, y_val)), batch_size=16
|
||||
)
|
||||
|
||||
# Simple model
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 2)
|
||||
)
|
||||
|
||||
# Training
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
loop = TrainingLoop(model, optimizer, scheduler, criterion)
|
||||
|
||||
# Should complete without errors
|
||||
loop.train(train_loader, val_loader, num_epochs=5, checkpoint_dir='test_ckpts')
|
||||
|
||||
# Check outputs
|
||||
assert len(loop.train_losses) == 5
|
||||
assert len(loop.val_losses) == 5
|
||||
assert all(isinstance(l, float) for l in loop.train_losses)
|
||||
|
||||
print("✓ Training loop test passed")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_training_loop()
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user