10 KiB
PyMC Sampling and Inference Methods
This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.
MCMC Sampling Methods
Primary Sampling Function
pm.sample(draws=1000, tune=1000, chains=4, **kwargs)
The main interface for MCMC sampling in PyMC.
Key Parameters:
draws: Number of samples to draw per chain (default: 1000)tune: Number of tuning/warmup samples (default: 1000, discarded)chains: Number of parallel chains (default: 4)cores: Number of CPU cores to use (default: all available)target_accept: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)random_seed: Random seed for reproducibilityreturn_inferencedata: Return ArviZ InferenceData object (default: True)idata_kwargs: Additional kwargs for InferenceData creation (e.g.,{"log_likelihood": True}for model comparison)
Returns: InferenceData object containing posterior samples, sampling statistics, and diagnostics
Example:
with pm.Model() as model:
# ... define model ...
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
Sampling Algorithms
PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.
NUTS (No-U-Turn Sampler)
Default algorithm for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.
- Automatically tunes step size and mass matrix
- Adaptive: explores posterior geometry during tuning
- Best for smooth, continuous posteriors
- Can struggle with high correlation or multimodality
Manual specification:
with model:
idata = pm.sample(step=pm.NUTS(target_accept=0.95))
When to adjust:
- Increase
target_accept(0.9-0.99) if seeing divergences - Use
init='adapt_diag'for faster initialization (default) - Use
init='jitter+adapt_diag'for difficult initializations
Metropolis
General-purpose Metropolis-Hastings sampler.
- Works for both continuous and discrete variables
- Less efficient than NUTS for smooth continuous posteriors
- Useful for discrete parameters or non-differentiable models
- Requires manual tuning
Example:
with model:
idata = pm.sample(step=pm.Metropolis())
Slice Sampler
Slice sampling for univariate distributions.
- No tuning required
- Good for difficult univariate posteriors
- Can be slow for high dimensions
Example:
with model:
idata = pm.sample(step=pm.Slice())
CompoundStep
Combine different samplers for different parameters.
Example:
with model:
# Use NUTS for continuous params, Metropolis for discrete
step1 = pm.NUTS([continuous_var1, continuous_var2])
step2 = pm.Metropolis([discrete_var])
idata = pm.sample(step=[step1, step2])
Sampling Diagnostics
PyMC automatically computes diagnostics. Check these before trusting results:
Effective Sample Size (ESS)
Measures independent information in correlated samples.
- Rule of thumb: ESS > 400 per chain (1600 total for 4 chains)
- Low ESS indicates high autocorrelation
- Access via:
az.ess(idata)
R-hat (Gelman-Rubin statistic)
Measures convergence across chains.
- Rule of thumb: R-hat < 1.01 for all parameters
- R-hat > 1.01 indicates non-convergence
- Access via:
az.rhat(idata)
Divergences
Indicate regions where NUTS struggled.
- Rule of thumb: 0 divergences (or very few)
- Divergences suggest biased samples
- Fix: Increase
target_accept, reparameterize, or use stronger priors - Access via:
idata.sample_stats.diverging.sum()
Energy Plot
Visualizes Hamiltonian Monte Carlo energy transitions.
az.plot_energy(idata)
Good separation between energy distributions indicates healthy sampling.
Handling Sampling Issues
Divergences
# Increase target acceptance rate
idata = pm.sample(target_accept=0.95)
# Or reparameterize using non-centered parameterization
# Bad (centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x = pm.Normal('x', mu, sigma, observed=data)
# Good (non-centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)
Slow Sampling
# Use fewer tuning steps if model is simple
idata = pm.sample(tune=500)
# Increase cores for parallelization
idata = pm.sample(cores=8, chains=8)
# Use variational inference for initialization
with model:
approx = pm.fit() # Run ADVI
idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])
High Autocorrelation
# Increase draws
idata = pm.sample(draws=5000)
# Reparameterize to reduce correlation
# Consider using QR decomposition for regression models
Variational Inference
Faster approximate inference for large models or quick exploration.
ADVI (Automatic Differentiation Variational Inference)
pm.fit(n=10000, method='advi', **kwargs)
Approximates posterior with simpler distribution (typically mean-field Gaussian).
Key Parameters:
n: Number of iterations (default: 10000)method: VI algorithm ('advi', 'fullrank_advi', 'svgd')random_seed: Random seed
Returns: Approximation object for sampling and analysis
Example:
with model:
approx = pm.fit(n=50000)
# Draw samples from approximation
idata = approx.sample(1000)
# Or sample for MCMC initialization
start = approx.sample(return_inferencedata=False)[0]
Trade-offs:
- Pros: Much faster than MCMC, scales to large data
- Cons: Approximate, may miss posterior structure, underestimates uncertainty
Full-Rank ADVI
Captures correlations between parameters.
with model:
approx = pm.fit(method='fullrank_advi')
More accurate than mean-field but slower.
SVGD (Stein Variational Gradient Descent)
Non-parametric variational inference.
with model:
approx = pm.fit(method='svgd', n=20000)
Better captures multimodality but more computationally expensive.
Prior and Posterior Predictive Sampling
Prior Predictive Sampling
Sample from the prior distribution (before seeing data).
pm.sample_prior_predictive(samples=500, **kwargs)
Purpose:
- Validate priors are reasonable
- Check implied predictions before fitting
- Ensure model generates plausible data
Example:
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior')
Posterior Predictive Sampling
Sample from posterior predictive distribution (after fitting).
pm.sample_posterior_predictive(trace, **kwargs)
Purpose:
- Model validation via posterior predictive checks
- Generate predictions for new data
- Assess goodness-of-fit
Example:
with model:
# After sampling
idata = pm.sample()
# Add posterior predictive samples
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
# Posterior predictive check
az.plot_ppc(idata)
Predictions for New Data
Update data and sample predictive distribution:
with model:
# Original model fit
idata = pm.sample()
# Update with new predictor values
pm.set_data({'X': X_new})
# Sample predictions
post_pred_new = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_pred']
)
Maximum A Posteriori (MAP) Estimation
Find posterior mode (point estimate).
pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)
When to use:
- Quick point estimates
- Initialization for MCMC
- When full posterior not needed
Example:
with model:
map_estimate = pm.find_MAP()
print(map_estimate)
Limitations:
- Doesn't quantify uncertainty
- Can find local optima in multimodal posteriors
- Sensitive to prior specification
Inference Recommendations
Standard Workflow
-
Start with ADVI for quick exploration:
approx = pm.fit(n=20000) -
Run MCMC for full inference:
idata = pm.sample(draws=2000, tune=1000) -
Check diagnostics:
az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars -
Sample posterior predictive:
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
Choosing Inference Method
| Scenario | Recommended Method |
|---|---|
| Small-medium models, need full uncertainty | MCMC with NUTS |
| Large models, initial exploration | ADVI |
| Discrete parameters | Metropolis or marginalize |
| Hierarchical models with divergences | Non-centered parameterization + NUTS |
| Very large data | Minibatch ADVI |
| Quick point estimates | MAP or ADVI |
Reparameterization Tricks
Non-centered parameterization for hierarchical models:
# Centered (can cause divergences):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta = pm.Normal('theta', mu, sigma, shape=n_groups)
# Non-centered (better sampling):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
theta = pm.Deterministic('theta', mu + sigma * theta_offset)
QR decomposition for correlated predictors:
import numpy as np
# QR decomposition
Q, R = np.linalg.qr(X)
with pm.Model():
# Uncorrelated coefficients
beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)
# Transform back to original scale
beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))
mu = pm.math.dot(Q, beta_tilde)
sigma = pm.HalfNormal('sigma', 1)
y = pm.Normal('y', mu, sigma, observed=y_obs)
Advanced Sampling
Sequential Monte Carlo (SMC)
For complex posteriors or model evidence estimation:
with model:
idata = pm.sample_smc(draws=2000, chains=4)
Good for multimodal posteriors or when NUTS struggles.
Custom Initialization
Provide starting values:
start = {'mu': 0, 'sigma': 1}
with model:
idata = pm.sample(start=start)
Or use MAP estimate:
with model:
start = pm.find_MAP()
idata = pm.sample(start=start)