Files
gh-k-dense-ai-claude-scient…/skills/stable-baselines3/references/custom_environments.md
2025-11-30 08:30:10 +08:00

13 KiB

Creating Custom Environments for Stable Baselines3

This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3.

Environment Structure

Required Methods

Every custom environment must inherit from gymnasium.Env and implement:

import gymnasium as gym
from gymnasium import spaces
import numpy as np

class CustomEnv(gym.Env):
    def __init__(self):
        """Initialize environment, define action_space and observation_space"""
        super().__init__()
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)

    def reset(self, seed=None, options=None):
        """Reset environment to initial state"""
        super().reset(seed=seed)
        observation = self.observation_space.sample()
        info = {}
        return observation, info

    def step(self, action):
        """Execute one timestep"""
        observation = self.observation_space.sample()
        reward = 0.0
        terminated = False  # Episode ended naturally
        truncated = False   # Episode ended due to time limit
        info = {}
        return observation, reward, terminated, truncated, info

    def render(self):
        """Visualize environment (optional)"""
        pass

    def close(self):
        """Cleanup resources (optional)"""
        pass

Method Details

__init__(self, ...)

Purpose: Initialize the environment and define spaces.

Requirements:

  • Must call super().__init__()
  • Must define self.action_space
  • Must define self.observation_space

Example:

def __init__(self, grid_size=10, max_steps=100):
    super().__init__()
    self.grid_size = grid_size
    self.max_steps = max_steps
    self.current_step = 0

    # Define spaces
    self.action_space = spaces.Discrete(4)
    self.observation_space = spaces.Box(
        low=0, high=grid_size-1, shape=(2,), dtype=np.float32
    )

reset(self, seed=None, options=None)

Purpose: Reset the environment to an initial state.

Requirements:

  • Must call super().reset(seed=seed)
  • Must return (observation, info) tuple
  • Observation must match observation_space
  • Info must be a dictionary (can be empty)

Example:

def reset(self, seed=None, options=None):
    super().reset(seed=seed)

    # Initialize state
    self.agent_pos = self.np_random.integers(0, self.grid_size, size=2)
    self.goal_pos = self.np_random.integers(0, self.grid_size, size=2)
    self.current_step = 0

    observation = self._get_observation()
    info = {"episode": "started"}

    return observation, info

step(self, action)

Purpose: Execute one timestep in the environment.

Requirements:

  • Must return 5-tuple: (observation, reward, terminated, truncated, info)
  • Action must be valid according to action_space
  • Observation must match observation_space
  • Reward should be a float
  • Terminated: True if episode ended naturally (goal reached, failure, etc.)
  • Truncated: True if episode ended due to time limit
  • Info must be a dictionary

Example:

def step(self, action):
    # Apply action
    self.agent_pos += self._action_to_direction(action)
    self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1)
    self.current_step += 1

    # Calculate reward
    distance = np.linalg.norm(self.agent_pos - self.goal_pos)
    goal_reached = distance < 1.0

    if goal_reached:
        reward = 100.0
    else:
        reward = -distance * 0.1

    # Check termination conditions
    terminated = goal_reached
    truncated = self.current_step >= self.max_steps

    observation = self._get_observation()
    info = {"distance": distance, "steps": self.current_step}

    return observation, reward, terminated, truncated, info

Space Types

Discrete

For discrete actions (e.g., {0, 1, 2, 3}).

self.action_space = spaces.Discrete(4)  # 4 actions: 0, 1, 2, 3

Important: SB3 does NOT support Discrete spaces with start != 0. Always start from 0.

Box (Continuous)

For continuous values within a range.

# 1D continuous action in [-1, 1]
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)

# 2D position observation
self.observation_space = spaces.Box(
    low=0, high=10, shape=(2,), dtype=np.float32
)

# 3D RGB image (channel-first format)
self.observation_space = spaces.Box(
    low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
)

Important for Images:

  • Must be dtype=np.uint8 in range [0, 255]
  • Use channel-first format: (channels, height, width)
  • SB3 automatically normalizes by dividing by 255
  • Set normalize_images=False in policy_kwargs if pre-normalized

MultiDiscrete

For multiple discrete variables.

# Two discrete variables: first with 3 options, second with 4 options
self.action_space = spaces.MultiDiscrete([3, 4])

MultiBinary

For binary vectors.

# 5 binary flags
self.action_space = spaces.MultiBinary(5)  # e.g., [0, 1, 1, 0, 1]

Dict

For dictionary observations (e.g., combining image with sensors).

self.observation_space = spaces.Dict({
    "image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8),
    "vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32),
    "discrete": spaces.Discrete(3),
})

Important: When using Dict observations, use "MultiInputPolicy" instead of "MlpPolicy".

model = PPO("MultiInputPolicy", env, verbose=1)

Tuple

For tuple observations (less common).

self.observation_space = spaces.Tuple((
    spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32),
    spaces.Discrete(3),
))

Important Constraints and Best Practices

Data Types

  • Observations: Use np.float32 for continuous values
  • Images: Use np.uint8 in range [0, 255]
  • Rewards: Return Python float or np.float32
  • Terminated/Truncated: Return Python bool

Random Number Generation

Always use self.np_random for reproducibility:

def reset(self, seed=None, options=None):
    super().reset(seed=seed)
    # Use self.np_random instead of np.random
    random_pos = self.np_random.integers(0, 10, size=2)
    random_float = self.np_random.random()

Episode Termination

  • Terminated: Natural ending (goal reached, agent died, etc.)
  • Truncated: Artificial ending (time limit, external interrupt)
def step(self, action):
    # ... environment logic ...

    goal_reached = self._check_goal()
    time_limit_exceeded = self.current_step >= self.max_steps

    terminated = goal_reached  # Natural ending
    truncated = time_limit_exceeded  # Time limit

    return observation, reward, terminated, truncated, info

Info Dictionary

Use the info dict for debugging and logging:

info = {
    "episode_length": self.current_step,
    "distance_to_goal": distance,
    "success": goal_reached,
    "total_reward": self.cumulative_reward,
}

Special Keys:

  • "terminal_observation": Automatically added by VecEnv when episode ends

Advanced Features

Metadata

Provide rendering information:

class CustomEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }

    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        # ...

Render Modes

def render(self):
    if self.render_mode == "human":
        # Print or display for human viewing
        print(f"Agent at {self.agent_pos}")

    elif self.render_mode == "rgb_array":
        # Return numpy array (height, width, 3) for video recording
        canvas = np.zeros((500, 500, 3), dtype=np.uint8)
        # Draw environment on canvas
        return canvas

Goal-Conditioned Environments (for HER)

For Hindsight Experience Replay, use specific observation structure:

self.observation_space = spaces.Dict({
    "observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
    "achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
    "desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
})

def compute_reward(self, achieved_goal, desired_goal, info):
    """Required for HER environments"""
    distance = np.linalg.norm(achieved_goal - desired_goal)
    return -distance

Environment Validation

Always validate your environment before training:

from stable_baselines3.common.env_checker import check_env

env = CustomEnv()
check_env(env, warn=True)

Common Validation Errors:

  1. "Observation is not within bounds"

    • Check that observations stay within defined space
    • Ensure correct dtype (np.float32 for Box spaces)
  2. "Reset should return tuple"

    • Return (observation, info), not just observation
  3. "Step should return 5-tuple"

    • Return (obs, reward, terminated, truncated, info)
  4. "Action is out of bounds"

    • Verify action_space definition matches expected actions
  5. "Observation/Action dtype mismatch"

    • Ensure observations match space dtype (usually np.float32)

Environment Registration

Register your environment with Gymnasium:

import gymnasium as gym
from gymnasium.envs.registration import register

register(
    id="MyCustomEnv-v0",
    entry_point="my_module:CustomEnv",
    max_episode_steps=200,
    kwargs={"grid_size": 10},  # Default kwargs
)

# Now can use with gym.make
env = gym.make("MyCustomEnv-v0")

Testing Custom Environments

Basic Testing

def test_environment(env, n_episodes=5):
    """Test environment with random actions"""
    for episode in range(n_episodes):
        obs, info = env.reset()
        episode_reward = 0
        done = False
        steps = 0

        while not done:
            action = env.action_space.sample()
            obs, reward, terminated, truncated, info = env.step(action)
            episode_reward += reward
            steps += 1
            done = terminated or truncated

        print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}")

Training Test

from stable_baselines3 import PPO

def train_test(env, timesteps=10000):
    """Quick training test"""
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=timesteps)

    # Evaluate
    obs, info = env.reset()
    for _ in range(100):
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            break

Common Patterns

Grid World

class GridWorldEnv(gym.Env):
    def __init__(self, size=10):
        super().__init__()
        self.size = size
        self.action_space = spaces.Discrete(4)  # up, down, left, right
        self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32)

Continuous Control

class ContinuousEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)

Image-Based Environment

class VisionEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = spaces.Discrete(4)
        # Channel-first: (channels, height, width)
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
        )

Multi-Modal Environment

class MultiModalEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Dict({
            "image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
            "sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32),
        })

Performance Considerations

Efficient Observation Generation

# Pre-allocate arrays
def __init__(self):
    # ...
    self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32)

def _get_observation(self):
    # Reuse buffer instead of allocating new array
    self._obs_buffer[0] = self.agent_x
    self._obs_buffer[1] = self.agent_y
    return self._obs_buffer

Vectorization

Make environment operations vectorizable:

# Good: Uses numpy operations
def step(self, action):
    direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action]
    self.pos = np.clip(self.pos + direction, 0, self.size-1)

# Avoid: Python loops when possible
# for i in range(len(self.agents)):
#     self.agents[i].update()

Troubleshooting

"Observation out of bounds"

  • Check that all observations are within defined space
  • Verify correct dtype (np.float32 vs np.float64)

"NaN or Inf in observation/reward"

  • Add checks: assert np.isfinite(reward)
  • Use VecCheckNan wrapper to catch issues

"Policy doesn't learn"

  • Check reward scaling (normalize rewards)
  • Verify observation normalization
  • Ensure reward signal is meaningful
  • Check if exploration is sufficient

"Training crashes"

  • Validate environment with check_env()
  • Check for race conditions in custom env
  • Verify action/observation spaces are consistent

Additional Resources