Files
gh-k-dense-ai-claude-scient…/skills/pufferlib/references/policies.md
2025-11-30 08:30:10 +08:00

19 KiB

PufferLib Policies Guide

Overview

PufferLib policies are standard PyTorch modules with optional utilities for observation processing and LSTM integration. The framework provides default architectures and tools while allowing full flexibility in policy design.

Policy Architecture

Basic Policy Structure

import torch
import torch.nn as nn
from pufferlib.pytorch import layer_init

class BasicPolicy(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        self.observation_space = observation_space
        self.action_space = action_space

        # Encoder network
        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 256)),
            nn.ReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.ReLU()
        )

        # Policy head (actor)
        self.actor = layer_init(nn.Linear(256, action_space.n), std=0.01)

        # Value head (critic)
        self.critic = layer_init(nn.Linear(256, 1), std=1.0)

    def forward(self, observations):
        """Forward pass through policy."""
        # Encode observations
        features = self.encoder(observations)

        # Get action logits and value
        logits = self.actor(features)
        value = self.critic(features)

        return logits, value

    def get_action(self, observations, deterministic=False):
        """Sample action from policy."""
        logits, value = self.forward(observations)

        if deterministic:
            action = logits.argmax(dim=-1)
        else:
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()

        return action, value

Layer Initialization

PufferLib provides layer_init for proper weight initialization:

from pufferlib.pytorch import layer_init

# Default orthogonal initialization
layer = layer_init(nn.Linear(256, 256))

# Custom standard deviation
actor_head = layer_init(nn.Linear(256, num_actions), std=0.01)
critic_head = layer_init(nn.Linear(256, 1), std=1.0)

# Works with any layer type
conv = layer_init(nn.Conv2d(3, 32, kernel_size=8, stride=4))

CNN Policies

For image-based observations:

class CNNPolicy(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        # CNN encoder for images
        self.encoder = nn.Sequential(
            layer_init(nn.Conv2d(3, 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU()
        )

        self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1.0)

    def forward(self, observations):
        # Normalize pixel values
        x = observations.float() / 255.0

        features = self.encoder(x)
        logits = self.actor(features)
        value = self.critic(features)

        return logits, value

Efficient CNN Architecture

class EfficientCNN(nn.Module):
    """Optimized CNN for Atari-style games."""

    def __init__(self, observation_space, action_space):
        super().__init__()

        in_channels = observation_space.shape[0]  # Typically 4 for framestack

        self.network = nn.Sequential(
            layer_init(nn.Conv2d(in_channels, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten()
        )

        # Calculate feature size
        with torch.no_grad():
            sample = torch.zeros(1, *observation_space.shape)
            n_features = self.network(sample).shape[1]

        self.fc = layer_init(nn.Linear(n_features, 512))
        self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1.0)

    def forward(self, x):
        x = x.float() / 255.0
        x = self.network(x)
        x = torch.relu(self.fc(x))

        return self.actor(x), self.critic(x)

Recurrent Policies (LSTM)

PufferLib provides optimized LSTM integration with automatic recurrence handling:

from pufferlib.pytorch import LSTMWrapper

class RecurrentPolicy(nn.Module):
    def __init__(self, observation_space, action_space, hidden_size=256):
        super().__init__()

        # Observation encoder
        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 128)),
            nn.ReLU()
        )

        # LSTM layer
        self.lstm = nn.LSTM(128, hidden_size, num_layers=1)

        # Policy and value heads
        self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0)

        # Hidden state
        self.hidden_size = hidden_size

    def forward(self, observations, state=None):
        """
        Args:
            observations: (batch, obs_dim)
            state: Optional (h, c) tuple for LSTM

        Returns:
            logits, value, new_state
        """
        batch_size = observations.shape[0]

        # Encode observations
        features = self.encoder(observations)

        # Initialize hidden state if needed
        if state is None:
            h = torch.zeros(1, batch_size, self.hidden_size, device=features.device)
            c = torch.zeros(1, batch_size, self.hidden_size, device=features.device)
            state = (h, c)

        # LSTM forward
        features = features.unsqueeze(0)  # Add sequence dimension
        lstm_out, new_state = self.lstm(features, state)
        lstm_out = lstm_out.squeeze(0)

        # Get outputs
        logits = self.actor(lstm_out)
        value = self.critic(lstm_out)

        return logits, value, new_state

LSTM Optimization

PufferLib's LSTM optimization uses LSTMCell during rollouts and LSTM during training for up to 3x faster inference:

class OptimizedLSTMPolicy(nn.Module):
    def __init__(self, observation_space, action_space, hidden_size=256):
        super().__init__()

        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 128)),
            nn.ReLU()
        )

        # Use LSTMCell for step-by-step inference
        self.lstm_cell = nn.LSTMCell(128, hidden_size)

        # Use LSTM for batch training
        self.lstm = nn.LSTM(128, hidden_size, num_layers=1)

        self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0)

        self.hidden_size = hidden_size

    def encode_observations(self, observations, state):
        """Fast inference using LSTMCell."""
        features = self.encoder(observations)

        if state is None:
            h = torch.zeros(observations.shape[0], self.hidden_size, device=features.device)
            c = torch.zeros(observations.shape[0], self.hidden_size, device=features.device)
        else:
            h, c = state

        # Step-by-step with LSTMCell (faster for inference)
        h, c = self.lstm_cell(features, (h, c))

        logits = self.actor(h)
        value = self.critic(h)

        return logits, value, (h, c)

    def decode_actions(self, observations, actions, state):
        """Batch training using LSTM."""
        seq_len, batch_size = observations.shape[:2]

        # Reshape for LSTM
        obs_flat = observations.reshape(seq_len * batch_size, -1)
        features = self.encoder(obs_flat)
        features = features.reshape(seq_len, batch_size, -1)

        if state is None:
            h = torch.zeros(1, batch_size, self.hidden_size, device=features.device)
            c = torch.zeros(1, batch_size, self.hidden_size, device=features.device)
            state = (h, c)

        # Batch processing with LSTM (faster for training)
        lstm_out, new_state = self.lstm(features, state)

        # Flatten back
        lstm_out = lstm_out.reshape(seq_len * batch_size, -1)

        logits = self.actor(lstm_out)
        value = self.critic(lstm_out)

        return logits, value, new_state

Multi-Input Policies

For environments with multiple observation types:

class MultiInputPolicy(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        # Separate encoders for different observation types
        self.image_encoder = nn.Sequential(
            layer_init(nn.Conv2d(3, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            nn.Flatten()
        )

        self.vector_encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space['vector'].shape[0], 128)),
            nn.ReLU()
        )

        # Combine features
        combined_size = 64 * 9 * 9 + 128  # Image features + vector features
        self.combiner = nn.Sequential(
            layer_init(nn.Linear(combined_size, 512)),
            nn.ReLU()
        )

        self.actor = layer_init(nn.Linear(512, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1.0)

    def forward(self, observations):
        # Process each observation type
        image_features = self.image_encoder(observations['image'].float() / 255.0)
        vector_features = self.vector_encoder(observations['vector'])

        # Combine
        combined = torch.cat([image_features, vector_features], dim=-1)
        features = self.combiner(combined)

        return self.actor(features), self.critic(features)

Continuous Action Policies

For continuous control tasks:

class ContinuousPolicy(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 256)),
            nn.ReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.ReLU()
        )

        # Mean of action distribution
        self.actor_mean = layer_init(nn.Linear(256, action_space.shape[0]), std=0.01)

        # Log std of action distribution
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_space.shape[0]))

        # Value head
        self.critic = layer_init(nn.Linear(256, 1), std=1.0)

    def forward(self, observations):
        features = self.encoder(observations)

        action_mean = self.actor_mean(features)
        action_std = torch.exp(self.actor_logstd)

        value = self.critic(features)

        return action_mean, action_std, value

    def get_action(self, observations, deterministic=False):
        action_mean, action_std, value = self.forward(observations)

        if deterministic:
            return action_mean, value
        else:
            dist = torch.distributions.Normal(action_mean, action_std)
            action = dist.sample()
            return torch.tanh(action), value  # Bound actions to [-1, 1]

Observation Processing

PufferLib provides utilities for unflattening observations:

from pufferlib.pytorch import unflatten_observations

class PolicyWithUnflatten(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        self.observation_space = observation_space

        # Define encoders for each observation component
        self.encoders = nn.ModuleDict({
            'image': self._make_image_encoder(),
            'vector': self._make_vector_encoder()
        })

        # ... rest of policy ...

    def forward(self, flat_observations):
        # Unflatten observations into structured format
        observations = unflatten_observations(
            flat_observations,
            self.observation_space
        )

        # Process each component
        image_features = self.encoders['image'](observations['image'])
        vector_features = self.encoders['vector'](observations['vector'])

        # Combine and continue...

Multi-Agent Policies

Shared Parameters

All agents use the same policy:

class SharedMultiAgentPolicy(nn.Module):
    def __init__(self, observation_space, action_space, num_agents):
        super().__init__()

        self.num_agents = num_agents

        # Single policy shared across all agents
        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 256)),
            nn.ReLU()
        )

        self.actor = layer_init(nn.Linear(256, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(256, 1), std=1.0)

    def forward(self, observations):
        """
        Args:
            observations: (batch * num_agents, obs_dim)
        Returns:
            logits: (batch * num_agents, num_actions)
            values: (batch * num_agents, 1)
        """
        features = self.encoder(observations)
        return self.actor(features), self.critic(features)

Independent Parameters

Each agent has its own policy:

class IndependentMultiAgentPolicy(nn.Module):
    def __init__(self, observation_space, action_space, num_agents):
        super().__init__()

        self.num_agents = num_agents

        # Separate policy for each agent
        self.policies = nn.ModuleList([
            self._make_policy(observation_space, action_space)
            for _ in range(num_agents)
        ])

    def _make_policy(self, observation_space, action_space):
        return nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], 256)),
            nn.ReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.ReLU()
        )

    def forward(self, observations, agent_ids):
        """
        Args:
            observations: (batch, obs_dim)
            agent_ids: (batch,) which agent each obs belongs to
        """
        outputs = []
        for agent_id in range(self.num_agents):
            mask = agent_ids == agent_id
            if mask.any():
                agent_obs = observations[mask]
                agent_out = self.policies[agent_id](agent_obs)
                outputs.append(agent_out)

        return torch.cat(outputs, dim=0)

Advanced Architectures

Attention-Based Policy

class AttentionPolicy(nn.Module):
    def __init__(self, observation_space, action_space, d_model=256, nhead=8):
        super().__init__()

        self.encoder = layer_init(nn.Linear(observation_space.shape[0], d_model))

        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)

        self.actor = layer_init(nn.Linear(d_model, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(d_model, 1), std=1.0)

    def forward(self, observations):
        # Encode
        features = self.encoder(observations)

        # Self-attention
        features = features.unsqueeze(1)  # Add sequence dimension
        attn_out, _ = self.attention(features, features, features)
        attn_out = attn_out.squeeze(1)

        return self.actor(attn_out), self.critic(attn_out)

Residual Policy

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            layer_init(nn.Linear(dim, dim)),
            nn.ReLU(),
            layer_init(nn.Linear(dim, dim))
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualPolicy(nn.Module):
    def __init__(self, observation_space, action_space, num_blocks=4):
        super().__init__()

        dim = 256

        self.encoder = layer_init(nn.Linear(observation_space.shape[0], dim))

        self.blocks = nn.Sequential(
            *[ResidualBlock(dim) for _ in range(num_blocks)]
        )

        self.actor = layer_init(nn.Linear(dim, action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(dim, 1), std=1.0)

    def forward(self, observations):
        x = torch.relu(self.encoder(observations))
        x = self.blocks(x)
        return self.actor(x), self.critic(x)

Policy Best Practices

Initialization

# Always use layer_init for proper initialization
good_layer = layer_init(nn.Linear(256, 256))

# Use small std for actor head (more stable early training)
actor = layer_init(nn.Linear(256, num_actions), std=0.01)

# Use std=1.0 for critic head
critic = layer_init(nn.Linear(256, 1), std=1.0)

Observation Normalization

class NormalizedPolicy(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()

        # Running statistics for normalization
        self.obs_mean = nn.Parameter(torch.zeros(observation_space.shape[0]), requires_grad=False)
        self.obs_std = nn.Parameter(torch.ones(observation_space.shape[0]), requires_grad=False)

        # ... rest of policy ...

    def forward(self, observations):
        # Normalize observations
        normalized_obs = (observations - self.obs_mean) / (self.obs_std + 1e-8)

        # Continue with normalized observations
        return self.policy(normalized_obs)

    def update_normalization(self, observations):
        """Update running statistics."""
        self.obs_mean.data = observations.mean(dim=0)
        self.obs_std.data = observations.std(dim=0)

Gradient Clipping

# PufferLib trainer handles gradient clipping automatically
trainer = PuffeRL(
    env=env,
    policy=policy,
    max_grad_norm=0.5  # Clip gradients to this norm
)

Model Compilation

# Enable torch.compile for faster training (PyTorch 2.0+)
policy = MyPolicy(observation_space, action_space)

# Compile the model
policy = torch.compile(policy, mode='reduce-overhead')

# Use with trainer
trainer = PuffeRL(env=env, policy=policy, compile=True)

Debugging Policies

Check Output Shapes

def test_policy_shapes(policy, observation_space, batch_size=32):
    """Verify policy output shapes."""
    # Create dummy observations
    obs = torch.randn(batch_size, *observation_space.shape)

    # Forward pass
    logits, value = policy(obs)

    # Check shapes
    assert logits.shape == (batch_size, policy.action_space.n)
    assert value.shape == (batch_size, 1)

    print("✓ Policy shapes correct")

Verify Gradients

def check_gradients(policy, observation_space):
    """Check that gradients flow properly."""
    obs = torch.randn(1, *observation_space.shape, requires_grad=True)

    logits, value = policy(obs)

    # Backward pass
    loss = logits.sum() + value.sum()
    loss.backward()

    # Check gradients exist
    for name, param in policy.named_parameters():
        if param.grad is None:
            print(f"⚠ No gradient for {name}")
        elif torch.isnan(param.grad).any():
            print(f"⚠ NaN gradient for {name}")
        else:
            print(f"✓ Gradient OK for {name}")