15 KiB
Stable Baselines3 Callback System
This document provides comprehensive information about the callback system in Stable Baselines3 for monitoring and controlling training.
Overview
Callbacks are functions called at specific points during training to:
- Monitor training metrics
- Save checkpoints
- Implement early stopping
- Log custom metrics
- Adjust hyperparameters dynamically
- Trigger evaluations
Built-in Callbacks
EvalCallback
Evaluates the agent periodically and saves the best model.
from stable_baselines3.common.callbacks import EvalCallback
eval_callback = EvalCallback(
eval_env, # Separate evaluation environment
best_model_save_path="./logs/best_model/", # Where to save best model
log_path="./logs/eval/", # Where to save evaluation logs
eval_freq=10000, # Evaluate every N steps
n_eval_episodes=5, # Number of episodes per evaluation
deterministic=True, # Use deterministic actions
render=False, # Render during evaluation
verbose=1,
warn=True,
)
model.learn(total_timesteps=100000, callback=eval_callback)
Key Features:
- Automatically saves best model based on mean reward
- Logs evaluation metrics to TensorBoard
- Can stop training if reward threshold reached
Important: When using vectorized training environments, adjust eval_freq:
# With 4 parallel environments, divide eval_freq by n_envs
eval_freq = 10000 // 4 # Evaluate every 10000 total environment steps
CheckpointCallback
Saves model checkpoints at regular intervals.
from stable_baselines3.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(
save_freq=10000, # Save every N steps
save_path="./logs/checkpoints/", # Directory for checkpoints
name_prefix="rl_model", # Prefix for checkpoint files
save_replay_buffer=True, # Save replay buffer (off-policy only)
save_vecnormalize=True, # Save VecNormalize stats
verbose=2,
)
model.learn(total_timesteps=100000, callback=checkpoint_callback)
Output Files:
rl_model_10000_steps.zip- Model at 10k stepsrl_model_20000_steps.zip- Model at 20k steps- etc.
Important: Adjust save_freq for vectorized environments (divide by n_envs).
StopTrainingOnRewardThreshold
Stops training when mean reward exceeds a threshold.
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
stop_callback = StopTrainingOnRewardThreshold(
reward_threshold=200, # Stop when mean reward >= 200
verbose=1,
)
# Must be used with EvalCallback
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=stop_callback, # Trigger when new best found
eval_freq=10000,
n_eval_episodes=5,
)
model.learn(total_timesteps=1000000, callback=eval_callback)
StopTrainingOnNoModelImprovement
Stops training if model doesn't improve for N evaluations.
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement
stop_callback = StopTrainingOnNoModelImprovement(
max_no_improvement_evals=10, # Stop after 10 evals with no improvement
min_evals=20, # Minimum evaluations before stopping
verbose=1,
)
# Use with EvalCallback
eval_callback = EvalCallback(
eval_env,
callback_after_eval=stop_callback,
eval_freq=10000,
)
model.learn(total_timesteps=1000000, callback=eval_callback)
StopTrainingOnMaxEpisodes
Stops training after a maximum number of episodes.
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
stop_callback = StopTrainingOnMaxEpisodes(
max_episodes=1000, # Stop after 1000 episodes
verbose=1,
)
model.learn(total_timesteps=1000000, callback=stop_callback)
ProgressBarCallback
Displays a progress bar during training (requires tqdm).
from stable_baselines3.common.callbacks import ProgressBarCallback
progress_callback = ProgressBarCallback()
model.learn(total_timesteps=100000, callback=progress_callback)
Output:
100%|██████████| 100000/100000 [05:23<00:00, 309.31it/s]
Creating Custom Callbacks
BaseCallback Structure
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
Custom callback template.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
# Custom initialization
def _init_callback(self) -> None:
"""
Called once when training starts.
Useful for initialization that requires access to model/env.
"""
pass
def _on_training_start(self) -> None:
"""
Called before the first rollout starts.
"""
pass
def _on_rollout_start(self) -> None:
"""
Called before collecting new samples (on-policy algorithms).
"""
pass
def _on_step(self) -> bool:
"""
Called after every step in the environment.
Returns:
bool: If False, training will be stopped.
"""
return True # Continue training
def _on_rollout_end(self) -> None:
"""
Called after rollout ends (on-policy algorithms).
"""
pass
def _on_training_end(self) -> None:
"""
Called at the end of training.
"""
pass
Useful Attributes
Inside callbacks, you have access to:
self.model: The RL algorithm instanceself.training_env: The training environmentself.n_calls: Number of times_on_step()was calledself.num_timesteps: Total number of environment stepsself.locals: Local variables from the algorithm (varies by algorithm)self.globals: Global variables from the algorithmself.logger: Logger for TensorBoard/CSV loggingself.parent: Parent callback (if used in CallbackList)
Custom Callback Examples
Example 1: Log Custom Metrics
class LogCustomMetricsCallback(BaseCallback):
"""
Log custom metrics to TensorBoard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.episode_rewards = []
def _on_step(self) -> bool:
# Check if episode ended
if self.locals["dones"][0]:
# Log episode reward
episode_reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
self.episode_rewards.append(episode_reward)
# Log to TensorBoard
self.logger.record("custom/episode_reward", episode_reward)
self.logger.record("custom/mean_reward_last_100",
np.mean(self.episode_rewards[-100:]))
return True
Example 2: Adjust Learning Rate
class LinearScheduleCallback(BaseCallback):
"""
Linearly decrease learning rate during training.
"""
def __init__(self, initial_lr=3e-4, final_lr=3e-5, verbose=0):
super().__init__(verbose)
self.initial_lr = initial_lr
self.final_lr = final_lr
def _on_step(self) -> bool:
# Calculate progress (0 to 1)
progress = self.num_timesteps / self.locals["total_timesteps"]
# Linear interpolation
new_lr = self.initial_lr + (self.final_lr - self.initial_lr) * progress
# Update learning rate
for param_group in self.model.policy.optimizer.param_groups:
param_group["lr"] = new_lr
# Log learning rate
self.logger.record("train/learning_rate", new_lr)
return True
Example 3: Early Stopping on Moving Average
class EarlyStoppingCallback(BaseCallback):
"""
Stop training if moving average of rewards doesn't improve.
"""
def __init__(self, check_freq=10000, min_reward=200, window=100, verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.min_reward = min_reward
self.window = window
self.rewards = []
def _on_step(self) -> bool:
# Collect episode rewards
if self.locals["dones"][0]:
reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
self.rewards.append(reward)
# Check every check_freq steps
if self.n_calls % self.check_freq == 0 and len(self.rewards) >= self.window:
mean_reward = np.mean(self.rewards[-self.window:])
if self.verbose > 0:
print(f"Mean reward: {mean_reward:.2f}")
if mean_reward >= self.min_reward:
if self.verbose > 0:
print(f"Stopping: reward threshold reached!")
return False # Stop training
return True # Continue training
Example 4: Save Best Model by Custom Metric
class SaveBestModelCallback(BaseCallback):
"""
Save model when custom metric is best.
"""
def __init__(self, check_freq=1000, save_path="./best_model/", verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
self.best_score = -np.inf
def _init_callback(self) -> None:
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Calculate custom metric (example: policy entropy)
custom_metric = self.locals.get("entropy_losses", [0])[-1]
if custom_metric > self.best_score:
self.best_score = custom_metric
if self.verbose > 0:
print(f"New best! Saving model to {self.save_path}")
self.model.save(os.path.join(self.save_path, "best_model"))
return True
Example 5: Log Environment-Specific Information
class EnvironmentInfoCallback(BaseCallback):
"""
Log custom info from environment.
"""
def _on_step(self) -> bool:
# Access info dict from environment
info = self.locals["infos"][0]
# Log custom metrics from environment
if "distance_to_goal" in info:
self.logger.record("env/distance_to_goal", info["distance_to_goal"])
if "success" in info:
self.logger.record("env/success_rate", info["success"])
return True
Chaining Multiple Callbacks
Use CallbackList to combine multiple callbacks:
from stable_baselines3.common.callbacks import CallbackList
callback_list = CallbackList([
eval_callback,
checkpoint_callback,
progress_callback,
custom_callback,
])
model.learn(total_timesteps=100000, callback=callback_list)
Or pass a list directly:
model.learn(
total_timesteps=100000,
callback=[eval_callback, checkpoint_callback, custom_callback]
)
Event-Based Callbacks
Callbacks can trigger other callbacks on specific events:
from stable_baselines3.common.callbacks import EventCallback
# Stop training when reward threshold reached
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200)
# Evaluate periodically and trigger stop_callback when new best found
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=stop_callback, # Triggered when new best model
eval_freq=10000,
)
Logging to TensorBoard
Use self.logger.record() to log metrics:
class TensorBoardCallback(BaseCallback):
def _on_step(self) -> bool:
# Log scalar
self.logger.record("custom/my_metric", value)
# Log multiple metrics
self.logger.record("custom/metric1", value1)
self.logger.record("custom/metric2", value2)
# Logger automatically writes to TensorBoard
return True
View in TensorBoard:
tensorboard --logdir ./logs/
Advanced Patterns
Curriculum Learning
class CurriculumCallback(BaseCallback):
"""
Increase task difficulty over time.
"""
def __init__(self, difficulty_schedule, verbose=0):
super().__init__(verbose)
self.difficulty_schedule = difficulty_schedule
def _on_step(self) -> bool:
# Update environment difficulty based on progress
progress = self.num_timesteps / self.locals["total_timesteps"]
for threshold, difficulty in self.difficulty_schedule:
if progress >= threshold:
self.training_env.env_method("set_difficulty", difficulty)
return True
Population-Based Training
class PopulationBasedCallback(BaseCallback):
"""
Adjust hyperparameters based on performance.
"""
def __init__(self, check_freq=10000, verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.performance_history = []
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Evaluate performance
perf = self._evaluate_performance()
self.performance_history.append(perf)
# Adjust hyperparameters if performance plateaus
if len(self.performance_history) >= 3:
recent = self.performance_history[-3:]
if max(recent) - min(recent) < 0.01: # Plateau detected
self._adjust_hyperparameters()
return True
def _adjust_hyperparameters(self):
# Example: increase learning rate
for param_group in self.model.policy.optimizer.param_groups:
param_group["lr"] *= 1.2
Debugging Tips
Print Available Attributes
class DebugCallback(BaseCallback):
def _on_step(self) -> bool:
if self.n_calls == 1:
print("Available in self.locals:")
for key in self.locals.keys():
print(f" {key}: {type(self.locals[key])}")
return True
Common Issues
-
Callback not being called:
- Ensure callback is passed to
model.learn() - Check that
_on_step()returnsTrue
- Ensure callback is passed to
-
AttributeError in callback:
- Not all attributes available in all callbacks
- Use
self.locals.get("key", default)for safety
-
Memory leaks:
- Don't store large arrays in callback state
- Clear buffers periodically
-
Performance impact:
- Minimize computation in
_on_step()(called every step) - Use
check_freqto limit expensive operations
- Minimize computation in
Best Practices
-
Use appropriate callback timing:
_on_step(): For metrics that change every step_on_rollout_end(): For metrics computed over rollouts_init_callback(): For one-time initialization
-
Log efficiently:
- Don't log every step (hurts performance)
- Aggregate metrics and log periodically
-
Handle vectorized environments:
- Remember that
dones,infos, etc. are arrays - Check
dones[i]for each environment
- Remember that
-
Test callbacks independently:
- Create simple test cases
- Verify callback behavior before long training runs
-
Document custom callbacks:
- Clear docstrings
- Example usage in comments
Additional Resources
- Official SB3 Callbacks Guide: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
- Callback API Reference: https://stable-baselines3.readthedocs.io/en/master/common/callbacks.html
- TensorBoard Documentation: https://www.tensorflow.org/tensorboard