14 KiB
14 KiB
PyMC Workflows and Common Patterns
This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC.
Standard Bayesian Workflow
Complete Workflow Template
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
# 1. PREPARE DATA
# ===============
X = ... # Predictor variables
y = ... # Observed outcomes
# Standardize predictors for better sampling
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
# 2. BUILD MODEL
# ==============
with pm.Model() as model:
# Define coordinates for named dimensions
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# 3. PRIOR PREDICTIVE CHECK
# ==========================
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100)
plt.title('Prior Predictive Check')
plt.show()
# 4. FIT MODEL
# ============
with model:
# Quick VI exploration (optional)
approx = pm.fit(n=20000, random_seed=42)
# Full MCMC inference
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # For model comparison
)
# 5. CHECK DIAGNOSTICS
# ====================
# Summary statistics
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# R-hat and ESS
summary = az.summary(idata)
if (summary['r_hat'] > 1.01).any():
print("WARNING: Some R-hat values > 1.01, chains may not have converged")
if (summary['ess_bulk'] < 400).any():
print("WARNING: Some ESS values < 400, consider more samples")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
print(f"Number of divergences: {divergences}")
# Trace plots
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'])
plt.tight_layout()
plt.show()
# 6. POSTERIOR PREDICTIVE CHECK
# ==============================
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
az.plot_ppc(idata, num_pp_samples=100)
plt.title('Posterior Predictive Check')
plt.show()
# 7. ANALYZE RESULTS
# ==================
# Posterior distributions
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
plt.tight_layout()
plt.show()
# Forest plot for coefficients
az.plot_forest(idata, var_names=['beta'], combined=True)
plt.title('Coefficient Estimates')
plt.show()
# 8. PREDICTIONS FOR NEW DATA
# ============================
X_new = ... # New predictor values
X_new_scaled = (X_new - X.mean(axis=0)) / X.std(axis=0)
with model:
# Update data
pm.set_data({'X': X_new_scaled})
# Sample predictions
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Prediction intervals
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
# 9. SAVE RESULTS
# ===============
idata.to_netcdf('model_results.nc') # Save for later
Model Building Patterns
Linear Regression
with pm.Model() as linear_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
Logistic Regression
with pm.Model() as logistic_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
# Linear predictor
logit_p = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
Hierarchical/Multilevel Model
with pm.Model(coords={'group': group_names, 'obs': np.arange(n_obs)}) as hierarchical_model:
# Hyperpriors
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=1)
# Group-level parameters (non-centered)
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='group')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='group')
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='group')
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='group')
# Observation-level model
mu = alpha[group_idx] + beta[group_idx] * X
sigma = pm.HalfNormal('sigma', sigma=1)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs, dims='obs')
Poisson Regression (Count Data)
with pm.Model() as poisson_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
# Linear predictor on log scale
log_lambda = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
Time Series (Autoregressive)
with pm.Model() as ar_model:
# Innovation standard deviation
sigma = pm.HalfNormal('sigma', sigma=1)
# AR coefficients
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
# Initial distribution
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
# AR process
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
Mixture Model
with pm.Model() as mixture_model:
# Component weights
w = pm.Dirichlet('w', a=np.ones(n_components))
# Component parameters
mu = pm.Normal('mu', mu=0, sigma=10, shape=n_components)
sigma = pm.HalfNormal('sigma', sigma=1, shape=n_components)
# Mixture
components = [pm.Normal.dist(mu=mu[i], sigma=sigma[i]) for i in range(n_components)]
y = pm.Mixture('y', w=w, comp_dists=components, observed=y_obs)
Data Preparation Best Practices
Standardization
Standardize continuous predictors for better sampling:
# Standardize
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
# Model with scaled data
with pm.Model() as model:
beta_scaled = pm.Normal('beta_scaled', 0, 1)
# ... rest of model ...
# Transform back to original scale
beta_original = beta_scaled / X_std
alpha_original = alpha - (beta_scaled * X_mean / X_std).sum()
Handling Missing Data
Treat missing values as parameters:
# Identify missing values
missing_idx = np.isnan(X)
X_observed = np.where(missing_idx, 0, X) # Placeholder
with pm.Model() as model:
# Prior for missing values
X_missing = pm.Normal('X_missing', mu=0, sigma=1, shape=missing_idx.sum())
# Combine observed and imputed
X_complete = pm.math.switch(missing_idx.flatten(), X_missing, X_observed.flatten())
# ... rest of model using X_complete ...
Centering and Scaling
For regression models, center predictors and outcome:
# Center
X_centered = X - X.mean(axis=0)
y_centered = y - y.mean()
with pm.Model() as model:
# Simpler prior on intercept
alpha = pm.Normal('alpha', mu=0, sigma=1) # Intercept near 0 when centered
beta = pm.Normal('beta', mu=0, sigma=1, shape=n_predictors)
mu = alpha + pm.math.dot(X_centered, beta)
sigma = pm.HalfNormal('sigma', sigma=1)
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_centered)
Prior Selection Guidelines
Weakly Informative Priors
Use when you have limited prior knowledge:
# For standardized predictors
beta = pm.Normal('beta', mu=0, sigma=1)
# For scale parameters
sigma = pm.HalfNormal('sigma', sigma=1)
# For probabilities
p = pm.Beta('p', alpha=2, beta=2) # Slight preference for middle values
Informative Priors
Use domain knowledge:
# Effect size from literature: Cohen's d ≈ 0.3
beta = pm.Normal('beta', mu=0.3, sigma=0.1)
# Physical constraint: probability between 0.7-0.9
p = pm.Beta('p', alpha=8, beta=2) # Check with prior predictive!
Prior Predictive Checks
Always validate priors:
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Check if predictions are reasonable
print(f"Prior predictive range: {prior_pred.prior_predictive['y'].min():.2f} to {prior_pred.prior_predictive['y'].max():.2f}")
print(f"Observed range: {y_obs.min():.2f} to {y_obs.max():.2f}")
# Visualize
az.plot_ppc(prior_pred, group='prior')
Model Comparison Workflow
Comparing Multiple Models
import arviz as az
# Fit multiple models
models = {}
idatas = {}
# Model 1: Simple linear
with pm.Model() as models['linear']:
# ... define model ...
idatas['linear'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Model 2: With interaction
with pm.Model() as models['interaction']:
# ... define model ...
idatas['interaction'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Model 3: Hierarchical
with pm.Model() as models['hierarchical']:
# ... define model ...
idatas['hierarchical'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Compare using LOO
comparison = az.compare(idatas, ic='loo')
print(comparison)
# Visualize comparison
az.plot_compare(comparison)
plt.show()
# Check LOO reliability
for name, idata in idatas.items():
loo = az.loo(idata, pointwise=True)
high_pareto_k = (loo.pareto_k > 0.7).sum().item()
if high_pareto_k > 0:
print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k")
Model Weights
# Get model weights (pseudo-BMA)
weights = comparison['weight'].values
print("Model probabilities:")
for name, weight in zip(comparison.index, weights):
print(f" {name}: {weight:.2%}")
# Model averaging (weighted predictions)
def weighted_predictions(idatas, weights):
preds = []
for (name, idata), weight in zip(idatas.items(), weights):
pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
preds.append(weight * pred)
return sum(preds)
averaged_pred = weighted_predictions(idatas, weights)
Diagnostics and Troubleshooting
Diagnosing Sampling Problems
def diagnose_sampling(idata, var_names=None):
"""Comprehensive sampling diagnostics"""
# Check convergence
summary = az.summary(idata, var_names=var_names)
print("=== Convergence Diagnostics ===")
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("✓ All R-hat values < 1.01")
# Check effective sample size
print("\n=== Effective Sample Size ===")
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"⚠️ {len(low_ess)} variables with ESS < 400")
print(low_ess[['ess_bulk', 'ess_tail']])
else:
print("✓ All ESS values > 400")
# Check divergences
print("\n=== Divergences ===")
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"⚠️ {divergences} divergent transitions")
print(" Consider: increase target_accept, reparameterize, or stronger priors")
else:
print("✓ No divergences")
# Check tree depth
print("\n=== NUTS Statistics ===")
max_treedepth = idata.sample_stats.tree_depth.max().item()
hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item()
if hits_max > 0:
print(f"⚠️ Hit max treedepth {hits_max} times")
print(" Consider: reparameterize or increase max_treedepth")
else:
print(f"✓ No max treedepth issues (max: {max_treedepth})")
return summary
# Usage
diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma'])
Common Fixes
| Problem | Solution |
|---|---|
| Divergences | Increase target_accept=0.95, use non-centered parameterization |
| Low ESS | Sample more draws, reparameterize to reduce correlation |
| High R-hat | Run longer chains, check for multimodality, improve initialization |
| Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity |
| Biased posterior | Check prior predictive, ensure likelihood is correct |
Using Named Dimensions (dims)
Benefits of dims
- More readable code
- Easier subsetting and analysis
- Better xarray integration
# Define coordinates
coords = {
'predictors': ['age', 'income', 'education'],
'groups': ['A', 'B', 'C'],
'time': pd.date_range('2020-01-01', periods=100, freq='D')
}
with pm.Model(coords=coords) as model:
# Use dims instead of shape
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups')
y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data)
# After sampling, dimensions are preserved
idata = pm.sample()
# Easy subsetting
beta_age = idata.posterior['beta'].sel(predictors='age')
group_A = idata.posterior['alpha'].sel(groups='A')
Saving and Loading Results
# Save InferenceData
idata.to_netcdf('results.nc')
# Load InferenceData
loaded_idata = az.from_netcdf('results.nc')
# Save model for later predictions
import pickle
with open('model.pkl', 'wb') as f:
pickle.dump({'model': model, 'idata': idata}, f)
# Load model
with open('model.pkl', 'rb') as f:
saved = pickle.load(f)
model = saved['model']
idata = saved['idata']