Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:30:10 +08:00
commit f0bd18fb4e
824 changed files with 331919 additions and 0 deletions

566
skills/pymc/SKILL.md Normal file
View File

@@ -0,0 +1,566 @@
---
name: pymc-bayesian-modeling
description: "Bayesian modeling with PyMC. Build hierarchical models, MCMC (NUTS), variational inference, LOO/WAIC comparison, posterior checks, for probabilistic programming and inference."
---
# PyMC Bayesian Modeling
## Overview
PyMC is a Python library for Bayesian modeling and probabilistic programming. Build, fit, validate, and compare Bayesian models using PyMC's modern API (version 5.x+), including hierarchical models, MCMC sampling (NUTS), variational inference, and model comparison (LOO, WAIC).
## When to Use This Skill
This skill should be used when:
- Building Bayesian models (linear/logistic regression, hierarchical models, time series, etc.)
- Performing MCMC sampling or variational inference
- Conducting prior/posterior predictive checks
- Diagnosing sampling issues (divergences, convergence, ESS)
- Comparing multiple models using information criteria (LOO, WAIC)
- Implementing uncertainty quantification through Bayesian methods
- Working with hierarchical/multilevel data structures
- Handling missing data or measurement error in a principled way
## Standard Bayesian Workflow
Follow this workflow for building and validating Bayesian models:
### 1. Data Preparation
```python
import pymc as pm
import arviz as az
import numpy as np
# Load and prepare data
X = ... # Predictors
y = ... # Outcomes
# Standardize predictors for better sampling
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
```
**Key practices:**
- Standardize continuous predictors (improves sampling efficiency)
- Center outcomes when possible
- Handle missing data explicitly (treat as parameters)
- Use named dimensions with `coords` for clarity
### 2. Model Building
```python
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
with pm.Model(coords=coords) as model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
```
**Key practices:**
- Use weakly informative priors (not flat priors)
- Use `HalfNormal` or `Exponential` for scale parameters
- Use named dimensions (`dims`) instead of `shape` when possible
- Use `pm.Data()` for values that will be updated for predictions
### 3. Prior Predictive Check
**Always validate priors before fitting:**
```python
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize
az.plot_ppc(prior_pred, group='prior')
```
**Check:**
- Do prior predictions span reasonable values?
- Are extreme values plausible given domain knowledge?
- If priors generate implausible data, adjust and re-check
### 4. Fit Model
```python
with model:
# Optional: Quick exploration with ADVI
# approx = pm.fit(n=20000)
# Full MCMC inference
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # For model comparison
)
```
**Key parameters:**
- `draws=2000`: Number of samples per chain
- `tune=1000`: Warmup samples (discarded)
- `chains=4`: Run 4 chains for convergence checking
- `target_accept=0.9`: Higher for difficult posteriors (0.95-0.99)
- Include `log_likelihood=True` for model comparison
### 5. Check Diagnostics
**Use the diagnostic script:**
```python
from scripts.model_diagnostics import check_diagnostics
results = check_diagnostics(idata, var_names=['alpha', 'beta', 'sigma'])
```
**Check:**
- **R-hat < 1.01**: Chains have converged
- **ESS > 400**: Sufficient effective samples
- **No divergences**: NUTS sampled successfully
- **Trace plots**: Chains should mix well (fuzzy caterpillar)
**If issues arise:**
- Divergences → Increase `target_accept=0.95`, use non-centered parameterization
- Low ESS → Sample more draws, reparameterize to reduce correlation
- High R-hat → Run longer, check for multimodality
### 6. Posterior Predictive Check
**Validate model fit:**
```python
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize
az.plot_ppc(idata)
```
**Check:**
- Do posterior predictions capture observed data patterns?
- Are systematic deviations evident (model misspecification)?
- Consider alternative models if fit is poor
### 7. Analyze Results
```python
# Summary statistics
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# Posterior distributions
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
# Coefficient estimates
az.plot_forest(idata, var_names=['beta'], combined=True)
```
### 8. Make Predictions
```python
X_new = ... # New predictor values
X_new_scaled = (X_new - X_mean) / X_std
with model:
pm.set_data({'X_scaled': X_new_scaled})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Extract prediction intervals
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
```
## Common Model Patterns
### Linear Regression
For continuous outcomes with linear relationships:
```python
with pm.Model() as linear_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
sigma = pm.HalfNormal('sigma', sigma=1)
mu = alpha + pm.math.dot(X, beta)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
```
**Use template:** `assets/linear_regression_template.py`
### Logistic Regression
For binary outcomes:
```python
with pm.Model() as logistic_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
logit_p = alpha + pm.math.dot(X, beta)
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
```
### Hierarchical Models
For grouped data (use non-centered parameterization):
```python
with pm.Model(coords={'groups': group_names}) as hierarchical_model:
# Hyperpriors
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
# Group-level (non-centered)
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
# Observation-level
mu = alpha[group_idx]
sigma = pm.HalfNormal('sigma', sigma=1)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
```
**Use template:** `assets/hierarchical_model_template.py`
**Critical:** Always use non-centered parameterization for hierarchical models to avoid divergences.
### Poisson Regression
For count data:
```python
with pm.Model() as poisson_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
log_lambda = alpha + pm.math.dot(X, beta)
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
```
For overdispersed counts, use `NegativeBinomial` instead.
### Time Series
For autoregressive processes:
```python
with pm.Model() as ar_model:
sigma = pm.HalfNormal('sigma', sigma=1)
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
```
## Model Comparison
### Comparing Models
Use LOO or WAIC for model comparison:
```python
from scripts.model_comparison import compare_models, check_loo_reliability
# Fit models with log_likelihood
models = {
'Model1': idata1,
'Model2': idata2,
'Model3': idata3
}
# Compare using LOO
comparison = compare_models(models, ic='loo')
# Check reliability
check_loo_reliability(models)
```
**Interpretation:**
- **Δloo < 2**: Models are similar, choose simpler model
- **2 < Δloo < 4**: Weak evidence for better model
- **4 < Δloo < 10**: Moderate evidence
- **Δloo > 10**: Strong evidence for better model
**Check Pareto-k values:**
- k < 0.7: LOO reliable
- k > 0.7: Consider WAIC or k-fold CV
### Model Averaging
When models are similar, average predictions:
```python
from scripts.model_comparison import model_averaging
averaged_pred, weights = model_averaging(models, var_name='y_obs')
```
## Distribution Selection Guide
### For Priors
**Scale parameters** (σ, τ):
- `pm.HalfNormal('sigma', sigma=1)` - Default choice
- `pm.Exponential('sigma', lam=1)` - Alternative
- `pm.Gamma('sigma', alpha=2, beta=1)` - More informative
**Unbounded parameters**:
- `pm.Normal('theta', mu=0, sigma=1)` - For standardized data
- `pm.StudentT('theta', nu=3, mu=0, sigma=1)` - Robust to outliers
**Positive parameters**:
- `pm.LogNormal('theta', mu=0, sigma=1)`
- `pm.Gamma('theta', alpha=2, beta=1)`
**Probabilities**:
- `pm.Beta('p', alpha=2, beta=2)` - Weakly informative
- `pm.Uniform('p', lower=0, upper=1)` - Non-informative (use sparingly)
**Correlation matrices**:
- `pm.LKJCorr('corr', n=n_vars, eta=2)` - eta=1 uniform, eta>1 prefers identity
### For Likelihoods
**Continuous outcomes**:
- `pm.Normal('y', mu=mu, sigma=sigma)` - Default for continuous data
- `pm.StudentT('y', nu=nu, mu=mu, sigma=sigma)` - Robust to outliers
**Count data**:
- `pm.Poisson('y', mu=lambda)` - Equidispersed counts
- `pm.NegativeBinomial('y', mu=mu, alpha=alpha)` - Overdispersed counts
- `pm.ZeroInflatedPoisson('y', psi=psi, mu=mu)` - Excess zeros
**Binary outcomes**:
- `pm.Bernoulli('y', p=p)` or `pm.Bernoulli('y', logit_p=logit_p)`
**Categorical outcomes**:
- `pm.Categorical('y', p=probs)`
**See:** `references/distributions.md` for comprehensive distribution reference
## Sampling and Inference
### MCMC with NUTS
Default and recommended for most models:
```python
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42
)
```
**Adjust when needed:**
- Divergences → `target_accept=0.95` or higher
- Slow sampling → Use ADVI for initialization
- Discrete parameters → Use `pm.Metropolis()` for discrete vars
### Variational Inference
Fast approximation for exploration or initialization:
```python
with model:
approx = pm.fit(n=20000, method='advi')
# Use for initialization
start = approx.sample(return_inferencedata=False)[0]
idata = pm.sample(start=start)
```
**Trade-offs:**
- Much faster than MCMC
- Approximate (may underestimate uncertainty)
- Good for large models or quick exploration
**See:** `references/sampling_inference.md` for detailed sampling guide
## Diagnostic Scripts
### Comprehensive Diagnostics
```python
from scripts.model_diagnostics import create_diagnostic_report
create_diagnostic_report(
idata,
var_names=['alpha', 'beta', 'sigma'],
output_dir='diagnostics/'
)
```
Creates:
- Trace plots
- Rank plots (mixing check)
- Autocorrelation plots
- Energy plots
- ESS evolution
- Summary statistics CSV
### Quick Diagnostic Check
```python
from scripts.model_diagnostics import check_diagnostics
results = check_diagnostics(idata)
```
Checks R-hat, ESS, divergences, and tree depth.
## Common Issues and Solutions
### Divergences
**Symptom:** `idata.sample_stats.diverging.sum() > 0`
**Solutions:**
1. Increase `target_accept=0.95` or `0.99`
2. Use non-centered parameterization (hierarchical models)
3. Add stronger priors to constrain parameters
4. Check for model misspecification
### Low Effective Sample Size
**Symptom:** `ESS < 400`
**Solutions:**
1. Sample more draws: `draws=5000`
2. Reparameterize to reduce posterior correlation
3. Use QR decomposition for regression with correlated predictors
### High R-hat
**Symptom:** `R-hat > 1.01`
**Solutions:**
1. Run longer chains: `tune=2000, draws=5000`
2. Check for multimodality
3. Improve initialization with ADVI
### Slow Sampling
**Solutions:**
1. Use ADVI initialization
2. Reduce model complexity
3. Increase parallelization: `cores=8, chains=8`
4. Use variational inference if appropriate
## Best Practices
### Model Building
1. **Always standardize predictors** for better sampling
2. **Use weakly informative priors** (not flat)
3. **Use named dimensions** (`dims`) for clarity
4. **Non-centered parameterization** for hierarchical models
5. **Check prior predictive** before fitting
### Sampling
1. **Run multiple chains** (at least 4) for convergence
2. **Use `target_accept=0.9`** as baseline (higher if needed)
3. **Include `log_likelihood=True`** for model comparison
4. **Set random seed** for reproducibility
### Validation
1. **Check diagnostics** before interpretation (R-hat, ESS, divergences)
2. **Posterior predictive check** for model validation
3. **Compare multiple models** when appropriate
4. **Report uncertainty** (HDI intervals, not just point estimates)
### Workflow
1. Start simple, add complexity gradually
2. Prior predictive check → Fit → Diagnostics → Posterior predictive check
3. Iterate on model specification based on checks
4. Document assumptions and prior choices
## Resources
This skill includes:
### References (`references/`)
- **`distributions.md`**: Comprehensive catalog of PyMC distributions organized by category (continuous, discrete, multivariate, mixture, time series). Use when selecting priors or likelihoods.
- **`sampling_inference.md`**: Detailed guide to sampling algorithms (NUTS, Metropolis, SMC), variational inference (ADVI, SVGD), and handling sampling issues. Use when encountering convergence problems or choosing inference methods.
- **`workflows.md`**: Complete workflow examples and code patterns for common model types, data preparation, prior selection, and model validation. Use as a cookbook for standard Bayesian analyses.
### Scripts (`scripts/`)
- **`model_diagnostics.py`**: Automated diagnostic checking and report generation. Functions: `check_diagnostics()` for quick checks, `create_diagnostic_report()` for comprehensive analysis with plots.
- **`model_comparison.py`**: Model comparison utilities using LOO/WAIC. Functions: `compare_models()`, `check_loo_reliability()`, `model_averaging()`.
### Templates (`assets/`)
- **`linear_regression_template.py`**: Complete template for Bayesian linear regression with full workflow (data prep, prior checks, fitting, diagnostics, predictions).
- **`hierarchical_model_template.py`**: Complete template for hierarchical/multilevel models with non-centered parameterization and group-level analysis.
## Quick Reference
### Model Building
```python
with pm.Model(coords={'var': names}) as model:
# Priors
param = pm.Normal('param', mu=0, sigma=1, dims='var')
# Likelihood
y = pm.Normal('y', mu=..., sigma=..., observed=data)
```
### Sampling
```python
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
```
### Diagnostics
```python
from scripts.model_diagnostics import check_diagnostics
check_diagnostics(idata)
```
### Model Comparison
```python
from scripts.model_comparison import compare_models
compare_models({'m1': idata1, 'm2': idata2}, ic='loo')
```
### Predictions
```python
with model:
pm.set_data({'X': X_new})
pred = pm.sample_posterior_predictive(idata.posterior)
```
## Additional Notes
- PyMC integrates with ArviZ for visualization and diagnostics
- Use `pm.model_to_graphviz(model)` to visualize model structure
- Save results with `idata.to_netcdf('results.nc')`
- Load with `az.from_netcdf('results.nc')`
- For very large models, consider minibatch ADVI or data subsampling

View File

@@ -0,0 +1,333 @@
"""
PyMC Hierarchical/Multilevel Model Template
This template provides a complete workflow for Bayesian hierarchical models,
useful for grouped/nested data (e.g., students within schools, patients within hospitals).
Customize the sections marked with # TODO
"""
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# =============================================================================
# 1. DATA PREPARATION
# =============================================================================
# TODO: Load your data with group structure
# Example:
# df = pd.read_csv('data.csv')
# groups = df['group_id'].values
# X = df['predictor'].values
# y = df['outcome'].values
# For demonstration: Generate hierarchical data
np.random.seed(42)
n_groups = 10
n_per_group = 20
n_obs = n_groups * n_per_group
# True hierarchical structure
true_mu_alpha = 5.0
true_sigma_alpha = 2.0
true_mu_beta = 1.5
true_sigma_beta = 0.5
true_sigma = 1.0
group_alphas = np.random.normal(true_mu_alpha, true_sigma_alpha, n_groups)
group_betas = np.random.normal(true_mu_beta, true_sigma_beta, n_groups)
# Generate data
groups = np.repeat(np.arange(n_groups), n_per_group)
X = np.random.randn(n_obs)
y = group_alphas[groups] + group_betas[groups] * X + np.random.randn(n_obs) * true_sigma
# TODO: Customize group names
group_names = [f'Group_{i}' for i in range(n_groups)]
# =============================================================================
# 2. BUILD HIERARCHICAL MODEL
# =============================================================================
print("Building hierarchical model...")
coords = {
'groups': group_names,
'obs': np.arange(n_obs)
}
with pm.Model(coords=coords) as hierarchical_model:
# Data containers (for later predictions)
X_data = pm.Data('X_data', X)
groups_data = pm.Data('groups_data', groups)
# Hyperpriors (population-level parameters)
# TODO: Adjust hyperpriors based on your domain knowledge
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=5)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=5)
# Group-level parameters (non-centered parameterization)
# Non-centered parameterization improves sampling efficiency
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='groups')
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='groups')
# Observation-level model
mu = alpha[groups_data] + beta[groups_data] * X_data
# Observation noise
sigma = pm.HalfNormal('sigma', sigma=5)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs')
print("Model built successfully!")
print(f"Groups: {n_groups}")
print(f"Observations: {n_obs}")
# =============================================================================
# 3. PRIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning prior predictive check...")
with hierarchical_model:
prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
# Visualize prior predictions
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
ax.set_title('Prior Predictive Check')
plt.tight_layout()
plt.savefig('hierarchical_prior_check.png', dpi=300, bbox_inches='tight')
print("Prior predictive check saved to 'hierarchical_prior_check.png'")
# =============================================================================
# 4. FIT MODEL
# =============================================================================
print("\nFitting hierarchical model...")
print("(This may take a few minutes due to model complexity)")
with hierarchical_model:
# MCMC sampling with higher target_accept for hierarchical models
idata = pm.sample(
draws=2000,
tune=2000, # More tuning for hierarchical models
chains=4,
target_accept=0.95, # Higher for better convergence
random_seed=42,
idata_kwargs={'log_likelihood': True}
)
print("Sampling complete!")
# =============================================================================
# 5. CHECK DIAGNOSTICS
# =============================================================================
print("\n" + "="*60)
print("DIAGNOSTICS")
print("="*60)
# Summary for key parameters
summary = az.summary(
idata,
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma', 'alpha', 'beta']
)
print("\nParameter Summary:")
print(summary)
# Check convergence
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("\n✓ All R-hat values < 1.01 (good convergence)")
# Check effective sample size
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
print(low_ess[['ess_bulk']].head(10))
else:
print("\n✓ All ESS values > 400 (sufficient samples)")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
print(" This is common in hierarchical models - non-centered parameterization already applied")
print(" Consider even higher target_accept or stronger hyperpriors")
else:
print("\n✓ No divergences")
# Trace plots for hyperparameters
fig, axes = plt.subplots(5, 2, figsize=(12, 12))
az.plot_trace(
idata,
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma'],
axes=axes
)
plt.tight_layout()
plt.savefig('hierarchical_trace_plots.png', dpi=300, bbox_inches='tight')
print("\nTrace plots saved to 'hierarchical_trace_plots.png'")
# =============================================================================
# 6. POSTERIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning posterior predictive check...")
with hierarchical_model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
ax.set_title('Posterior Predictive Check')
plt.tight_layout()
plt.savefig('hierarchical_posterior_check.png', dpi=300, bbox_inches='tight')
print("Posterior predictive check saved to 'hierarchical_posterior_check.png'")
# =============================================================================
# 7. ANALYZE HIERARCHICAL STRUCTURE
# =============================================================================
print("\n" + "="*60)
print("POPULATION-LEVEL (HYPERPARAMETER) ESTIMATES")
print("="*60)
# Population-level estimates
hyper_summary = summary.loc[['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma']]
print(hyper_summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']])
# Forest plot for group-level parameters
fig, axes = plt.subplots(1, 2, figsize=(14, 8))
# Group intercepts
az.plot_forest(idata, var_names=['alpha'], combined=True, ax=axes[0])
axes[0].set_title('Group-Level Intercepts (α)')
axes[0].set_yticklabels(group_names)
axes[0].axvline(idata.posterior['mu_alpha'].mean().item(), color='red', linestyle='--', label='Population mean')
axes[0].legend()
# Group slopes
az.plot_forest(idata, var_names=['beta'], combined=True, ax=axes[1])
axes[1].set_title('Group-Level Slopes (β)')
axes[1].set_yticklabels(group_names)
axes[1].axvline(idata.posterior['mu_beta'].mean().item(), color='red', linestyle='--', label='Population mean')
axes[1].legend()
plt.tight_layout()
plt.savefig('group_level_estimates.png', dpi=300, bbox_inches='tight')
print("\nGroup-level estimates saved to 'group_level_estimates.png'")
# Shrinkage visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Intercepts
alpha_samples = idata.posterior['alpha'].values.reshape(-1, n_groups)
alpha_means = alpha_samples.mean(axis=0)
mu_alpha_mean = idata.posterior['mu_alpha'].mean().item()
axes[0].scatter(range(n_groups), alpha_means, alpha=0.6)
axes[0].axhline(mu_alpha_mean, color='red', linestyle='--', label='Population mean')
axes[0].set_xlabel('Group')
axes[0].set_ylabel('Intercept')
axes[0].set_title('Group Intercepts (showing shrinkage to population mean)')
axes[0].legend()
# Slopes
beta_samples = idata.posterior['beta'].values.reshape(-1, n_groups)
beta_means = beta_samples.mean(axis=0)
mu_beta_mean = idata.posterior['mu_beta'].mean().item()
axes[1].scatter(range(n_groups), beta_means, alpha=0.6)
axes[1].axhline(mu_beta_mean, color='red', linestyle='--', label='Population mean')
axes[1].set_xlabel('Group')
axes[1].set_ylabel('Slope')
axes[1].set_title('Group Slopes (showing shrinkage to population mean)')
axes[1].legend()
plt.tight_layout()
plt.savefig('shrinkage_plot.png', dpi=300, bbox_inches='tight')
print("Shrinkage plot saved to 'shrinkage_plot.png'")
# =============================================================================
# 8. PREDICTIONS FOR NEW DATA
# =============================================================================
# TODO: Specify new data
# For existing groups:
# new_X = np.array([...])
# new_groups = np.array([0, 1, 2, ...]) # Existing group indices
# For a new group (predict using population-level parameters):
# Just use mu_alpha and mu_beta
print("\n" + "="*60)
print("PREDICTIONS FOR NEW DATA")
print("="*60)
# Example: Predict for existing groups
new_X = np.array([-2, -1, 0, 1, 2])
new_groups = np.array([0, 2, 4, 6, 8]) # Select some groups
with hierarchical_model:
pm.set_data({'X_data': new_X, 'groups_data': new_groups, 'obs': np.arange(len(new_X))})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
y_pred_samples = post_pred.posterior_predictive['y_obs']
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
print(f"Predictions for existing groups:")
print(f"{'Group':<10} {'X':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
print("-"*65)
for i, g in enumerate(new_groups):
print(f"{group_names[g]:<10} {new_X[i]:<10.2f} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
# Predict for a new group (using population parameters)
print(f"\nPrediction for a NEW group (using population-level parameters):")
new_X_newgroup = np.array([0.0])
# Manually compute using population parameters
mu_alpha_samples = idata.posterior['mu_alpha'].values.flatten()
mu_beta_samples = idata.posterior['mu_beta'].values.flatten()
sigma_samples = idata.posterior['sigma'].values.flatten()
# Predicted mean for new group
y_pred_newgroup = mu_alpha_samples + mu_beta_samples * new_X_newgroup[0]
y_pred_mean_newgroup = y_pred_newgroup.mean()
y_pred_hdi_newgroup = az.hdi(y_pred_newgroup, hdi_prob=0.95)
print(f"X = {new_X_newgroup[0]:.2f}")
print(f"Predicted mean: {y_pred_mean_newgroup:.3f}")
print(f"95% HDI: [{y_pred_hdi_newgroup[0]:.3f}, {y_pred_hdi_newgroup[1]:.3f}]")
# =============================================================================
# 9. SAVE RESULTS
# =============================================================================
idata.to_netcdf('hierarchical_model_results.nc')
print("\nResults saved to 'hierarchical_model_results.nc'")
summary.to_csv('hierarchical_model_summary.csv')
print("Summary saved to 'hierarchical_model_summary.csv'")
print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)

View File

@@ -0,0 +1,241 @@
"""
PyMC Linear Regression Template
This template provides a complete workflow for Bayesian linear regression,
including data preparation, model building, diagnostics, and predictions.
Customize the sections marked with # TODO
"""
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# =============================================================================
# 1. DATA PREPARATION
# =============================================================================
# TODO: Load your data
# Example:
# df = pd.read_csv('data.csv')
# X = df[['predictor1', 'predictor2', 'predictor3']].values
# y = df['outcome'].values
# For demonstration:
np.random.seed(42)
n_samples = 100
n_predictors = 3
X = np.random.randn(n_samples, n_predictors)
true_beta = np.array([1.5, -0.8, 2.1])
true_alpha = 0.5
y = true_alpha + X @ true_beta + np.random.randn(n_samples) * 0.5
# Standardize predictors for better sampling
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
# =============================================================================
# 2. BUILD MODEL
# =============================================================================
# TODO: Customize predictor names
predictor_names = ['predictor1', 'predictor2', 'predictor3']
coords = {
'predictors': predictor_names,
'obs_id': np.arange(len(y))
}
with pm.Model(coords=coords) as linear_model:
# Priors
# TODO: Adjust prior parameters based on your domain knowledge
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# =============================================================================
# 3. PRIOR PREDICTIVE CHECK
# =============================================================================
print("Running prior predictive check...")
with linear_model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize prior predictions
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
ax.set_title('Prior Predictive Check')
plt.tight_layout()
plt.savefig('prior_predictive_check.png', dpi=300, bbox_inches='tight')
print("Prior predictive check saved to 'prior_predictive_check.png'")
# =============================================================================
# 4. FIT MODEL
# =============================================================================
print("\nFitting model...")
with linear_model:
# Optional: Quick ADVI exploration
# approx = pm.fit(n=20000, random_seed=42)
# MCMC sampling
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True}
)
print("Sampling complete!")
# =============================================================================
# 5. CHECK DIAGNOSTICS
# =============================================================================
print("\n" + "="*60)
print("DIAGNOSTICS")
print("="*60)
# Summary statistics
summary = az.summary(idata, var_names=['alpha', 'beta', 'sigma'])
print("\nParameter Summary:")
print(summary)
# Check convergence
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("\n✓ All R-hat values < 1.01 (good convergence)")
# Check effective sample size
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
print(low_ess[['ess_bulk', 'ess_tail']])
else:
print("\n✓ All ESS values > 400 (sufficient samples)")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
print(" Consider increasing target_accept or reparameterizing")
else:
print("\n✓ No divergences")
# Trace plots
fig, axes = plt.subplots(len(['alpha', 'beta', 'sigma']), 2, figsize=(12, 8))
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'], axes=axes)
plt.tight_layout()
plt.savefig('trace_plots.png', dpi=300, bbox_inches='tight')
print("\nTrace plots saved to 'trace_plots.png'")
# =============================================================================
# 6. POSTERIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning posterior predictive check...")
with linear_model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
ax.set_title('Posterior Predictive Check')
plt.tight_layout()
plt.savefig('posterior_predictive_check.png', dpi=300, bbox_inches='tight')
print("Posterior predictive check saved to 'posterior_predictive_check.png'")
# =============================================================================
# 7. ANALYZE RESULTS
# =============================================================================
# Posterior distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'], ax=axes)
plt.tight_layout()
plt.savefig('posterior_distributions.png', dpi=300, bbox_inches='tight')
print("Posterior distributions saved to 'posterior_distributions.png'")
# Forest plot for coefficients
fig, ax = plt.subplots(figsize=(8, 6))
az.plot_forest(idata, var_names=['beta'], combined=True, ax=ax)
ax.set_title('Coefficient Estimates (95% HDI)')
ax.set_yticklabels(predictor_names)
plt.tight_layout()
plt.savefig('coefficient_forest_plot.png', dpi=300, bbox_inches='tight')
print("Forest plot saved to 'coefficient_forest_plot.png'")
# Print coefficient estimates
print("\n" + "="*60)
print("COEFFICIENT ESTIMATES")
print("="*60)
beta_samples = idata.posterior['beta']
for i, name in enumerate(predictor_names):
mean = beta_samples.sel(predictors=name).mean().item()
hdi = az.hdi(beta_samples.sel(predictors=name), hdi_prob=0.95)
print(f"{name:20s}: {mean:7.3f} [95% HDI: {hdi.values[0]:7.3f}, {hdi.values[1]:7.3f}]")
# =============================================================================
# 8. PREDICTIONS FOR NEW DATA
# =============================================================================
# TODO: Provide new data for predictions
# X_new = np.array([[...], [...], ...]) # New predictor values
# For demonstration, use some test data
X_new = np.random.randn(10, n_predictors)
X_new_scaled = (X_new - X_mean) / X_std
# Update model data and predict
with linear_model:
pm.set_data({'X_scaled': X_new_scaled, 'obs_id': np.arange(len(X_new))})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Extract predictions
y_pred_samples = post_pred.posterior_predictive['y_obs']
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
print("\n" + "="*60)
print("PREDICTIONS FOR NEW DATA")
print("="*60)
print(f"{'Index':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
print("-"*60)
for i in range(len(X_new)):
print(f"{i:<10} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
# =============================================================================
# 9. SAVE RESULTS
# =============================================================================
# Save InferenceData
idata.to_netcdf('linear_regression_results.nc')
print("\nResults saved to 'linear_regression_results.nc'")
# Save summary to CSV
summary.to_csv('model_summary.csv')
print("Summary saved to 'model_summary.csv'")
print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)

View File

@@ -0,0 +1,320 @@
# PyMC Distributions Reference
This reference provides a comprehensive catalog of probability distributions available in PyMC, organized by category. Use this to select appropriate distributions for priors and likelihoods when building Bayesian models.
## Continuous Distributions
Continuous distributions define probability densities over real-valued domains.
### Common Continuous Distributions
**`pm.Normal(name, mu, sigma)`**
- Normal (Gaussian) distribution
- Parameters: `mu` (mean), `sigma` (standard deviation)
- Support: (-∞, ∞)
- Common uses: Default prior for unbounded parameters, likelihood for continuous data with additive noise
**`pm.HalfNormal(name, sigma)`**
- Half-normal distribution (positive half of normal)
- Parameters: `sigma` (standard deviation)
- Support: [0, ∞)
- Common uses: Prior for scale/standard deviation parameters
**`pm.Uniform(name, lower, upper)`**
- Uniform distribution
- Parameters: `lower`, `upper` (bounds)
- Support: [lower, upper]
- Common uses: Weakly informative prior when parameter must be bounded
**`pm.Beta(name, alpha, beta)`**
- Beta distribution
- Parameters: `alpha`, `beta` (shape parameters)
- Support: [0, 1]
- Common uses: Prior for probabilities and proportions
**`pm.Gamma(name, alpha, beta)`**
- Gamma distribution
- Parameters: `alpha` (shape), `beta` (rate)
- Support: (0, ∞)
- Common uses: Prior for positive parameters, rate parameters
**`pm.Exponential(name, lam)`**
- Exponential distribution
- Parameters: `lam` (rate parameter)
- Support: [0, ∞)
- Common uses: Prior for scale parameters, waiting times
**`pm.LogNormal(name, mu, sigma)`**
- Log-normal distribution
- Parameters: `mu`, `sigma` (parameters of underlying normal)
- Support: (0, ∞)
- Common uses: Prior for positive parameters with multiplicative effects
**`pm.StudentT(name, nu, mu, sigma)`**
- Student's t-distribution
- Parameters: `nu` (degrees of freedom), `mu` (location), `sigma` (scale)
- Support: (-∞, ∞)
- Common uses: Robust alternative to normal for outlier-resistant models
**`pm.Cauchy(name, alpha, beta)`**
- Cauchy distribution
- Parameters: `alpha` (location), `beta` (scale)
- Support: (-∞, ∞)
- Common uses: Heavy-tailed alternative to normal
### Specialized Continuous Distributions
**`pm.Laplace(name, mu, b)`** - Laplace (double exponential) distribution
**`pm.AsymmetricLaplace(name, kappa, mu, b)`** - Asymmetric Laplace distribution
**`pm.InverseGamma(name, alpha, beta)`** - Inverse gamma distribution
**`pm.Weibull(name, alpha, beta)`** - Weibull distribution for reliability analysis
**`pm.Logistic(name, mu, s)`** - Logistic distribution
**`pm.LogitNormal(name, mu, sigma)`** - Logit-normal distribution for (0,1) support
**`pm.Pareto(name, alpha, m)`** - Pareto distribution for power-law phenomena
**`pm.ChiSquared(name, nu)`** - Chi-squared distribution
**`pm.ExGaussian(name, mu, sigma, nu)`** - Exponentially modified Gaussian
**`pm.VonMises(name, mu, kappa)`** - Von Mises (circular normal) distribution
**`pm.SkewNormal(name, mu, sigma, alpha)`** - Skew-normal distribution
**`pm.Triangular(name, lower, c, upper)`** - Triangular distribution
**`pm.Gumbel(name, mu, beta)`** - Gumbel distribution for extreme values
**`pm.Rice(name, nu, sigma)`** - Rice (Rician) distribution
**`pm.Moyal(name, mu, sigma)`** - Moyal distribution
**`pm.Kumaraswamy(name, a, b)`** - Kumaraswamy distribution (Beta alternative)
**`pm.Interpolated(name, x_points, pdf_points)`** - Custom distribution from interpolation
## Discrete Distributions
Discrete distributions define probabilities over integer-valued domains.
### Common Discrete Distributions
**`pm.Bernoulli(name, p)`**
- Bernoulli distribution (binary outcome)
- Parameters: `p` (success probability)
- Support: {0, 1}
- Common uses: Binary classification, coin flips
**`pm.Binomial(name, n, p)`**
- Binomial distribution
- Parameters: `n` (number of trials), `p` (success probability)
- Support: {0, 1, ..., n}
- Common uses: Number of successes in fixed trials
**`pm.Poisson(name, mu)`**
- Poisson distribution
- Parameters: `mu` (rate parameter)
- Support: {0, 1, 2, ...}
- Common uses: Count data, rates, occurrences
**`pm.Categorical(name, p)`**
- Categorical distribution
- Parameters: `p` (probability vector)
- Support: {0, 1, ..., K-1}
- Common uses: Multi-class classification
**`pm.DiscreteUniform(name, lower, upper)`**
- Discrete uniform distribution
- Parameters: `lower`, `upper` (bounds)
- Support: {lower, ..., upper}
- Common uses: Uniform prior over finite integers
**`pm.NegativeBinomial(name, mu, alpha)`**
- Negative binomial distribution
- Parameters: `mu` (mean), `alpha` (dispersion)
- Support: {0, 1, 2, ...}
- Common uses: Overdispersed count data
**`pm.Geometric(name, p)`**
- Geometric distribution
- Parameters: `p` (success probability)
- Support: {0, 1, 2, ...}
- Common uses: Number of failures before first success
### Specialized Discrete Distributions
**`pm.BetaBinomial(name, alpha, beta, n)`** - Beta-binomial (overdispersed binomial)
**`pm.HyperGeometric(name, N, k, n)`** - Hypergeometric distribution
**`pm.DiscreteWeibull(name, q, beta)`** - Discrete Weibull distribution
**`pm.OrderedLogistic(name, eta, cutpoints)`** - Ordered logistic for ordinal data
**`pm.OrderedProbit(name, eta, cutpoints)`** - Ordered probit for ordinal data
## Multivariate Distributions
Multivariate distributions define joint probability distributions over vector-valued random variables.
### Common Multivariate Distributions
**`pm.MvNormal(name, mu, cov)`**
- Multivariate normal distribution
- Parameters: `mu` (mean vector), `cov` (covariance matrix)
- Common uses: Correlated continuous variables, Gaussian processes
**`pm.Dirichlet(name, a)`**
- Dirichlet distribution
- Parameters: `a` (concentration parameters)
- Support: Simplex (sums to 1)
- Common uses: Prior for probability vectors, topic modeling
**`pm.Multinomial(name, n, p)`**
- Multinomial distribution
- Parameters: `n` (number of trials), `p` (probability vector)
- Common uses: Count data across multiple categories
**`pm.MvStudentT(name, nu, mu, cov)`**
- Multivariate Student's t-distribution
- Parameters: `nu` (degrees of freedom), `mu` (location), `cov` (scale matrix)
- Common uses: Robust multivariate modeling
### Specialized Multivariate Distributions
**`pm.LKJCorr(name, n, eta)`** - LKJ correlation matrix prior (for correlation matrices)
**`pm.LKJCholeskyCov(name, n, eta, sd_dist)`** - LKJ prior with Cholesky decomposition
**`pm.Wishart(name, nu, V)`** - Wishart distribution (for covariance matrices)
**`pm.InverseWishart(name, nu, V)`** - Inverse Wishart distribution
**`pm.MatrixNormal(name, mu, rowcov, colcov)`** - Matrix normal distribution
**`pm.KroneckerNormal(name, mu, covs, sigma)`** - Kronecker-structured normal
**`pm.CAR(name, mu, W, alpha, tau)`** - Conditional autoregressive (spatial)
**`pm.ICAR(name, W, sigma)`** - Intrinsic conditional autoregressive (spatial)
## Mixture Distributions
Mixture distributions combine multiple component distributions.
**`pm.Mixture(name, w, comp_dists)`**
- General mixture distribution
- Parameters: `w` (weights), `comp_dists` (component distributions)
- Common uses: Clustering, multi-modal data
**`pm.NormalMixture(name, w, mu, sigma)`**
- Mixture of normal distributions
- Common uses: Mixture of Gaussians clustering
### Zero-Inflated and Hurdle Models
**`pm.ZeroInflatedPoisson(name, psi, mu)`** - Excess zeros in count data
**`pm.ZeroInflatedBinomial(name, psi, n, p)`** - Zero-inflated binomial
**`pm.ZeroInflatedNegativeBinomial(name, psi, mu, alpha)`** - Zero-inflated negative binomial
**`pm.HurdlePoisson(name, psi, mu)`** - Hurdle Poisson (two-part model)
**`pm.HurdleGamma(name, psi, alpha, beta)`** - Hurdle gamma
**`pm.HurdleLogNormal(name, psi, mu, sigma)`** - Hurdle log-normal
## Time Series Distributions
Distributions designed for temporal data and sequential modeling.
**`pm.AR(name, rho, sigma, init_dist)`**
- Autoregressive process
- Parameters: `rho` (AR coefficients), `sigma` (innovation std), `init_dist` (initial distribution)
- Common uses: Time series modeling, sequential data
**`pm.GaussianRandomWalk(name, mu, sigma, init_dist)`**
- Gaussian random walk
- Parameters: `mu` (drift), `sigma` (step size), `init_dist` (initial value)
- Common uses: Cumulative processes, random walk priors
**`pm.MvGaussianRandomWalk(name, mu, cov, init_dist)`**
- Multivariate Gaussian random walk
**`pm.GARCH11(name, omega, alpha_1, beta_1)`**
- GARCH(1,1) volatility model
- Common uses: Financial time series, volatility modeling
**`pm.EulerMaruyama(name, dt, sde_fn, sde_pars, init_dist)`**
- Stochastic differential equation via Euler-Maruyama discretization
- Common uses: Continuous-time processes
## Special Distributions
**`pm.Deterministic(name, var)`**
- Deterministic transformation (not a random variable)
- Use for computed quantities derived from other variables
**`pm.Potential(name, logp)`**
- Add arbitrary log-probability contribution
- Use for custom likelihood components or constraints
**`pm.Flat(name)`**
- Improper flat prior (constant density)
- Use sparingly; can cause sampling issues
**`pm.HalfFlat(name)`**
- Improper flat prior on positive reals
- Use sparingly; can cause sampling issues
## Distribution Modifiers
**`pm.Truncated(name, dist, lower, upper)`**
- Truncate any distribution to specified bounds
**`pm.Censored(name, dist, lower, upper)`**
- Handle censored observations (observed bounds, not exact values)
**`pm.CustomDist(name, ..., logp, random)`**
- Define custom distributions with user-specified log-probability and random sampling functions
**`pm.Simulator(name, fn, params, ...)`**
- Custom distributions via simulation (for likelihood-free inference)
## Usage Tips
### Choosing Priors
1. **Scale parameters** (σ, τ): Use `HalfNormal`, `HalfCauchy`, `Exponential`, or `Gamma`
2. **Probabilities**: Use `Beta` or `Uniform(0, 1)`
3. **Unbounded parameters**: Use `Normal` or `StudentT` (for robustness)
4. **Positive parameters**: Use `LogNormal`, `Gamma`, or `Exponential`
5. **Correlation matrices**: Use `LKJCorr`
6. **Count data**: Use `Poisson` or `NegativeBinomial` (for overdispersion)
### Shape Broadcasting
PyMC distributions support NumPy-style broadcasting. Use the `shape` parameter to create vectors or arrays of random variables:
```python
# Vector of 5 independent normals
beta = pm.Normal('beta', mu=0, sigma=1, shape=5)
# 3x4 matrix of independent gammas
tau = pm.Gamma('tau', alpha=2, beta=1, shape=(3, 4))
```
### Using dims for Named Dimensions
Instead of shape, use `dims` for more readable models:
```python
with pm.Model(coords={'predictors': ['age', 'income', 'education']}) as model:
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
```

View File

@@ -0,0 +1,424 @@
# PyMC Sampling and Inference Methods
This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.
## MCMC Sampling Methods
### Primary Sampling Function
**`pm.sample(draws=1000, tune=1000, chains=4, **kwargs)`**
The main interface for MCMC sampling in PyMC.
**Key Parameters:**
- `draws`: Number of samples to draw per chain (default: 1000)
- `tune`: Number of tuning/warmup samples (default: 1000, discarded)
- `chains`: Number of parallel chains (default: 4)
- `cores`: Number of CPU cores to use (default: all available)
- `target_accept`: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)
- `random_seed`: Random seed for reproducibility
- `return_inferencedata`: Return ArviZ InferenceData object (default: True)
- `idata_kwargs`: Additional kwargs for InferenceData creation (e.g., `{"log_likelihood": True}` for model comparison)
**Returns:** InferenceData object containing posterior samples, sampling statistics, and diagnostics
**Example:**
```python
with pm.Model() as model:
# ... define model ...
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
```
### Sampling Algorithms
PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.
#### NUTS (No-U-Turn Sampler)
**Default algorithm** for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.
- Automatically tunes step size and mass matrix
- Adaptive: explores posterior geometry during tuning
- Best for smooth, continuous posteriors
- Can struggle with high correlation or multimodality
**Manual specification:**
```python
with model:
idata = pm.sample(step=pm.NUTS(target_accept=0.95))
```
**When to adjust:**
- Increase `target_accept` (0.9-0.99) if seeing divergences
- Use `init='adapt_diag'` for faster initialization (default)
- Use `init='jitter+adapt_diag'` for difficult initializations
#### Metropolis
General-purpose Metropolis-Hastings sampler.
- Works for both continuous and discrete variables
- Less efficient than NUTS for smooth continuous posteriors
- Useful for discrete parameters or non-differentiable models
- Requires manual tuning
**Example:**
```python
with model:
idata = pm.sample(step=pm.Metropolis())
```
#### Slice Sampler
Slice sampling for univariate distributions.
- No tuning required
- Good for difficult univariate posteriors
- Can be slow for high dimensions
**Example:**
```python
with model:
idata = pm.sample(step=pm.Slice())
```
#### CompoundStep
Combine different samplers for different parameters.
**Example:**
```python
with model:
# Use NUTS for continuous params, Metropolis for discrete
step1 = pm.NUTS([continuous_var1, continuous_var2])
step2 = pm.Metropolis([discrete_var])
idata = pm.sample(step=[step1, step2])
```
### Sampling Diagnostics
PyMC automatically computes diagnostics. Check these before trusting results:
#### Effective Sample Size (ESS)
Measures independent information in correlated samples.
- **Rule of thumb**: ESS > 400 per chain (1600 total for 4 chains)
- Low ESS indicates high autocorrelation
- Access via: `az.ess(idata)`
#### R-hat (Gelman-Rubin statistic)
Measures convergence across chains.
- **Rule of thumb**: R-hat < 1.01 for all parameters
- R-hat > 1.01 indicates non-convergence
- Access via: `az.rhat(idata)`
#### Divergences
Indicate regions where NUTS struggled.
- **Rule of thumb**: 0 divergences (or very few)
- Divergences suggest biased samples
- **Fix**: Increase `target_accept`, reparameterize, or use stronger priors
- Access via: `idata.sample_stats.diverging.sum()`
#### Energy Plot
Visualizes Hamiltonian Monte Carlo energy transitions.
```python
az.plot_energy(idata)
```
Good separation between energy distributions indicates healthy sampling.
### Handling Sampling Issues
#### Divergences
```python
# Increase target acceptance rate
idata = pm.sample(target_accept=0.95)
# Or reparameterize using non-centered parameterization
# Bad (centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x = pm.Normal('x', mu, sigma, observed=data)
# Good (non-centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)
```
#### Slow Sampling
```python
# Use fewer tuning steps if model is simple
idata = pm.sample(tune=500)
# Increase cores for parallelization
idata = pm.sample(cores=8, chains=8)
# Use variational inference for initialization
with model:
approx = pm.fit() # Run ADVI
idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])
```
#### High Autocorrelation
```python
# Increase draws
idata = pm.sample(draws=5000)
# Reparameterize to reduce correlation
# Consider using QR decomposition for regression models
```
## Variational Inference
Faster approximate inference for large models or quick exploration.
### ADVI (Automatic Differentiation Variational Inference)
**`pm.fit(n=10000, method='advi', **kwargs)`**
Approximates posterior with simpler distribution (typically mean-field Gaussian).
**Key Parameters:**
- `n`: Number of iterations (default: 10000)
- `method`: VI algorithm ('advi', 'fullrank_advi', 'svgd')
- `random_seed`: Random seed
**Returns:** Approximation object for sampling and analysis
**Example:**
```python
with model:
approx = pm.fit(n=50000)
# Draw samples from approximation
idata = approx.sample(1000)
# Or sample for MCMC initialization
start = approx.sample(return_inferencedata=False)[0]
```
**Trade-offs:**
- **Pros**: Much faster than MCMC, scales to large data
- **Cons**: Approximate, may miss posterior structure, underestimates uncertainty
### Full-Rank ADVI
Captures correlations between parameters.
```python
with model:
approx = pm.fit(method='fullrank_advi')
```
More accurate than mean-field but slower.
### SVGD (Stein Variational Gradient Descent)
Non-parametric variational inference.
```python
with model:
approx = pm.fit(method='svgd', n=20000)
```
Better captures multimodality but more computationally expensive.
## Prior and Posterior Predictive Sampling
### Prior Predictive Sampling
Sample from the prior distribution (before seeing data).
**`pm.sample_prior_predictive(samples=500, **kwargs)`**
**Purpose:**
- Validate priors are reasonable
- Check implied predictions before fitting
- Ensure model generates plausible data
**Example:**
```python
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior')
```
### Posterior Predictive Sampling
Sample from posterior predictive distribution (after fitting).
**`pm.sample_posterior_predictive(trace, **kwargs)`**
**Purpose:**
- Model validation via posterior predictive checks
- Generate predictions for new data
- Assess goodness-of-fit
**Example:**
```python
with model:
# After sampling
idata = pm.sample()
# Add posterior predictive samples
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
# Posterior predictive check
az.plot_ppc(idata)
```
### Predictions for New Data
Update data and sample predictive distribution:
```python
with model:
# Original model fit
idata = pm.sample()
# Update with new predictor values
pm.set_data({'X': X_new})
# Sample predictions
post_pred_new = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_pred']
)
```
## Maximum A Posteriori (MAP) Estimation
Find posterior mode (point estimate).
**`pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)`**
**When to use:**
- Quick point estimates
- Initialization for MCMC
- When full posterior not needed
**Example:**
```python
with model:
map_estimate = pm.find_MAP()
print(map_estimate)
```
**Limitations:**
- Doesn't quantify uncertainty
- Can find local optima in multimodal posteriors
- Sensitive to prior specification
## Inference Recommendations
### Standard Workflow
1. **Start with ADVI** for quick exploration:
```python
approx = pm.fit(n=20000)
```
2. **Run MCMC** for full inference:
```python
idata = pm.sample(draws=2000, tune=1000)
```
3. **Check diagnostics**:
```python
az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars
```
4. **Sample posterior predictive**:
```python
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
```
### Choosing Inference Method
| Scenario | Recommended Method |
|----------|-------------------|
| Small-medium models, need full uncertainty | MCMC with NUTS |
| Large models, initial exploration | ADVI |
| Discrete parameters | Metropolis or marginalize |
| Hierarchical models with divergences | Non-centered parameterization + NUTS |
| Very large data | Minibatch ADVI |
| Quick point estimates | MAP or ADVI |
### Reparameterization Tricks
**Non-centered parameterization** for hierarchical models:
```python
# Centered (can cause divergences):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta = pm.Normal('theta', mu, sigma, shape=n_groups)
# Non-centered (better sampling):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
theta = pm.Deterministic('theta', mu + sigma * theta_offset)
```
**QR decomposition** for correlated predictors:
```python
import numpy as np
# QR decomposition
Q, R = np.linalg.qr(X)
with pm.Model():
# Uncorrelated coefficients
beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)
# Transform back to original scale
beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))
mu = pm.math.dot(Q, beta_tilde)
sigma = pm.HalfNormal('sigma', 1)
y = pm.Normal('y', mu, sigma, observed=y_obs)
```
## Advanced Sampling
### Sequential Monte Carlo (SMC)
For complex posteriors or model evidence estimation:
```python
with model:
idata = pm.sample_smc(draws=2000, chains=4)
```
Good for multimodal posteriors or when NUTS struggles.
### Custom Initialization
Provide starting values:
```python
start = {'mu': 0, 'sigma': 1}
with model:
idata = pm.sample(start=start)
```
Or use MAP estimate:
```python
with model:
start = pm.find_MAP()
idata = pm.sample(start=start)
```

View File

@@ -0,0 +1,526 @@
# PyMC Workflows and Common Patterns
This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC.
## Standard Bayesian Workflow
### Complete Workflow Template
```python
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
# 1. PREPARE DATA
# ===============
X = ... # Predictor variables
y = ... # Observed outcomes
# Standardize predictors for better sampling
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
# 2. BUILD MODEL
# ==============
with pm.Model() as model:
# Define coordinates for named dimensions
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# 3. PRIOR PREDICTIVE CHECK
# ==========================
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100)
plt.title('Prior Predictive Check')
plt.show()
# 4. FIT MODEL
# ============
with model:
# Quick VI exploration (optional)
approx = pm.fit(n=20000, random_seed=42)
# Full MCMC inference
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # For model comparison
)
# 5. CHECK DIAGNOSTICS
# ====================
# Summary statistics
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# R-hat and ESS
summary = az.summary(idata)
if (summary['r_hat'] > 1.01).any():
print("WARNING: Some R-hat values > 1.01, chains may not have converged")
if (summary['ess_bulk'] < 400).any():
print("WARNING: Some ESS values < 400, consider more samples")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
print(f"Number of divergences: {divergences}")
# Trace plots
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'])
plt.tight_layout()
plt.show()
# 6. POSTERIOR PREDICTIVE CHECK
# ==============================
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
az.plot_ppc(idata, num_pp_samples=100)
plt.title('Posterior Predictive Check')
plt.show()
# 7. ANALYZE RESULTS
# ==================
# Posterior distributions
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
plt.tight_layout()
plt.show()
# Forest plot for coefficients
az.plot_forest(idata, var_names=['beta'], combined=True)
plt.title('Coefficient Estimates')
plt.show()
# 8. PREDICTIONS FOR NEW DATA
# ============================
X_new = ... # New predictor values
X_new_scaled = (X_new - X.mean(axis=0)) / X.std(axis=0)
with model:
# Update data
pm.set_data({'X': X_new_scaled})
# Sample predictions
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Prediction intervals
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
# 9. SAVE RESULTS
# ===============
idata.to_netcdf('model_results.nc') # Save for later
```
## Model Building Patterns
### Linear Regression
```python
with pm.Model() as linear_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
```
### Logistic Regression
```python
with pm.Model() as logistic_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
# Linear predictor
logit_p = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
```
### Hierarchical/Multilevel Model
```python
with pm.Model(coords={'group': group_names, 'obs': np.arange(n_obs)}) as hierarchical_model:
# Hyperpriors
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=1)
# Group-level parameters (non-centered)
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='group')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='group')
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='group')
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='group')
# Observation-level model
mu = alpha[group_idx] + beta[group_idx] * X
sigma = pm.HalfNormal('sigma', sigma=1)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs, dims='obs')
```
### Poisson Regression (Count Data)
```python
with pm.Model() as poisson_model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
# Linear predictor on log scale
log_lambda = alpha + pm.math.dot(X, beta)
# Likelihood
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
```
### Time Series (Autoregressive)
```python
with pm.Model() as ar_model:
# Innovation standard deviation
sigma = pm.HalfNormal('sigma', sigma=1)
# AR coefficients
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
# Initial distribution
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
# AR process
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
```
### Mixture Model
```python
with pm.Model() as mixture_model:
# Component weights
w = pm.Dirichlet('w', a=np.ones(n_components))
# Component parameters
mu = pm.Normal('mu', mu=0, sigma=10, shape=n_components)
sigma = pm.HalfNormal('sigma', sigma=1, shape=n_components)
# Mixture
components = [pm.Normal.dist(mu=mu[i], sigma=sigma[i]) for i in range(n_components)]
y = pm.Mixture('y', w=w, comp_dists=components, observed=y_obs)
```
## Data Preparation Best Practices
### Standardization
Standardize continuous predictors for better sampling:
```python
# Standardize
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
# Model with scaled data
with pm.Model() as model:
beta_scaled = pm.Normal('beta_scaled', 0, 1)
# ... rest of model ...
# Transform back to original scale
beta_original = beta_scaled / X_std
alpha_original = alpha - (beta_scaled * X_mean / X_std).sum()
```
### Handling Missing Data
Treat missing values as parameters:
```python
# Identify missing values
missing_idx = np.isnan(X)
X_observed = np.where(missing_idx, 0, X) # Placeholder
with pm.Model() as model:
# Prior for missing values
X_missing = pm.Normal('X_missing', mu=0, sigma=1, shape=missing_idx.sum())
# Combine observed and imputed
X_complete = pm.math.switch(missing_idx.flatten(), X_missing, X_observed.flatten())
# ... rest of model using X_complete ...
```
### Centering and Scaling
For regression models, center predictors and outcome:
```python
# Center
X_centered = X - X.mean(axis=0)
y_centered = y - y.mean()
with pm.Model() as model:
# Simpler prior on intercept
alpha = pm.Normal('alpha', mu=0, sigma=1) # Intercept near 0 when centered
beta = pm.Normal('beta', mu=0, sigma=1, shape=n_predictors)
mu = alpha + pm.math.dot(X_centered, beta)
sigma = pm.HalfNormal('sigma', sigma=1)
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_centered)
```
## Prior Selection Guidelines
### Weakly Informative Priors
Use when you have limited prior knowledge:
```python
# For standardized predictors
beta = pm.Normal('beta', mu=0, sigma=1)
# For scale parameters
sigma = pm.HalfNormal('sigma', sigma=1)
# For probabilities
p = pm.Beta('p', alpha=2, beta=2) # Slight preference for middle values
```
### Informative Priors
Use domain knowledge:
```python
# Effect size from literature: Cohen's d ≈ 0.3
beta = pm.Normal('beta', mu=0.3, sigma=0.1)
# Physical constraint: probability between 0.7-0.9
p = pm.Beta('p', alpha=8, beta=2) # Check with prior predictive!
```
### Prior Predictive Checks
Always validate priors:
```python
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Check if predictions are reasonable
print(f"Prior predictive range: {prior_pred.prior_predictive['y'].min():.2f} to {prior_pred.prior_predictive['y'].max():.2f}")
print(f"Observed range: {y_obs.min():.2f} to {y_obs.max():.2f}")
# Visualize
az.plot_ppc(prior_pred, group='prior')
```
## Model Comparison Workflow
### Comparing Multiple Models
```python
import arviz as az
# Fit multiple models
models = {}
idatas = {}
# Model 1: Simple linear
with pm.Model() as models['linear']:
# ... define model ...
idatas['linear'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Model 2: With interaction
with pm.Model() as models['interaction']:
# ... define model ...
idatas['interaction'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Model 3: Hierarchical
with pm.Model() as models['hierarchical']:
# ... define model ...
idatas['hierarchical'] = pm.sample(idata_kwargs={'log_likelihood': True})
# Compare using LOO
comparison = az.compare(idatas, ic='loo')
print(comparison)
# Visualize comparison
az.plot_compare(comparison)
plt.show()
# Check LOO reliability
for name, idata in idatas.items():
loo = az.loo(idata, pointwise=True)
high_pareto_k = (loo.pareto_k > 0.7).sum().item()
if high_pareto_k > 0:
print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k")
```
### Model Weights
```python
# Get model weights (pseudo-BMA)
weights = comparison['weight'].values
print("Model probabilities:")
for name, weight in zip(comparison.index, weights):
print(f" {name}: {weight:.2%}")
# Model averaging (weighted predictions)
def weighted_predictions(idatas, weights):
preds = []
for (name, idata), weight in zip(idatas.items(), weights):
pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
preds.append(weight * pred)
return sum(preds)
averaged_pred = weighted_predictions(idatas, weights)
```
## Diagnostics and Troubleshooting
### Diagnosing Sampling Problems
```python
def diagnose_sampling(idata, var_names=None):
"""Comprehensive sampling diagnostics"""
# Check convergence
summary = az.summary(idata, var_names=var_names)
print("=== Convergence Diagnostics ===")
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("✓ All R-hat values < 1.01")
# Check effective sample size
print("\n=== Effective Sample Size ===")
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"⚠️ {len(low_ess)} variables with ESS < 400")
print(low_ess[['ess_bulk', 'ess_tail']])
else:
print("✓ All ESS values > 400")
# Check divergences
print("\n=== Divergences ===")
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"⚠️ {divergences} divergent transitions")
print(" Consider: increase target_accept, reparameterize, or stronger priors")
else:
print("✓ No divergences")
# Check tree depth
print("\n=== NUTS Statistics ===")
max_treedepth = idata.sample_stats.tree_depth.max().item()
hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item()
if hits_max > 0:
print(f"⚠️ Hit max treedepth {hits_max} times")
print(" Consider: reparameterize or increase max_treedepth")
else:
print(f"✓ No max treedepth issues (max: {max_treedepth})")
return summary
# Usage
diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma'])
```
### Common Fixes
| Problem | Solution |
|---------|----------|
| Divergences | Increase `target_accept=0.95`, use non-centered parameterization |
| Low ESS | Sample more draws, reparameterize to reduce correlation |
| High R-hat | Run longer chains, check for multimodality, improve initialization |
| Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity |
| Biased posterior | Check prior predictive, ensure likelihood is correct |
## Using Named Dimensions (dims)
### Benefits of dims
- More readable code
- Easier subsetting and analysis
- Better xarray integration
```python
# Define coordinates
coords = {
'predictors': ['age', 'income', 'education'],
'groups': ['A', 'B', 'C'],
'time': pd.date_range('2020-01-01', periods=100, freq='D')
}
with pm.Model(coords=coords) as model:
# Use dims instead of shape
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups')
y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data)
# After sampling, dimensions are preserved
idata = pm.sample()
# Easy subsetting
beta_age = idata.posterior['beta'].sel(predictors='age')
group_A = idata.posterior['alpha'].sel(groups='A')
```
## Saving and Loading Results
```python
# Save InferenceData
idata.to_netcdf('results.nc')
# Load InferenceData
loaded_idata = az.from_netcdf('results.nc')
# Save model for later predictions
import pickle
with open('model.pkl', 'wb') as f:
pickle.dump({'model': model, 'idata': idata}, f)
# Load model
with open('model.pkl', 'rb') as f:
saved = pickle.load(f)
model = saved['model']
idata = saved['idata']
```

View File

@@ -0,0 +1,387 @@
"""
PyMC Model Comparison Script
Utilities for comparing multiple Bayesian models using information criteria
and cross-validation metrics.
Usage:
from scripts.model_comparison import compare_models, plot_model_comparison
# Compare multiple models
comparison = compare_models(
{'model1': idata1, 'model2': idata2, 'model3': idata3},
ic='loo'
)
# Visualize comparison
plot_model_comparison(comparison, output_path='model_comparison.png')
"""
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict
def compare_models(models_dict: Dict[str, az.InferenceData],
ic='loo',
scale='deviance',
verbose=True):
"""
Compare multiple models using information criteria.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects.
All models must have log_likelihood computed.
ic : str
Information criterion to use: 'loo' (default) or 'waic'
scale : str
Scale for IC: 'deviance' (default), 'log', or 'negative_log'
verbose : bool
Print detailed comparison results (default: True)
Returns
-------
pd.DataFrame
Comparison DataFrame with model rankings and statistics
Notes
-----
Models must be fit with idata_kwargs={'log_likelihood': True} or
log-likelihood computed afterwards with pm.compute_log_likelihood().
"""
if verbose:
print("="*70)
print(f" " * 25 + f"MODEL COMPARISON ({ic.upper()})")
print("="*70)
# Perform comparison
comparison = az.compare(models_dict, ic=ic, scale=scale)
if verbose:
print("\nModel Rankings:")
print("-"*70)
print(comparison.to_string())
print("\n" + "="*70)
print("INTERPRETATION GUIDE")
print("="*70)
print(f"• rank: Model ranking (0 = best)")
print(f"{ic}: {ic.upper()} estimate (lower is better)")
print(f"• p_{ic}: Effective number of parameters")
print(f"• d{ic}: Difference from best model")
print(f"• weight: Model probability (pseudo-BMA)")
print(f"• se: Standard error of {ic.upper()}")
print(f"• dse: Standard error of the difference")
print(f"• warning: True if model has reliability issues")
print(f"• scale: {scale}")
print("\n" + "="*70)
print("MODEL SELECTION GUIDELINES")
print("="*70)
best_model = comparison.index[0]
print(f"\n✓ Best model: {best_model}")
# Check for clear winner
if len(comparison) > 1:
delta = comparison.iloc[1][f'd{ic}']
delta_se = comparison.iloc[1]['dse']
if delta > 10:
print(f" → STRONG evidence for {best_model}{ic} > 10)")
elif delta > 4:
print(f" → MODERATE evidence for {best_model} (4 < Δ{ic} < 10)")
elif delta > 2:
print(f" → WEAK evidence for {best_model} (2 < Δ{ic} < 4)")
else:
print(f" → Models are SIMILAR (Δ{ic} < 2)")
print(f" Consider model averaging or choose based on simplicity")
# Check if difference is significant relative to SE
if delta > 2 * delta_se:
print(f" → Difference is > 2 SE, likely reliable")
else:
print(f" → Difference is < 2 SE, uncertain distinction")
# Check for warnings
if comparison['warning'].any():
print("\n⚠️ WARNING: Some models have reliability issues")
warned_models = comparison[comparison['warning']].index.tolist()
print(f" Models with warnings: {', '.join(warned_models)}")
print(f" → Check Pareto-k diagnostics with check_loo_reliability()")
return comparison
def check_loo_reliability(models_dict: Dict[str, az.InferenceData],
threshold=0.7,
verbose=True):
"""
Check LOO-CV reliability using Pareto-k diagnostics.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects
threshold : float
Pareto-k threshold for flagging observations (default: 0.7)
verbose : bool
Print detailed diagnostics (default: True)
Returns
-------
dict
Dictionary with Pareto-k diagnostics for each model
"""
if verbose:
print("="*70)
print(" " * 20 + "LOO RELIABILITY CHECK")
print("="*70)
results = {}
for name, idata in models_dict.items():
if verbose:
print(f"\n{name}:")
print("-"*70)
# Compute LOO with pointwise results
loo_result = az.loo(idata, pointwise=True)
pareto_k = loo_result.pareto_k.values
# Count problematic observations
n_high = (pareto_k > threshold).sum()
n_very_high = (pareto_k > 1.0).sum()
results[name] = {
'pareto_k': pareto_k,
'n_high': n_high,
'n_very_high': n_very_high,
'max_k': pareto_k.max(),
'loo': loo_result
}
if verbose:
print(f"Pareto-k diagnostics:")
print(f" • Good (k < 0.5): {(pareto_k < 0.5).sum()} observations")
print(f" • OK (0.5 ≤ k < 0.7): {((pareto_k >= 0.5) & (pareto_k < 0.7)).sum()} observations")
print(f" • Bad (0.7 ≤ k < 1.0): {((pareto_k >= 0.7) & (pareto_k < 1.0)).sum()} observations")
print(f" • Very bad (k ≥ 1.0): {(pareto_k >= 1.0).sum()} observations")
print(f" • Maximum k: {pareto_k.max():.3f}")
if n_high > 0:
print(f"\n⚠️ {n_high} observations with k > {threshold}")
print(" LOO approximation may be unreliable for these points")
print(" Solutions:")
print(" → Use WAIC instead (less sensitive to outliers)")
print(" → Investigate influential observations")
print(" → Consider more flexible model")
if n_very_high > 0:
print(f"\n⚠️ {n_very_high} observations with k > 1.0")
print(" These points have very high influence")
print(" → Strongly consider K-fold CV or other validation")
else:
print(f"✓ All Pareto-k values < {threshold}")
print(" LOO estimates are reliable")
return results
def plot_model_comparison(comparison, output_path=None, show=True):
"""
Visualize model comparison results.
Parameters
----------
comparison : pd.DataFrame
Comparison DataFrame from az.compare()
output_path : str, optional
If provided, save plot to this path
show : bool
Whether to display plot (default: True)
Returns
-------
matplotlib.figure.Figure
The comparison figure
"""
fig = plt.figure(figsize=(10, 6))
az.plot_compare(comparison)
plt.title('Model Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Comparison plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return fig
def model_averaging(models_dict: Dict[str, az.InferenceData],
weights=None,
var_name='y_obs',
ic='loo'):
"""
Perform Bayesian model averaging using model weights.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects
weights : array-like, optional
Model weights. If None, computed from IC (pseudo-BMA weights)
var_name : str
Name of the predicted variable (default: 'y_obs')
ic : str
Information criterion for computing weights if not provided
Returns
-------
np.ndarray
Averaged predictions across models
np.ndarray
Model weights used
"""
if weights is None:
comparison = az.compare(models_dict, ic=ic)
weights = comparison['weight'].values
model_names = comparison.index.tolist()
else:
model_names = list(models_dict.keys())
weights = np.array(weights)
weights = weights / weights.sum() # Normalize
print("="*70)
print(" " * 22 + "BAYESIAN MODEL AVERAGING")
print("="*70)
print("\nModel weights:")
for name, weight in zip(model_names, weights):
print(f" {name}: {weight:.4f} ({weight*100:.2f}%)")
# Extract predictions and average
predictions = []
for name in model_names:
idata = models_dict[name]
if 'posterior_predictive' in idata:
pred = idata.posterior_predictive[var_name].values
else:
print(f"Warning: {name} missing posterior_predictive, skipping")
continue
predictions.append(pred)
# Weighted average
averaged = sum(w * p for w, p in zip(weights, predictions))
print(f"\n✓ Model averaging complete")
print(f" Combined predictions using {len(predictions)} models")
return averaged, weights
def cross_validation_comparison(models_dict: Dict[str, az.InferenceData],
k=10,
verbose=True):
"""
Perform k-fold cross-validation comparison (conceptual guide).
Note: This function provides guidance. Full k-fold CV requires
re-fitting models k times, which should be done in the main script.
Parameters
----------
models_dict : dict
Dictionary of model names to InferenceData
k : int
Number of folds (default: 10)
verbose : bool
Print guidance
Returns
-------
None
"""
if verbose:
print("="*70)
print(" " * 20 + "K-FOLD CROSS-VALIDATION GUIDE")
print("="*70)
print(f"\nTo perform {k}-fold CV:")
print("""
1. Split data into k folds
2. For each fold:
- Train all models on k-1 folds
- Compute log-likelihood on held-out fold
3. Sum log-likelihoods across folds for each model
4. Compare models using total CV score
Example code:
-------------
from sklearn.model_selection import KFold
kf = KFold(n_splits=k, shuffle=True, random_seed=42)
cv_scores = {name: [] for name in models_dict.keys()}
for train_idx, test_idx in kf.split(X):
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]
for name in models_dict.keys():
# Fit model on train set
with create_model(name, X_train, y_train) as model:
idata = pm.sample()
# Compute log-likelihood on test set
with model:
pm.set_data({'X': X_test, 'y': y_test})
log_lik = pm.compute_log_likelihood(idata).sum()
cv_scores[name].append(log_lik)
# Compare total CV scores
for name, scores in cv_scores.items():
print(f"{name}: {np.sum(scores):.2f}")
""")
print("\nNote: K-fold CV is expensive but most reliable for model comparison")
print(" Use when LOO has reliability issues (high Pareto-k values)")
# Example usage
if __name__ == '__main__':
print("This script provides model comparison utilities for PyMC.")
print("\nExample usage:")
print("""
import pymc as pm
from scripts.model_comparison import compare_models, check_loo_reliability
# Fit multiple models (must include log_likelihood)
with pm.Model() as model1:
# ... define model 1 ...
idata1 = pm.sample(idata_kwargs={'log_likelihood': True})
with pm.Model() as model2:
# ... define model 2 ...
idata2 = pm.sample(idata_kwargs={'log_likelihood': True})
# Compare models
models = {'Simple': idata1, 'Complex': idata2}
comparison = compare_models(models, ic='loo')
# Check reliability
reliability = check_loo_reliability(models)
# Visualize
plot_model_comparison(comparison, output_path='comparison.png')
# Model averaging
averaged_pred, weights = model_averaging(models, var_name='y_obs')
""")

View File

@@ -0,0 +1,350 @@
"""
PyMC Model Diagnostics Script
Comprehensive diagnostic checks for PyMC models.
Run this after sampling to validate results before interpretation.
Usage:
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
# Quick check
check_diagnostics(idata)
# Full report with plots
create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/')
"""
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01):
"""
Perform comprehensive diagnostic checks on MCMC samples.
Parameters
----------
idata : arviz.InferenceData
InferenceData object from pm.sample()
var_names : list, optional
Variables to check. If None, checks all model parameters
ess_threshold : int
Minimum acceptable effective sample size (default: 400)
rhat_threshold : float
Maximum acceptable R-hat value (default: 1.01)
Returns
-------
dict
Dictionary with diagnostic results and flags
"""
print("="*70)
print(" " * 20 + "MCMC DIAGNOSTICS REPORT")
print("="*70)
# Get summary statistics
summary = az.summary(idata, var_names=var_names)
results = {
'summary': summary,
'has_issues': False,
'issues': []
}
# 1. Check R-hat (convergence)
print("\n1. CONVERGENCE CHECK (R-hat)")
print("-" * 70)
bad_rhat = summary[summary['r_hat'] > rhat_threshold]
if len(bad_rhat) > 0:
print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}")
print("\nTop 10 worst R-hat values:")
print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10))
print("\n⚠️ Chains may not have converged!")
print(" → Run longer chains or check for multimodality")
results['has_issues'] = True
results['issues'].append('convergence')
else:
print(f"✓ All R-hat values ≤ {rhat_threshold}")
print(" Chains have converged successfully")
# 2. Check Effective Sample Size
print("\n2. EFFECTIVE SAMPLE SIZE (ESS)")
print("-" * 70)
low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold]
low_ess_tail = summary[summary['ess_tail'] < ess_threshold]
if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0:
print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}")
if len(low_ess_bulk) > 0:
print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):")
print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10))
if len(low_ess_tail) > 0:
print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):")
print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10))
print("\n⚠️ High autocorrelation detected!")
print(" → Sample more draws or reparameterize to reduce correlation")
results['has_issues'] = True
results['issues'].append('low_ess')
else:
print(f"✓ All ESS values ≥ {ess_threshold}")
print(" Sufficient effective samples")
# 3. Check Divergences
print("\n3. DIVERGENT TRANSITIONS")
print("-" * 70)
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
divergence_rate = divergences / total_samples * 100
print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)")
print("\n Divergences indicate biased sampling in difficult posterior regions")
print(" Solutions:")
print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)")
print(" → Use non-centered parameterization for hierarchical models")
print(" → Add stronger/more informative priors")
print(" → Check for model misspecification")
results['has_issues'] = True
results['issues'].append('divergences')
results['n_divergences'] = divergences
else:
print("✓ No divergences detected")
print(" NUTS explored the posterior successfully")
# 4. Check Tree Depth
print("\n4. TREE DEPTH")
print("-" * 70)
tree_depth = idata.sample_stats.tree_depth
max_tree_depth = tree_depth.max().item()
# Typical max_treedepth is 10 (default in PyMC)
hits_max = (tree_depth >= 10).sum().item()
if hits_max > 0:
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
hit_rate = hits_max / total_samples * 100
print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)")
print("\n Model may be difficult to explore efficiently")
print(" Solutions:")
print(" → Reparameterize model to improve geometry")
print(" → Increase max_treedepth (if necessary)")
results['issues'].append('max_treedepth')
else:
print(f"✓ No maximum tree depth issues")
print(f" Maximum tree depth reached: {max_tree_depth}")
# 5. Check Energy (if available)
if hasattr(idata.sample_stats, 'energy'):
print("\n5. ENERGY DIAGNOSTICS")
print("-" * 70)
print("✓ Energy statistics available")
print(" Use az.plot_energy(idata) to visualize energy transitions")
print(" Good separation indicates healthy HMC sampling")
# Summary
print("\n" + "="*70)
print("SUMMARY")
print("="*70)
if not results['has_issues']:
print("✓ All diagnostics passed!")
print(" Your model has sampled successfully.")
print(" Proceed with inference and interpretation.")
else:
print("⚠️ Some diagnostics failed!")
print(f" Issues found: {', '.join(results['issues'])}")
print(" Review warnings above and consider re-running with adjustments.")
print("="*70)
return results
def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False):
"""
Create comprehensive diagnostic report with plots.
Parameters
----------
idata : arviz.InferenceData
InferenceData object from pm.sample()
var_names : list, optional
Variables to plot. If None, uses all model parameters
output_dir : str
Directory to save diagnostic plots
show : bool
Whether to display plots (default: False, just save)
Returns
-------
dict
Diagnostic results from check_diagnostics
"""
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Run diagnostic checks
results = check_diagnostics(idata, var_names=var_names)
print(f"\nGenerating diagnostic plots in '{output_dir}'...")
# 1. Trace plots
fig, axes = plt.subplots(
len(var_names) if var_names else 5,
2,
figsize=(12, 10)
)
az.plot_trace(idata, var_names=var_names, axes=axes)
plt.tight_layout()
plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved trace plots")
if show:
plt.show()
else:
plt.close()
# 2. Rank plots (check mixing)
fig = plt.figure(figsize=(12, 8))
az.plot_rank(idata, var_names=var_names)
plt.tight_layout()
plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved rank plots")
if show:
plt.show()
else:
plt.close()
# 3. Autocorrelation plots
fig = plt.figure(figsize=(12, 8))
az.plot_autocorr(idata, var_names=var_names, combined=True)
plt.tight_layout()
plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved autocorrelation plots")
if show:
plt.show()
else:
plt.close()
# 4. Energy plot (if available)
if hasattr(idata.sample_stats, 'energy'):
fig = plt.figure(figsize=(10, 6))
az.plot_energy(idata)
plt.tight_layout()
plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved energy plot")
if show:
plt.show()
else:
plt.close()
# 5. ESS plot
fig = plt.figure(figsize=(10, 6))
az.plot_ess(idata, var_names=var_names, kind='evolution')
plt.tight_layout()
plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved ESS evolution plot")
if show:
plt.show()
else:
plt.close()
# Save summary to CSV
results['summary'].to_csv(output_path / 'summary_statistics.csv')
print(f" ✓ Saved summary statistics")
print(f"\nDiagnostic report complete! Files saved in '{output_dir}'")
return results
def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None):
"""
Compare prior and posterior distributions.
Parameters
----------
idata : arviz.InferenceData
InferenceData with posterior samples
prior_idata : arviz.InferenceData
InferenceData with prior samples
var_names : list, optional
Variables to compare
output_path : str, optional
If provided, save plot to this path
Returns
-------
None
"""
fig, axes = plt.subplots(
len(var_names) if var_names else 3,
1,
figsize=(10, 8)
)
if not isinstance(axes, np.ndarray):
axes = [axes]
for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]):
# Plot prior
az.plot_dist(
prior_idata.prior[var].values.flatten(),
label='Prior',
ax=axes[idx],
color='blue',
alpha=0.3
)
# Plot posterior
az.plot_dist(
idata.posterior[var].values.flatten(),
label='Posterior',
ax=axes[idx],
color='green',
alpha=0.3
)
axes[idx].set_title(f'{var}: Prior vs Posterior')
axes[idx].legend()
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Prior-posterior comparison saved to {output_path}")
else:
plt.show()
# Example usage
if __name__ == '__main__':
print("This script provides diagnostic functions for PyMC models.")
print("\nExample usage:")
print("""
import pymc as pm
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
# After sampling
with pm.Model() as model:
# ... define model ...
idata = pm.sample()
# Quick diagnostic check
results = check_diagnostics(idata)
# Full diagnostic report with plots
create_diagnostic_report(
idata,
var_names=['alpha', 'beta', 'sigma'],
output_dir='my_diagnostics/'
)
""")