240 lines
7.8 KiB
Python
240 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
PufferLib Training Template
|
|
|
|
This template provides a complete training script for reinforcement learning
|
|
with PufferLib. Customize the environment, policy, and training configuration
|
|
as needed for your use case.
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
import pufferlib
|
|
from pufferlib import PuffeRL
|
|
from pufferlib.pytorch import layer_init
|
|
|
|
|
|
class Policy(nn.Module):
|
|
"""Example policy network."""
|
|
|
|
def __init__(self, observation_space, action_space, hidden_size=256):
|
|
super().__init__()
|
|
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
|
|
# Encoder network
|
|
self.encoder = nn.Sequential(
|
|
layer_init(nn.Linear(observation_space.shape[0], hidden_size)),
|
|
nn.ReLU(),
|
|
layer_init(nn.Linear(hidden_size, hidden_size)),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Policy head (actor)
|
|
self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01)
|
|
|
|
# Value head (critic)
|
|
self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0)
|
|
|
|
def forward(self, observations):
|
|
"""Forward pass through policy."""
|
|
features = self.encoder(observations)
|
|
logits = self.actor(features)
|
|
value = self.critic(features)
|
|
return logits, value
|
|
|
|
|
|
def make_env():
|
|
"""Create environment. Customize this for your task."""
|
|
# Option 1: Use Ocean environment
|
|
return pufferlib.make('procgen-coinrun', num_envs=256)
|
|
|
|
# Option 2: Use Gymnasium environment
|
|
# return pufferlib.make('gym-CartPole-v1', num_envs=256)
|
|
|
|
# Option 3: Use custom environment
|
|
# from my_envs import MyEnvironment
|
|
# return pufferlib.emulate(MyEnvironment, num_envs=256)
|
|
|
|
|
|
def create_policy(env):
|
|
"""Create policy network."""
|
|
return Policy(
|
|
observation_space=env.observation_space,
|
|
action_space=env.action_space,
|
|
hidden_size=256
|
|
)
|
|
|
|
|
|
def train(args):
|
|
"""Main training function."""
|
|
# Set random seeds
|
|
torch.manual_seed(args.seed)
|
|
|
|
# Create environment
|
|
print(f"Creating environment with {args.num_envs} parallel environments...")
|
|
env = pufferlib.make(
|
|
args.env_name,
|
|
num_envs=args.num_envs,
|
|
num_workers=args.num_workers
|
|
)
|
|
|
|
# Create policy
|
|
print("Initializing policy...")
|
|
policy = create_policy(env)
|
|
|
|
if args.device == 'cuda' and torch.cuda.is_available():
|
|
policy = policy.cuda()
|
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
else:
|
|
print("Using CPU")
|
|
|
|
# Create logger
|
|
if args.logger == 'wandb':
|
|
from pufferlib import WandbLogger
|
|
logger = WandbLogger(
|
|
project=args.project,
|
|
name=args.exp_name,
|
|
config=vars(args)
|
|
)
|
|
elif args.logger == 'neptune':
|
|
from pufferlib import NeptuneLogger
|
|
logger = NeptuneLogger(
|
|
project=args.project,
|
|
name=args.exp_name,
|
|
api_token=args.neptune_token
|
|
)
|
|
else:
|
|
from pufferlib import NoLogger
|
|
logger = NoLogger()
|
|
|
|
# Create trainer
|
|
print("Creating trainer...")
|
|
trainer = PuffeRL(
|
|
env=env,
|
|
policy=policy,
|
|
device=args.device,
|
|
learning_rate=args.learning_rate,
|
|
batch_size=args.batch_size,
|
|
n_epochs=args.n_epochs,
|
|
gamma=args.gamma,
|
|
gae_lambda=args.gae_lambda,
|
|
clip_coef=args.clip_coef,
|
|
ent_coef=args.ent_coef,
|
|
vf_coef=args.vf_coef,
|
|
max_grad_norm=args.max_grad_norm,
|
|
logger=logger,
|
|
compile=args.compile
|
|
)
|
|
|
|
# Training loop
|
|
print(f"Starting training for {args.num_iterations} iterations...")
|
|
for iteration in range(1, args.num_iterations + 1):
|
|
# Collect rollouts
|
|
rollout_data = trainer.evaluate()
|
|
|
|
# Train on batch
|
|
train_metrics = trainer.train()
|
|
|
|
# Log results
|
|
trainer.mean_and_log()
|
|
|
|
# Save checkpoint
|
|
if iteration % args.save_freq == 0:
|
|
checkpoint_path = f"{args.checkpoint_dir}/checkpoint_{iteration}.pt"
|
|
trainer.save_checkpoint(checkpoint_path)
|
|
print(f"Saved checkpoint to {checkpoint_path}")
|
|
|
|
# Print progress
|
|
if iteration % args.log_freq == 0:
|
|
mean_reward = rollout_data.get('mean_reward', 0)
|
|
sps = rollout_data.get('sps', 0)
|
|
print(f"Iteration {iteration}/{args.num_iterations} | "
|
|
f"Mean Reward: {mean_reward:.2f} | "
|
|
f"SPS: {sps:,.0f}")
|
|
|
|
print("Training complete!")
|
|
|
|
# Save final model
|
|
final_path = f"{args.checkpoint_dir}/final_model.pt"
|
|
trainer.save_checkpoint(final_path)
|
|
print(f"Saved final model to {final_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='PufferLib Training')
|
|
|
|
# Environment
|
|
parser.add_argument('--env-name', type=str, default='procgen-coinrun',
|
|
help='Environment name')
|
|
parser.add_argument('--num-envs', type=int, default=256,
|
|
help='Number of parallel environments')
|
|
parser.add_argument('--num-workers', type=int, default=8,
|
|
help='Number of vectorization workers')
|
|
|
|
# Training
|
|
parser.add_argument('--num-iterations', type=int, default=10000,
|
|
help='Number of training iterations')
|
|
parser.add_argument('--learning-rate', type=float, default=3e-4,
|
|
help='Learning rate')
|
|
parser.add_argument('--batch-size', type=int, default=32768,
|
|
help='Batch size for training')
|
|
parser.add_argument('--n-epochs', type=int, default=4,
|
|
help='Number of training epochs per batch')
|
|
parser.add_argument('--device', type=str, default='cuda',
|
|
choices=['cuda', 'cpu'], help='Device to use')
|
|
|
|
# PPO Parameters
|
|
parser.add_argument('--gamma', type=float, default=0.99,
|
|
help='Discount factor')
|
|
parser.add_argument('--gae-lambda', type=float, default=0.95,
|
|
help='GAE lambda')
|
|
parser.add_argument('--clip-coef', type=float, default=0.2,
|
|
help='PPO clipping coefficient')
|
|
parser.add_argument('--ent-coef', type=float, default=0.01,
|
|
help='Entropy coefficient')
|
|
parser.add_argument('--vf-coef', type=float, default=0.5,
|
|
help='Value function coefficient')
|
|
parser.add_argument('--max-grad-norm', type=float, default=0.5,
|
|
help='Maximum gradient norm')
|
|
|
|
# Logging
|
|
parser.add_argument('--logger', type=str, default='none',
|
|
choices=['wandb', 'neptune', 'none'],
|
|
help='Logger to use')
|
|
parser.add_argument('--project', type=str, default='pufferlib-training',
|
|
help='Project name for logging')
|
|
parser.add_argument('--exp-name', type=str, default='experiment',
|
|
help='Experiment name')
|
|
parser.add_argument('--neptune-token', type=str, default=None,
|
|
help='Neptune API token')
|
|
parser.add_argument('--log-freq', type=int, default=10,
|
|
help='Logging frequency (iterations)')
|
|
|
|
# Checkpointing
|
|
parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
|
|
help='Directory to save checkpoints')
|
|
parser.add_argument('--save-freq', type=int, default=100,
|
|
help='Checkpoint save frequency (iterations)')
|
|
|
|
# Misc
|
|
parser.add_argument('--seed', type=int, default=42,
|
|
help='Random seed')
|
|
parser.add_argument('--compile', action='store_true',
|
|
help='Use torch.compile for faster training')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create checkpoint directory
|
|
import os
|
|
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
|
|
|
# Run training
|
|
train(args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|