13 KiB
Creating Custom Environments for Stable Baselines3
This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3.
Environment Structure
Required Methods
Every custom environment must inherit from gymnasium.Env and implement:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class CustomEnv(gym.Env):
def __init__(self):
"""Initialize environment, define action_space and observation_space"""
super().__init__()
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
def reset(self, seed=None, options=None):
"""Reset environment to initial state"""
super().reset(seed=seed)
observation = self.observation_space.sample()
info = {}
return observation, info
def step(self, action):
"""Execute one timestep"""
observation = self.observation_space.sample()
reward = 0.0
terminated = False # Episode ended naturally
truncated = False # Episode ended due to time limit
info = {}
return observation, reward, terminated, truncated, info
def render(self):
"""Visualize environment (optional)"""
pass
def close(self):
"""Cleanup resources (optional)"""
pass
Method Details
__init__(self, ...)
Purpose: Initialize the environment and define spaces.
Requirements:
- Must call
super().__init__() - Must define
self.action_space - Must define
self.observation_space
Example:
def __init__(self, grid_size=10, max_steps=100):
super().__init__()
self.grid_size = grid_size
self.max_steps = max_steps
self.current_step = 0
# Define spaces
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(
low=0, high=grid_size-1, shape=(2,), dtype=np.float32
)
reset(self, seed=None, options=None)
Purpose: Reset the environment to an initial state.
Requirements:
- Must call
super().reset(seed=seed) - Must return
(observation, info)tuple - Observation must match
observation_space - Info must be a dictionary (can be empty)
Example:
def reset(self, seed=None, options=None):
super().reset(seed=seed)
# Initialize state
self.agent_pos = self.np_random.integers(0, self.grid_size, size=2)
self.goal_pos = self.np_random.integers(0, self.grid_size, size=2)
self.current_step = 0
observation = self._get_observation()
info = {"episode": "started"}
return observation, info
step(self, action)
Purpose: Execute one timestep in the environment.
Requirements:
- Must return 5-tuple:
(observation, reward, terminated, truncated, info) - Action must be valid according to
action_space - Observation must match
observation_space - Reward should be a float
- Terminated: True if episode ended naturally (goal reached, failure, etc.)
- Truncated: True if episode ended due to time limit
- Info must be a dictionary
Example:
def step(self, action):
# Apply action
self.agent_pos += self._action_to_direction(action)
self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1)
self.current_step += 1
# Calculate reward
distance = np.linalg.norm(self.agent_pos - self.goal_pos)
goal_reached = distance < 1.0
if goal_reached:
reward = 100.0
else:
reward = -distance * 0.1
# Check termination conditions
terminated = goal_reached
truncated = self.current_step >= self.max_steps
observation = self._get_observation()
info = {"distance": distance, "steps": self.current_step}
return observation, reward, terminated, truncated, info
Space Types
Discrete
For discrete actions (e.g., {0, 1, 2, 3}).
self.action_space = spaces.Discrete(4) # 4 actions: 0, 1, 2, 3
Important: SB3 does NOT support Discrete spaces with start != 0. Always start from 0.
Box (Continuous)
For continuous values within a range.
# 1D continuous action in [-1, 1]
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
# 2D position observation
self.observation_space = spaces.Box(
low=0, high=10, shape=(2,), dtype=np.float32
)
# 3D RGB image (channel-first format)
self.observation_space = spaces.Box(
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
)
Important for Images:
- Must be
dtype=np.uint8in range [0, 255] - Use channel-first format: (channels, height, width)
- SB3 automatically normalizes by dividing by 255
- Set
normalize_images=Falsein policy_kwargs if pre-normalized
MultiDiscrete
For multiple discrete variables.
# Two discrete variables: first with 3 options, second with 4 options
self.action_space = spaces.MultiDiscrete([3, 4])
MultiBinary
For binary vectors.
# 5 binary flags
self.action_space = spaces.MultiBinary(5) # e.g., [0, 1, 1, 0, 1]
Dict
For dictionary observations (e.g., combining image with sensors).
self.observation_space = spaces.Dict({
"image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8),
"vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32),
"discrete": spaces.Discrete(3),
})
Important: When using Dict observations, use "MultiInputPolicy" instead of "MlpPolicy".
model = PPO("MultiInputPolicy", env, verbose=1)
Tuple
For tuple observations (less common).
self.observation_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32),
spaces.Discrete(3),
))
Important Constraints and Best Practices
Data Types
- Observations: Use
np.float32for continuous values - Images: Use
np.uint8in range [0, 255] - Rewards: Return Python float or
np.float32 - Terminated/Truncated: Return Python bool
Random Number Generation
Always use self.np_random for reproducibility:
def reset(self, seed=None, options=None):
super().reset(seed=seed)
# Use self.np_random instead of np.random
random_pos = self.np_random.integers(0, 10, size=2)
random_float = self.np_random.random()
Episode Termination
- Terminated: Natural ending (goal reached, agent died, etc.)
- Truncated: Artificial ending (time limit, external interrupt)
def step(self, action):
# ... environment logic ...
goal_reached = self._check_goal()
time_limit_exceeded = self.current_step >= self.max_steps
terminated = goal_reached # Natural ending
truncated = time_limit_exceeded # Time limit
return observation, reward, terminated, truncated, info
Info Dictionary
Use the info dict for debugging and logging:
info = {
"episode_length": self.current_step,
"distance_to_goal": distance,
"success": goal_reached,
"total_reward": self.cumulative_reward,
}
Special Keys:
"terminal_observation": Automatically added by VecEnv when episode ends
Advanced Features
Metadata
Provide rendering information:
class CustomEnv(gym.Env):
metadata = {
"render_modes": ["human", "rgb_array"],
"render_fps": 30,
}
def __init__(self, render_mode=None):
super().__init__()
self.render_mode = render_mode
# ...
Render Modes
def render(self):
if self.render_mode == "human":
# Print or display for human viewing
print(f"Agent at {self.agent_pos}")
elif self.render_mode == "rgb_array":
# Return numpy array (height, width, 3) for video recording
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
# Draw environment on canvas
return canvas
Goal-Conditioned Environments (for HER)
For Hindsight Experience Replay, use specific observation structure:
self.observation_space = spaces.Dict({
"observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
"achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
"desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
})
def compute_reward(self, achieved_goal, desired_goal, info):
"""Required for HER environments"""
distance = np.linalg.norm(achieved_goal - desired_goal)
return -distance
Environment Validation
Always validate your environment before training:
from stable_baselines3.common.env_checker import check_env
env = CustomEnv()
check_env(env, warn=True)
Common Validation Errors:
-
"Observation is not within bounds"
- Check that observations stay within defined space
- Ensure correct dtype (np.float32 for Box spaces)
-
"Reset should return tuple"
- Return
(observation, info), not just observation
- Return
-
"Step should return 5-tuple"
- Return
(obs, reward, terminated, truncated, info)
- Return
-
"Action is out of bounds"
- Verify action_space definition matches expected actions
-
"Observation/Action dtype mismatch"
- Ensure observations match space dtype (usually np.float32)
Environment Registration
Register your environment with Gymnasium:
import gymnasium as gym
from gymnasium.envs.registration import register
register(
id="MyCustomEnv-v0",
entry_point="my_module:CustomEnv",
max_episode_steps=200,
kwargs={"grid_size": 10}, # Default kwargs
)
# Now can use with gym.make
env = gym.make("MyCustomEnv-v0")
Testing Custom Environments
Basic Testing
def test_environment(env, n_episodes=5):
"""Test environment with random actions"""
for episode in range(n_episodes):
obs, info = env.reset()
episode_reward = 0
done = False
steps = 0
while not done:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
steps += 1
done = terminated or truncated
print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}")
Training Test
from stable_baselines3 import PPO
def train_test(env, timesteps=10000):
"""Quick training test"""
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=timesteps)
# Evaluate
obs, info = env.reset()
for _ in range(100):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
break
Common Patterns
Grid World
class GridWorldEnv(gym.Env):
def __init__(self, size=10):
super().__init__()
self.size = size
self.action_space = spaces.Discrete(4) # up, down, left, right
self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32)
Continuous Control
class ContinuousEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)
Image-Based Environment
class VisionEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Discrete(4)
# Channel-first: (channels, height, width)
self.observation_space = spaces.Box(
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
)
Multi-Modal Environment
class MultiModalEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Dict({
"image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
"sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32),
})
Performance Considerations
Efficient Observation Generation
# Pre-allocate arrays
def __init__(self):
# ...
self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32)
def _get_observation(self):
# Reuse buffer instead of allocating new array
self._obs_buffer[0] = self.agent_x
self._obs_buffer[1] = self.agent_y
return self._obs_buffer
Vectorization
Make environment operations vectorizable:
# Good: Uses numpy operations
def step(self, action):
direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action]
self.pos = np.clip(self.pos + direction, 0, self.size-1)
# Avoid: Python loops when possible
# for i in range(len(self.agents)):
# self.agents[i].update()
Troubleshooting
"Observation out of bounds"
- Check that all observations are within defined space
- Verify correct dtype (np.float32 vs np.float64)
"NaN or Inf in observation/reward"
- Add checks:
assert np.isfinite(reward) - Use
VecCheckNanwrapper to catch issues
"Policy doesn't learn"
- Check reward scaling (normalize rewards)
- Verify observation normalization
- Ensure reward signal is meaningful
- Check if exploration is sufficient
"Training crashes"
- Validate environment with
check_env() - Check for race conditions in custom env
- Verify action/observation spaces are consistent
Additional Resources
- Template: See
scripts/custom_env_template.py - Gymnasium Documentation: https://gymnasium.farama.org/
- SB3 Custom Env Guide: https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html