527 lines
13 KiB
Markdown
527 lines
13 KiB
Markdown
# Creating Custom Environments for Stable Baselines3
|
|
|
|
This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3.
|
|
|
|
## Environment Structure
|
|
|
|
### Required Methods
|
|
|
|
Every custom environment must inherit from `gymnasium.Env` and implement:
|
|
|
|
```python
|
|
import gymnasium as gym
|
|
from gymnasium import spaces
|
|
import numpy as np
|
|
|
|
class CustomEnv(gym.Env):
|
|
def __init__(self):
|
|
"""Initialize environment, define action_space and observation_space"""
|
|
super().__init__()
|
|
self.action_space = spaces.Discrete(4)
|
|
self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
|
|
|
|
def reset(self, seed=None, options=None):
|
|
"""Reset environment to initial state"""
|
|
super().reset(seed=seed)
|
|
observation = self.observation_space.sample()
|
|
info = {}
|
|
return observation, info
|
|
|
|
def step(self, action):
|
|
"""Execute one timestep"""
|
|
observation = self.observation_space.sample()
|
|
reward = 0.0
|
|
terminated = False # Episode ended naturally
|
|
truncated = False # Episode ended due to time limit
|
|
info = {}
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def render(self):
|
|
"""Visualize environment (optional)"""
|
|
pass
|
|
|
|
def close(self):
|
|
"""Cleanup resources (optional)"""
|
|
pass
|
|
```
|
|
|
|
### Method Details
|
|
|
|
#### `__init__(self, ...)`
|
|
|
|
**Purpose:** Initialize the environment and define spaces.
|
|
|
|
**Requirements:**
|
|
- Must call `super().__init__()`
|
|
- Must define `self.action_space`
|
|
- Must define `self.observation_space`
|
|
|
|
**Example:**
|
|
```python
|
|
def __init__(self, grid_size=10, max_steps=100):
|
|
super().__init__()
|
|
self.grid_size = grid_size
|
|
self.max_steps = max_steps
|
|
self.current_step = 0
|
|
|
|
# Define spaces
|
|
self.action_space = spaces.Discrete(4)
|
|
self.observation_space = spaces.Box(
|
|
low=0, high=grid_size-1, shape=(2,), dtype=np.float32
|
|
)
|
|
```
|
|
|
|
#### `reset(self, seed=None, options=None)`
|
|
|
|
**Purpose:** Reset the environment to an initial state.
|
|
|
|
**Requirements:**
|
|
- Must call `super().reset(seed=seed)`
|
|
- Must return `(observation, info)` tuple
|
|
- Observation must match `observation_space`
|
|
- Info must be a dictionary (can be empty)
|
|
|
|
**Example:**
|
|
```python
|
|
def reset(self, seed=None, options=None):
|
|
super().reset(seed=seed)
|
|
|
|
# Initialize state
|
|
self.agent_pos = self.np_random.integers(0, self.grid_size, size=2)
|
|
self.goal_pos = self.np_random.integers(0, self.grid_size, size=2)
|
|
self.current_step = 0
|
|
|
|
observation = self._get_observation()
|
|
info = {"episode": "started"}
|
|
|
|
return observation, info
|
|
```
|
|
|
|
#### `step(self, action)`
|
|
|
|
**Purpose:** Execute one timestep in the environment.
|
|
|
|
**Requirements:**
|
|
- Must return 5-tuple: `(observation, reward, terminated, truncated, info)`
|
|
- Action must be valid according to `action_space`
|
|
- Observation must match `observation_space`
|
|
- Reward should be a float
|
|
- Terminated: True if episode ended naturally (goal reached, failure, etc.)
|
|
- Truncated: True if episode ended due to time limit
|
|
- Info must be a dictionary
|
|
|
|
**Example:**
|
|
```python
|
|
def step(self, action):
|
|
# Apply action
|
|
self.agent_pos += self._action_to_direction(action)
|
|
self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1)
|
|
self.current_step += 1
|
|
|
|
# Calculate reward
|
|
distance = np.linalg.norm(self.agent_pos - self.goal_pos)
|
|
goal_reached = distance < 1.0
|
|
|
|
if goal_reached:
|
|
reward = 100.0
|
|
else:
|
|
reward = -distance * 0.1
|
|
|
|
# Check termination conditions
|
|
terminated = goal_reached
|
|
truncated = self.current_step >= self.max_steps
|
|
|
|
observation = self._get_observation()
|
|
info = {"distance": distance, "steps": self.current_step}
|
|
|
|
return observation, reward, terminated, truncated, info
|
|
```
|
|
|
|
## Space Types
|
|
|
|
### Discrete
|
|
|
|
For discrete actions (e.g., {0, 1, 2, 3}).
|
|
|
|
```python
|
|
self.action_space = spaces.Discrete(4) # 4 actions: 0, 1, 2, 3
|
|
```
|
|
|
|
**Important:** SB3 does NOT support `Discrete` spaces with `start != 0`. Always start from 0.
|
|
|
|
### Box (Continuous)
|
|
|
|
For continuous values within a range.
|
|
|
|
```python
|
|
# 1D continuous action in [-1, 1]
|
|
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
|
|
|
|
# 2D position observation
|
|
self.observation_space = spaces.Box(
|
|
low=0, high=10, shape=(2,), dtype=np.float32
|
|
)
|
|
|
|
# 3D RGB image (channel-first format)
|
|
self.observation_space = spaces.Box(
|
|
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
|
|
)
|
|
```
|
|
|
|
**Important for Images:**
|
|
- Must be `dtype=np.uint8` in range [0, 255]
|
|
- Use **channel-first** format: (channels, height, width)
|
|
- SB3 automatically normalizes by dividing by 255
|
|
- Set `normalize_images=False` in policy_kwargs if pre-normalized
|
|
|
|
### MultiDiscrete
|
|
|
|
For multiple discrete variables.
|
|
|
|
```python
|
|
# Two discrete variables: first with 3 options, second with 4 options
|
|
self.action_space = spaces.MultiDiscrete([3, 4])
|
|
```
|
|
|
|
### MultiBinary
|
|
|
|
For binary vectors.
|
|
|
|
```python
|
|
# 5 binary flags
|
|
self.action_space = spaces.MultiBinary(5) # e.g., [0, 1, 1, 0, 1]
|
|
```
|
|
|
|
### Dict
|
|
|
|
For dictionary observations (e.g., combining image with sensors).
|
|
|
|
```python
|
|
self.observation_space = spaces.Dict({
|
|
"image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8),
|
|
"vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32),
|
|
"discrete": spaces.Discrete(3),
|
|
})
|
|
```
|
|
|
|
**Important:** When using Dict observations, use `"MultiInputPolicy"` instead of `"MlpPolicy"`.
|
|
|
|
```python
|
|
model = PPO("MultiInputPolicy", env, verbose=1)
|
|
```
|
|
|
|
### Tuple
|
|
|
|
For tuple observations (less common).
|
|
|
|
```python
|
|
self.observation_space = spaces.Tuple((
|
|
spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32),
|
|
spaces.Discrete(3),
|
|
))
|
|
```
|
|
|
|
## Important Constraints and Best Practices
|
|
|
|
### Data Types
|
|
|
|
- **Observations:** Use `np.float32` for continuous values
|
|
- **Images:** Use `np.uint8` in range [0, 255]
|
|
- **Rewards:** Return Python float or `np.float32`
|
|
- **Terminated/Truncated:** Return Python bool
|
|
|
|
### Random Number Generation
|
|
|
|
Always use `self.np_random` for reproducibility:
|
|
|
|
```python
|
|
def reset(self, seed=None, options=None):
|
|
super().reset(seed=seed)
|
|
# Use self.np_random instead of np.random
|
|
random_pos = self.np_random.integers(0, 10, size=2)
|
|
random_float = self.np_random.random()
|
|
```
|
|
|
|
### Episode Termination
|
|
|
|
- **Terminated:** Natural ending (goal reached, agent died, etc.)
|
|
- **Truncated:** Artificial ending (time limit, external interrupt)
|
|
|
|
```python
|
|
def step(self, action):
|
|
# ... environment logic ...
|
|
|
|
goal_reached = self._check_goal()
|
|
time_limit_exceeded = self.current_step >= self.max_steps
|
|
|
|
terminated = goal_reached # Natural ending
|
|
truncated = time_limit_exceeded # Time limit
|
|
|
|
return observation, reward, terminated, truncated, info
|
|
```
|
|
|
|
### Info Dictionary
|
|
|
|
Use the info dict for debugging and logging:
|
|
|
|
```python
|
|
info = {
|
|
"episode_length": self.current_step,
|
|
"distance_to_goal": distance,
|
|
"success": goal_reached,
|
|
"total_reward": self.cumulative_reward,
|
|
}
|
|
```
|
|
|
|
**Special Keys:**
|
|
- `"terminal_observation"`: Automatically added by VecEnv when episode ends
|
|
|
|
## Advanced Features
|
|
|
|
### Metadata
|
|
|
|
Provide rendering information:
|
|
|
|
```python
|
|
class CustomEnv(gym.Env):
|
|
metadata = {
|
|
"render_modes": ["human", "rgb_array"],
|
|
"render_fps": 30,
|
|
}
|
|
|
|
def __init__(self, render_mode=None):
|
|
super().__init__()
|
|
self.render_mode = render_mode
|
|
# ...
|
|
```
|
|
|
|
### Render Modes
|
|
|
|
```python
|
|
def render(self):
|
|
if self.render_mode == "human":
|
|
# Print or display for human viewing
|
|
print(f"Agent at {self.agent_pos}")
|
|
|
|
elif self.render_mode == "rgb_array":
|
|
# Return numpy array (height, width, 3) for video recording
|
|
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
|
|
# Draw environment on canvas
|
|
return canvas
|
|
```
|
|
|
|
### Goal-Conditioned Environments (for HER)
|
|
|
|
For Hindsight Experience Replay, use specific observation structure:
|
|
|
|
```python
|
|
self.observation_space = spaces.Dict({
|
|
"observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
|
"achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
|
"desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
|
|
})
|
|
|
|
def compute_reward(self, achieved_goal, desired_goal, info):
|
|
"""Required for HER environments"""
|
|
distance = np.linalg.norm(achieved_goal - desired_goal)
|
|
return -distance
|
|
```
|
|
|
|
## Environment Validation
|
|
|
|
Always validate your environment before training:
|
|
|
|
```python
|
|
from stable_baselines3.common.env_checker import check_env
|
|
|
|
env = CustomEnv()
|
|
check_env(env, warn=True)
|
|
```
|
|
|
|
**Common Validation Errors:**
|
|
|
|
1. **"Observation is not within bounds"**
|
|
- Check that observations stay within defined space
|
|
- Ensure correct dtype (np.float32 for Box spaces)
|
|
|
|
2. **"Reset should return tuple"**
|
|
- Return `(observation, info)`, not just observation
|
|
|
|
3. **"Step should return 5-tuple"**
|
|
- Return `(obs, reward, terminated, truncated, info)`
|
|
|
|
4. **"Action is out of bounds"**
|
|
- Verify action_space definition matches expected actions
|
|
|
|
5. **"Observation/Action dtype mismatch"**
|
|
- Ensure observations match space dtype (usually np.float32)
|
|
|
|
## Environment Registration
|
|
|
|
Register your environment with Gymnasium:
|
|
|
|
```python
|
|
import gymnasium as gym
|
|
from gymnasium.envs.registration import register
|
|
|
|
register(
|
|
id="MyCustomEnv-v0",
|
|
entry_point="my_module:CustomEnv",
|
|
max_episode_steps=200,
|
|
kwargs={"grid_size": 10}, # Default kwargs
|
|
)
|
|
|
|
# Now can use with gym.make
|
|
env = gym.make("MyCustomEnv-v0")
|
|
```
|
|
|
|
## Testing Custom Environments
|
|
|
|
### Basic Testing
|
|
|
|
```python
|
|
def test_environment(env, n_episodes=5):
|
|
"""Test environment with random actions"""
|
|
for episode in range(n_episodes):
|
|
obs, info = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
steps = 0
|
|
|
|
while not done:
|
|
action = env.action_space.sample()
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
episode_reward += reward
|
|
steps += 1
|
|
done = terminated or truncated
|
|
|
|
print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}")
|
|
```
|
|
|
|
### Training Test
|
|
|
|
```python
|
|
from stable_baselines3 import PPO
|
|
|
|
def train_test(env, timesteps=10000):
|
|
"""Quick training test"""
|
|
model = PPO("MlpPolicy", env, verbose=1)
|
|
model.learn(total_timesteps=timesteps)
|
|
|
|
# Evaluate
|
|
obs, info = env.reset()
|
|
for _ in range(100):
|
|
action, _states = model.predict(obs, deterministic=True)
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
if terminated or truncated:
|
|
break
|
|
```
|
|
|
|
## Common Patterns
|
|
|
|
### Grid World
|
|
|
|
```python
|
|
class GridWorldEnv(gym.Env):
|
|
def __init__(self, size=10):
|
|
super().__init__()
|
|
self.size = size
|
|
self.action_space = spaces.Discrete(4) # up, down, left, right
|
|
self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32)
|
|
```
|
|
|
|
### Continuous Control
|
|
|
|
```python
|
|
class ContinuousEnv(gym.Env):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
|
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)
|
|
```
|
|
|
|
### Image-Based Environment
|
|
|
|
```python
|
|
class VisionEnv(gym.Env):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.action_space = spaces.Discrete(4)
|
|
# Channel-first: (channels, height, width)
|
|
self.observation_space = spaces.Box(
|
|
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
|
|
)
|
|
```
|
|
|
|
### Multi-Modal Environment
|
|
|
|
```python
|
|
class MultiModalEnv(gym.Env):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.action_space = spaces.Discrete(4)
|
|
self.observation_space = spaces.Dict({
|
|
"image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
|
|
"sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32),
|
|
})
|
|
```
|
|
|
|
## Performance Considerations
|
|
|
|
### Efficient Observation Generation
|
|
|
|
```python
|
|
# Pre-allocate arrays
|
|
def __init__(self):
|
|
# ...
|
|
self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32)
|
|
|
|
def _get_observation(self):
|
|
# Reuse buffer instead of allocating new array
|
|
self._obs_buffer[0] = self.agent_x
|
|
self._obs_buffer[1] = self.agent_y
|
|
return self._obs_buffer
|
|
```
|
|
|
|
### Vectorization
|
|
|
|
Make environment operations vectorizable:
|
|
|
|
```python
|
|
# Good: Uses numpy operations
|
|
def step(self, action):
|
|
direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action]
|
|
self.pos = np.clip(self.pos + direction, 0, self.size-1)
|
|
|
|
# Avoid: Python loops when possible
|
|
# for i in range(len(self.agents)):
|
|
# self.agents[i].update()
|
|
```
|
|
|
|
## Troubleshooting
|
|
|
|
### "Observation out of bounds"
|
|
- Check that all observations are within defined space
|
|
- Verify correct dtype (np.float32 vs np.float64)
|
|
|
|
### "NaN or Inf in observation/reward"
|
|
- Add checks: `assert np.isfinite(reward)`
|
|
- Use `VecCheckNan` wrapper to catch issues
|
|
|
|
### "Policy doesn't learn"
|
|
- Check reward scaling (normalize rewards)
|
|
- Verify observation normalization
|
|
- Ensure reward signal is meaningful
|
|
- Check if exploration is sufficient
|
|
|
|
### "Training crashes"
|
|
- Validate environment with `check_env()`
|
|
- Check for race conditions in custom env
|
|
- Verify action/observation spaces are consistent
|
|
|
|
## Additional Resources
|
|
|
|
- Template: See `scripts/custom_env_template.py`
|
|
- Gymnasium Documentation: https://gymnasium.farama.org/
|
|
- SB3 Custom Env Guide: https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html
|