Initial commit
This commit is contained in:
556
skills/stable-baselines3/references/callbacks.md
Normal file
556
skills/stable-baselines3/references/callbacks.md
Normal file
@@ -0,0 +1,556 @@
|
||||
# Stable Baselines3 Callback System
|
||||
|
||||
This document provides comprehensive information about the callback system in Stable Baselines3 for monitoring and controlling training.
|
||||
|
||||
## Overview
|
||||
|
||||
Callbacks are functions called at specific points during training to:
|
||||
- Monitor training metrics
|
||||
- Save checkpoints
|
||||
- Implement early stopping
|
||||
- Log custom metrics
|
||||
- Adjust hyperparameters dynamically
|
||||
- Trigger evaluations
|
||||
|
||||
## Built-in Callbacks
|
||||
|
||||
### EvalCallback
|
||||
|
||||
Evaluates the agent periodically and saves the best model.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
eval_env, # Separate evaluation environment
|
||||
best_model_save_path="./logs/best_model/", # Where to save best model
|
||||
log_path="./logs/eval/", # Where to save evaluation logs
|
||||
eval_freq=10000, # Evaluate every N steps
|
||||
n_eval_episodes=5, # Number of episodes per evaluation
|
||||
deterministic=True, # Use deterministic actions
|
||||
render=False, # Render during evaluation
|
||||
verbose=1,
|
||||
warn=True,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=100000, callback=eval_callback)
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Automatically saves best model based on mean reward
|
||||
- Logs evaluation metrics to TensorBoard
|
||||
- Can stop training if reward threshold reached
|
||||
|
||||
**Important:** When using vectorized training environments, adjust `eval_freq`:
|
||||
```python
|
||||
# With 4 parallel environments, divide eval_freq by n_envs
|
||||
eval_freq = 10000 // 4 # Evaluate every 10000 total environment steps
|
||||
```
|
||||
|
||||
### CheckpointCallback
|
||||
|
||||
Saves model checkpoints at regular intervals.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=10000, # Save every N steps
|
||||
save_path="./logs/checkpoints/", # Directory for checkpoints
|
||||
name_prefix="rl_model", # Prefix for checkpoint files
|
||||
save_replay_buffer=True, # Save replay buffer (off-policy only)
|
||||
save_vecnormalize=True, # Save VecNormalize stats
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=100000, callback=checkpoint_callback)
|
||||
```
|
||||
|
||||
**Output Files:**
|
||||
- `rl_model_10000_steps.zip` - Model at 10k steps
|
||||
- `rl_model_20000_steps.zip` - Model at 20k steps
|
||||
- etc.
|
||||
|
||||
**Important:** Adjust `save_freq` for vectorized environments (divide by n_envs).
|
||||
|
||||
### StopTrainingOnRewardThreshold
|
||||
|
||||
Stops training when mean reward exceeds a threshold.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
|
||||
|
||||
stop_callback = StopTrainingOnRewardThreshold(
|
||||
reward_threshold=200, # Stop when mean reward >= 200
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# Must be used with EvalCallback
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_on_new_best=stop_callback, # Trigger when new best found
|
||||
eval_freq=10000,
|
||||
n_eval_episodes=5,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=eval_callback)
|
||||
```
|
||||
|
||||
### StopTrainingOnNoModelImprovement
|
||||
|
||||
Stops training if model doesn't improve for N evaluations.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement
|
||||
|
||||
stop_callback = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=10, # Stop after 10 evals with no improvement
|
||||
min_evals=20, # Minimum evaluations before stopping
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# Use with EvalCallback
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_after_eval=stop_callback,
|
||||
eval_freq=10000,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=eval_callback)
|
||||
```
|
||||
|
||||
### StopTrainingOnMaxEpisodes
|
||||
|
||||
Stops training after a maximum number of episodes.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
|
||||
|
||||
stop_callback = StopTrainingOnMaxEpisodes(
|
||||
max_episodes=1000, # Stop after 1000 episodes
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=stop_callback)
|
||||
```
|
||||
|
||||
### ProgressBarCallback
|
||||
|
||||
Displays a progress bar during training (requires tqdm).
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
||||
progress_callback = ProgressBarCallback()
|
||||
|
||||
model.learn(total_timesteps=100000, callback=progress_callback)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
100%|██████████| 100000/100000 [05:23<00:00, 309.31it/s]
|
||||
```
|
||||
|
||||
## Creating Custom Callbacks
|
||||
|
||||
### BaseCallback Structure
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
class CustomCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback template.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super().__init__(verbose)
|
||||
# Custom initialization
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
"""
|
||||
Called once when training starts.
|
||||
Useful for initialization that requires access to model/env.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
"""
|
||||
Called before the first rollout starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_rollout_start(self) -> None:
|
||||
"""
|
||||
Called before collecting new samples (on-policy algorithms).
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
"""
|
||||
Called after every step in the environment.
|
||||
|
||||
Returns:
|
||||
bool: If False, training will be stopped.
|
||||
"""
|
||||
return True # Continue training
|
||||
|
||||
def _on_rollout_end(self) -> None:
|
||||
"""
|
||||
Called after rollout ends (on-policy algorithms).
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
"""
|
||||
Called at the end of training.
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### Useful Attributes
|
||||
|
||||
Inside callbacks, you have access to:
|
||||
|
||||
- **`self.model`**: The RL algorithm instance
|
||||
- **`self.training_env`**: The training environment
|
||||
- **`self.n_calls`**: Number of times `_on_step()` was called
|
||||
- **`self.num_timesteps`**: Total number of environment steps
|
||||
- **`self.locals`**: Local variables from the algorithm (varies by algorithm)
|
||||
- **`self.globals`**: Global variables from the algorithm
|
||||
- **`self.logger`**: Logger for TensorBoard/CSV logging
|
||||
- **`self.parent`**: Parent callback (if used in CallbackList)
|
||||
|
||||
## Custom Callback Examples
|
||||
|
||||
### Example 1: Log Custom Metrics
|
||||
|
||||
```python
|
||||
class LogCustomMetricsCallback(BaseCallback):
|
||||
"""
|
||||
Log custom metrics to TensorBoard.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.episode_rewards = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Check if episode ended
|
||||
if self.locals["dones"][0]:
|
||||
# Log episode reward
|
||||
episode_reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
||||
self.episode_rewards.append(episode_reward)
|
||||
|
||||
# Log to TensorBoard
|
||||
self.logger.record("custom/episode_reward", episode_reward)
|
||||
self.logger.record("custom/mean_reward_last_100",
|
||||
np.mean(self.episode_rewards[-100:]))
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 2: Adjust Learning Rate
|
||||
|
||||
```python
|
||||
class LinearScheduleCallback(BaseCallback):
|
||||
"""
|
||||
Linearly decrease learning rate during training.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_lr=3e-4, final_lr=3e-5, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.initial_lr = initial_lr
|
||||
self.final_lr = final_lr
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Calculate progress (0 to 1)
|
||||
progress = self.num_timesteps / self.locals["total_timesteps"]
|
||||
|
||||
# Linear interpolation
|
||||
new_lr = self.initial_lr + (self.final_lr - self.initial_lr) * progress
|
||||
|
||||
# Update learning rate
|
||||
for param_group in self.model.policy.optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
|
||||
# Log learning rate
|
||||
self.logger.record("train/learning_rate", new_lr)
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 3: Early Stopping on Moving Average
|
||||
|
||||
```python
|
||||
class EarlyStoppingCallback(BaseCallback):
|
||||
"""
|
||||
Stop training if moving average of rewards doesn't improve.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=10000, min_reward=200, window=100, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.min_reward = min_reward
|
||||
self.window = window
|
||||
self.rewards = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Collect episode rewards
|
||||
if self.locals["dones"][0]:
|
||||
reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Check every check_freq steps
|
||||
if self.n_calls % self.check_freq == 0 and len(self.rewards) >= self.window:
|
||||
mean_reward = np.mean(self.rewards[-self.window:])
|
||||
if self.verbose > 0:
|
||||
print(f"Mean reward: {mean_reward:.2f}")
|
||||
|
||||
if mean_reward >= self.min_reward:
|
||||
if self.verbose > 0:
|
||||
print(f"Stopping: reward threshold reached!")
|
||||
return False # Stop training
|
||||
|
||||
return True # Continue training
|
||||
```
|
||||
|
||||
### Example 4: Save Best Model by Custom Metric
|
||||
|
||||
```python
|
||||
class SaveBestModelCallback(BaseCallback):
|
||||
"""
|
||||
Save model when custom metric is best.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=1000, save_path="./best_model/", verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.save_path = save_path
|
||||
self.best_score = -np.inf
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
if self.save_path is not None:
|
||||
os.makedirs(self.save_path, exist_ok=True)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.check_freq == 0:
|
||||
# Calculate custom metric (example: policy entropy)
|
||||
custom_metric = self.locals.get("entropy_losses", [0])[-1]
|
||||
|
||||
if custom_metric > self.best_score:
|
||||
self.best_score = custom_metric
|
||||
if self.verbose > 0:
|
||||
print(f"New best! Saving model to {self.save_path}")
|
||||
self.model.save(os.path.join(self.save_path, "best_model"))
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 5: Log Environment-Specific Information
|
||||
|
||||
```python
|
||||
class EnvironmentInfoCallback(BaseCallback):
|
||||
"""
|
||||
Log custom info from environment.
|
||||
"""
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Access info dict from environment
|
||||
info = self.locals["infos"][0]
|
||||
|
||||
# Log custom metrics from environment
|
||||
if "distance_to_goal" in info:
|
||||
self.logger.record("env/distance_to_goal", info["distance_to_goal"])
|
||||
|
||||
if "success" in info:
|
||||
self.logger.record("env/success_rate", info["success"])
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
## Chaining Multiple Callbacks
|
||||
|
||||
Use `CallbackList` to combine multiple callbacks:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import CallbackList
|
||||
|
||||
callback_list = CallbackList([
|
||||
eval_callback,
|
||||
checkpoint_callback,
|
||||
progress_callback,
|
||||
custom_callback,
|
||||
])
|
||||
|
||||
model.learn(total_timesteps=100000, callback=callback_list)
|
||||
```
|
||||
|
||||
Or pass a list directly:
|
||||
|
||||
```python
|
||||
model.learn(
|
||||
total_timesteps=100000,
|
||||
callback=[eval_callback, checkpoint_callback, custom_callback]
|
||||
)
|
||||
```
|
||||
|
||||
## Event-Based Callbacks
|
||||
|
||||
Callbacks can trigger other callbacks on specific events:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import EventCallback
|
||||
|
||||
# Stop training when reward threshold reached
|
||||
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200)
|
||||
|
||||
# Evaluate periodically and trigger stop_callback when new best found
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_on_new_best=stop_callback, # Triggered when new best model
|
||||
eval_freq=10000,
|
||||
)
|
||||
```
|
||||
|
||||
## Logging to TensorBoard
|
||||
|
||||
Use `self.logger.record()` to log metrics:
|
||||
|
||||
```python
|
||||
class TensorBoardCallback(BaseCallback):
|
||||
def _on_step(self) -> bool:
|
||||
# Log scalar
|
||||
self.logger.record("custom/my_metric", value)
|
||||
|
||||
# Log multiple metrics
|
||||
self.logger.record("custom/metric1", value1)
|
||||
self.logger.record("custom/metric2", value2)
|
||||
|
||||
# Logger automatically writes to TensorBoard
|
||||
return True
|
||||
```
|
||||
|
||||
**View in TensorBoard:**
|
||||
```bash
|
||||
tensorboard --logdir ./logs/
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Curriculum Learning
|
||||
|
||||
```python
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""
|
||||
Increase task difficulty over time.
|
||||
"""
|
||||
|
||||
def __init__(self, difficulty_schedule, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.difficulty_schedule = difficulty_schedule
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Update environment difficulty based on progress
|
||||
progress = self.num_timesteps / self.locals["total_timesteps"]
|
||||
|
||||
for threshold, difficulty in self.difficulty_schedule:
|
||||
if progress >= threshold:
|
||||
self.training_env.env_method("set_difficulty", difficulty)
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Population-Based Training
|
||||
|
||||
```python
|
||||
class PopulationBasedCallback(BaseCallback):
|
||||
"""
|
||||
Adjust hyperparameters based on performance.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=10000, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.performance_history = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.check_freq == 0:
|
||||
# Evaluate performance
|
||||
perf = self._evaluate_performance()
|
||||
self.performance_history.append(perf)
|
||||
|
||||
# Adjust hyperparameters if performance plateaus
|
||||
if len(self.performance_history) >= 3:
|
||||
recent = self.performance_history[-3:]
|
||||
if max(recent) - min(recent) < 0.01: # Plateau detected
|
||||
self._adjust_hyperparameters()
|
||||
|
||||
return True
|
||||
|
||||
def _adjust_hyperparameters(self):
|
||||
# Example: increase learning rate
|
||||
for param_group in self.model.policy.optimizer.param_groups:
|
||||
param_group["lr"] *= 1.2
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Print Available Attributes
|
||||
|
||||
```python
|
||||
class DebugCallback(BaseCallback):
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls == 1:
|
||||
print("Available in self.locals:")
|
||||
for key in self.locals.keys():
|
||||
print(f" {key}: {type(self.locals[key])}")
|
||||
return True
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Callback not being called:**
|
||||
- Ensure callback is passed to `model.learn()`
|
||||
- Check that `_on_step()` returns `True`
|
||||
|
||||
2. **AttributeError in callback:**
|
||||
- Not all attributes available in all callbacks
|
||||
- Use `self.locals.get("key", default)` for safety
|
||||
|
||||
3. **Memory leaks:**
|
||||
- Don't store large arrays in callback state
|
||||
- Clear buffers periodically
|
||||
|
||||
4. **Performance impact:**
|
||||
- Minimize computation in `_on_step()` (called every step)
|
||||
- Use `check_freq` to limit expensive operations
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use appropriate callback timing:**
|
||||
- `_on_step()`: For metrics that change every step
|
||||
- `_on_rollout_end()`: For metrics computed over rollouts
|
||||
- `_init_callback()`: For one-time initialization
|
||||
|
||||
2. **Log efficiently:**
|
||||
- Don't log every step (hurts performance)
|
||||
- Aggregate metrics and log periodically
|
||||
|
||||
3. **Handle vectorized environments:**
|
||||
- Remember that `dones`, `infos`, etc. are arrays
|
||||
- Check `dones[i]` for each environment
|
||||
|
||||
4. **Test callbacks independently:**
|
||||
- Create simple test cases
|
||||
- Verify callback behavior before long training runs
|
||||
|
||||
5. **Document custom callbacks:**
|
||||
- Clear docstrings
|
||||
- Example usage in comments
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official SB3 Callbacks Guide: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
|
||||
- Callback API Reference: https://stable-baselines3.readthedocs.io/en/master/common/callbacks.html
|
||||
- TensorBoard Documentation: https://www.tensorflow.org/tensorboard
|
||||
Reference in New Issue
Block a user