""" Template for creating custom Gymnasium environments compatible with Stable Baselines3. This template demonstrates: - Proper Gymnasium environment structure - Observation and action space definition - Step and reset implementation - Validation with SB3's env_checker - Registration with Gymnasium """ import gymnasium as gym from gymnasium import spaces import numpy as np class CustomEnv(gym.Env): """ Custom Gymnasium Environment Template. This is a template for creating custom environments that work with Stable Baselines3. Modify the observation space, action space, reward function, and state transitions to match your specific problem. Example: A simple grid world where the agent tries to reach a goal position. """ # Optional: Provide metadata for rendering modes metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} def __init__(self, grid_size=5, render_mode=None): """ Initialize the environment. Args: grid_size: Size of the grid world (grid_size x grid_size) render_mode: How to render ('human', 'rgb_array', or None) """ super().__init__() self.grid_size = grid_size self.render_mode = render_mode # Define action space # Example: 4 discrete actions (up, down, left, right) self.action_space = spaces.Discrete(4) # Define observation space # Example: 2D position [x, y] in continuous space # Note: Use np.float32 for observations (SB3 recommendation) self.observation_space = spaces.Box( low=0, high=grid_size - 1, shape=(2,), dtype=np.float32, ) # Alternative observation spaces: # 1. Discrete: spaces.Discrete(n) # 2. Multi-discrete: spaces.MultiDiscrete([n1, n2, ...]) # 3. Multi-binary: spaces.MultiBinary(n) # 4. Box (continuous): spaces.Box(low=, high=, shape=, dtype=np.float32) # 5. Dict: spaces.Dict({"key1": space1, "key2": space2}) # For image observations (e.g., 84x84 RGB image): # self.observation_space = spaces.Box( # low=0, # high=255, # shape=(3, 84, 84), # (channels, height, width) - channel-first # dtype=np.uint8, # ) # Initialize state self._agent_position = None self._goal_position = None def reset(self, seed=None, options=None): """ Reset the environment to initial state. Args: seed: Random seed for reproducibility options: Additional options (optional) Returns: observation: Initial observation info: Additional information dictionary """ # Set seed for reproducibility super().reset(seed=seed) # Initialize agent position randomly self._agent_position = self.np_random.integers(0, self.grid_size, size=2) # Initialize goal position (different from agent) self._goal_position = self.np_random.integers(0, self.grid_size, size=2) while np.array_equal(self._agent_position, self._goal_position): self._goal_position = self.np_random.integers(0, self.grid_size, size=2) observation = self._get_obs() info = self._get_info() return observation, info def step(self, action): """ Execute one step in the environment. Args: action: Action to take Returns: observation: New observation reward: Reward for this step terminated: Whether episode has ended (goal reached) truncated: Whether episode was truncated (time limit, etc.) info: Additional information dictionary """ # Map action to direction (0: up, 1: down, 2: left, 3: right) direction = np.array([ [-1, 0], # up [1, 0], # down [0, -1], # left [0, 1], # right ])[action] # Update agent position (clip to stay within grid) self._agent_position = np.clip( self._agent_position + direction, 0, self.grid_size - 1, ) # Check if goal is reached terminated = np.array_equal(self._agent_position, self._goal_position) # Calculate reward if terminated: reward = 1.0 # Goal reached else: # Negative reward based on distance to goal (encourages efficiency) distance = np.linalg.norm(self._agent_position - self._goal_position) reward = -0.1 * distance # Episode not truncated in this example (no time limit) truncated = False observation = self._get_obs() info = self._get_info() return observation, reward, terminated, truncated, info def _get_obs(self): """ Get current observation. Returns: observation: Current state as defined by observation_space """ # Return agent position as observation return self._agent_position.astype(np.float32) # For dict observations: # return { # "agent": self._agent_position.astype(np.float32), # "goal": self._goal_position.astype(np.float32), # } def _get_info(self): """ Get additional information (for debugging/logging). Returns: info: Dictionary with additional information """ return { "agent_position": self._agent_position, "goal_position": self._goal_position, "distance_to_goal": np.linalg.norm( self._agent_position - self._goal_position ), } def render(self): """ Render the environment. Returns: Rendered frame (if render_mode is 'rgb_array') """ if self.render_mode == "human": # Print simple text-based rendering grid = np.zeros((self.grid_size, self.grid_size), dtype=str) grid[:, :] = "." grid[tuple(self._agent_position)] = "A" grid[tuple(self._goal_position)] = "G" print("\n" + "=" * (self.grid_size * 2 + 1)) for row in grid: print(" ".join(row)) print("=" * (self.grid_size * 2 + 1) + "\n") elif self.render_mode == "rgb_array": # Return RGB array for video recording # This is a placeholder - implement proper rendering as needed canvas = np.zeros(( self.grid_size * 50, self.grid_size * 50, 3 ), dtype=np.uint8) # Draw agent and goal on canvas # ... (implement visual rendering) return canvas def close(self): """ Clean up environment resources. """ pass # Optional: Register the environment with Gymnasium # This allows creating the environment with gym.make("CustomEnv-v0") gym.register( id="CustomEnv-v0", entry_point=__name__ + ":CustomEnv", max_episode_steps=100, ) def validate_environment(): """ Validate the custom environment with SB3's env_checker. """ from stable_baselines3.common.env_checker import check_env print("Validating custom environment...") env = CustomEnv() check_env(env, warn=True) print("Environment validation passed!") def test_environment(): """ Test the custom environment with random actions. """ print("Testing environment with random actions...") env = CustomEnv(render_mode="human") obs, info = env.reset() print(f"Initial observation: {obs}") print(f"Initial info: {info}") for step in range(10): action = env.action_space.sample() # Random action obs, reward, terminated, truncated, info = env.step(action) print(f"\nStep {step + 1}:") print(f" Action: {action}") print(f" Observation: {obs}") print(f" Reward: {reward:.3f}") print(f" Terminated: {terminated}") print(f" Info: {info}") env.render() if terminated or truncated: print("Episode finished!") break env.close() def train_on_custom_env(): """ Train a PPO agent on the custom environment. """ from stable_baselines3 import PPO print("Training PPO agent on custom environment...") # Create environment env = CustomEnv() # Validate first from stable_baselines3.common.env_checker import check_env check_env(env, warn=True) # Train agent model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) # Test trained agent obs, info = env.reset() for _ in range(20): action, _states = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: print(f"Goal reached! Final reward: {reward}") break env.close() if __name__ == "__main__": # Validate the environment validate_environment() # Test with random actions # test_environment() # Train an agent # train_on_custom_env()