# Creating Custom Environments for Stable Baselines3 This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3. ## Environment Structure ### Required Methods Every custom environment must inherit from `gymnasium.Env` and implement: ```python import gymnasium as gym from gymnasium import spaces import numpy as np class CustomEnv(gym.Env): def __init__(self): """Initialize environment, define action_space and observation_space""" super().__init__() self.action_space = spaces.Discrete(4) self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32) def reset(self, seed=None, options=None): """Reset environment to initial state""" super().reset(seed=seed) observation = self.observation_space.sample() info = {} return observation, info def step(self, action): """Execute one timestep""" observation = self.observation_space.sample() reward = 0.0 terminated = False # Episode ended naturally truncated = False # Episode ended due to time limit info = {} return observation, reward, terminated, truncated, info def render(self): """Visualize environment (optional)""" pass def close(self): """Cleanup resources (optional)""" pass ``` ### Method Details #### `__init__(self, ...)` **Purpose:** Initialize the environment and define spaces. **Requirements:** - Must call `super().__init__()` - Must define `self.action_space` - Must define `self.observation_space` **Example:** ```python def __init__(self, grid_size=10, max_steps=100): super().__init__() self.grid_size = grid_size self.max_steps = max_steps self.current_step = 0 # Define spaces self.action_space = spaces.Discrete(4) self.observation_space = spaces.Box( low=0, high=grid_size-1, shape=(2,), dtype=np.float32 ) ``` #### `reset(self, seed=None, options=None)` **Purpose:** Reset the environment to an initial state. **Requirements:** - Must call `super().reset(seed=seed)` - Must return `(observation, info)` tuple - Observation must match `observation_space` - Info must be a dictionary (can be empty) **Example:** ```python def reset(self, seed=None, options=None): super().reset(seed=seed) # Initialize state self.agent_pos = self.np_random.integers(0, self.grid_size, size=2) self.goal_pos = self.np_random.integers(0, self.grid_size, size=2) self.current_step = 0 observation = self._get_observation() info = {"episode": "started"} return observation, info ``` #### `step(self, action)` **Purpose:** Execute one timestep in the environment. **Requirements:** - Must return 5-tuple: `(observation, reward, terminated, truncated, info)` - Action must be valid according to `action_space` - Observation must match `observation_space` - Reward should be a float - Terminated: True if episode ended naturally (goal reached, failure, etc.) - Truncated: True if episode ended due to time limit - Info must be a dictionary **Example:** ```python def step(self, action): # Apply action self.agent_pos += self._action_to_direction(action) self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1) self.current_step += 1 # Calculate reward distance = np.linalg.norm(self.agent_pos - self.goal_pos) goal_reached = distance < 1.0 if goal_reached: reward = 100.0 else: reward = -distance * 0.1 # Check termination conditions terminated = goal_reached truncated = self.current_step >= self.max_steps observation = self._get_observation() info = {"distance": distance, "steps": self.current_step} return observation, reward, terminated, truncated, info ``` ## Space Types ### Discrete For discrete actions (e.g., {0, 1, 2, 3}). ```python self.action_space = spaces.Discrete(4) # 4 actions: 0, 1, 2, 3 ``` **Important:** SB3 does NOT support `Discrete` spaces with `start != 0`. Always start from 0. ### Box (Continuous) For continuous values within a range. ```python # 1D continuous action in [-1, 1] self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) # 2D position observation self.observation_space = spaces.Box( low=0, high=10, shape=(2,), dtype=np.float32 ) # 3D RGB image (channel-first format) self.observation_space = spaces.Box( low=0, high=255, shape=(3, 84, 84), dtype=np.uint8 ) ``` **Important for Images:** - Must be `dtype=np.uint8` in range [0, 255] - Use **channel-first** format: (channels, height, width) - SB3 automatically normalizes by dividing by 255 - Set `normalize_images=False` in policy_kwargs if pre-normalized ### MultiDiscrete For multiple discrete variables. ```python # Two discrete variables: first with 3 options, second with 4 options self.action_space = spaces.MultiDiscrete([3, 4]) ``` ### MultiBinary For binary vectors. ```python # 5 binary flags self.action_space = spaces.MultiBinary(5) # e.g., [0, 1, 1, 0, 1] ``` ### Dict For dictionary observations (e.g., combining image with sensors). ```python self.observation_space = spaces.Dict({ "image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8), "vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32), "discrete": spaces.Discrete(3), }) ``` **Important:** When using Dict observations, use `"MultiInputPolicy"` instead of `"MlpPolicy"`. ```python model = PPO("MultiInputPolicy", env, verbose=1) ``` ### Tuple For tuple observations (less common). ```python self.observation_space = spaces.Tuple(( spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32), spaces.Discrete(3), )) ``` ## Important Constraints and Best Practices ### Data Types - **Observations:** Use `np.float32` for continuous values - **Images:** Use `np.uint8` in range [0, 255] - **Rewards:** Return Python float or `np.float32` - **Terminated/Truncated:** Return Python bool ### Random Number Generation Always use `self.np_random` for reproducibility: ```python def reset(self, seed=None, options=None): super().reset(seed=seed) # Use self.np_random instead of np.random random_pos = self.np_random.integers(0, 10, size=2) random_float = self.np_random.random() ``` ### Episode Termination - **Terminated:** Natural ending (goal reached, agent died, etc.) - **Truncated:** Artificial ending (time limit, external interrupt) ```python def step(self, action): # ... environment logic ... goal_reached = self._check_goal() time_limit_exceeded = self.current_step >= self.max_steps terminated = goal_reached # Natural ending truncated = time_limit_exceeded # Time limit return observation, reward, terminated, truncated, info ``` ### Info Dictionary Use the info dict for debugging and logging: ```python info = { "episode_length": self.current_step, "distance_to_goal": distance, "success": goal_reached, "total_reward": self.cumulative_reward, } ``` **Special Keys:** - `"terminal_observation"`: Automatically added by VecEnv when episode ends ## Advanced Features ### Metadata Provide rendering information: ```python class CustomEnv(gym.Env): metadata = { "render_modes": ["human", "rgb_array"], "render_fps": 30, } def __init__(self, render_mode=None): super().__init__() self.render_mode = render_mode # ... ``` ### Render Modes ```python def render(self): if self.render_mode == "human": # Print or display for human viewing print(f"Agent at {self.agent_pos}") elif self.render_mode == "rgb_array": # Return numpy array (height, width, 3) for video recording canvas = np.zeros((500, 500, 3), dtype=np.uint8) # Draw environment on canvas return canvas ``` ### Goal-Conditioned Environments (for HER) For Hindsight Experience Replay, use specific observation structure: ```python self.observation_space = spaces.Dict({ "observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32), "achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32), "desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32), }) def compute_reward(self, achieved_goal, desired_goal, info): """Required for HER environments""" distance = np.linalg.norm(achieved_goal - desired_goal) return -distance ``` ## Environment Validation Always validate your environment before training: ```python from stable_baselines3.common.env_checker import check_env env = CustomEnv() check_env(env, warn=True) ``` **Common Validation Errors:** 1. **"Observation is not within bounds"** - Check that observations stay within defined space - Ensure correct dtype (np.float32 for Box spaces) 2. **"Reset should return tuple"** - Return `(observation, info)`, not just observation 3. **"Step should return 5-tuple"** - Return `(obs, reward, terminated, truncated, info)` 4. **"Action is out of bounds"** - Verify action_space definition matches expected actions 5. **"Observation/Action dtype mismatch"** - Ensure observations match space dtype (usually np.float32) ## Environment Registration Register your environment with Gymnasium: ```python import gymnasium as gym from gymnasium.envs.registration import register register( id="MyCustomEnv-v0", entry_point="my_module:CustomEnv", max_episode_steps=200, kwargs={"grid_size": 10}, # Default kwargs ) # Now can use with gym.make env = gym.make("MyCustomEnv-v0") ``` ## Testing Custom Environments ### Basic Testing ```python def test_environment(env, n_episodes=5): """Test environment with random actions""" for episode in range(n_episodes): obs, info = env.reset() episode_reward = 0 done = False steps = 0 while not done: action = env.action_space.sample() obs, reward, terminated, truncated, info = env.step(action) episode_reward += reward steps += 1 done = terminated or truncated print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}") ``` ### Training Test ```python from stable_baselines3 import PPO def train_test(env, timesteps=10000): """Quick training test""" model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=timesteps) # Evaluate obs, info = env.reset() for _ in range(100): action, _states = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: break ``` ## Common Patterns ### Grid World ```python class GridWorldEnv(gym.Env): def __init__(self, size=10): super().__init__() self.size = size self.action_space = spaces.Discrete(4) # up, down, left, right self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32) ``` ### Continuous Control ```python class ContinuousEnv(gym.Env): def __init__(self): super().__init__() self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32) ``` ### Image-Based Environment ```python class VisionEnv(gym.Env): def __init__(self): super().__init__() self.action_space = spaces.Discrete(4) # Channel-first: (channels, height, width) self.observation_space = spaces.Box( low=0, high=255, shape=(3, 84, 84), dtype=np.uint8 ) ``` ### Multi-Modal Environment ```python class MultiModalEnv(gym.Env): def __init__(self): super().__init__() self.action_space = spaces.Discrete(4) self.observation_space = spaces.Dict({ "image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8), "sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32), }) ``` ## Performance Considerations ### Efficient Observation Generation ```python # Pre-allocate arrays def __init__(self): # ... self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32) def _get_observation(self): # Reuse buffer instead of allocating new array self._obs_buffer[0] = self.agent_x self._obs_buffer[1] = self.agent_y return self._obs_buffer ``` ### Vectorization Make environment operations vectorizable: ```python # Good: Uses numpy operations def step(self, action): direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action] self.pos = np.clip(self.pos + direction, 0, self.size-1) # Avoid: Python loops when possible # for i in range(len(self.agents)): # self.agents[i].update() ``` ## Troubleshooting ### "Observation out of bounds" - Check that all observations are within defined space - Verify correct dtype (np.float32 vs np.float64) ### "NaN or Inf in observation/reward" - Add checks: `assert np.isfinite(reward)` - Use `VecCheckNan` wrapper to catch issues ### "Policy doesn't learn" - Check reward scaling (normalize rewards) - Verify observation normalization - Ensure reward signal is meaningful - Check if exploration is sufficient ### "Training crashes" - Validate environment with `check_env()` - Check for race conditions in custom env - Verify action/observation spaces are consistent ## Additional Resources - Template: See `scripts/custom_env_template.py` - Gymnasium Documentation: https://gymnasium.farama.org/ - SB3 Custom Env Guide: https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html