Initial commit
This commit is contained in:
314
skills/stable-baselines3/scripts/custom_env_template.py
Normal file
314
skills/stable-baselines3/scripts/custom_env_template.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Template for creating custom Gymnasium environments compatible with Stable Baselines3.
|
||||
|
||||
This template demonstrates:
|
||||
- Proper Gymnasium environment structure
|
||||
- Observation and action space definition
|
||||
- Step and reset implementation
|
||||
- Validation with SB3's env_checker
|
||||
- Registration with Gymnasium
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CustomEnv(gym.Env):
|
||||
"""
|
||||
Custom Gymnasium Environment Template.
|
||||
|
||||
This is a template for creating custom environments that work with
|
||||
Stable Baselines3. Modify the observation space, action space, reward
|
||||
function, and state transitions to match your specific problem.
|
||||
|
||||
Example:
|
||||
A simple grid world where the agent tries to reach a goal position.
|
||||
"""
|
||||
|
||||
# Optional: Provide metadata for rendering modes
|
||||
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
|
||||
|
||||
def __init__(self, grid_size=5, render_mode=None):
|
||||
"""
|
||||
Initialize the environment.
|
||||
|
||||
Args:
|
||||
grid_size: Size of the grid world (grid_size x grid_size)
|
||||
render_mode: How to render ('human', 'rgb_array', or None)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.grid_size = grid_size
|
||||
self.render_mode = render_mode
|
||||
|
||||
# Define action space
|
||||
# Example: 4 discrete actions (up, down, left, right)
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
# Define observation space
|
||||
# Example: 2D position [x, y] in continuous space
|
||||
# Note: Use np.float32 for observations (SB3 recommendation)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=grid_size - 1,
|
||||
shape=(2,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Alternative observation spaces:
|
||||
# 1. Discrete: spaces.Discrete(n)
|
||||
# 2. Multi-discrete: spaces.MultiDiscrete([n1, n2, ...])
|
||||
# 3. Multi-binary: spaces.MultiBinary(n)
|
||||
# 4. Box (continuous): spaces.Box(low=, high=, shape=, dtype=np.float32)
|
||||
# 5. Dict: spaces.Dict({"key1": space1, "key2": space2})
|
||||
|
||||
# For image observations (e.g., 84x84 RGB image):
|
||||
# self.observation_space = spaces.Box(
|
||||
# low=0,
|
||||
# high=255,
|
||||
# shape=(3, 84, 84), # (channels, height, width) - channel-first
|
||||
# dtype=np.uint8,
|
||||
# )
|
||||
|
||||
# Initialize state
|
||||
self._agent_position = None
|
||||
self._goal_position = None
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""
|
||||
Reset the environment to initial state.
|
||||
|
||||
Args:
|
||||
seed: Random seed for reproducibility
|
||||
options: Additional options (optional)
|
||||
|
||||
Returns:
|
||||
observation: Initial observation
|
||||
info: Additional information dictionary
|
||||
"""
|
||||
# Set seed for reproducibility
|
||||
super().reset(seed=seed)
|
||||
|
||||
# Initialize agent position randomly
|
||||
self._agent_position = self.np_random.integers(0, self.grid_size, size=2)
|
||||
|
||||
# Initialize goal position (different from agent)
|
||||
self._goal_position = self.np_random.integers(0, self.grid_size, size=2)
|
||||
while np.array_equal(self._agent_position, self._goal_position):
|
||||
self._goal_position = self.np_random.integers(0, self.grid_size, size=2)
|
||||
|
||||
observation = self._get_obs()
|
||||
info = self._get_info()
|
||||
|
||||
return observation, info
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Execute one step in the environment.
|
||||
|
||||
Args:
|
||||
action: Action to take
|
||||
|
||||
Returns:
|
||||
observation: New observation
|
||||
reward: Reward for this step
|
||||
terminated: Whether episode has ended (goal reached)
|
||||
truncated: Whether episode was truncated (time limit, etc.)
|
||||
info: Additional information dictionary
|
||||
"""
|
||||
# Map action to direction (0: up, 1: down, 2: left, 3: right)
|
||||
direction = np.array([
|
||||
[-1, 0], # up
|
||||
[1, 0], # down
|
||||
[0, -1], # left
|
||||
[0, 1], # right
|
||||
])[action]
|
||||
|
||||
# Update agent position (clip to stay within grid)
|
||||
self._agent_position = np.clip(
|
||||
self._agent_position + direction,
|
||||
0,
|
||||
self.grid_size - 1,
|
||||
)
|
||||
|
||||
# Check if goal is reached
|
||||
terminated = np.array_equal(self._agent_position, self._goal_position)
|
||||
|
||||
# Calculate reward
|
||||
if terminated:
|
||||
reward = 1.0 # Goal reached
|
||||
else:
|
||||
# Negative reward based on distance to goal (encourages efficiency)
|
||||
distance = np.linalg.norm(self._agent_position - self._goal_position)
|
||||
reward = -0.1 * distance
|
||||
|
||||
# Episode not truncated in this example (no time limit)
|
||||
truncated = False
|
||||
|
||||
observation = self._get_obs()
|
||||
info = self._get_info()
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
"""
|
||||
Get current observation.
|
||||
|
||||
Returns:
|
||||
observation: Current state as defined by observation_space
|
||||
"""
|
||||
# Return agent position as observation
|
||||
return self._agent_position.astype(np.float32)
|
||||
|
||||
# For dict observations:
|
||||
# return {
|
||||
# "agent": self._agent_position.astype(np.float32),
|
||||
# "goal": self._goal_position.astype(np.float32),
|
||||
# }
|
||||
|
||||
def _get_info(self):
|
||||
"""
|
||||
Get additional information (for debugging/logging).
|
||||
|
||||
Returns:
|
||||
info: Dictionary with additional information
|
||||
"""
|
||||
return {
|
||||
"agent_position": self._agent_position,
|
||||
"goal_position": self._goal_position,
|
||||
"distance_to_goal": np.linalg.norm(
|
||||
self._agent_position - self._goal_position
|
||||
),
|
||||
}
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the environment.
|
||||
|
||||
Returns:
|
||||
Rendered frame (if render_mode is 'rgb_array')
|
||||
"""
|
||||
if self.render_mode == "human":
|
||||
# Print simple text-based rendering
|
||||
grid = np.zeros((self.grid_size, self.grid_size), dtype=str)
|
||||
grid[:, :] = "."
|
||||
grid[tuple(self._agent_position)] = "A"
|
||||
grid[tuple(self._goal_position)] = "G"
|
||||
|
||||
print("\n" + "=" * (self.grid_size * 2 + 1))
|
||||
for row in grid:
|
||||
print(" ".join(row))
|
||||
print("=" * (self.grid_size * 2 + 1) + "\n")
|
||||
|
||||
elif self.render_mode == "rgb_array":
|
||||
# Return RGB array for video recording
|
||||
# This is a placeholder - implement proper rendering as needed
|
||||
canvas = np.zeros((
|
||||
self.grid_size * 50,
|
||||
self.grid_size * 50,
|
||||
3
|
||||
), dtype=np.uint8)
|
||||
# Draw agent and goal on canvas
|
||||
# ... (implement visual rendering)
|
||||
return canvas
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Clean up environment resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Optional: Register the environment with Gymnasium
|
||||
# This allows creating the environment with gym.make("CustomEnv-v0")
|
||||
gym.register(
|
||||
id="CustomEnv-v0",
|
||||
entry_point=__name__ + ":CustomEnv",
|
||||
max_episode_steps=100,
|
||||
)
|
||||
|
||||
|
||||
def validate_environment():
|
||||
"""
|
||||
Validate the custom environment with SB3's env_checker.
|
||||
"""
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
print("Validating custom environment...")
|
||||
env = CustomEnv()
|
||||
check_env(env, warn=True)
|
||||
print("Environment validation passed!")
|
||||
|
||||
|
||||
def test_environment():
|
||||
"""
|
||||
Test the custom environment with random actions.
|
||||
"""
|
||||
print("Testing environment with random actions...")
|
||||
env = CustomEnv(render_mode="human")
|
||||
|
||||
obs, info = env.reset()
|
||||
print(f"Initial observation: {obs}")
|
||||
print(f"Initial info: {info}")
|
||||
|
||||
for step in range(10):
|
||||
action = env.action_space.sample() # Random action
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
print(f"\nStep {step + 1}:")
|
||||
print(f" Action: {action}")
|
||||
print(f" Observation: {obs}")
|
||||
print(f" Reward: {reward:.3f}")
|
||||
print(f" Terminated: {terminated}")
|
||||
print(f" Info: {info}")
|
||||
|
||||
env.render()
|
||||
|
||||
if terminated or truncated:
|
||||
print("Episode finished!")
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def train_on_custom_env():
|
||||
"""
|
||||
Train a PPO agent on the custom environment.
|
||||
"""
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
print("Training PPO agent on custom environment...")
|
||||
|
||||
# Create environment
|
||||
env = CustomEnv()
|
||||
|
||||
# Validate first
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
check_env(env, warn=True)
|
||||
|
||||
# Train agent
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
# Test trained agent
|
||||
obs, info = env.reset()
|
||||
for _ in range(20):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
print(f"Goal reached! Final reward: {reward}")
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Validate the environment
|
||||
validate_environment()
|
||||
|
||||
# Test with random actions
|
||||
# test_environment()
|
||||
|
||||
# Train an agent
|
||||
# train_on_custom_env()
|
||||
245
skills/stable-baselines3/scripts/evaluate_agent.py
Normal file
245
skills/stable-baselines3/scripts/evaluate_agent.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Template script for evaluating trained RL agents with Stable Baselines3.
|
||||
|
||||
This template demonstrates:
|
||||
- Loading trained models
|
||||
- Evaluating performance with statistics
|
||||
- Recording videos of agent behavior
|
||||
- Visualizing agent performance
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder, VecNormalize
|
||||
import os
|
||||
|
||||
|
||||
def evaluate_agent(
|
||||
model_path,
|
||||
env_id="CartPole-v1",
|
||||
n_eval_episodes=10,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
record_video=False,
|
||||
video_folder="./videos/",
|
||||
vec_normalize_path=None,
|
||||
):
|
||||
"""
|
||||
Evaluate a trained RL agent.
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model
|
||||
env_id: Gymnasium environment ID
|
||||
n_eval_episodes: Number of episodes to evaluate
|
||||
deterministic: Use deterministic actions
|
||||
render: Render the environment during evaluation
|
||||
record_video: Record videos of the agent
|
||||
video_folder: Folder to save videos
|
||||
vec_normalize_path: Path to VecNormalize statistics (if used during training)
|
||||
|
||||
Returns:
|
||||
mean_reward: Mean episode reward
|
||||
std_reward: Standard deviation of episode rewards
|
||||
"""
|
||||
# Load the trained model
|
||||
print(f"Loading model from {model_path}...")
|
||||
model = PPO.load(model_path)
|
||||
|
||||
# Create evaluation environment
|
||||
if render:
|
||||
env = gym.make(env_id, render_mode="human")
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
|
||||
# Wrap in DummyVecEnv for consistency
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# Load VecNormalize statistics if they were used during training
|
||||
if vec_normalize_path and os.path.exists(vec_normalize_path):
|
||||
print(f"Loading VecNormalize statistics from {vec_normalize_path}...")
|
||||
env = VecNormalize.load(vec_normalize_path, env)
|
||||
env.training = False # Don't update statistics during evaluation
|
||||
env.norm_reward = False # Don't normalize rewards during evaluation
|
||||
|
||||
# Set up video recording if requested
|
||||
if record_video:
|
||||
os.makedirs(video_folder, exist_ok=True)
|
||||
env = VecVideoRecorder(
|
||||
env,
|
||||
video_folder,
|
||||
record_video_trigger=lambda x: x == 0, # Record all episodes
|
||||
video_length=1000, # Max video length
|
||||
name_prefix=f"eval-{env_id}",
|
||||
)
|
||||
print(f"Recording videos to {video_folder}...")
|
||||
|
||||
# Evaluate the agent
|
||||
print(f"Evaluating for {n_eval_episodes} episodes...")
|
||||
mean_reward, std_reward = evaluate_policy(
|
||||
model,
|
||||
env,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
deterministic=deterministic,
|
||||
render=False, # VecEnv doesn't support render parameter
|
||||
return_episode_rewards=False,
|
||||
)
|
||||
|
||||
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
|
||||
# Cleanup
|
||||
env.close()
|
||||
|
||||
return mean_reward, std_reward
|
||||
|
||||
|
||||
def watch_agent(
|
||||
model_path,
|
||||
env_id="CartPole-v1",
|
||||
n_episodes=5,
|
||||
deterministic=True,
|
||||
vec_normalize_path=None,
|
||||
):
|
||||
"""
|
||||
Watch a trained agent play (with rendering).
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model
|
||||
env_id: Gymnasium environment ID
|
||||
n_episodes: Number of episodes to watch
|
||||
deterministic: Use deterministic actions
|
||||
vec_normalize_path: Path to VecNormalize statistics (if used during training)
|
||||
"""
|
||||
# Load the trained model
|
||||
print(f"Loading model from {model_path}...")
|
||||
model = PPO.load(model_path)
|
||||
|
||||
# Create environment with rendering
|
||||
env = gym.make(env_id, render_mode="human")
|
||||
|
||||
# Load VecNormalize statistics if needed
|
||||
obs_normalization = None
|
||||
if vec_normalize_path and os.path.exists(vec_normalize_path):
|
||||
print(f"Loading VecNormalize statistics from {vec_normalize_path}...")
|
||||
# For rendering, we'll manually apply normalization
|
||||
dummy_env = DummyVecEnv([lambda: gym.make(env_id)])
|
||||
vec_env = VecNormalize.load(vec_normalize_path, dummy_env)
|
||||
obs_normalization = vec_env
|
||||
dummy_env.close()
|
||||
|
||||
# Run episodes
|
||||
for episode in range(n_episodes):
|
||||
obs, info = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
step = 0
|
||||
|
||||
print(f"\nEpisode {episode + 1}/{n_episodes}")
|
||||
|
||||
while not done:
|
||||
# Apply observation normalization if needed
|
||||
if obs_normalization:
|
||||
obs_normalized = obs_normalization.normalize_obs(obs)
|
||||
else:
|
||||
obs_normalized = obs
|
||||
|
||||
# Get action from model
|
||||
action, _states = model.predict(obs_normalized, deterministic=deterministic)
|
||||
|
||||
# Take step in environment
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
episode_reward += reward
|
||||
step += 1
|
||||
|
||||
print(f"Episode reward: {episode_reward:.2f} ({step} steps)")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def compare_models(
|
||||
model_paths,
|
||||
env_id="CartPole-v1",
|
||||
n_eval_episodes=10,
|
||||
deterministic=True,
|
||||
):
|
||||
"""
|
||||
Compare performance of multiple trained models.
|
||||
|
||||
Args:
|
||||
model_paths: List of paths to saved models
|
||||
env_id: Gymnasium environment ID
|
||||
n_eval_episodes: Number of episodes to evaluate each model
|
||||
deterministic: Use deterministic actions
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for model_path in model_paths:
|
||||
print(f"\nEvaluating {model_path}...")
|
||||
mean_reward, std_reward = evaluate_agent(
|
||||
model_path,
|
||||
env_id=env_id,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
results[model_path] = {"mean": mean_reward, "std": std_reward}
|
||||
|
||||
# Print comparison
|
||||
print("\n" + "=" * 60)
|
||||
print("Model Comparison Results")
|
||||
print("=" * 60)
|
||||
for model_path, stats in results.items():
|
||||
print(f"{model_path}: {stats['mean']:.2f} +/- {stats['std']:.2f}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Evaluate a trained model
|
||||
model_path = "./models/best_model/best_model.zip"
|
||||
evaluate_agent(
|
||||
model_path=model_path,
|
||||
env_id="CartPole-v1",
|
||||
n_eval_episodes=10,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
# Example 2: Record videos of agent behavior
|
||||
# evaluate_agent(
|
||||
# model_path=model_path,
|
||||
# env_id="CartPole-v1",
|
||||
# n_eval_episodes=5,
|
||||
# deterministic=True,
|
||||
# record_video=True,
|
||||
# video_folder="./videos/",
|
||||
# )
|
||||
|
||||
# Example 3: Watch agent play with rendering
|
||||
# watch_agent(
|
||||
# model_path=model_path,
|
||||
# env_id="CartPole-v1",
|
||||
# n_episodes=3,
|
||||
# deterministic=True,
|
||||
# )
|
||||
|
||||
# Example 4: Compare multiple models
|
||||
# compare_models(
|
||||
# model_paths=[
|
||||
# "./models/model_100k.zip",
|
||||
# "./models/model_200k.zip",
|
||||
# "./models/best_model/best_model.zip",
|
||||
# ],
|
||||
# env_id="CartPole-v1",
|
||||
# n_eval_episodes=10,
|
||||
# )
|
||||
|
||||
# Example 5: Evaluate with VecNormalize statistics
|
||||
# evaluate_agent(
|
||||
# model_path="./models/best_model/best_model.zip",
|
||||
# env_id="Pendulum-v1",
|
||||
# n_eval_episodes=10,
|
||||
# vec_normalize_path="./models/vec_normalize.pkl",
|
||||
# )
|
||||
165
skills/stable-baselines3/scripts/train_rl_agent.py
Normal file
165
skills/stable-baselines3/scripts/train_rl_agent.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Template script for training RL agents with Stable Baselines3.
|
||||
|
||||
This template demonstrates best practices for:
|
||||
- Setting up training with proper monitoring
|
||||
- Using callbacks for evaluation and checkpointing
|
||||
- Vectorized environments for efficiency
|
||||
- TensorBoard integration
|
||||
- Model saving and loading
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.callbacks import (
|
||||
EvalCallback,
|
||||
CheckpointCallback,
|
||||
CallbackList,
|
||||
)
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
|
||||
import os
|
||||
|
||||
|
||||
def train_agent(
|
||||
env_id="CartPole-v1",
|
||||
algorithm=PPO,
|
||||
policy="MlpPolicy",
|
||||
n_envs=4,
|
||||
total_timesteps=100000,
|
||||
eval_freq=10000,
|
||||
save_freq=10000,
|
||||
log_dir="./logs/",
|
||||
save_path="./models/",
|
||||
):
|
||||
"""
|
||||
Train an RL agent with best practices.
|
||||
|
||||
Args:
|
||||
env_id: Gymnasium environment ID
|
||||
algorithm: SB3 algorithm class (PPO, SAC, DQN, etc.)
|
||||
policy: Policy type ("MlpPolicy", "CnnPolicy", "MultiInputPolicy")
|
||||
n_envs: Number of parallel environments
|
||||
total_timesteps: Total training timesteps
|
||||
eval_freq: Frequency of evaluation (in timesteps)
|
||||
save_freq: Frequency of model checkpoints (in timesteps)
|
||||
log_dir: Directory for logs and TensorBoard
|
||||
save_path: Directory for model checkpoints
|
||||
"""
|
||||
# Create directories
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
eval_log_dir = os.path.join(log_dir, "eval")
|
||||
os.makedirs(eval_log_dir, exist_ok=True)
|
||||
|
||||
# Create training environment (vectorized for efficiency)
|
||||
print(f"Creating {n_envs} parallel training environments...")
|
||||
env = make_vec_env(
|
||||
env_id,
|
||||
n_envs=n_envs,
|
||||
vec_env_cls=SubprocVecEnv, # Use SubprocVecEnv for parallel execution
|
||||
# vec_env_cls=DummyVecEnv, # Use DummyVecEnv for lightweight environments
|
||||
)
|
||||
|
||||
# Optional: Add normalization wrapper for better performance
|
||||
# Uncomment for continuous control tasks
|
||||
# env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
# Create separate evaluation environment
|
||||
print("Creating evaluation environment...")
|
||||
eval_env = make_vec_env(env_id, n_envs=1)
|
||||
# If using VecNormalize, wrap eval env but set training=False
|
||||
# eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
|
||||
|
||||
# Set up callbacks
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=os.path.join(save_path, "best_model"),
|
||||
log_path=eval_log_dir,
|
||||
eval_freq=eval_freq // n_envs, # Adjust for number of environments
|
||||
n_eval_episodes=10,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
)
|
||||
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=save_freq // n_envs, # Adjust for number of environments
|
||||
save_path=save_path,
|
||||
name_prefix="rl_model",
|
||||
save_replay_buffer=False, # Set True for off-policy algorithms if needed
|
||||
)
|
||||
|
||||
callback = CallbackList([eval_callback, checkpoint_callback])
|
||||
|
||||
# Initialize the agent
|
||||
print(f"Initializing {algorithm.__name__} agent...")
|
||||
model = algorithm(
|
||||
policy,
|
||||
env,
|
||||
verbose=1,
|
||||
tensorboard_log=log_dir,
|
||||
# Algorithm-specific hyperparameters can be added here
|
||||
# learning_rate=3e-4,
|
||||
# n_steps=2048, # For PPO/A2C
|
||||
# batch_size=64,
|
||||
# gamma=0.99,
|
||||
)
|
||||
|
||||
# Train the agent
|
||||
print(f"Training for {total_timesteps} timesteps...")
|
||||
model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
tb_log_name=f"{algorithm.__name__}_{env_id}",
|
||||
)
|
||||
|
||||
# Save final model
|
||||
final_model_path = os.path.join(save_path, "final_model")
|
||||
print(f"Saving final model to {final_model_path}...")
|
||||
model.save(final_model_path)
|
||||
|
||||
# Save VecNormalize statistics if used
|
||||
# env.save(os.path.join(save_path, "vec_normalize.pkl"))
|
||||
|
||||
print("Training complete!")
|
||||
print(f"Best model saved at: {os.path.join(save_path, 'best_model')}")
|
||||
print(f"Final model saved at: {final_model_path}")
|
||||
print(f"TensorBoard logs: {log_dir}")
|
||||
print(f"Run 'tensorboard --logdir {log_dir}' to view training progress")
|
||||
|
||||
# Cleanup
|
||||
env.close()
|
||||
eval_env.close()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example: Train PPO on CartPole
|
||||
train_agent(
|
||||
env_id="CartPole-v1",
|
||||
algorithm=PPO,
|
||||
policy="MlpPolicy",
|
||||
n_envs=4,
|
||||
total_timesteps=100000,
|
||||
)
|
||||
|
||||
# Example: Train SAC on continuous control task
|
||||
# from stable_baselines3 import SAC
|
||||
# train_agent(
|
||||
# env_id="Pendulum-v1",
|
||||
# algorithm=SAC,
|
||||
# policy="MlpPolicy",
|
||||
# n_envs=4,
|
||||
# total_timesteps=50000,
|
||||
# )
|
||||
|
||||
# Example: Train DQN on discrete task
|
||||
# from stable_baselines3 import DQN
|
||||
# train_agent(
|
||||
# env_id="LunarLander-v2",
|
||||
# algorithm=DQN,
|
||||
# policy="MlpPolicy",
|
||||
# n_envs=1, # DQN typically uses single env
|
||||
# total_timesteps=100000,
|
||||
# )
|
||||
Reference in New Issue
Block a user