557 lines
15 KiB
Markdown
557 lines
15 KiB
Markdown
# Stable Baselines3 Callback System
|
|
|
|
This document provides comprehensive information about the callback system in Stable Baselines3 for monitoring and controlling training.
|
|
|
|
## Overview
|
|
|
|
Callbacks are functions called at specific points during training to:
|
|
- Monitor training metrics
|
|
- Save checkpoints
|
|
- Implement early stopping
|
|
- Log custom metrics
|
|
- Adjust hyperparameters dynamically
|
|
- Trigger evaluations
|
|
|
|
## Built-in Callbacks
|
|
|
|
### EvalCallback
|
|
|
|
Evaluates the agent periodically and saves the best model.
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import EvalCallback
|
|
|
|
eval_callback = EvalCallback(
|
|
eval_env, # Separate evaluation environment
|
|
best_model_save_path="./logs/best_model/", # Where to save best model
|
|
log_path="./logs/eval/", # Where to save evaluation logs
|
|
eval_freq=10000, # Evaluate every N steps
|
|
n_eval_episodes=5, # Number of episodes per evaluation
|
|
deterministic=True, # Use deterministic actions
|
|
render=False, # Render during evaluation
|
|
verbose=1,
|
|
warn=True,
|
|
)
|
|
|
|
model.learn(total_timesteps=100000, callback=eval_callback)
|
|
```
|
|
|
|
**Key Features:**
|
|
- Automatically saves best model based on mean reward
|
|
- Logs evaluation metrics to TensorBoard
|
|
- Can stop training if reward threshold reached
|
|
|
|
**Important:** When using vectorized training environments, adjust `eval_freq`:
|
|
```python
|
|
# With 4 parallel environments, divide eval_freq by n_envs
|
|
eval_freq = 10000 // 4 # Evaluate every 10000 total environment steps
|
|
```
|
|
|
|
### CheckpointCallback
|
|
|
|
Saves model checkpoints at regular intervals.
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import CheckpointCallback
|
|
|
|
checkpoint_callback = CheckpointCallback(
|
|
save_freq=10000, # Save every N steps
|
|
save_path="./logs/checkpoints/", # Directory for checkpoints
|
|
name_prefix="rl_model", # Prefix for checkpoint files
|
|
save_replay_buffer=True, # Save replay buffer (off-policy only)
|
|
save_vecnormalize=True, # Save VecNormalize stats
|
|
verbose=2,
|
|
)
|
|
|
|
model.learn(total_timesteps=100000, callback=checkpoint_callback)
|
|
```
|
|
|
|
**Output Files:**
|
|
- `rl_model_10000_steps.zip` - Model at 10k steps
|
|
- `rl_model_20000_steps.zip` - Model at 20k steps
|
|
- etc.
|
|
|
|
**Important:** Adjust `save_freq` for vectorized environments (divide by n_envs).
|
|
|
|
### StopTrainingOnRewardThreshold
|
|
|
|
Stops training when mean reward exceeds a threshold.
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
|
|
|
|
stop_callback = StopTrainingOnRewardThreshold(
|
|
reward_threshold=200, # Stop when mean reward >= 200
|
|
verbose=1,
|
|
)
|
|
|
|
# Must be used with EvalCallback
|
|
eval_callback = EvalCallback(
|
|
eval_env,
|
|
callback_on_new_best=stop_callback, # Trigger when new best found
|
|
eval_freq=10000,
|
|
n_eval_episodes=5,
|
|
)
|
|
|
|
model.learn(total_timesteps=1000000, callback=eval_callback)
|
|
```
|
|
|
|
### StopTrainingOnNoModelImprovement
|
|
|
|
Stops training if model doesn't improve for N evaluations.
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement
|
|
|
|
stop_callback = StopTrainingOnNoModelImprovement(
|
|
max_no_improvement_evals=10, # Stop after 10 evals with no improvement
|
|
min_evals=20, # Minimum evaluations before stopping
|
|
verbose=1,
|
|
)
|
|
|
|
# Use with EvalCallback
|
|
eval_callback = EvalCallback(
|
|
eval_env,
|
|
callback_after_eval=stop_callback,
|
|
eval_freq=10000,
|
|
)
|
|
|
|
model.learn(total_timesteps=1000000, callback=eval_callback)
|
|
```
|
|
|
|
### StopTrainingOnMaxEpisodes
|
|
|
|
Stops training after a maximum number of episodes.
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
|
|
|
|
stop_callback = StopTrainingOnMaxEpisodes(
|
|
max_episodes=1000, # Stop after 1000 episodes
|
|
verbose=1,
|
|
)
|
|
|
|
model.learn(total_timesteps=1000000, callback=stop_callback)
|
|
```
|
|
|
|
### ProgressBarCallback
|
|
|
|
Displays a progress bar during training (requires tqdm).
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import ProgressBarCallback
|
|
|
|
progress_callback = ProgressBarCallback()
|
|
|
|
model.learn(total_timesteps=100000, callback=progress_callback)
|
|
```
|
|
|
|
**Output:**
|
|
```
|
|
100%|██████████| 100000/100000 [05:23<00:00, 309.31it/s]
|
|
```
|
|
|
|
## Creating Custom Callbacks
|
|
|
|
### BaseCallback Structure
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
class CustomCallback(BaseCallback):
|
|
"""
|
|
Custom callback template.
|
|
"""
|
|
|
|
def __init__(self, verbose=0):
|
|
super().__init__(verbose)
|
|
# Custom initialization
|
|
|
|
def _init_callback(self) -> None:
|
|
"""
|
|
Called once when training starts.
|
|
Useful for initialization that requires access to model/env.
|
|
"""
|
|
pass
|
|
|
|
def _on_training_start(self) -> None:
|
|
"""
|
|
Called before the first rollout starts.
|
|
"""
|
|
pass
|
|
|
|
def _on_rollout_start(self) -> None:
|
|
"""
|
|
Called before collecting new samples (on-policy algorithms).
|
|
"""
|
|
pass
|
|
|
|
def _on_step(self) -> bool:
|
|
"""
|
|
Called after every step in the environment.
|
|
|
|
Returns:
|
|
bool: If False, training will be stopped.
|
|
"""
|
|
return True # Continue training
|
|
|
|
def _on_rollout_end(self) -> None:
|
|
"""
|
|
Called after rollout ends (on-policy algorithms).
|
|
"""
|
|
pass
|
|
|
|
def _on_training_end(self) -> None:
|
|
"""
|
|
Called at the end of training.
|
|
"""
|
|
pass
|
|
```
|
|
|
|
### Useful Attributes
|
|
|
|
Inside callbacks, you have access to:
|
|
|
|
- **`self.model`**: The RL algorithm instance
|
|
- **`self.training_env`**: The training environment
|
|
- **`self.n_calls`**: Number of times `_on_step()` was called
|
|
- **`self.num_timesteps`**: Total number of environment steps
|
|
- **`self.locals`**: Local variables from the algorithm (varies by algorithm)
|
|
- **`self.globals`**: Global variables from the algorithm
|
|
- **`self.logger`**: Logger for TensorBoard/CSV logging
|
|
- **`self.parent`**: Parent callback (if used in CallbackList)
|
|
|
|
## Custom Callback Examples
|
|
|
|
### Example 1: Log Custom Metrics
|
|
|
|
```python
|
|
class LogCustomMetricsCallback(BaseCallback):
|
|
"""
|
|
Log custom metrics to TensorBoard.
|
|
"""
|
|
|
|
def __init__(self, verbose=0):
|
|
super().__init__(verbose)
|
|
self.episode_rewards = []
|
|
|
|
def _on_step(self) -> bool:
|
|
# Check if episode ended
|
|
if self.locals["dones"][0]:
|
|
# Log episode reward
|
|
episode_reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
|
self.episode_rewards.append(episode_reward)
|
|
|
|
# Log to TensorBoard
|
|
self.logger.record("custom/episode_reward", episode_reward)
|
|
self.logger.record("custom/mean_reward_last_100",
|
|
np.mean(self.episode_rewards[-100:]))
|
|
|
|
return True
|
|
```
|
|
|
|
### Example 2: Adjust Learning Rate
|
|
|
|
```python
|
|
class LinearScheduleCallback(BaseCallback):
|
|
"""
|
|
Linearly decrease learning rate during training.
|
|
"""
|
|
|
|
def __init__(self, initial_lr=3e-4, final_lr=3e-5, verbose=0):
|
|
super().__init__(verbose)
|
|
self.initial_lr = initial_lr
|
|
self.final_lr = final_lr
|
|
|
|
def _on_step(self) -> bool:
|
|
# Calculate progress (0 to 1)
|
|
progress = self.num_timesteps / self.locals["total_timesteps"]
|
|
|
|
# Linear interpolation
|
|
new_lr = self.initial_lr + (self.final_lr - self.initial_lr) * progress
|
|
|
|
# Update learning rate
|
|
for param_group in self.model.policy.optimizer.param_groups:
|
|
param_group["lr"] = new_lr
|
|
|
|
# Log learning rate
|
|
self.logger.record("train/learning_rate", new_lr)
|
|
|
|
return True
|
|
```
|
|
|
|
### Example 3: Early Stopping on Moving Average
|
|
|
|
```python
|
|
class EarlyStoppingCallback(BaseCallback):
|
|
"""
|
|
Stop training if moving average of rewards doesn't improve.
|
|
"""
|
|
|
|
def __init__(self, check_freq=10000, min_reward=200, window=100, verbose=0):
|
|
super().__init__(verbose)
|
|
self.check_freq = check_freq
|
|
self.min_reward = min_reward
|
|
self.window = window
|
|
self.rewards = []
|
|
|
|
def _on_step(self) -> bool:
|
|
# Collect episode rewards
|
|
if self.locals["dones"][0]:
|
|
reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
|
self.rewards.append(reward)
|
|
|
|
# Check every check_freq steps
|
|
if self.n_calls % self.check_freq == 0 and len(self.rewards) >= self.window:
|
|
mean_reward = np.mean(self.rewards[-self.window:])
|
|
if self.verbose > 0:
|
|
print(f"Mean reward: {mean_reward:.2f}")
|
|
|
|
if mean_reward >= self.min_reward:
|
|
if self.verbose > 0:
|
|
print(f"Stopping: reward threshold reached!")
|
|
return False # Stop training
|
|
|
|
return True # Continue training
|
|
```
|
|
|
|
### Example 4: Save Best Model by Custom Metric
|
|
|
|
```python
|
|
class SaveBestModelCallback(BaseCallback):
|
|
"""
|
|
Save model when custom metric is best.
|
|
"""
|
|
|
|
def __init__(self, check_freq=1000, save_path="./best_model/", verbose=0):
|
|
super().__init__(verbose)
|
|
self.check_freq = check_freq
|
|
self.save_path = save_path
|
|
self.best_score = -np.inf
|
|
|
|
def _init_callback(self) -> None:
|
|
if self.save_path is not None:
|
|
os.makedirs(self.save_path, exist_ok=True)
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.n_calls % self.check_freq == 0:
|
|
# Calculate custom metric (example: policy entropy)
|
|
custom_metric = self.locals.get("entropy_losses", [0])[-1]
|
|
|
|
if custom_metric > self.best_score:
|
|
self.best_score = custom_metric
|
|
if self.verbose > 0:
|
|
print(f"New best! Saving model to {self.save_path}")
|
|
self.model.save(os.path.join(self.save_path, "best_model"))
|
|
|
|
return True
|
|
```
|
|
|
|
### Example 5: Log Environment-Specific Information
|
|
|
|
```python
|
|
class EnvironmentInfoCallback(BaseCallback):
|
|
"""
|
|
Log custom info from environment.
|
|
"""
|
|
|
|
def _on_step(self) -> bool:
|
|
# Access info dict from environment
|
|
info = self.locals["infos"][0]
|
|
|
|
# Log custom metrics from environment
|
|
if "distance_to_goal" in info:
|
|
self.logger.record("env/distance_to_goal", info["distance_to_goal"])
|
|
|
|
if "success" in info:
|
|
self.logger.record("env/success_rate", info["success"])
|
|
|
|
return True
|
|
```
|
|
|
|
## Chaining Multiple Callbacks
|
|
|
|
Use `CallbackList` to combine multiple callbacks:
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import CallbackList
|
|
|
|
callback_list = CallbackList([
|
|
eval_callback,
|
|
checkpoint_callback,
|
|
progress_callback,
|
|
custom_callback,
|
|
])
|
|
|
|
model.learn(total_timesteps=100000, callback=callback_list)
|
|
```
|
|
|
|
Or pass a list directly:
|
|
|
|
```python
|
|
model.learn(
|
|
total_timesteps=100000,
|
|
callback=[eval_callback, checkpoint_callback, custom_callback]
|
|
)
|
|
```
|
|
|
|
## Event-Based Callbacks
|
|
|
|
Callbacks can trigger other callbacks on specific events:
|
|
|
|
```python
|
|
from stable_baselines3.common.callbacks import EventCallback
|
|
|
|
# Stop training when reward threshold reached
|
|
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200)
|
|
|
|
# Evaluate periodically and trigger stop_callback when new best found
|
|
eval_callback = EvalCallback(
|
|
eval_env,
|
|
callback_on_new_best=stop_callback, # Triggered when new best model
|
|
eval_freq=10000,
|
|
)
|
|
```
|
|
|
|
## Logging to TensorBoard
|
|
|
|
Use `self.logger.record()` to log metrics:
|
|
|
|
```python
|
|
class TensorBoardCallback(BaseCallback):
|
|
def _on_step(self) -> bool:
|
|
# Log scalar
|
|
self.logger.record("custom/my_metric", value)
|
|
|
|
# Log multiple metrics
|
|
self.logger.record("custom/metric1", value1)
|
|
self.logger.record("custom/metric2", value2)
|
|
|
|
# Logger automatically writes to TensorBoard
|
|
return True
|
|
```
|
|
|
|
**View in TensorBoard:**
|
|
```bash
|
|
tensorboard --logdir ./logs/
|
|
```
|
|
|
|
## Advanced Patterns
|
|
|
|
### Curriculum Learning
|
|
|
|
```python
|
|
class CurriculumCallback(BaseCallback):
|
|
"""
|
|
Increase task difficulty over time.
|
|
"""
|
|
|
|
def __init__(self, difficulty_schedule, verbose=0):
|
|
super().__init__(verbose)
|
|
self.difficulty_schedule = difficulty_schedule
|
|
|
|
def _on_step(self) -> bool:
|
|
# Update environment difficulty based on progress
|
|
progress = self.num_timesteps / self.locals["total_timesteps"]
|
|
|
|
for threshold, difficulty in self.difficulty_schedule:
|
|
if progress >= threshold:
|
|
self.training_env.env_method("set_difficulty", difficulty)
|
|
|
|
return True
|
|
```
|
|
|
|
### Population-Based Training
|
|
|
|
```python
|
|
class PopulationBasedCallback(BaseCallback):
|
|
"""
|
|
Adjust hyperparameters based on performance.
|
|
"""
|
|
|
|
def __init__(self, check_freq=10000, verbose=0):
|
|
super().__init__(verbose)
|
|
self.check_freq = check_freq
|
|
self.performance_history = []
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.n_calls % self.check_freq == 0:
|
|
# Evaluate performance
|
|
perf = self._evaluate_performance()
|
|
self.performance_history.append(perf)
|
|
|
|
# Adjust hyperparameters if performance plateaus
|
|
if len(self.performance_history) >= 3:
|
|
recent = self.performance_history[-3:]
|
|
if max(recent) - min(recent) < 0.01: # Plateau detected
|
|
self._adjust_hyperparameters()
|
|
|
|
return True
|
|
|
|
def _adjust_hyperparameters(self):
|
|
# Example: increase learning rate
|
|
for param_group in self.model.policy.optimizer.param_groups:
|
|
param_group["lr"] *= 1.2
|
|
```
|
|
|
|
## Debugging Tips
|
|
|
|
### Print Available Attributes
|
|
|
|
```python
|
|
class DebugCallback(BaseCallback):
|
|
def _on_step(self) -> bool:
|
|
if self.n_calls == 1:
|
|
print("Available in self.locals:")
|
|
for key in self.locals.keys():
|
|
print(f" {key}: {type(self.locals[key])}")
|
|
return True
|
|
```
|
|
|
|
### Common Issues
|
|
|
|
1. **Callback not being called:**
|
|
- Ensure callback is passed to `model.learn()`
|
|
- Check that `_on_step()` returns `True`
|
|
|
|
2. **AttributeError in callback:**
|
|
- Not all attributes available in all callbacks
|
|
- Use `self.locals.get("key", default)` for safety
|
|
|
|
3. **Memory leaks:**
|
|
- Don't store large arrays in callback state
|
|
- Clear buffers periodically
|
|
|
|
4. **Performance impact:**
|
|
- Minimize computation in `_on_step()` (called every step)
|
|
- Use `check_freq` to limit expensive operations
|
|
|
|
## Best Practices
|
|
|
|
1. **Use appropriate callback timing:**
|
|
- `_on_step()`: For metrics that change every step
|
|
- `_on_rollout_end()`: For metrics computed over rollouts
|
|
- `_init_callback()`: For one-time initialization
|
|
|
|
2. **Log efficiently:**
|
|
- Don't log every step (hurts performance)
|
|
- Aggregate metrics and log periodically
|
|
|
|
3. **Handle vectorized environments:**
|
|
- Remember that `dones`, `infos`, etc. are arrays
|
|
- Check `dones[i]` for each environment
|
|
|
|
4. **Test callbacks independently:**
|
|
- Create simple test cases
|
|
- Verify callback behavior before long training runs
|
|
|
|
5. **Document custom callbacks:**
|
|
- Clear docstrings
|
|
- Example usage in comments
|
|
|
|
## Additional Resources
|
|
|
|
- Official SB3 Callbacks Guide: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
|
|
- Callback API Reference: https://stable-baselines3.readthedocs.io/en/master/common/callbacks.html
|
|
- TensorBoard Documentation: https://www.tensorflow.org/tensorboard
|