Files
2025-11-30 08:30:10 +08:00

10 KiB

PyMC Sampling and Inference Methods

This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.

MCMC Sampling Methods

Primary Sampling Function

pm.sample(draws=1000, tune=1000, chains=4, **kwargs)

The main interface for MCMC sampling in PyMC.

Key Parameters:

  • draws: Number of samples to draw per chain (default: 1000)
  • tune: Number of tuning/warmup samples (default: 1000, discarded)
  • chains: Number of parallel chains (default: 4)
  • cores: Number of CPU cores to use (default: all available)
  • target_accept: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)
  • random_seed: Random seed for reproducibility
  • return_inferencedata: Return ArviZ InferenceData object (default: True)
  • idata_kwargs: Additional kwargs for InferenceData creation (e.g., {"log_likelihood": True} for model comparison)

Returns: InferenceData object containing posterior samples, sampling statistics, and diagnostics

Example:

with pm.Model() as model:
    # ... define model ...
    idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)

Sampling Algorithms

PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.

NUTS (No-U-Turn Sampler)

Default algorithm for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.

  • Automatically tunes step size and mass matrix
  • Adaptive: explores posterior geometry during tuning
  • Best for smooth, continuous posteriors
  • Can struggle with high correlation or multimodality

Manual specification:

with model:
    idata = pm.sample(step=pm.NUTS(target_accept=0.95))

When to adjust:

  • Increase target_accept (0.9-0.99) if seeing divergences
  • Use init='adapt_diag' for faster initialization (default)
  • Use init='jitter+adapt_diag' for difficult initializations

Metropolis

General-purpose Metropolis-Hastings sampler.

  • Works for both continuous and discrete variables
  • Less efficient than NUTS for smooth continuous posteriors
  • Useful for discrete parameters or non-differentiable models
  • Requires manual tuning

Example:

with model:
    idata = pm.sample(step=pm.Metropolis())

Slice Sampler

Slice sampling for univariate distributions.

  • No tuning required
  • Good for difficult univariate posteriors
  • Can be slow for high dimensions

Example:

with model:
    idata = pm.sample(step=pm.Slice())

CompoundStep

Combine different samplers for different parameters.

Example:

with model:
    # Use NUTS for continuous params, Metropolis for discrete
    step1 = pm.NUTS([continuous_var1, continuous_var2])
    step2 = pm.Metropolis([discrete_var])
    idata = pm.sample(step=[step1, step2])

Sampling Diagnostics

PyMC automatically computes diagnostics. Check these before trusting results:

Effective Sample Size (ESS)

Measures independent information in correlated samples.

  • Rule of thumb: ESS > 400 per chain (1600 total for 4 chains)
  • Low ESS indicates high autocorrelation
  • Access via: az.ess(idata)

R-hat (Gelman-Rubin statistic)

Measures convergence across chains.

  • Rule of thumb: R-hat < 1.01 for all parameters
  • R-hat > 1.01 indicates non-convergence
  • Access via: az.rhat(idata)

Divergences

Indicate regions where NUTS struggled.

  • Rule of thumb: 0 divergences (or very few)
  • Divergences suggest biased samples
  • Fix: Increase target_accept, reparameterize, or use stronger priors
  • Access via: idata.sample_stats.diverging.sum()

Energy Plot

Visualizes Hamiltonian Monte Carlo energy transitions.

az.plot_energy(idata)

Good separation between energy distributions indicates healthy sampling.

Handling Sampling Issues

Divergences

# Increase target acceptance rate
idata = pm.sample(target_accept=0.95)

# Or reparameterize using non-centered parameterization
# Bad (centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x = pm.Normal('x', mu, sigma, observed=data)

# Good (non-centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)

Slow Sampling

# Use fewer tuning steps if model is simple
idata = pm.sample(tune=500)

# Increase cores for parallelization
idata = pm.sample(cores=8, chains=8)

# Use variational inference for initialization
with model:
    approx = pm.fit()  # Run ADVI
    idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])

High Autocorrelation

# Increase draws
idata = pm.sample(draws=5000)

# Reparameterize to reduce correlation
# Consider using QR decomposition for regression models

Variational Inference

Faster approximate inference for large models or quick exploration.

ADVI (Automatic Differentiation Variational Inference)

pm.fit(n=10000, method='advi', **kwargs)

Approximates posterior with simpler distribution (typically mean-field Gaussian).

Key Parameters:

  • n: Number of iterations (default: 10000)
  • method: VI algorithm ('advi', 'fullrank_advi', 'svgd')
  • random_seed: Random seed

Returns: Approximation object for sampling and analysis

Example:

with model:
    approx = pm.fit(n=50000)
    # Draw samples from approximation
    idata = approx.sample(1000)
    # Or sample for MCMC initialization
    start = approx.sample(return_inferencedata=False)[0]

Trade-offs:

  • Pros: Much faster than MCMC, scales to large data
  • Cons: Approximate, may miss posterior structure, underestimates uncertainty

Full-Rank ADVI

Captures correlations between parameters.

with model:
    approx = pm.fit(method='fullrank_advi')

More accurate than mean-field but slower.

SVGD (Stein Variational Gradient Descent)

Non-parametric variational inference.

with model:
    approx = pm.fit(method='svgd', n=20000)

Better captures multimodality but more computationally expensive.

Prior and Posterior Predictive Sampling

Prior Predictive Sampling

Sample from the prior distribution (before seeing data).

pm.sample_prior_predictive(samples=500, **kwargs)

Purpose:

  • Validate priors are reasonable
  • Check implied predictions before fitting
  • Ensure model generates plausible data

Example:

with model:
    prior_pred = pm.sample_prior_predictive(samples=1000)

# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior')

Posterior Predictive Sampling

Sample from posterior predictive distribution (after fitting).

pm.sample_posterior_predictive(trace, **kwargs)

Purpose:

  • Model validation via posterior predictive checks
  • Generate predictions for new data
  • Assess goodness-of-fit

Example:

with model:
    # After sampling
    idata = pm.sample()

    # Add posterior predictive samples
    pm.sample_posterior_predictive(idata, extend_inferencedata=True)

# Posterior predictive check
az.plot_ppc(idata)

Predictions for New Data

Update data and sample predictive distribution:

with model:
    # Original model fit
    idata = pm.sample()

    # Update with new predictor values
    pm.set_data({'X': X_new})

    # Sample predictions
    post_pred_new = pm.sample_posterior_predictive(
        idata.posterior,
        var_names=['y_pred']
    )

Maximum A Posteriori (MAP) Estimation

Find posterior mode (point estimate).

pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)

When to use:

  • Quick point estimates
  • Initialization for MCMC
  • When full posterior not needed

Example:

with model:
    map_estimate = pm.find_MAP()
    print(map_estimate)

Limitations:

  • Doesn't quantify uncertainty
  • Can find local optima in multimodal posteriors
  • Sensitive to prior specification

Inference Recommendations

Standard Workflow

  1. Start with ADVI for quick exploration:

    approx = pm.fit(n=20000)
    
  2. Run MCMC for full inference:

    idata = pm.sample(draws=2000, tune=1000)
    
  3. Check diagnostics:

    az.summary(idata, var_names=['~mu_log__'])  # Exclude transformed vars
    
  4. Sample posterior predictive:

    pm.sample_posterior_predictive(idata, extend_inferencedata=True)
    

Choosing Inference Method

Scenario Recommended Method
Small-medium models, need full uncertainty MCMC with NUTS
Large models, initial exploration ADVI
Discrete parameters Metropolis or marginalize
Hierarchical models with divergences Non-centered parameterization + NUTS
Very large data Minibatch ADVI
Quick point estimates MAP or ADVI

Reparameterization Tricks

Non-centered parameterization for hierarchical models:

# Centered (can cause divergences):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta = pm.Normal('theta', mu, sigma, shape=n_groups)

# Non-centered (better sampling):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
theta = pm.Deterministic('theta', mu + sigma * theta_offset)

QR decomposition for correlated predictors:

import numpy as np

# QR decomposition
Q, R = np.linalg.qr(X)

with pm.Model():
    # Uncorrelated coefficients
    beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)

    # Transform back to original scale
    beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))

    mu = pm.math.dot(Q, beta_tilde)
    sigma = pm.HalfNormal('sigma', 1)
    y = pm.Normal('y', mu, sigma, observed=y_obs)

Advanced Sampling

Sequential Monte Carlo (SMC)

For complex posteriors or model evidence estimation:

with model:
    idata = pm.sample_smc(draws=2000, chains=4)

Good for multimodal posteriors or when NUTS struggles.

Custom Initialization

Provide starting values:

start = {'mu': 0, 'sigma': 1}
with model:
    idata = pm.sample(start=start)

Or use MAP estimate:

with model:
    start = pm.find_MAP()
    idata = pm.sample(start=start)