Initial commit
This commit is contained in:
333
skills/stable-baselines3/references/algorithms.md
Normal file
333
skills/stable-baselines3/references/algorithms.md
Normal file
@@ -0,0 +1,333 @@
|
||||
# Stable Baselines3 Algorithm Reference
|
||||
|
||||
This document provides detailed characteristics of all RL algorithms in Stable Baselines3 to help select the right algorithm for specific tasks.
|
||||
|
||||
## Algorithm Comparison Table
|
||||
|
||||
| Algorithm | Type | Action Space | Sample Efficiency | Training Speed | Use Case |
|
||||
|-----------|------|--------------|-------------------|----------------|----------|
|
||||
| **PPO** | On-Policy | All | Medium | Fast | General-purpose, stable |
|
||||
| **A2C** | On-Policy | All | Low | Very Fast | Quick prototyping, multiprocessing |
|
||||
| **SAC** | Off-Policy | Continuous | High | Medium | Continuous control, sample-efficient |
|
||||
| **TD3** | Off-Policy | Continuous | High | Medium | Continuous control, deterministic |
|
||||
| **DDPG** | Off-Policy | Continuous | High | Medium | Continuous control (use TD3 instead) |
|
||||
| **DQN** | Off-Policy | Discrete | Medium | Medium | Discrete actions, Atari games |
|
||||
| **HER** | Off-Policy | All | Very High | Medium | Goal-conditioned tasks |
|
||||
| **RecurrentPPO** | On-Policy | All | Medium | Slow | Partial observability (POMDP) |
|
||||
|
||||
## Detailed Algorithm Characteristics
|
||||
|
||||
### PPO (Proximal Policy Optimization)
|
||||
|
||||
**Overview:** General-purpose on-policy algorithm with good performance across many tasks.
|
||||
|
||||
**Strengths:**
|
||||
- Stable and reliable training
|
||||
- Works with all action space types (Discrete, Box, MultiDiscrete, MultiBinary)
|
||||
- Good balance between sample efficiency and training speed
|
||||
- Excellent for multiprocessing with vectorized environments
|
||||
- Easy to tune
|
||||
|
||||
**Weaknesses:**
|
||||
- Less sample-efficient than off-policy methods
|
||||
- Requires many environment interactions
|
||||
|
||||
**Best For:**
|
||||
- General-purpose RL tasks
|
||||
- When stability is important
|
||||
- When you have cheap environment simulations
|
||||
- Tasks with continuous or discrete actions
|
||||
|
||||
**Hyperparameter Guidance:**
|
||||
- `n_steps`: 2048-4096 for continuous, 128-256 for Atari
|
||||
- `learning_rate`: 3e-4 is a good default
|
||||
- `n_epochs`: 10 for continuous, 4 for Atari
|
||||
- `batch_size`: 64
|
||||
- `gamma`: 0.99 (0.995-0.999 for long episodes)
|
||||
|
||||
### A2C (Advantage Actor-Critic)
|
||||
|
||||
**Overview:** Synchronous variant of A3C, simpler than PPO but less stable.
|
||||
|
||||
**Strengths:**
|
||||
- Very fast training (simpler than PPO)
|
||||
- Works with all action space types
|
||||
- Good for quick prototyping
|
||||
- Memory efficient
|
||||
|
||||
**Weaknesses:**
|
||||
- Less stable than PPO
|
||||
- Requires careful hyperparameter tuning
|
||||
- Lower sample efficiency
|
||||
|
||||
**Best For:**
|
||||
- Quick experimentation
|
||||
- When training speed is critical
|
||||
- Simple environments
|
||||
|
||||
**Hyperparameter Guidance:**
|
||||
- `n_steps`: 5-256 depending on task
|
||||
- `learning_rate`: 7e-4
|
||||
- `gamma`: 0.99
|
||||
|
||||
### SAC (Soft Actor-Critic)
|
||||
|
||||
**Overview:** Off-policy algorithm with entropy regularization, state-of-the-art for continuous control.
|
||||
|
||||
**Strengths:**
|
||||
- Excellent sample efficiency
|
||||
- Very stable training
|
||||
- Automatic entropy tuning
|
||||
- Good exploration through stochastic policy
|
||||
- State-of-the-art for robotics
|
||||
|
||||
**Weaknesses:**
|
||||
- Only supports continuous action spaces (Box)
|
||||
- Slower wall-clock time than on-policy methods
|
||||
- More complex hyperparameters
|
||||
|
||||
**Best For:**
|
||||
- Continuous control (robotics, physics simulations)
|
||||
- When sample efficiency is critical
|
||||
- Expensive environment simulations
|
||||
- Tasks requiring good exploration
|
||||
|
||||
**Hyperparameter Guidance:**
|
||||
- `learning_rate`: 3e-4
|
||||
- `buffer_size`: 1M for most tasks
|
||||
- `learning_starts`: 10000
|
||||
- `batch_size`: 256
|
||||
- `tau`: 0.005 (target network update rate)
|
||||
- `train_freq`: 1 with `gradient_steps=-1` for best performance
|
||||
|
||||
### TD3 (Twin Delayed DDPG)
|
||||
|
||||
**Overview:** Improved DDPG with double Q-learning and delayed policy updates.
|
||||
|
||||
**Strengths:**
|
||||
- High sample efficiency
|
||||
- Deterministic policy (good for deployment)
|
||||
- More stable than DDPG
|
||||
- Good for continuous control
|
||||
|
||||
**Weaknesses:**
|
||||
- Only supports continuous action spaces (Box)
|
||||
- Less exploration than SAC
|
||||
- Requires careful tuning
|
||||
|
||||
**Best For:**
|
||||
- Continuous control tasks
|
||||
- When deterministic policies are preferred
|
||||
- Sample-efficient learning
|
||||
|
||||
**Hyperparameter Guidance:**
|
||||
- `learning_rate`: 1e-3
|
||||
- `buffer_size`: 1M
|
||||
- `learning_starts`: 10000
|
||||
- `batch_size`: 100
|
||||
- `policy_delay`: 2 (update policy every 2 critic updates)
|
||||
|
||||
### DDPG (Deep Deterministic Policy Gradient)
|
||||
|
||||
**Overview:** Early off-policy continuous control algorithm.
|
||||
|
||||
**Strengths:**
|
||||
- Continuous action space support
|
||||
- Off-policy learning
|
||||
|
||||
**Weaknesses:**
|
||||
- Less stable than TD3 or SAC
|
||||
- Sensitive to hyperparameters
|
||||
- Generally outperformed by TD3
|
||||
|
||||
**Best For:**
|
||||
- Legacy compatibility
|
||||
- **Recommendation:** Use TD3 instead for new projects
|
||||
|
||||
### DQN (Deep Q-Network)
|
||||
|
||||
**Overview:** Classic off-policy algorithm for discrete action spaces.
|
||||
|
||||
**Strengths:**
|
||||
- Sample-efficient for discrete actions
|
||||
- Experience replay enables reuse of past data
|
||||
- Proven success on Atari games
|
||||
|
||||
**Weaknesses:**
|
||||
- Only supports discrete action spaces
|
||||
- Can be unstable without proper tuning
|
||||
- Overestimation bias
|
||||
|
||||
**Best For:**
|
||||
- Discrete action tasks
|
||||
- Atari games and similar environments
|
||||
- When sample efficiency matters
|
||||
|
||||
**Hyperparameter Guidance:**
|
||||
- `learning_rate`: 1e-4
|
||||
- `buffer_size`: 100K-1M depending on task
|
||||
- `learning_starts`: 50000 for Atari
|
||||
- `batch_size`: 32
|
||||
- `exploration_fraction`: 0.1
|
||||
- `exploration_final_eps`: 0.05
|
||||
|
||||
**Variants:**
|
||||
- **QR-DQN**: Distributional RL version for better value estimates
|
||||
- **Maskable DQN**: For environments with action masking
|
||||
|
||||
### HER (Hindsight Experience Replay)
|
||||
|
||||
**Overview:** Not a standalone algorithm but a replay buffer strategy for goal-conditioned tasks.
|
||||
|
||||
**Strengths:**
|
||||
- Dramatically improves learning in sparse reward settings
|
||||
- Learns from failures by relabeling goals
|
||||
- Works with any off-policy algorithm (SAC, TD3, DQN)
|
||||
|
||||
**Weaknesses:**
|
||||
- Only for goal-conditioned environments
|
||||
- Requires specific observation structure (Dict with "observation", "achieved_goal", "desired_goal")
|
||||
|
||||
**Best For:**
|
||||
- Goal-conditioned tasks (robotics manipulation, navigation)
|
||||
- Sparse reward environments
|
||||
- Tasks where goal is clear but reward is binary
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from stable_baselines3 import SAC, HerReplayBuffer
|
||||
|
||||
model = SAC(
|
||||
"MultiInputPolicy",
|
||||
env,
|
||||
replay_buffer_class=HerReplayBuffer,
|
||||
replay_buffer_kwargs=dict(
|
||||
n_sampled_goal=4,
|
||||
goal_selection_strategy="future", # or "episode", "final"
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### RecurrentPPO
|
||||
|
||||
**Overview:** PPO with LSTM policy for handling partial observability.
|
||||
|
||||
**Strengths:**
|
||||
- Handles partial observability (POMDP)
|
||||
- Can learn temporal dependencies
|
||||
- Good for memory-required tasks
|
||||
|
||||
**Weaknesses:**
|
||||
- Slower training than standard PPO
|
||||
- More complex to tune
|
||||
- Requires sequential data
|
||||
|
||||
**Best For:**
|
||||
- Partially observable environments
|
||||
- Tasks requiring memory (e.g., navigation without full map)
|
||||
- Time-series problems
|
||||
|
||||
## Algorithm Selection Guide
|
||||
|
||||
### Decision Tree
|
||||
|
||||
1. **What is your action space?**
|
||||
- **Continuous (Box)** → Consider PPO, SAC, or TD3
|
||||
- **Discrete** → Consider PPO, A2C, or DQN
|
||||
- **MultiDiscrete/MultiBinary** → Use PPO or A2C
|
||||
|
||||
2. **Is sample efficiency critical?**
|
||||
- **Yes (expensive simulations)** → Use off-policy: SAC, TD3, DQN, or HER
|
||||
- **No (cheap simulations)** → Use on-policy: PPO, A2C
|
||||
|
||||
3. **Do you need fast wall-clock training?**
|
||||
- **Yes** → Use PPO or A2C with vectorized environments
|
||||
- **No** → Any algorithm works
|
||||
|
||||
4. **Is the task goal-conditioned with sparse rewards?**
|
||||
- **Yes** → Use HER with SAC or TD3
|
||||
- **No** → Continue with standard algorithms
|
||||
|
||||
5. **Is the environment partially observable?**
|
||||
- **Yes** → Use RecurrentPPO
|
||||
- **No** → Use standard algorithms
|
||||
|
||||
### Quick Recommendations
|
||||
|
||||
- **Starting out / General tasks:** PPO
|
||||
- **Continuous control / Robotics:** SAC
|
||||
- **Discrete actions / Atari:** DQN or PPO
|
||||
- **Goal-conditioned / Sparse rewards:** SAC + HER
|
||||
- **Fast prototyping:** A2C
|
||||
- **Sample efficiency critical:** SAC, TD3, or DQN
|
||||
- **Partial observability:** RecurrentPPO
|
||||
|
||||
## Training Configuration Tips
|
||||
|
||||
### For On-Policy Algorithms (PPO, A2C)
|
||||
|
||||
```python
|
||||
# Use vectorized environments for speed
|
||||
env = make_vec_env(env_id, n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
n_steps=2048, # Collect this many steps per environment before update
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
learning_rate=3e-4,
|
||||
gamma=0.99,
|
||||
)
|
||||
```
|
||||
|
||||
### For Off-Policy Algorithms (SAC, TD3, DQN)
|
||||
|
||||
```python
|
||||
# Fewer environments, but use gradient_steps=-1 for efficiency
|
||||
env = make_vec_env(env_id, n_envs=4)
|
||||
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
buffer_size=1_000_000,
|
||||
learning_starts=10000,
|
||||
batch_size=256,
|
||||
train_freq=1,
|
||||
gradient_steps=-1, # Do 1 gradient step per env step (4 with 4 envs)
|
||||
learning_rate=3e-4,
|
||||
)
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Using DQN with continuous actions** - DQN only works with discrete actions
|
||||
2. **Not using vectorized environments with PPO/A2C** - Wastes potential speedup
|
||||
3. **Using too few environments** - On-policy methods need many samples
|
||||
4. **Using too large replay buffer** - Can cause memory issues
|
||||
5. **Not tuning learning rate** - Critical for stable training
|
||||
6. **Ignoring reward scaling** - Normalize rewards for better learning
|
||||
7. **Wrong policy type** - Use "CnnPolicy" for images, "MultiInputPolicy" for dict observations
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Approximate expected performance (mean reward) on common benchmarks:
|
||||
|
||||
### Continuous Control (MuJoCo)
|
||||
- **HalfCheetah-v3**: PPO ~1800, SAC ~12000, TD3 ~9500
|
||||
- **Hopper-v3**: PPO ~2500, SAC ~3600, TD3 ~3600
|
||||
- **Walker2d-v3**: PPO ~3000, SAC ~5500, TD3 ~5000
|
||||
|
||||
### Discrete Control (Atari)
|
||||
- **Breakout**: PPO ~400, DQN ~300
|
||||
- **Pong**: PPO ~20, DQN ~20
|
||||
- **Space Invaders**: PPO ~1000, DQN ~800
|
||||
|
||||
*Note: Performance varies significantly with hyperparameters and training time.*
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **RL Baselines3 Zoo**: Collection of pre-trained agents and hyperparameters: https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
- **Hyperparameter Tuning**: Use Optuna for systematic tuning
|
||||
- **Custom Policies**: Extend base policies for custom network architectures
|
||||
- **Contribution Repo**: SB3-Contrib for experimental algorithms (QR-DQN, TQC, etc.)
|
||||
556
skills/stable-baselines3/references/callbacks.md
Normal file
556
skills/stable-baselines3/references/callbacks.md
Normal file
@@ -0,0 +1,556 @@
|
||||
# Stable Baselines3 Callback System
|
||||
|
||||
This document provides comprehensive information about the callback system in Stable Baselines3 for monitoring and controlling training.
|
||||
|
||||
## Overview
|
||||
|
||||
Callbacks are functions called at specific points during training to:
|
||||
- Monitor training metrics
|
||||
- Save checkpoints
|
||||
- Implement early stopping
|
||||
- Log custom metrics
|
||||
- Adjust hyperparameters dynamically
|
||||
- Trigger evaluations
|
||||
|
||||
## Built-in Callbacks
|
||||
|
||||
### EvalCallback
|
||||
|
||||
Evaluates the agent periodically and saves the best model.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
eval_env, # Separate evaluation environment
|
||||
best_model_save_path="./logs/best_model/", # Where to save best model
|
||||
log_path="./logs/eval/", # Where to save evaluation logs
|
||||
eval_freq=10000, # Evaluate every N steps
|
||||
n_eval_episodes=5, # Number of episodes per evaluation
|
||||
deterministic=True, # Use deterministic actions
|
||||
render=False, # Render during evaluation
|
||||
verbose=1,
|
||||
warn=True,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=100000, callback=eval_callback)
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Automatically saves best model based on mean reward
|
||||
- Logs evaluation metrics to TensorBoard
|
||||
- Can stop training if reward threshold reached
|
||||
|
||||
**Important:** When using vectorized training environments, adjust `eval_freq`:
|
||||
```python
|
||||
# With 4 parallel environments, divide eval_freq by n_envs
|
||||
eval_freq = 10000 // 4 # Evaluate every 10000 total environment steps
|
||||
```
|
||||
|
||||
### CheckpointCallback
|
||||
|
||||
Saves model checkpoints at regular intervals.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=10000, # Save every N steps
|
||||
save_path="./logs/checkpoints/", # Directory for checkpoints
|
||||
name_prefix="rl_model", # Prefix for checkpoint files
|
||||
save_replay_buffer=True, # Save replay buffer (off-policy only)
|
||||
save_vecnormalize=True, # Save VecNormalize stats
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=100000, callback=checkpoint_callback)
|
||||
```
|
||||
|
||||
**Output Files:**
|
||||
- `rl_model_10000_steps.zip` - Model at 10k steps
|
||||
- `rl_model_20000_steps.zip` - Model at 20k steps
|
||||
- etc.
|
||||
|
||||
**Important:** Adjust `save_freq` for vectorized environments (divide by n_envs).
|
||||
|
||||
### StopTrainingOnRewardThreshold
|
||||
|
||||
Stops training when mean reward exceeds a threshold.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
|
||||
|
||||
stop_callback = StopTrainingOnRewardThreshold(
|
||||
reward_threshold=200, # Stop when mean reward >= 200
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# Must be used with EvalCallback
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_on_new_best=stop_callback, # Trigger when new best found
|
||||
eval_freq=10000,
|
||||
n_eval_episodes=5,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=eval_callback)
|
||||
```
|
||||
|
||||
### StopTrainingOnNoModelImprovement
|
||||
|
||||
Stops training if model doesn't improve for N evaluations.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement
|
||||
|
||||
stop_callback = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=10, # Stop after 10 evals with no improvement
|
||||
min_evals=20, # Minimum evaluations before stopping
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# Use with EvalCallback
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_after_eval=stop_callback,
|
||||
eval_freq=10000,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=eval_callback)
|
||||
```
|
||||
|
||||
### StopTrainingOnMaxEpisodes
|
||||
|
||||
Stops training after a maximum number of episodes.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
|
||||
|
||||
stop_callback = StopTrainingOnMaxEpisodes(
|
||||
max_episodes=1000, # Stop after 1000 episodes
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000000, callback=stop_callback)
|
||||
```
|
||||
|
||||
### ProgressBarCallback
|
||||
|
||||
Displays a progress bar during training (requires tqdm).
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
||||
progress_callback = ProgressBarCallback()
|
||||
|
||||
model.learn(total_timesteps=100000, callback=progress_callback)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
100%|██████████| 100000/100000 [05:23<00:00, 309.31it/s]
|
||||
```
|
||||
|
||||
## Creating Custom Callbacks
|
||||
|
||||
### BaseCallback Structure
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
class CustomCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback template.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super().__init__(verbose)
|
||||
# Custom initialization
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
"""
|
||||
Called once when training starts.
|
||||
Useful for initialization that requires access to model/env.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
"""
|
||||
Called before the first rollout starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_rollout_start(self) -> None:
|
||||
"""
|
||||
Called before collecting new samples (on-policy algorithms).
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
"""
|
||||
Called after every step in the environment.
|
||||
|
||||
Returns:
|
||||
bool: If False, training will be stopped.
|
||||
"""
|
||||
return True # Continue training
|
||||
|
||||
def _on_rollout_end(self) -> None:
|
||||
"""
|
||||
Called after rollout ends (on-policy algorithms).
|
||||
"""
|
||||
pass
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
"""
|
||||
Called at the end of training.
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### Useful Attributes
|
||||
|
||||
Inside callbacks, you have access to:
|
||||
|
||||
- **`self.model`**: The RL algorithm instance
|
||||
- **`self.training_env`**: The training environment
|
||||
- **`self.n_calls`**: Number of times `_on_step()` was called
|
||||
- **`self.num_timesteps`**: Total number of environment steps
|
||||
- **`self.locals`**: Local variables from the algorithm (varies by algorithm)
|
||||
- **`self.globals`**: Global variables from the algorithm
|
||||
- **`self.logger`**: Logger for TensorBoard/CSV logging
|
||||
- **`self.parent`**: Parent callback (if used in CallbackList)
|
||||
|
||||
## Custom Callback Examples
|
||||
|
||||
### Example 1: Log Custom Metrics
|
||||
|
||||
```python
|
||||
class LogCustomMetricsCallback(BaseCallback):
|
||||
"""
|
||||
Log custom metrics to TensorBoard.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.episode_rewards = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Check if episode ended
|
||||
if self.locals["dones"][0]:
|
||||
# Log episode reward
|
||||
episode_reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
||||
self.episode_rewards.append(episode_reward)
|
||||
|
||||
# Log to TensorBoard
|
||||
self.logger.record("custom/episode_reward", episode_reward)
|
||||
self.logger.record("custom/mean_reward_last_100",
|
||||
np.mean(self.episode_rewards[-100:]))
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 2: Adjust Learning Rate
|
||||
|
||||
```python
|
||||
class LinearScheduleCallback(BaseCallback):
|
||||
"""
|
||||
Linearly decrease learning rate during training.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_lr=3e-4, final_lr=3e-5, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.initial_lr = initial_lr
|
||||
self.final_lr = final_lr
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Calculate progress (0 to 1)
|
||||
progress = self.num_timesteps / self.locals["total_timesteps"]
|
||||
|
||||
# Linear interpolation
|
||||
new_lr = self.initial_lr + (self.final_lr - self.initial_lr) * progress
|
||||
|
||||
# Update learning rate
|
||||
for param_group in self.model.policy.optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
|
||||
# Log learning rate
|
||||
self.logger.record("train/learning_rate", new_lr)
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 3: Early Stopping on Moving Average
|
||||
|
||||
```python
|
||||
class EarlyStoppingCallback(BaseCallback):
|
||||
"""
|
||||
Stop training if moving average of rewards doesn't improve.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=10000, min_reward=200, window=100, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.min_reward = min_reward
|
||||
self.window = window
|
||||
self.rewards = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Collect episode rewards
|
||||
if self.locals["dones"][0]:
|
||||
reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Check every check_freq steps
|
||||
if self.n_calls % self.check_freq == 0 and len(self.rewards) >= self.window:
|
||||
mean_reward = np.mean(self.rewards[-self.window:])
|
||||
if self.verbose > 0:
|
||||
print(f"Mean reward: {mean_reward:.2f}")
|
||||
|
||||
if mean_reward >= self.min_reward:
|
||||
if self.verbose > 0:
|
||||
print(f"Stopping: reward threshold reached!")
|
||||
return False # Stop training
|
||||
|
||||
return True # Continue training
|
||||
```
|
||||
|
||||
### Example 4: Save Best Model by Custom Metric
|
||||
|
||||
```python
|
||||
class SaveBestModelCallback(BaseCallback):
|
||||
"""
|
||||
Save model when custom metric is best.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=1000, save_path="./best_model/", verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.save_path = save_path
|
||||
self.best_score = -np.inf
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
if self.save_path is not None:
|
||||
os.makedirs(self.save_path, exist_ok=True)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.check_freq == 0:
|
||||
# Calculate custom metric (example: policy entropy)
|
||||
custom_metric = self.locals.get("entropy_losses", [0])[-1]
|
||||
|
||||
if custom_metric > self.best_score:
|
||||
self.best_score = custom_metric
|
||||
if self.verbose > 0:
|
||||
print(f"New best! Saving model to {self.save_path}")
|
||||
self.model.save(os.path.join(self.save_path, "best_model"))
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Example 5: Log Environment-Specific Information
|
||||
|
||||
```python
|
||||
class EnvironmentInfoCallback(BaseCallback):
|
||||
"""
|
||||
Log custom info from environment.
|
||||
"""
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Access info dict from environment
|
||||
info = self.locals["infos"][0]
|
||||
|
||||
# Log custom metrics from environment
|
||||
if "distance_to_goal" in info:
|
||||
self.logger.record("env/distance_to_goal", info["distance_to_goal"])
|
||||
|
||||
if "success" in info:
|
||||
self.logger.record("env/success_rate", info["success"])
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
## Chaining Multiple Callbacks
|
||||
|
||||
Use `CallbackList` to combine multiple callbacks:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import CallbackList
|
||||
|
||||
callback_list = CallbackList([
|
||||
eval_callback,
|
||||
checkpoint_callback,
|
||||
progress_callback,
|
||||
custom_callback,
|
||||
])
|
||||
|
||||
model.learn(total_timesteps=100000, callback=callback_list)
|
||||
```
|
||||
|
||||
Or pass a list directly:
|
||||
|
||||
```python
|
||||
model.learn(
|
||||
total_timesteps=100000,
|
||||
callback=[eval_callback, checkpoint_callback, custom_callback]
|
||||
)
|
||||
```
|
||||
|
||||
## Event-Based Callbacks
|
||||
|
||||
Callbacks can trigger other callbacks on specific events:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.callbacks import EventCallback
|
||||
|
||||
# Stop training when reward threshold reached
|
||||
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200)
|
||||
|
||||
# Evaluate periodically and trigger stop_callback when new best found
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_on_new_best=stop_callback, # Triggered when new best model
|
||||
eval_freq=10000,
|
||||
)
|
||||
```
|
||||
|
||||
## Logging to TensorBoard
|
||||
|
||||
Use `self.logger.record()` to log metrics:
|
||||
|
||||
```python
|
||||
class TensorBoardCallback(BaseCallback):
|
||||
def _on_step(self) -> bool:
|
||||
# Log scalar
|
||||
self.logger.record("custom/my_metric", value)
|
||||
|
||||
# Log multiple metrics
|
||||
self.logger.record("custom/metric1", value1)
|
||||
self.logger.record("custom/metric2", value2)
|
||||
|
||||
# Logger automatically writes to TensorBoard
|
||||
return True
|
||||
```
|
||||
|
||||
**View in TensorBoard:**
|
||||
```bash
|
||||
tensorboard --logdir ./logs/
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Curriculum Learning
|
||||
|
||||
```python
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""
|
||||
Increase task difficulty over time.
|
||||
"""
|
||||
|
||||
def __init__(self, difficulty_schedule, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.difficulty_schedule = difficulty_schedule
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Update environment difficulty based on progress
|
||||
progress = self.num_timesteps / self.locals["total_timesteps"]
|
||||
|
||||
for threshold, difficulty in self.difficulty_schedule:
|
||||
if progress >= threshold:
|
||||
self.training_env.env_method("set_difficulty", difficulty)
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### Population-Based Training
|
||||
|
||||
```python
|
||||
class PopulationBasedCallback(BaseCallback):
|
||||
"""
|
||||
Adjust hyperparameters based on performance.
|
||||
"""
|
||||
|
||||
def __init__(self, check_freq=10000, verbose=0):
|
||||
super().__init__(verbose)
|
||||
self.check_freq = check_freq
|
||||
self.performance_history = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.check_freq == 0:
|
||||
# Evaluate performance
|
||||
perf = self._evaluate_performance()
|
||||
self.performance_history.append(perf)
|
||||
|
||||
# Adjust hyperparameters if performance plateaus
|
||||
if len(self.performance_history) >= 3:
|
||||
recent = self.performance_history[-3:]
|
||||
if max(recent) - min(recent) < 0.01: # Plateau detected
|
||||
self._adjust_hyperparameters()
|
||||
|
||||
return True
|
||||
|
||||
def _adjust_hyperparameters(self):
|
||||
# Example: increase learning rate
|
||||
for param_group in self.model.policy.optimizer.param_groups:
|
||||
param_group["lr"] *= 1.2
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Print Available Attributes
|
||||
|
||||
```python
|
||||
class DebugCallback(BaseCallback):
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls == 1:
|
||||
print("Available in self.locals:")
|
||||
for key in self.locals.keys():
|
||||
print(f" {key}: {type(self.locals[key])}")
|
||||
return True
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Callback not being called:**
|
||||
- Ensure callback is passed to `model.learn()`
|
||||
- Check that `_on_step()` returns `True`
|
||||
|
||||
2. **AttributeError in callback:**
|
||||
- Not all attributes available in all callbacks
|
||||
- Use `self.locals.get("key", default)` for safety
|
||||
|
||||
3. **Memory leaks:**
|
||||
- Don't store large arrays in callback state
|
||||
- Clear buffers periodically
|
||||
|
||||
4. **Performance impact:**
|
||||
- Minimize computation in `_on_step()` (called every step)
|
||||
- Use `check_freq` to limit expensive operations
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use appropriate callback timing:**
|
||||
- `_on_step()`: For metrics that change every step
|
||||
- `_on_rollout_end()`: For metrics computed over rollouts
|
||||
- `_init_callback()`: For one-time initialization
|
||||
|
||||
2. **Log efficiently:**
|
||||
- Don't log every step (hurts performance)
|
||||
- Aggregate metrics and log periodically
|
||||
|
||||
3. **Handle vectorized environments:**
|
||||
- Remember that `dones`, `infos`, etc. are arrays
|
||||
- Check `dones[i]` for each environment
|
||||
|
||||
4. **Test callbacks independently:**
|
||||
- Create simple test cases
|
||||
- Verify callback behavior before long training runs
|
||||
|
||||
5. **Document custom callbacks:**
|
||||
- Clear docstrings
|
||||
- Example usage in comments
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official SB3 Callbacks Guide: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
|
||||
- Callback API Reference: https://stable-baselines3.readthedocs.io/en/master/common/callbacks.html
|
||||
- TensorBoard Documentation: https://www.tensorflow.org/tensorboard
|
||||
526
skills/stable-baselines3/references/custom_environments.md
Normal file
526
skills/stable-baselines3/references/custom_environments.md
Normal file
@@ -0,0 +1,526 @@
|
||||
# Creating Custom Environments for Stable Baselines3
|
||||
|
||||
This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3.
|
||||
|
||||
## Environment Structure
|
||||
|
||||
### Required Methods
|
||||
|
||||
Every custom environment must inherit from `gymnasium.Env` and implement:
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
|
||||
class CustomEnv(gym.Env):
|
||||
def __init__(self):
|
||||
"""Initialize environment, define action_space and observation_space"""
|
||||
super().__init__()
|
||||
self.action_space = spaces.Discrete(4)
|
||||
self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""Reset environment to initial state"""
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
info = {}
|
||||
return observation, info
|
||||
|
||||
def step(self, action):
|
||||
"""Execute one timestep"""
|
||||
observation = self.observation_space.sample()
|
||||
reward = 0.0
|
||||
terminated = False # Episode ended naturally
|
||||
truncated = False # Episode ended due to time limit
|
||||
info = {}
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self):
|
||||
"""Visualize environment (optional)"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Cleanup resources (optional)"""
|
||||
pass
|
||||
```
|
||||
|
||||
### Method Details
|
||||
|
||||
#### `__init__(self, ...)`
|
||||
|
||||
**Purpose:** Initialize the environment and define spaces.
|
||||
|
||||
**Requirements:**
|
||||
- Must call `super().__init__()`
|
||||
- Must define `self.action_space`
|
||||
- Must define `self.observation_space`
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def __init__(self, grid_size=10, max_steps=100):
|
||||
super().__init__()
|
||||
self.grid_size = grid_size
|
||||
self.max_steps = max_steps
|
||||
self.current_step = 0
|
||||
|
||||
# Define spaces
|
||||
self.action_space = spaces.Discrete(4)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=grid_size-1, shape=(2,), dtype=np.float32
|
||||
)
|
||||
```
|
||||
|
||||
#### `reset(self, seed=None, options=None)`
|
||||
|
||||
**Purpose:** Reset the environment to an initial state.
|
||||
|
||||
**Requirements:**
|
||||
- Must call `super().reset(seed=seed)`
|
||||
- Must return `(observation, info)` tuple
|
||||
- Observation must match `observation_space`
|
||||
- Info must be a dictionary (can be empty)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
|
||||
# Initialize state
|
||||
self.agent_pos = self.np_random.integers(0, self.grid_size, size=2)
|
||||
self.goal_pos = self.np_random.integers(0, self.grid_size, size=2)
|
||||
self.current_step = 0
|
||||
|
||||
observation = self._get_observation()
|
||||
info = {"episode": "started"}
|
||||
|
||||
return observation, info
|
||||
```
|
||||
|
||||
#### `step(self, action)`
|
||||
|
||||
**Purpose:** Execute one timestep in the environment.
|
||||
|
||||
**Requirements:**
|
||||
- Must return 5-tuple: `(observation, reward, terminated, truncated, info)`
|
||||
- Action must be valid according to `action_space`
|
||||
- Observation must match `observation_space`
|
||||
- Reward should be a float
|
||||
- Terminated: True if episode ended naturally (goal reached, failure, etc.)
|
||||
- Truncated: True if episode ended due to time limit
|
||||
- Info must be a dictionary
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def step(self, action):
|
||||
# Apply action
|
||||
self.agent_pos += self._action_to_direction(action)
|
||||
self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1)
|
||||
self.current_step += 1
|
||||
|
||||
# Calculate reward
|
||||
distance = np.linalg.norm(self.agent_pos - self.goal_pos)
|
||||
goal_reached = distance < 1.0
|
||||
|
||||
if goal_reached:
|
||||
reward = 100.0
|
||||
else:
|
||||
reward = -distance * 0.1
|
||||
|
||||
# Check termination conditions
|
||||
terminated = goal_reached
|
||||
truncated = self.current_step >= self.max_steps
|
||||
|
||||
observation = self._get_observation()
|
||||
info = {"distance": distance, "steps": self.current_step}
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
```
|
||||
|
||||
## Space Types
|
||||
|
||||
### Discrete
|
||||
|
||||
For discrete actions (e.g., {0, 1, 2, 3}).
|
||||
|
||||
```python
|
||||
self.action_space = spaces.Discrete(4) # 4 actions: 0, 1, 2, 3
|
||||
```
|
||||
|
||||
**Important:** SB3 does NOT support `Discrete` spaces with `start != 0`. Always start from 0.
|
||||
|
||||
### Box (Continuous)
|
||||
|
||||
For continuous values within a range.
|
||||
|
||||
```python
|
||||
# 1D continuous action in [-1, 1]
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
|
||||
|
||||
# 2D position observation
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=10, shape=(2,), dtype=np.float32
|
||||
)
|
||||
|
||||
# 3D RGB image (channel-first format)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
|
||||
)
|
||||
```
|
||||
|
||||
**Important for Images:**
|
||||
- Must be `dtype=np.uint8` in range [0, 255]
|
||||
- Use **channel-first** format: (channels, height, width)
|
||||
- SB3 automatically normalizes by dividing by 255
|
||||
- Set `normalize_images=False` in policy_kwargs if pre-normalized
|
||||
|
||||
### MultiDiscrete
|
||||
|
||||
For multiple discrete variables.
|
||||
|
||||
```python
|
||||
# Two discrete variables: first with 3 options, second with 4 options
|
||||
self.action_space = spaces.MultiDiscrete([3, 4])
|
||||
```
|
||||
|
||||
### MultiBinary
|
||||
|
||||
For binary vectors.
|
||||
|
||||
```python
|
||||
# 5 binary flags
|
||||
self.action_space = spaces.MultiBinary(5) # e.g., [0, 1, 1, 0, 1]
|
||||
```
|
||||
|
||||
### Dict
|
||||
|
||||
For dictionary observations (e.g., combining image with sensors).
|
||||
|
||||
```python
|
||||
self.observation_space = spaces.Dict({
|
||||
"image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8),
|
||||
"vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32),
|
||||
"discrete": spaces.Discrete(3),
|
||||
})
|
||||
```
|
||||
|
||||
**Important:** When using Dict observations, use `"MultiInputPolicy"` instead of `"MlpPolicy"`.
|
||||
|
||||
```python
|
||||
model = PPO("MultiInputPolicy", env, verbose=1)
|
||||
```
|
||||
|
||||
### Tuple
|
||||
|
||||
For tuple observations (less common).
|
||||
|
||||
```python
|
||||
self.observation_space = spaces.Tuple((
|
||||
spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32),
|
||||
spaces.Discrete(3),
|
||||
))
|
||||
```
|
||||
|
||||
## Important Constraints and Best Practices
|
||||
|
||||
### Data Types
|
||||
|
||||
- **Observations:** Use `np.float32` for continuous values
|
||||
- **Images:** Use `np.uint8` in range [0, 255]
|
||||
- **Rewards:** Return Python float or `np.float32`
|
||||
- **Terminated/Truncated:** Return Python bool
|
||||
|
||||
### Random Number Generation
|
||||
|
||||
Always use `self.np_random` for reproducibility:
|
||||
|
||||
```python
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
# Use self.np_random instead of np.random
|
||||
random_pos = self.np_random.integers(0, 10, size=2)
|
||||
random_float = self.np_random.random()
|
||||
```
|
||||
|
||||
### Episode Termination
|
||||
|
||||
- **Terminated:** Natural ending (goal reached, agent died, etc.)
|
||||
- **Truncated:** Artificial ending (time limit, external interrupt)
|
||||
|
||||
```python
|
||||
def step(self, action):
|
||||
# ... environment logic ...
|
||||
|
||||
goal_reached = self._check_goal()
|
||||
time_limit_exceeded = self.current_step >= self.max_steps
|
||||
|
||||
terminated = goal_reached # Natural ending
|
||||
truncated = time_limit_exceeded # Time limit
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
```
|
||||
|
||||
### Info Dictionary
|
||||
|
||||
Use the info dict for debugging and logging:
|
||||
|
||||
```python
|
||||
info = {
|
||||
"episode_length": self.current_step,
|
||||
"distance_to_goal": distance,
|
||||
"success": goal_reached,
|
||||
"total_reward": self.cumulative_reward,
|
||||
}
|
||||
```
|
||||
|
||||
**Special Keys:**
|
||||
- `"terminal_observation"`: Automatically added by VecEnv when episode ends
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Metadata
|
||||
|
||||
Provide rendering information:
|
||||
|
||||
```python
|
||||
class CustomEnv(gym.Env):
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 30,
|
||||
}
|
||||
|
||||
def __init__(self, render_mode=None):
|
||||
super().__init__()
|
||||
self.render_mode = render_mode
|
||||
# ...
|
||||
```
|
||||
|
||||
### Render Modes
|
||||
|
||||
```python
|
||||
def render(self):
|
||||
if self.render_mode == "human":
|
||||
# Print or display for human viewing
|
||||
print(f"Agent at {self.agent_pos}")
|
||||
|
||||
elif self.render_mode == "rgb_array":
|
||||
# Return numpy array (height, width, 3) for video recording
|
||||
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
|
||||
# Draw environment on canvas
|
||||
return canvas
|
||||
```
|
||||
|
||||
### Goal-Conditioned Environments (for HER)
|
||||
|
||||
For Hindsight Experience Replay, use specific observation structure:
|
||||
|
||||
```python
|
||||
self.observation_space = spaces.Dict({
|
||||
"observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
||||
"achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
||||
"desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
||||
})
|
||||
|
||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||
"""Required for HER environments"""
|
||||
distance = np.linalg.norm(achieved_goal - desired_goal)
|
||||
return -distance
|
||||
```
|
||||
|
||||
## Environment Validation
|
||||
|
||||
Always validate your environment before training:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
env = CustomEnv()
|
||||
check_env(env, warn=True)
|
||||
```
|
||||
|
||||
**Common Validation Errors:**
|
||||
|
||||
1. **"Observation is not within bounds"**
|
||||
- Check that observations stay within defined space
|
||||
- Ensure correct dtype (np.float32 for Box spaces)
|
||||
|
||||
2. **"Reset should return tuple"**
|
||||
- Return `(observation, info)`, not just observation
|
||||
|
||||
3. **"Step should return 5-tuple"**
|
||||
- Return `(obs, reward, terminated, truncated, info)`
|
||||
|
||||
4. **"Action is out of bounds"**
|
||||
- Verify action_space definition matches expected actions
|
||||
|
||||
5. **"Observation/Action dtype mismatch"**
|
||||
- Ensure observations match space dtype (usually np.float32)
|
||||
|
||||
## Environment Registration
|
||||
|
||||
Register your environment with Gymnasium:
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
register(
|
||||
id="MyCustomEnv-v0",
|
||||
entry_point="my_module:CustomEnv",
|
||||
max_episode_steps=200,
|
||||
kwargs={"grid_size": 10}, # Default kwargs
|
||||
)
|
||||
|
||||
# Now can use with gym.make
|
||||
env = gym.make("MyCustomEnv-v0")
|
||||
```
|
||||
|
||||
## Testing Custom Environments
|
||||
|
||||
### Basic Testing
|
||||
|
||||
```python
|
||||
def test_environment(env, n_episodes=5):
|
||||
"""Test environment with random actions"""
|
||||
for episode in range(n_episodes):
|
||||
obs, info = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
steps = 0
|
||||
|
||||
while not done:
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
episode_reward += reward
|
||||
steps += 1
|
||||
done = terminated or truncated
|
||||
|
||||
print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}")
|
||||
```
|
||||
|
||||
### Training Test
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
def train_test(env, timesteps=10000):
|
||||
"""Quick training test"""
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=timesteps)
|
||||
|
||||
# Evaluate
|
||||
obs, info = env.reset()
|
||||
for _ in range(100):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
break
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Grid World
|
||||
|
||||
```python
|
||||
class GridWorldEnv(gym.Env):
|
||||
def __init__(self, size=10):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.action_space = spaces.Discrete(4) # up, down, left, right
|
||||
self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32)
|
||||
```
|
||||
|
||||
### Continuous Control
|
||||
|
||||
```python
|
||||
class ContinuousEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)
|
||||
```
|
||||
|
||||
### Image-Based Environment
|
||||
|
||||
```python
|
||||
class VisionEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_space = spaces.Discrete(4)
|
||||
# Channel-first: (channels, height, width)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-Modal Environment
|
||||
|
||||
```python
|
||||
class MultiModalEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_space = spaces.Discrete(4)
|
||||
self.observation_space = spaces.Dict({
|
||||
"image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
|
||||
"sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32),
|
||||
})
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Efficient Observation Generation
|
||||
|
||||
```python
|
||||
# Pre-allocate arrays
|
||||
def __init__(self):
|
||||
# ...
|
||||
self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32)
|
||||
|
||||
def _get_observation(self):
|
||||
# Reuse buffer instead of allocating new array
|
||||
self._obs_buffer[0] = self.agent_x
|
||||
self._obs_buffer[1] = self.agent_y
|
||||
return self._obs_buffer
|
||||
```
|
||||
|
||||
### Vectorization
|
||||
|
||||
Make environment operations vectorizable:
|
||||
|
||||
```python
|
||||
# Good: Uses numpy operations
|
||||
def step(self, action):
|
||||
direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action]
|
||||
self.pos = np.clip(self.pos + direction, 0, self.size-1)
|
||||
|
||||
# Avoid: Python loops when possible
|
||||
# for i in range(len(self.agents)):
|
||||
# self.agents[i].update()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Observation out of bounds"
|
||||
- Check that all observations are within defined space
|
||||
- Verify correct dtype (np.float32 vs np.float64)
|
||||
|
||||
### "NaN or Inf in observation/reward"
|
||||
- Add checks: `assert np.isfinite(reward)`
|
||||
- Use `VecCheckNan` wrapper to catch issues
|
||||
|
||||
### "Policy doesn't learn"
|
||||
- Check reward scaling (normalize rewards)
|
||||
- Verify observation normalization
|
||||
- Ensure reward signal is meaningful
|
||||
- Check if exploration is sufficient
|
||||
|
||||
### "Training crashes"
|
||||
- Validate environment with `check_env()`
|
||||
- Check for race conditions in custom env
|
||||
- Verify action/observation spaces are consistent
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Template: See `scripts/custom_env_template.py`
|
||||
- Gymnasium Documentation: https://gymnasium.farama.org/
|
||||
- SB3 Custom Env Guide: https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html
|
||||
568
skills/stable-baselines3/references/vectorized_envs.md
Normal file
568
skills/stable-baselines3/references/vectorized_envs.md
Normal file
@@ -0,0 +1,568 @@
|
||||
# Vectorized Environments in Stable Baselines3
|
||||
|
||||
This document provides comprehensive information about vectorized environments in Stable Baselines3 for efficient parallel training.
|
||||
|
||||
## Overview
|
||||
|
||||
Vectorized environments stack multiple independent environment instances into a single environment that processes actions and observations in batches. Instead of interacting with one environment at a time, you interact with `n` environments simultaneously.
|
||||
|
||||
**Benefits:**
|
||||
- **Speed:** Parallel execution significantly accelerates training
|
||||
- **Sample efficiency:** Collect more diverse experiences faster
|
||||
- **Required for:** Frame stacking and normalization wrappers
|
||||
- **Better for:** On-policy algorithms (PPO, A2C)
|
||||
|
||||
## VecEnv Types
|
||||
|
||||
### DummyVecEnv
|
||||
|
||||
Executes environments sequentially on the current Python process.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
# Method 1: Using make_vec_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=DummyVecEnv)
|
||||
|
||||
# Method 2: Manual creation
|
||||
def make_env():
|
||||
def _init():
|
||||
return gym.make("CartPole-v1")
|
||||
return _init
|
||||
|
||||
env = DummyVecEnv([make_env() for _ in range(4)])
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Lightweight environments (CartPole, simple grids)
|
||||
- When multiprocessing overhead > computation time
|
||||
- Debugging (easier to trace errors)
|
||||
- Single-threaded environments
|
||||
|
||||
**Performance:** No actual parallelism (sequential execution).
|
||||
|
||||
### SubprocVecEnv
|
||||
|
||||
Executes each environment in a separate process, enabling true parallelism.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Computationally expensive environments (physics simulations, 3D games)
|
||||
- When environment computation time justifies multiprocessing overhead
|
||||
- When you need true parallel execution
|
||||
|
||||
**Important:** Requires wrapping code in `if __name__ == "__main__":` when using forkserver or spawn:
|
||||
|
||||
```python
|
||||
if __name__ == "__main__":
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=100000)
|
||||
```
|
||||
|
||||
**Performance:** True parallelism across CPU cores.
|
||||
|
||||
## Quick Setup with make_vec_env
|
||||
|
||||
The easiest way to create vectorized environments:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
# Basic usage
|
||||
env = make_vec_env("CartPole-v1", n_envs=4)
|
||||
|
||||
# With SubprocVecEnv
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
# With custom environment kwargs
|
||||
env = make_vec_env(
|
||||
"MyEnv-v0",
|
||||
n_envs=4,
|
||||
env_kwargs={"difficulty": "hard", "max_steps": 500}
|
||||
)
|
||||
|
||||
# With custom seed
|
||||
env = make_vec_env("CartPole-v1", n_envs=4, seed=42)
|
||||
```
|
||||
|
||||
## API Differences from Standard Gym
|
||||
|
||||
Vectorized environments have a different API than standard Gym environments:
|
||||
|
||||
### reset()
|
||||
|
||||
**Standard Gym:**
|
||||
```python
|
||||
obs, info = env.reset()
|
||||
```
|
||||
|
||||
**VecEnv:**
|
||||
```python
|
||||
obs = env.reset() # Returns only observations (numpy array)
|
||||
# Access info via env.reset_infos (list of dicts)
|
||||
infos = env.reset_infos
|
||||
```
|
||||
|
||||
### step()
|
||||
|
||||
**Standard Gym:**
|
||||
```python
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
```
|
||||
|
||||
**VecEnv:**
|
||||
```python
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
# Returns 4-tuple instead of 5-tuple
|
||||
# dones = terminated | truncated
|
||||
# actions is an array of shape (n_envs,) or (n_envs, action_dim)
|
||||
```
|
||||
|
||||
### Auto-reset
|
||||
|
||||
**VecEnv automatically resets environments when episodes end:**
|
||||
|
||||
```python
|
||||
obs = env.reset() # Shape: (n_envs, obs_dim)
|
||||
for _ in range(1000):
|
||||
actions = env.action_space.sample() # Shape: (n_envs,)
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
# If dones[i] is True, env i was automatically reset
|
||||
# Final observation before reset available in infos[i]["terminal_observation"]
|
||||
```
|
||||
|
||||
### Terminal Observations
|
||||
|
||||
When an episode ends, access the true final observation:
|
||||
|
||||
```python
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
# The obs[i] is already the reset observation
|
||||
# True terminal observation is in info
|
||||
terminal_obs = infos[i]["terminal_observation"]
|
||||
print(f"Episode ended with terminal observation: {terminal_obs}")
|
||||
```
|
||||
|
||||
## Training with Vectorized Environments
|
||||
|
||||
### On-Policy Algorithms (PPO, A2C)
|
||||
|
||||
On-policy algorithms benefit greatly from vectorization:
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
# Create vectorized environment
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
# Train
|
||||
model = PPO("MlpPolicy", env, verbose=1, n_steps=128)
|
||||
model.learn(total_timesteps=100000)
|
||||
|
||||
# With n_envs=8 and n_steps=128:
|
||||
# - Collects 8*128=1024 steps per rollout
|
||||
# - Updates after every 1024 steps
|
||||
```
|
||||
|
||||
**Rule of thumb:** Use 4-16 parallel environments for on-policy methods.
|
||||
|
||||
### Off-Policy Algorithms (SAC, TD3, DQN)
|
||||
|
||||
Off-policy algorithms can use vectorization but benefit less:
|
||||
|
||||
```python
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
# Use fewer environments (1-4)
|
||||
env = make_vec_env("Pendulum-v1", n_envs=4)
|
||||
|
||||
# Set gradient_steps=-1 for efficiency
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
train_freq=1,
|
||||
gradient_steps=-1, # Do 1 gradient step per env step (4 total with 4 envs)
|
||||
)
|
||||
model.learn(total_timesteps=50000)
|
||||
```
|
||||
|
||||
**Rule of thumb:** Use 1-4 parallel environments for off-policy methods.
|
||||
|
||||
## Wrappers for Vectorized Environments
|
||||
|
||||
### VecNormalize
|
||||
|
||||
Normalizes observations and rewards using running statistics.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
env = make_vec_env("Pendulum-v1", n_envs=4)
|
||||
|
||||
# Wrap with normalization
|
||||
env = VecNormalize(
|
||||
env,
|
||||
norm_obs=True, # Normalize observations
|
||||
norm_reward=True, # Normalize rewards
|
||||
clip_obs=10.0, # Clip normalized observations
|
||||
clip_reward=10.0, # Clip normalized rewards
|
||||
gamma=0.99, # Discount factor for reward normalization
|
||||
)
|
||||
|
||||
# Train
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=50000)
|
||||
|
||||
# Save model AND normalization statistics
|
||||
model.save("ppo_pendulum")
|
||||
env.save("vec_normalize.pkl")
|
||||
|
||||
# Load for evaluation
|
||||
env = make_vec_env("Pendulum-v1", n_envs=1)
|
||||
env = VecNormalize.load("vec_normalize.pkl", env)
|
||||
env.training = False # Don't update stats during evaluation
|
||||
env.norm_reward = False # Don't normalize rewards during evaluation
|
||||
|
||||
model = PPO.load("ppo_pendulum", env=env)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Continuous control tasks (especially MuJoCo)
|
||||
- When observation scales vary widely
|
||||
- When rewards have high variance
|
||||
|
||||
**Important:**
|
||||
- Statistics are NOT saved with model - save separately
|
||||
- Disable training and reward normalization during evaluation
|
||||
|
||||
### VecFrameStack
|
||||
|
||||
Stacks observations from multiple consecutive frames.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import VecFrameStack
|
||||
|
||||
env = make_vec_env("PongNoFrameskip-v4", n_envs=8)
|
||||
|
||||
# Stack 4 frames
|
||||
env = VecFrameStack(env, n_stack=4)
|
||||
|
||||
# Now observations have shape: (n_envs, n_stack, height, width)
|
||||
model = PPO("CnnPolicy", env)
|
||||
model.learn(total_timesteps=1000000)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Atari games (stack 4 frames)
|
||||
- Environments where velocity information is needed
|
||||
- Partial observability problems
|
||||
|
||||
### VecVideoRecorder
|
||||
|
||||
Records videos of agent behavior.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import VecVideoRecorder
|
||||
|
||||
env = make_vec_env("CartPole-v1", n_envs=1)
|
||||
|
||||
# Record videos
|
||||
env = VecVideoRecorder(
|
||||
env,
|
||||
video_folder="./videos/",
|
||||
record_video_trigger=lambda x: x % 2000 == 0, # Record every 2000 steps
|
||||
video_length=200, # Max video length
|
||||
name_prefix="training"
|
||||
)
|
||||
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=10000)
|
||||
```
|
||||
|
||||
**Output:** MP4 videos in `./videos/` directory.
|
||||
|
||||
### VecCheckNan
|
||||
|
||||
Checks for NaN or infinite values in observations and rewards.
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import VecCheckNan
|
||||
|
||||
env = make_vec_env("CustomEnv-v0", n_envs=4)
|
||||
|
||||
# Add NaN checking (useful for debugging)
|
||||
env = VecCheckNan(env, raise_exception=True, warn_once=True)
|
||||
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=10000)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Debugging custom environments
|
||||
- Catching numerical instabilities
|
||||
- Validating environment implementation
|
||||
|
||||
### VecTransposeImage
|
||||
|
||||
Transposes image observations from (height, width, channels) to (channels, height, width).
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage
|
||||
|
||||
env = make_vec_env("PongNoFrameskip-v4", n_envs=4)
|
||||
|
||||
# Convert HWC to CHW format
|
||||
env = VecTransposeImage(env)
|
||||
|
||||
model = PPO("CnnPolicy", env)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- When environment returns images in HWC format
|
||||
- SB3 expects CHW format for CNN policies
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom VecEnv
|
||||
|
||||
Create custom vectorized environment:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
import gymnasium as gym
|
||||
|
||||
class CustomVecEnv(DummyVecEnv):
|
||||
def step_wait(self):
|
||||
# Custom logic before/after stepping
|
||||
obs, rewards, dones, infos = super().step_wait()
|
||||
# Modify observations/rewards/etc
|
||||
return obs, rewards, dones, infos
|
||||
```
|
||||
|
||||
### Environment Method Calls
|
||||
|
||||
Call methods on wrapped environments:
|
||||
|
||||
```python
|
||||
env = make_vec_env("MyEnv-v0", n_envs=4)
|
||||
|
||||
# Call method on all environments
|
||||
env.env_method("set_difficulty", "hard")
|
||||
|
||||
# Call method on specific environment
|
||||
env.env_method("reset_level", indices=[0, 2])
|
||||
|
||||
# Get attribute from all environments
|
||||
levels = env.get_attr("current_level")
|
||||
```
|
||||
|
||||
### Setting Attributes
|
||||
|
||||
```python
|
||||
# Set attribute on all environments
|
||||
env.set_attr("difficulty", "hard")
|
||||
|
||||
# Set attribute on specific environments
|
||||
env.set_attr("max_steps", 1000, indices=[1, 3])
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Choosing Number of Environments
|
||||
|
||||
**On-Policy (PPO, A2C):**
|
||||
```python
|
||||
# General rule: 4-16 environments
|
||||
# More environments = faster data collection
|
||||
n_envs = 8
|
||||
env = make_vec_env("CartPole-v1", n_envs=n_envs)
|
||||
|
||||
# Adjust n_steps to maintain same rollout length
|
||||
# Total steps per rollout = n_envs * n_steps
|
||||
model = PPO("MlpPolicy", env, n_steps=128) # 8*128 = 1024 steps/rollout
|
||||
```
|
||||
|
||||
**Off-Policy (SAC, TD3, DQN):**
|
||||
```python
|
||||
# General rule: 1-4 environments
|
||||
# More doesn't help as much (replay buffer provides diversity)
|
||||
n_envs = 4
|
||||
env = make_vec_env("Pendulum-v1", n_envs=n_envs)
|
||||
|
||||
model = SAC("MlpPolicy", env, gradient_steps=-1) # 1 grad step per env step
|
||||
```
|
||||
|
||||
### CPU Core Utilization
|
||||
|
||||
```python
|
||||
import multiprocessing
|
||||
|
||||
# Use one less than total cores (leave one for Python main process)
|
||||
n_cpus = multiprocessing.cpu_count() - 1
|
||||
env = make_vec_env("MyEnv-v0", n_envs=n_cpus, vec_env_cls=SubprocVecEnv)
|
||||
```
|
||||
|
||||
### Memory Considerations
|
||||
|
||||
```python
|
||||
# Large replay buffer + many environments = high memory usage
|
||||
# Reduce buffer size if memory constrained
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
buffer_size=100_000, # Reduced from 1M
|
||||
)
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: "Can't pickle local object"
|
||||
|
||||
**Cause:** SubprocVecEnv requires picklable environments.
|
||||
|
||||
**Solution:** Define environment creation outside class/function:
|
||||
|
||||
```python
|
||||
# Bad
|
||||
def train():
|
||||
def make_env():
|
||||
return gym.make("CartPole-v1")
|
||||
env = SubprocVecEnv([make_env for _ in range(4)])
|
||||
|
||||
# Good
|
||||
def make_env():
|
||||
return gym.make("CartPole-v1")
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = SubprocVecEnv([make_env for _ in range(4)])
|
||||
```
|
||||
|
||||
### Issue: Different behavior between single and vectorized env
|
||||
|
||||
**Cause:** Auto-reset in vectorized environments.
|
||||
|
||||
**Solution:** Handle terminal observations correctly:
|
||||
|
||||
```python
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
terminal_obs = infos[i]["terminal_observation"]
|
||||
# Process terminal_obs if needed
|
||||
```
|
||||
|
||||
### Issue: Slower with SubprocVecEnv than DummyVecEnv
|
||||
|
||||
**Cause:** Environment too lightweight (multiprocessing overhead > computation).
|
||||
|
||||
**Solution:** Use DummyVecEnv for simple environments:
|
||||
|
||||
```python
|
||||
# For CartPole, use DummyVecEnv
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=DummyVecEnv)
|
||||
```
|
||||
|
||||
### Issue: Training crashes with SubprocVecEnv
|
||||
|
||||
**Cause:** Environment not properly isolated or has shared state.
|
||||
|
||||
**Solution:**
|
||||
- Ensure environment has no shared global state
|
||||
- Wrap code in `if __name__ == "__main__":`
|
||||
- Use DummyVecEnv for debugging
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use appropriate VecEnv type:**
|
||||
- DummyVecEnv: Simple environments (CartPole, basic grids)
|
||||
- SubprocVecEnv: Complex environments (MuJoCo, Unity, 3D games)
|
||||
|
||||
2. **Adjust hyperparameters for vectorization:**
|
||||
- Divide `eval_freq`, `save_freq` by `n_envs` in callbacks
|
||||
- Maintain same `n_steps * n_envs` for on-policy algorithms
|
||||
|
||||
3. **Save normalization statistics:**
|
||||
- Always save VecNormalize stats with model
|
||||
- Disable training during evaluation
|
||||
|
||||
4. **Monitor memory usage:**
|
||||
- More environments = more memory
|
||||
- Reduce buffer size if needed
|
||||
|
||||
5. **Test with DummyVecEnv first:**
|
||||
- Easier debugging
|
||||
- Ensure environment works before parallelizing
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Training Loop
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
# Create vectorized environment
|
||||
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
|
||||
|
||||
# Train
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=100000)
|
||||
|
||||
# Evaluate
|
||||
obs = env.reset()
|
||||
for _ in range(1000):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, rewards, dones, infos = env.step(action)
|
||||
```
|
||||
|
||||
### With Normalization
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
# Create and normalize
|
||||
env = make_vec_env("Pendulum-v1", n_envs=4)
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True)
|
||||
|
||||
# Train
|
||||
model = PPO("MlpPolicy", env)
|
||||
model.learn(total_timesteps=50000)
|
||||
|
||||
# Save both
|
||||
model.save("model")
|
||||
env.save("vec_normalize.pkl")
|
||||
|
||||
# Load for evaluation
|
||||
eval_env = make_vec_env("Pendulum-v1", n_envs=1)
|
||||
eval_env = VecNormalize.load("vec_normalize.pkl", eval_env)
|
||||
eval_env.training = False
|
||||
eval_env.norm_reward = False
|
||||
|
||||
model = PPO.load("model", env=eval_env)
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official SB3 VecEnv Guide: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html
|
||||
- VecEnv API Reference: https://stable-baselines3.readthedocs.io/en/master/common/vec_env.html
|
||||
- Multiprocessing Best Practices: https://docs.python.org/3/library/multiprocessing.html
|
||||
Reference in New Issue
Block a user