Files
2025-11-30 08:30:10 +08:00

569 lines
14 KiB
Markdown

# Vectorized Environments in Stable Baselines3
This document provides comprehensive information about vectorized environments in Stable Baselines3 for efficient parallel training.
## Overview
Vectorized environments stack multiple independent environment instances into a single environment that processes actions and observations in batches. Instead of interacting with one environment at a time, you interact with `n` environments simultaneously.
**Benefits:**
- **Speed:** Parallel execution significantly accelerates training
- **Sample efficiency:** Collect more diverse experiences faster
- **Required for:** Frame stacking and normalization wrappers
- **Better for:** On-policy algorithms (PPO, A2C)
## VecEnv Types
### DummyVecEnv
Executes environments sequentially on the current Python process.
```python
from stable_baselines3.common.vec_env import DummyVecEnv
# Method 1: Using make_vec_env
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=DummyVecEnv)
# Method 2: Manual creation
def make_env():
def _init():
return gym.make("CartPole-v1")
return _init
env = DummyVecEnv([make_env() for _ in range(4)])
```
**When to use:**
- Lightweight environments (CartPole, simple grids)
- When multiprocessing overhead > computation time
- Debugging (easier to trace errors)
- Single-threaded environments
**Performance:** No actual parallelism (sequential execution).
### SubprocVecEnv
Executes each environment in a separate process, enabling true parallelism.
```python
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
```
**When to use:**
- Computationally expensive environments (physics simulations, 3D games)
- When environment computation time justifies multiprocessing overhead
- When you need true parallel execution
**Important:** Requires wrapping code in `if __name__ == "__main__":` when using forkserver or spawn:
```python
if __name__ == "__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=100000)
```
**Performance:** True parallelism across CPU cores.
## Quick Setup with make_vec_env
The easiest way to create vectorized environments:
```python
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
# Basic usage
env = make_vec_env("CartPole-v1", n_envs=4)
# With SubprocVecEnv
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
# With custom environment kwargs
env = make_vec_env(
"MyEnv-v0",
n_envs=4,
env_kwargs={"difficulty": "hard", "max_steps": 500}
)
# With custom seed
env = make_vec_env("CartPole-v1", n_envs=4, seed=42)
```
## API Differences from Standard Gym
Vectorized environments have a different API than standard Gym environments:
### reset()
**Standard Gym:**
```python
obs, info = env.reset()
```
**VecEnv:**
```python
obs = env.reset() # Returns only observations (numpy array)
# Access info via env.reset_infos (list of dicts)
infos = env.reset_infos
```
### step()
**Standard Gym:**
```python
obs, reward, terminated, truncated, info = env.step(action)
```
**VecEnv:**
```python
obs, rewards, dones, infos = env.step(actions)
# Returns 4-tuple instead of 5-tuple
# dones = terminated | truncated
# actions is an array of shape (n_envs,) or (n_envs, action_dim)
```
### Auto-reset
**VecEnv automatically resets environments when episodes end:**
```python
obs = env.reset() # Shape: (n_envs, obs_dim)
for _ in range(1000):
actions = env.action_space.sample() # Shape: (n_envs,)
obs, rewards, dones, infos = env.step(actions)
# If dones[i] is True, env i was automatically reset
# Final observation before reset available in infos[i]["terminal_observation"]
```
### Terminal Observations
When an episode ends, access the true final observation:
```python
obs, rewards, dones, infos = env.step(actions)
for i, done in enumerate(dones):
if done:
# The obs[i] is already the reset observation
# True terminal observation is in info
terminal_obs = infos[i]["terminal_observation"]
print(f"Episode ended with terminal observation: {terminal_obs}")
```
## Training with Vectorized Environments
### On-Policy Algorithms (PPO, A2C)
On-policy algorithms benefit greatly from vectorization:
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
# Create vectorized environment
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
# Train
model = PPO("MlpPolicy", env, verbose=1, n_steps=128)
model.learn(total_timesteps=100000)
# With n_envs=8 and n_steps=128:
# - Collects 8*128=1024 steps per rollout
# - Updates after every 1024 steps
```
**Rule of thumb:** Use 4-16 parallel environments for on-policy methods.
### Off-Policy Algorithms (SAC, TD3, DQN)
Off-policy algorithms can use vectorization but benefit less:
```python
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
# Use fewer environments (1-4)
env = make_vec_env("Pendulum-v1", n_envs=4)
# Set gradient_steps=-1 for efficiency
model = SAC(
"MlpPolicy",
env,
verbose=1,
train_freq=1,
gradient_steps=-1, # Do 1 gradient step per env step (4 total with 4 envs)
)
model.learn(total_timesteps=50000)
```
**Rule of thumb:** Use 1-4 parallel environments for off-policy methods.
## Wrappers for Vectorized Environments
### VecNormalize
Normalizes observations and rewards using running statistics.
```python
from stable_baselines3.common.vec_env import VecNormalize
env = make_vec_env("Pendulum-v1", n_envs=4)
# Wrap with normalization
env = VecNormalize(
env,
norm_obs=True, # Normalize observations
norm_reward=True, # Normalize rewards
clip_obs=10.0, # Clip normalized observations
clip_reward=10.0, # Clip normalized rewards
gamma=0.99, # Discount factor for reward normalization
)
# Train
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=50000)
# Save model AND normalization statistics
model.save("ppo_pendulum")
env.save("vec_normalize.pkl")
# Load for evaluation
env = make_vec_env("Pendulum-v1", n_envs=1)
env = VecNormalize.load("vec_normalize.pkl", env)
env.training = False # Don't update stats during evaluation
env.norm_reward = False # Don't normalize rewards during evaluation
model = PPO.load("ppo_pendulum", env=env)
```
**When to use:**
- Continuous control tasks (especially MuJoCo)
- When observation scales vary widely
- When rewards have high variance
**Important:**
- Statistics are NOT saved with model - save separately
- Disable training and reward normalization during evaluation
### VecFrameStack
Stacks observations from multiple consecutive frames.
```python
from stable_baselines3.common.vec_env import VecFrameStack
env = make_vec_env("PongNoFrameskip-v4", n_envs=8)
# Stack 4 frames
env = VecFrameStack(env, n_stack=4)
# Now observations have shape: (n_envs, n_stack, height, width)
model = PPO("CnnPolicy", env)
model.learn(total_timesteps=1000000)
```
**When to use:**
- Atari games (stack 4 frames)
- Environments where velocity information is needed
- Partial observability problems
### VecVideoRecorder
Records videos of agent behavior.
```python
from stable_baselines3.common.vec_env import VecVideoRecorder
env = make_vec_env("CartPole-v1", n_envs=1)
# Record videos
env = VecVideoRecorder(
env,
video_folder="./videos/",
record_video_trigger=lambda x: x % 2000 == 0, # Record every 2000 steps
video_length=200, # Max video length
name_prefix="training"
)
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=10000)
```
**Output:** MP4 videos in `./videos/` directory.
### VecCheckNan
Checks for NaN or infinite values in observations and rewards.
```python
from stable_baselines3.common.vec_env import VecCheckNan
env = make_vec_env("CustomEnv-v0", n_envs=4)
# Add NaN checking (useful for debugging)
env = VecCheckNan(env, raise_exception=True, warn_once=True)
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=10000)
```
**When to use:**
- Debugging custom environments
- Catching numerical instabilities
- Validating environment implementation
### VecTransposeImage
Transposes image observations from (height, width, channels) to (channels, height, width).
```python
from stable_baselines3.common.vec_env import VecTransposeImage
env = make_vec_env("PongNoFrameskip-v4", n_envs=4)
# Convert HWC to CHW format
env = VecTransposeImage(env)
model = PPO("CnnPolicy", env)
```
**When to use:**
- When environment returns images in HWC format
- SB3 expects CHW format for CNN policies
## Advanced Usage
### Custom VecEnv
Create custom vectorized environment:
```python
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym
class CustomVecEnv(DummyVecEnv):
def step_wait(self):
# Custom logic before/after stepping
obs, rewards, dones, infos = super().step_wait()
# Modify observations/rewards/etc
return obs, rewards, dones, infos
```
### Environment Method Calls
Call methods on wrapped environments:
```python
env = make_vec_env("MyEnv-v0", n_envs=4)
# Call method on all environments
env.env_method("set_difficulty", "hard")
# Call method on specific environment
env.env_method("reset_level", indices=[0, 2])
# Get attribute from all environments
levels = env.get_attr("current_level")
```
### Setting Attributes
```python
# Set attribute on all environments
env.set_attr("difficulty", "hard")
# Set attribute on specific environments
env.set_attr("max_steps", 1000, indices=[1, 3])
```
## Performance Optimization
### Choosing Number of Environments
**On-Policy (PPO, A2C):**
```python
# General rule: 4-16 environments
# More environments = faster data collection
n_envs = 8
env = make_vec_env("CartPole-v1", n_envs=n_envs)
# Adjust n_steps to maintain same rollout length
# Total steps per rollout = n_envs * n_steps
model = PPO("MlpPolicy", env, n_steps=128) # 8*128 = 1024 steps/rollout
```
**Off-Policy (SAC, TD3, DQN):**
```python
# General rule: 1-4 environments
# More doesn't help as much (replay buffer provides diversity)
n_envs = 4
env = make_vec_env("Pendulum-v1", n_envs=n_envs)
model = SAC("MlpPolicy", env, gradient_steps=-1) # 1 grad step per env step
```
### CPU Core Utilization
```python
import multiprocessing
# Use one less than total cores (leave one for Python main process)
n_cpus = multiprocessing.cpu_count() - 1
env = make_vec_env("MyEnv-v0", n_envs=n_cpus, vec_env_cls=SubprocVecEnv)
```
### Memory Considerations
```python
# Large replay buffer + many environments = high memory usage
# Reduce buffer size if memory constrained
model = SAC(
"MlpPolicy",
env,
buffer_size=100_000, # Reduced from 1M
)
```
## Common Issues
### Issue: "Can't pickle local object"
**Cause:** SubprocVecEnv requires picklable environments.
**Solution:** Define environment creation outside class/function:
```python
# Bad
def train():
def make_env():
return gym.make("CartPole-v1")
env = SubprocVecEnv([make_env for _ in range(4)])
# Good
def make_env():
return gym.make("CartPole-v1")
if __name__ == "__main__":
env = SubprocVecEnv([make_env for _ in range(4)])
```
### Issue: Different behavior between single and vectorized env
**Cause:** Auto-reset in vectorized environments.
**Solution:** Handle terminal observations correctly:
```python
obs, rewards, dones, infos = env.step(actions)
for i, done in enumerate(dones):
if done:
terminal_obs = infos[i]["terminal_observation"]
# Process terminal_obs if needed
```
### Issue: Slower with SubprocVecEnv than DummyVecEnv
**Cause:** Environment too lightweight (multiprocessing overhead > computation).
**Solution:** Use DummyVecEnv for simple environments:
```python
# For CartPole, use DummyVecEnv
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=DummyVecEnv)
```
### Issue: Training crashes with SubprocVecEnv
**Cause:** Environment not properly isolated or has shared state.
**Solution:**
- Ensure environment has no shared global state
- Wrap code in `if __name__ == "__main__":`
- Use DummyVecEnv for debugging
## Best Practices
1. **Use appropriate VecEnv type:**
- DummyVecEnv: Simple environments (CartPole, basic grids)
- SubprocVecEnv: Complex environments (MuJoCo, Unity, 3D games)
2. **Adjust hyperparameters for vectorization:**
- Divide `eval_freq`, `save_freq` by `n_envs` in callbacks
- Maintain same `n_steps * n_envs` for on-policy algorithms
3. **Save normalization statistics:**
- Always save VecNormalize stats with model
- Disable training during evaluation
4. **Monitor memory usage:**
- More environments = more memory
- Reduce buffer size if needed
5. **Test with DummyVecEnv first:**
- Easier debugging
- Ensure environment works before parallelizing
## Examples
### Basic Training Loop
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
# Create vectorized environment
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
# Train
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000)
# Evaluate
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, infos = env.step(action)
```
### With Normalization
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
# Create and normalize
env = make_vec_env("Pendulum-v1", n_envs=4)
env = VecNormalize(env, norm_obs=True, norm_reward=True)
# Train
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=50000)
# Save both
model.save("model")
env.save("vec_normalize.pkl")
# Load for evaluation
eval_env = make_vec_env("Pendulum-v1", n_envs=1)
eval_env = VecNormalize.load("vec_normalize.pkl", eval_env)
eval_env.training = False
eval_env.norm_reward = False
model = PPO.load("model", env=eval_env)
```
## Additional Resources
- Official SB3 VecEnv Guide: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html
- VecEnv API Reference: https://stable-baselines3.readthedocs.io/en/master/common/vec_env.html
- Multiprocessing Best Practices: https://docs.python.org/3/library/multiprocessing.html