# PufferLib Policies Guide ## Overview PufferLib policies are standard PyTorch modules with optional utilities for observation processing and LSTM integration. The framework provides default architectures and tools while allowing full flexibility in policy design. ## Policy Architecture ### Basic Policy Structure ```python import torch import torch.nn as nn from pufferlib.pytorch import layer_init class BasicPolicy(nn.Module): def __init__(self, observation_space, action_space): super().__init__() self.observation_space = observation_space self.action_space = action_space # Encoder network self.encoder = nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 256)), nn.ReLU(), layer_init(nn.Linear(256, 256)), nn.ReLU() ) # Policy head (actor) self.actor = layer_init(nn.Linear(256, action_space.n), std=0.01) # Value head (critic) self.critic = layer_init(nn.Linear(256, 1), std=1.0) def forward(self, observations): """Forward pass through policy.""" # Encode observations features = self.encoder(observations) # Get action logits and value logits = self.actor(features) value = self.critic(features) return logits, value def get_action(self, observations, deterministic=False): """Sample action from policy.""" logits, value = self.forward(observations) if deterministic: action = logits.argmax(dim=-1) else: dist = torch.distributions.Categorical(logits=logits) action = dist.sample() return action, value ``` ### Layer Initialization PufferLib provides `layer_init` for proper weight initialization: ```python from pufferlib.pytorch import layer_init # Default orthogonal initialization layer = layer_init(nn.Linear(256, 256)) # Custom standard deviation actor_head = layer_init(nn.Linear(256, num_actions), std=0.01) critic_head = layer_init(nn.Linear(256, 1), std=1.0) # Works with any layer type conv = layer_init(nn.Conv2d(3, 32, kernel_size=8, stride=4)) ``` ## CNN Policies For image-based observations: ```python class CNNPolicy(nn.Module): def __init__(self, observation_space, action_space): super().__init__() # CNN encoder for images self.encoder = nn.Sequential( layer_init(nn.Conv2d(3, 32, kernel_size=8, stride=4)), nn.ReLU(), layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), nn.ReLU(), layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), nn.ReLU(), nn.Flatten(), layer_init(nn.Linear(64 * 7 * 7, 512)), nn.ReLU() ) self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(512, 1), std=1.0) def forward(self, observations): # Normalize pixel values x = observations.float() / 255.0 features = self.encoder(x) logits = self.actor(features) value = self.critic(features) return logits, value ``` ### Efficient CNN Architecture ```python class EfficientCNN(nn.Module): """Optimized CNN for Atari-style games.""" def __init__(self, observation_space, action_space): super().__init__() in_channels = observation_space.shape[0] # Typically 4 for framestack self.network = nn.Sequential( layer_init(nn.Conv2d(in_channels, 32, 8, stride=4)), nn.ReLU(), layer_init(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), layer_init(nn.Conv2d(64, 64, 3, stride=1)), nn.ReLU(), nn.Flatten() ) # Calculate feature size with torch.no_grad(): sample = torch.zeros(1, *observation_space.shape) n_features = self.network(sample).shape[1] self.fc = layer_init(nn.Linear(n_features, 512)) self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(512, 1), std=1.0) def forward(self, x): x = x.float() / 255.0 x = self.network(x) x = torch.relu(self.fc(x)) return self.actor(x), self.critic(x) ``` ## Recurrent Policies (LSTM) PufferLib provides optimized LSTM integration with automatic recurrence handling: ```python from pufferlib.pytorch import LSTMWrapper class RecurrentPolicy(nn.Module): def __init__(self, observation_space, action_space, hidden_size=256): super().__init__() # Observation encoder self.encoder = nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 128)), nn.ReLU() ) # LSTM layer self.lstm = nn.LSTM(128, hidden_size, num_layers=1) # Policy and value heads self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0) # Hidden state self.hidden_size = hidden_size def forward(self, observations, state=None): """ Args: observations: (batch, obs_dim) state: Optional (h, c) tuple for LSTM Returns: logits, value, new_state """ batch_size = observations.shape[0] # Encode observations features = self.encoder(observations) # Initialize hidden state if needed if state is None: h = torch.zeros(1, batch_size, self.hidden_size, device=features.device) c = torch.zeros(1, batch_size, self.hidden_size, device=features.device) state = (h, c) # LSTM forward features = features.unsqueeze(0) # Add sequence dimension lstm_out, new_state = self.lstm(features, state) lstm_out = lstm_out.squeeze(0) # Get outputs logits = self.actor(lstm_out) value = self.critic(lstm_out) return logits, value, new_state ``` ### LSTM Optimization PufferLib's LSTM optimization uses LSTMCell during rollouts and LSTM during training for up to 3x faster inference: ```python class OptimizedLSTMPolicy(nn.Module): def __init__(self, observation_space, action_space, hidden_size=256): super().__init__() self.encoder = nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 128)), nn.ReLU() ) # Use LSTMCell for step-by-step inference self.lstm_cell = nn.LSTMCell(128, hidden_size) # Use LSTM for batch training self.lstm = nn.LSTM(128, hidden_size, num_layers=1) self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0) self.hidden_size = hidden_size def encode_observations(self, observations, state): """Fast inference using LSTMCell.""" features = self.encoder(observations) if state is None: h = torch.zeros(observations.shape[0], self.hidden_size, device=features.device) c = torch.zeros(observations.shape[0], self.hidden_size, device=features.device) else: h, c = state # Step-by-step with LSTMCell (faster for inference) h, c = self.lstm_cell(features, (h, c)) logits = self.actor(h) value = self.critic(h) return logits, value, (h, c) def decode_actions(self, observations, actions, state): """Batch training using LSTM.""" seq_len, batch_size = observations.shape[:2] # Reshape for LSTM obs_flat = observations.reshape(seq_len * batch_size, -1) features = self.encoder(obs_flat) features = features.reshape(seq_len, batch_size, -1) if state is None: h = torch.zeros(1, batch_size, self.hidden_size, device=features.device) c = torch.zeros(1, batch_size, self.hidden_size, device=features.device) state = (h, c) # Batch processing with LSTM (faster for training) lstm_out, new_state = self.lstm(features, state) # Flatten back lstm_out = lstm_out.reshape(seq_len * batch_size, -1) logits = self.actor(lstm_out) value = self.critic(lstm_out) return logits, value, new_state ``` ## Multi-Input Policies For environments with multiple observation types: ```python class MultiInputPolicy(nn.Module): def __init__(self, observation_space, action_space): super().__init__() # Separate encoders for different observation types self.image_encoder = nn.Sequential( layer_init(nn.Conv2d(3, 32, 8, stride=4)), nn.ReLU(), layer_init(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), nn.Flatten() ) self.vector_encoder = nn.Sequential( layer_init(nn.Linear(observation_space['vector'].shape[0], 128)), nn.ReLU() ) # Combine features combined_size = 64 * 9 * 9 + 128 # Image features + vector features self.combiner = nn.Sequential( layer_init(nn.Linear(combined_size, 512)), nn.ReLU() ) self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(512, 1), std=1.0) def forward(self, observations): # Process each observation type image_features = self.image_encoder(observations['image'].float() / 255.0) vector_features = self.vector_encoder(observations['vector']) # Combine combined = torch.cat([image_features, vector_features], dim=-1) features = self.combiner(combined) return self.actor(features), self.critic(features) ``` ## Continuous Action Policies For continuous control tasks: ```python class ContinuousPolicy(nn.Module): def __init__(self, observation_space, action_space): super().__init__() self.encoder = nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 256)), nn.ReLU(), layer_init(nn.Linear(256, 256)), nn.ReLU() ) # Mean of action distribution self.actor_mean = layer_init(nn.Linear(256, action_space.shape[0]), std=0.01) # Log std of action distribution self.actor_logstd = nn.Parameter(torch.zeros(1, action_space.shape[0])) # Value head self.critic = layer_init(nn.Linear(256, 1), std=1.0) def forward(self, observations): features = self.encoder(observations) action_mean = self.actor_mean(features) action_std = torch.exp(self.actor_logstd) value = self.critic(features) return action_mean, action_std, value def get_action(self, observations, deterministic=False): action_mean, action_std, value = self.forward(observations) if deterministic: return action_mean, value else: dist = torch.distributions.Normal(action_mean, action_std) action = dist.sample() return torch.tanh(action), value # Bound actions to [-1, 1] ``` ## Observation Processing PufferLib provides utilities for unflattening observations: ```python from pufferlib.pytorch import unflatten_observations class PolicyWithUnflatten(nn.Module): def __init__(self, observation_space, action_space): super().__init__() self.observation_space = observation_space # Define encoders for each observation component self.encoders = nn.ModuleDict({ 'image': self._make_image_encoder(), 'vector': self._make_vector_encoder() }) # ... rest of policy ... def forward(self, flat_observations): # Unflatten observations into structured format observations = unflatten_observations( flat_observations, self.observation_space ) # Process each component image_features = self.encoders['image'](observations['image']) vector_features = self.encoders['vector'](observations['vector']) # Combine and continue... ``` ## Multi-Agent Policies ### Shared Parameters All agents use the same policy: ```python class SharedMultiAgentPolicy(nn.Module): def __init__(self, observation_space, action_space, num_agents): super().__init__() self.num_agents = num_agents # Single policy shared across all agents self.encoder = nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 256)), nn.ReLU() ) self.actor = layer_init(nn.Linear(256, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(256, 1), std=1.0) def forward(self, observations): """ Args: observations: (batch * num_agents, obs_dim) Returns: logits: (batch * num_agents, num_actions) values: (batch * num_agents, 1) """ features = self.encoder(observations) return self.actor(features), self.critic(features) ``` ### Independent Parameters Each agent has its own policy: ```python class IndependentMultiAgentPolicy(nn.Module): def __init__(self, observation_space, action_space, num_agents): super().__init__() self.num_agents = num_agents # Separate policy for each agent self.policies = nn.ModuleList([ self._make_policy(observation_space, action_space) for _ in range(num_agents) ]) def _make_policy(self, observation_space, action_space): return nn.Sequential( layer_init(nn.Linear(observation_space.shape[0], 256)), nn.ReLU(), layer_init(nn.Linear(256, 256)), nn.ReLU() ) def forward(self, observations, agent_ids): """ Args: observations: (batch, obs_dim) agent_ids: (batch,) which agent each obs belongs to """ outputs = [] for agent_id in range(self.num_agents): mask = agent_ids == agent_id if mask.any(): agent_obs = observations[mask] agent_out = self.policies[agent_id](agent_obs) outputs.append(agent_out) return torch.cat(outputs, dim=0) ``` ## Advanced Architectures ### Attention-Based Policy ```python class AttentionPolicy(nn.Module): def __init__(self, observation_space, action_space, d_model=256, nhead=8): super().__init__() self.encoder = layer_init(nn.Linear(observation_space.shape[0], d_model)) self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True) self.actor = layer_init(nn.Linear(d_model, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(d_model, 1), std=1.0) def forward(self, observations): # Encode features = self.encoder(observations) # Self-attention features = features.unsqueeze(1) # Add sequence dimension attn_out, _ = self.attention(features, features, features) attn_out = attn_out.squeeze(1) return self.actor(attn_out), self.critic(attn_out) ``` ### Residual Policy ```python class ResidualBlock(nn.Module): def __init__(self, dim): super().__init__() self.block = nn.Sequential( layer_init(nn.Linear(dim, dim)), nn.ReLU(), layer_init(nn.Linear(dim, dim)) ) def forward(self, x): return x + self.block(x) class ResidualPolicy(nn.Module): def __init__(self, observation_space, action_space, num_blocks=4): super().__init__() dim = 256 self.encoder = layer_init(nn.Linear(observation_space.shape[0], dim)) self.blocks = nn.Sequential( *[ResidualBlock(dim) for _ in range(num_blocks)] ) self.actor = layer_init(nn.Linear(dim, action_space.n), std=0.01) self.critic = layer_init(nn.Linear(dim, 1), std=1.0) def forward(self, observations): x = torch.relu(self.encoder(observations)) x = self.blocks(x) return self.actor(x), self.critic(x) ``` ## Policy Best Practices ### Initialization ```python # Always use layer_init for proper initialization good_layer = layer_init(nn.Linear(256, 256)) # Use small std for actor head (more stable early training) actor = layer_init(nn.Linear(256, num_actions), std=0.01) # Use std=1.0 for critic head critic = layer_init(nn.Linear(256, 1), std=1.0) ``` ### Observation Normalization ```python class NormalizedPolicy(nn.Module): def __init__(self, observation_space, action_space): super().__init__() # Running statistics for normalization self.obs_mean = nn.Parameter(torch.zeros(observation_space.shape[0]), requires_grad=False) self.obs_std = nn.Parameter(torch.ones(observation_space.shape[0]), requires_grad=False) # ... rest of policy ... def forward(self, observations): # Normalize observations normalized_obs = (observations - self.obs_mean) / (self.obs_std + 1e-8) # Continue with normalized observations return self.policy(normalized_obs) def update_normalization(self, observations): """Update running statistics.""" self.obs_mean.data = observations.mean(dim=0) self.obs_std.data = observations.std(dim=0) ``` ### Gradient Clipping ```python # PufferLib trainer handles gradient clipping automatically trainer = PuffeRL( env=env, policy=policy, max_grad_norm=0.5 # Clip gradients to this norm ) ``` ### Model Compilation ```python # Enable torch.compile for faster training (PyTorch 2.0+) policy = MyPolicy(observation_space, action_space) # Compile the model policy = torch.compile(policy, mode='reduce-overhead') # Use with trainer trainer = PuffeRL(env=env, policy=policy, compile=True) ``` ## Debugging Policies ### Check Output Shapes ```python def test_policy_shapes(policy, observation_space, batch_size=32): """Verify policy output shapes.""" # Create dummy observations obs = torch.randn(batch_size, *observation_space.shape) # Forward pass logits, value = policy(obs) # Check shapes assert logits.shape == (batch_size, policy.action_space.n) assert value.shape == (batch_size, 1) print("✓ Policy shapes correct") ``` ### Verify Gradients ```python def check_gradients(policy, observation_space): """Check that gradients flow properly.""" obs = torch.randn(1, *observation_space.shape, requires_grad=True) logits, value = policy(obs) # Backward pass loss = logits.sum() + value.sum() loss.backward() # Check gradients exist for name, param in policy.named_parameters(): if param.grad is None: print(f"⚠ No gradient for {name}") elif torch.isnan(param.grad).any(): print(f"⚠ NaN gradient for {name}") else: print(f"✓ Gradient OK for {name}") ```