Initial commit
This commit is contained in:
566
skills/pymc/SKILL.md
Normal file
566
skills/pymc/SKILL.md
Normal file
@@ -0,0 +1,566 @@
|
||||
---
|
||||
name: pymc-bayesian-modeling
|
||||
description: "Bayesian modeling with PyMC. Build hierarchical models, MCMC (NUTS), variational inference, LOO/WAIC comparison, posterior checks, for probabilistic programming and inference."
|
||||
---
|
||||
|
||||
# PyMC Bayesian Modeling
|
||||
|
||||
## Overview
|
||||
|
||||
PyMC is a Python library for Bayesian modeling and probabilistic programming. Build, fit, validate, and compare Bayesian models using PyMC's modern API (version 5.x+), including hierarchical models, MCMC sampling (NUTS), variational inference, and model comparison (LOO, WAIC).
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Building Bayesian models (linear/logistic regression, hierarchical models, time series, etc.)
|
||||
- Performing MCMC sampling or variational inference
|
||||
- Conducting prior/posterior predictive checks
|
||||
- Diagnosing sampling issues (divergences, convergence, ESS)
|
||||
- Comparing multiple models using information criteria (LOO, WAIC)
|
||||
- Implementing uncertainty quantification through Bayesian methods
|
||||
- Working with hierarchical/multilevel data structures
|
||||
- Handling missing data or measurement error in a principled way
|
||||
|
||||
## Standard Bayesian Workflow
|
||||
|
||||
Follow this workflow for building and validating Bayesian models:
|
||||
|
||||
### 1. Data Preparation
|
||||
|
||||
```python
|
||||
import pymc as pm
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
|
||||
# Load and prepare data
|
||||
X = ... # Predictors
|
||||
y = ... # Outcomes
|
||||
|
||||
# Standardize predictors for better sampling
|
||||
X_mean = X.mean(axis=0)
|
||||
X_std = X.std(axis=0)
|
||||
X_scaled = (X - X_mean) / X_std
|
||||
```
|
||||
|
||||
**Key practices:**
|
||||
- Standardize continuous predictors (improves sampling efficiency)
|
||||
- Center outcomes when possible
|
||||
- Handle missing data explicitly (treat as parameters)
|
||||
- Use named dimensions with `coords` for clarity
|
||||
|
||||
### 2. Model Building
|
||||
|
||||
```python
|
||||
coords = {
|
||||
'predictors': ['var1', 'var2', 'var3'],
|
||||
'obs_id': np.arange(len(y))
|
||||
}
|
||||
|
||||
with pm.Model(coords=coords) as model:
|
||||
# Priors
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=1)
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# Linear predictor
|
||||
mu = alpha + pm.math.dot(X_scaled, beta)
|
||||
|
||||
# Likelihood
|
||||
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
|
||||
```
|
||||
|
||||
**Key practices:**
|
||||
- Use weakly informative priors (not flat priors)
|
||||
- Use `HalfNormal` or `Exponential` for scale parameters
|
||||
- Use named dimensions (`dims`) instead of `shape` when possible
|
||||
- Use `pm.Data()` for values that will be updated for predictions
|
||||
|
||||
### 3. Prior Predictive Check
|
||||
|
||||
**Always validate priors before fitting:**
|
||||
|
||||
```python
|
||||
with model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
|
||||
|
||||
# Visualize
|
||||
az.plot_ppc(prior_pred, group='prior')
|
||||
```
|
||||
|
||||
**Check:**
|
||||
- Do prior predictions span reasonable values?
|
||||
- Are extreme values plausible given domain knowledge?
|
||||
- If priors generate implausible data, adjust and re-check
|
||||
|
||||
### 4. Fit Model
|
||||
|
||||
```python
|
||||
with model:
|
||||
# Optional: Quick exploration with ADVI
|
||||
# approx = pm.fit(n=20000)
|
||||
|
||||
# Full MCMC inference
|
||||
idata = pm.sample(
|
||||
draws=2000,
|
||||
tune=1000,
|
||||
chains=4,
|
||||
target_accept=0.9,
|
||||
random_seed=42,
|
||||
idata_kwargs={'log_likelihood': True} # For model comparison
|
||||
)
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `draws=2000`: Number of samples per chain
|
||||
- `tune=1000`: Warmup samples (discarded)
|
||||
- `chains=4`: Run 4 chains for convergence checking
|
||||
- `target_accept=0.9`: Higher for difficult posteriors (0.95-0.99)
|
||||
- Include `log_likelihood=True` for model comparison
|
||||
|
||||
### 5. Check Diagnostics
|
||||
|
||||
**Use the diagnostic script:**
|
||||
|
||||
```python
|
||||
from scripts.model_diagnostics import check_diagnostics
|
||||
|
||||
results = check_diagnostics(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
```
|
||||
|
||||
**Check:**
|
||||
- **R-hat < 1.01**: Chains have converged
|
||||
- **ESS > 400**: Sufficient effective samples
|
||||
- **No divergences**: NUTS sampled successfully
|
||||
- **Trace plots**: Chains should mix well (fuzzy caterpillar)
|
||||
|
||||
**If issues arise:**
|
||||
- Divergences → Increase `target_accept=0.95`, use non-centered parameterization
|
||||
- Low ESS → Sample more draws, reparameterize to reduce correlation
|
||||
- High R-hat → Run longer, check for multimodality
|
||||
|
||||
### 6. Posterior Predictive Check
|
||||
|
||||
**Validate model fit:**
|
||||
|
||||
```python
|
||||
with model:
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
|
||||
|
||||
# Visualize
|
||||
az.plot_ppc(idata)
|
||||
```
|
||||
|
||||
**Check:**
|
||||
- Do posterior predictions capture observed data patterns?
|
||||
- Are systematic deviations evident (model misspecification)?
|
||||
- Consider alternative models if fit is poor
|
||||
|
||||
### 7. Analyze Results
|
||||
|
||||
```python
|
||||
# Summary statistics
|
||||
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
|
||||
|
||||
# Posterior distributions
|
||||
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
|
||||
# Coefficient estimates
|
||||
az.plot_forest(idata, var_names=['beta'], combined=True)
|
||||
```
|
||||
|
||||
### 8. Make Predictions
|
||||
|
||||
```python
|
||||
X_new = ... # New predictor values
|
||||
X_new_scaled = (X_new - X_mean) / X_std
|
||||
|
||||
with model:
|
||||
pm.set_data({'X_scaled': X_new_scaled})
|
||||
post_pred = pm.sample_posterior_predictive(
|
||||
idata.posterior,
|
||||
var_names=['y_obs'],
|
||||
random_seed=42
|
||||
)
|
||||
|
||||
# Extract prediction intervals
|
||||
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
|
||||
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
|
||||
```
|
||||
|
||||
## Common Model Patterns
|
||||
|
||||
### Linear Regression
|
||||
|
||||
For continuous outcomes with linear relationships:
|
||||
|
||||
```python
|
||||
with pm.Model() as linear_model:
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
mu = alpha + pm.math.dot(X, beta)
|
||||
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
|
||||
```
|
||||
|
||||
**Use template:** `assets/linear_regression_template.py`
|
||||
|
||||
### Logistic Regression
|
||||
|
||||
For binary outcomes:
|
||||
|
||||
```python
|
||||
with pm.Model() as logistic_model:
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
|
||||
logit_p = alpha + pm.math.dot(X, beta)
|
||||
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
|
||||
```
|
||||
|
||||
### Hierarchical Models
|
||||
|
||||
For grouped data (use non-centered parameterization):
|
||||
|
||||
```python
|
||||
with pm.Model(coords={'groups': group_names}) as hierarchical_model:
|
||||
# Hyperpriors
|
||||
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
|
||||
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
|
||||
|
||||
# Group-level (non-centered)
|
||||
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
|
||||
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
|
||||
|
||||
# Observation-level
|
||||
mu = alpha[group_idx]
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
|
||||
```
|
||||
|
||||
**Use template:** `assets/hierarchical_model_template.py`
|
||||
|
||||
**Critical:** Always use non-centered parameterization for hierarchical models to avoid divergences.
|
||||
|
||||
### Poisson Regression
|
||||
|
||||
For count data:
|
||||
|
||||
```python
|
||||
with pm.Model() as poisson_model:
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
|
||||
log_lambda = alpha + pm.math.dot(X, beta)
|
||||
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
|
||||
```
|
||||
|
||||
For overdispersed counts, use `NegativeBinomial` instead.
|
||||
|
||||
### Time Series
|
||||
|
||||
For autoregressive processes:
|
||||
|
||||
```python
|
||||
with pm.Model() as ar_model:
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
|
||||
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
|
||||
|
||||
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### Comparing Models
|
||||
|
||||
Use LOO or WAIC for model comparison:
|
||||
|
||||
```python
|
||||
from scripts.model_comparison import compare_models, check_loo_reliability
|
||||
|
||||
# Fit models with log_likelihood
|
||||
models = {
|
||||
'Model1': idata1,
|
||||
'Model2': idata2,
|
||||
'Model3': idata3
|
||||
}
|
||||
|
||||
# Compare using LOO
|
||||
comparison = compare_models(models, ic='loo')
|
||||
|
||||
# Check reliability
|
||||
check_loo_reliability(models)
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- **Δloo < 2**: Models are similar, choose simpler model
|
||||
- **2 < Δloo < 4**: Weak evidence for better model
|
||||
- **4 < Δloo < 10**: Moderate evidence
|
||||
- **Δloo > 10**: Strong evidence for better model
|
||||
|
||||
**Check Pareto-k values:**
|
||||
- k < 0.7: LOO reliable
|
||||
- k > 0.7: Consider WAIC or k-fold CV
|
||||
|
||||
### Model Averaging
|
||||
|
||||
When models are similar, average predictions:
|
||||
|
||||
```python
|
||||
from scripts.model_comparison import model_averaging
|
||||
|
||||
averaged_pred, weights = model_averaging(models, var_name='y_obs')
|
||||
```
|
||||
|
||||
## Distribution Selection Guide
|
||||
|
||||
### For Priors
|
||||
|
||||
**Scale parameters** (σ, τ):
|
||||
- `pm.HalfNormal('sigma', sigma=1)` - Default choice
|
||||
- `pm.Exponential('sigma', lam=1)` - Alternative
|
||||
- `pm.Gamma('sigma', alpha=2, beta=1)` - More informative
|
||||
|
||||
**Unbounded parameters**:
|
||||
- `pm.Normal('theta', mu=0, sigma=1)` - For standardized data
|
||||
- `pm.StudentT('theta', nu=3, mu=0, sigma=1)` - Robust to outliers
|
||||
|
||||
**Positive parameters**:
|
||||
- `pm.LogNormal('theta', mu=0, sigma=1)`
|
||||
- `pm.Gamma('theta', alpha=2, beta=1)`
|
||||
|
||||
**Probabilities**:
|
||||
- `pm.Beta('p', alpha=2, beta=2)` - Weakly informative
|
||||
- `pm.Uniform('p', lower=0, upper=1)` - Non-informative (use sparingly)
|
||||
|
||||
**Correlation matrices**:
|
||||
- `pm.LKJCorr('corr', n=n_vars, eta=2)` - eta=1 uniform, eta>1 prefers identity
|
||||
|
||||
### For Likelihoods
|
||||
|
||||
**Continuous outcomes**:
|
||||
- `pm.Normal('y', mu=mu, sigma=sigma)` - Default for continuous data
|
||||
- `pm.StudentT('y', nu=nu, mu=mu, sigma=sigma)` - Robust to outliers
|
||||
|
||||
**Count data**:
|
||||
- `pm.Poisson('y', mu=lambda)` - Equidispersed counts
|
||||
- `pm.NegativeBinomial('y', mu=mu, alpha=alpha)` - Overdispersed counts
|
||||
- `pm.ZeroInflatedPoisson('y', psi=psi, mu=mu)` - Excess zeros
|
||||
|
||||
**Binary outcomes**:
|
||||
- `pm.Bernoulli('y', p=p)` or `pm.Bernoulli('y', logit_p=logit_p)`
|
||||
|
||||
**Categorical outcomes**:
|
||||
- `pm.Categorical('y', p=probs)`
|
||||
|
||||
**See:** `references/distributions.md` for comprehensive distribution reference
|
||||
|
||||
## Sampling and Inference
|
||||
|
||||
### MCMC with NUTS
|
||||
|
||||
Default and recommended for most models:
|
||||
|
||||
```python
|
||||
idata = pm.sample(
|
||||
draws=2000,
|
||||
tune=1000,
|
||||
chains=4,
|
||||
target_accept=0.9,
|
||||
random_seed=42
|
||||
)
|
||||
```
|
||||
|
||||
**Adjust when needed:**
|
||||
- Divergences → `target_accept=0.95` or higher
|
||||
- Slow sampling → Use ADVI for initialization
|
||||
- Discrete parameters → Use `pm.Metropolis()` for discrete vars
|
||||
|
||||
### Variational Inference
|
||||
|
||||
Fast approximation for exploration or initialization:
|
||||
|
||||
```python
|
||||
with model:
|
||||
approx = pm.fit(n=20000, method='advi')
|
||||
|
||||
# Use for initialization
|
||||
start = approx.sample(return_inferencedata=False)[0]
|
||||
idata = pm.sample(start=start)
|
||||
```
|
||||
|
||||
**Trade-offs:**
|
||||
- Much faster than MCMC
|
||||
- Approximate (may underestimate uncertainty)
|
||||
- Good for large models or quick exploration
|
||||
|
||||
**See:** `references/sampling_inference.md` for detailed sampling guide
|
||||
|
||||
## Diagnostic Scripts
|
||||
|
||||
### Comprehensive Diagnostics
|
||||
|
||||
```python
|
||||
from scripts.model_diagnostics import create_diagnostic_report
|
||||
|
||||
create_diagnostic_report(
|
||||
idata,
|
||||
var_names=['alpha', 'beta', 'sigma'],
|
||||
output_dir='diagnostics/'
|
||||
)
|
||||
```
|
||||
|
||||
Creates:
|
||||
- Trace plots
|
||||
- Rank plots (mixing check)
|
||||
- Autocorrelation plots
|
||||
- Energy plots
|
||||
- ESS evolution
|
||||
- Summary statistics CSV
|
||||
|
||||
### Quick Diagnostic Check
|
||||
|
||||
```python
|
||||
from scripts.model_diagnostics import check_diagnostics
|
||||
|
||||
results = check_diagnostics(idata)
|
||||
```
|
||||
|
||||
Checks R-hat, ESS, divergences, and tree depth.
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Divergences
|
||||
|
||||
**Symptom:** `idata.sample_stats.diverging.sum() > 0`
|
||||
|
||||
**Solutions:**
|
||||
1. Increase `target_accept=0.95` or `0.99`
|
||||
2. Use non-centered parameterization (hierarchical models)
|
||||
3. Add stronger priors to constrain parameters
|
||||
4. Check for model misspecification
|
||||
|
||||
### Low Effective Sample Size
|
||||
|
||||
**Symptom:** `ESS < 400`
|
||||
|
||||
**Solutions:**
|
||||
1. Sample more draws: `draws=5000`
|
||||
2. Reparameterize to reduce posterior correlation
|
||||
3. Use QR decomposition for regression with correlated predictors
|
||||
|
||||
### High R-hat
|
||||
|
||||
**Symptom:** `R-hat > 1.01`
|
||||
|
||||
**Solutions:**
|
||||
1. Run longer chains: `tune=2000, draws=5000`
|
||||
2. Check for multimodality
|
||||
3. Improve initialization with ADVI
|
||||
|
||||
### Slow Sampling
|
||||
|
||||
**Solutions:**
|
||||
1. Use ADVI initialization
|
||||
2. Reduce model complexity
|
||||
3. Increase parallelization: `cores=8, chains=8`
|
||||
4. Use variational inference if appropriate
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Model Building
|
||||
|
||||
1. **Always standardize predictors** for better sampling
|
||||
2. **Use weakly informative priors** (not flat)
|
||||
3. **Use named dimensions** (`dims`) for clarity
|
||||
4. **Non-centered parameterization** for hierarchical models
|
||||
5. **Check prior predictive** before fitting
|
||||
|
||||
### Sampling
|
||||
|
||||
1. **Run multiple chains** (at least 4) for convergence
|
||||
2. **Use `target_accept=0.9`** as baseline (higher if needed)
|
||||
3. **Include `log_likelihood=True`** for model comparison
|
||||
4. **Set random seed** for reproducibility
|
||||
|
||||
### Validation
|
||||
|
||||
1. **Check diagnostics** before interpretation (R-hat, ESS, divergences)
|
||||
2. **Posterior predictive check** for model validation
|
||||
3. **Compare multiple models** when appropriate
|
||||
4. **Report uncertainty** (HDI intervals, not just point estimates)
|
||||
|
||||
### Workflow
|
||||
|
||||
1. Start simple, add complexity gradually
|
||||
2. Prior predictive check → Fit → Diagnostics → Posterior predictive check
|
||||
3. Iterate on model specification based on checks
|
||||
4. Document assumptions and prior choices
|
||||
|
||||
## Resources
|
||||
|
||||
This skill includes:
|
||||
|
||||
### References (`references/`)
|
||||
|
||||
- **`distributions.md`**: Comprehensive catalog of PyMC distributions organized by category (continuous, discrete, multivariate, mixture, time series). Use when selecting priors or likelihoods.
|
||||
|
||||
- **`sampling_inference.md`**: Detailed guide to sampling algorithms (NUTS, Metropolis, SMC), variational inference (ADVI, SVGD), and handling sampling issues. Use when encountering convergence problems or choosing inference methods.
|
||||
|
||||
- **`workflows.md`**: Complete workflow examples and code patterns for common model types, data preparation, prior selection, and model validation. Use as a cookbook for standard Bayesian analyses.
|
||||
|
||||
### Scripts (`scripts/`)
|
||||
|
||||
- **`model_diagnostics.py`**: Automated diagnostic checking and report generation. Functions: `check_diagnostics()` for quick checks, `create_diagnostic_report()` for comprehensive analysis with plots.
|
||||
|
||||
- **`model_comparison.py`**: Model comparison utilities using LOO/WAIC. Functions: `compare_models()`, `check_loo_reliability()`, `model_averaging()`.
|
||||
|
||||
### Templates (`assets/`)
|
||||
|
||||
- **`linear_regression_template.py`**: Complete template for Bayesian linear regression with full workflow (data prep, prior checks, fitting, diagnostics, predictions).
|
||||
|
||||
- **`hierarchical_model_template.py`**: Complete template for hierarchical/multilevel models with non-centered parameterization and group-level analysis.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Model Building
|
||||
```python
|
||||
with pm.Model(coords={'var': names}) as model:
|
||||
# Priors
|
||||
param = pm.Normal('param', mu=0, sigma=1, dims='var')
|
||||
# Likelihood
|
||||
y = pm.Normal('y', mu=..., sigma=..., observed=data)
|
||||
```
|
||||
|
||||
### Sampling
|
||||
```python
|
||||
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
|
||||
```
|
||||
|
||||
### Diagnostics
|
||||
```python
|
||||
from scripts.model_diagnostics import check_diagnostics
|
||||
check_diagnostics(idata)
|
||||
```
|
||||
|
||||
### Model Comparison
|
||||
```python
|
||||
from scripts.model_comparison import compare_models
|
||||
compare_models({'m1': idata1, 'm2': idata2}, ic='loo')
|
||||
```
|
||||
|
||||
### Predictions
|
||||
```python
|
||||
with model:
|
||||
pm.set_data({'X': X_new})
|
||||
pred = pm.sample_posterior_predictive(idata.posterior)
|
||||
```
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- PyMC integrates with ArviZ for visualization and diagnostics
|
||||
- Use `pm.model_to_graphviz(model)` to visualize model structure
|
||||
- Save results with `idata.to_netcdf('results.nc')`
|
||||
- Load with `az.from_netcdf('results.nc')`
|
||||
- For very large models, consider minibatch ADVI or data subsampling
|
||||
333
skills/pymc/assets/hierarchical_model_template.py
Normal file
333
skills/pymc/assets/hierarchical_model_template.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
PyMC Hierarchical/Multilevel Model Template
|
||||
|
||||
This template provides a complete workflow for Bayesian hierarchical models,
|
||||
useful for grouped/nested data (e.g., students within schools, patients within hospitals).
|
||||
|
||||
Customize the sections marked with # TODO
|
||||
"""
|
||||
|
||||
import pymc as pm
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# =============================================================================
|
||||
# 1. DATA PREPARATION
|
||||
# =============================================================================
|
||||
|
||||
# TODO: Load your data with group structure
|
||||
# Example:
|
||||
# df = pd.read_csv('data.csv')
|
||||
# groups = df['group_id'].values
|
||||
# X = df['predictor'].values
|
||||
# y = df['outcome'].values
|
||||
|
||||
# For demonstration: Generate hierarchical data
|
||||
np.random.seed(42)
|
||||
n_groups = 10
|
||||
n_per_group = 20
|
||||
n_obs = n_groups * n_per_group
|
||||
|
||||
# True hierarchical structure
|
||||
true_mu_alpha = 5.0
|
||||
true_sigma_alpha = 2.0
|
||||
true_mu_beta = 1.5
|
||||
true_sigma_beta = 0.5
|
||||
true_sigma = 1.0
|
||||
|
||||
group_alphas = np.random.normal(true_mu_alpha, true_sigma_alpha, n_groups)
|
||||
group_betas = np.random.normal(true_mu_beta, true_sigma_beta, n_groups)
|
||||
|
||||
# Generate data
|
||||
groups = np.repeat(np.arange(n_groups), n_per_group)
|
||||
X = np.random.randn(n_obs)
|
||||
y = group_alphas[groups] + group_betas[groups] * X + np.random.randn(n_obs) * true_sigma
|
||||
|
||||
# TODO: Customize group names
|
||||
group_names = [f'Group_{i}' for i in range(n_groups)]
|
||||
|
||||
# =============================================================================
|
||||
# 2. BUILD HIERARCHICAL MODEL
|
||||
# =============================================================================
|
||||
|
||||
print("Building hierarchical model...")
|
||||
|
||||
coords = {
|
||||
'groups': group_names,
|
||||
'obs': np.arange(n_obs)
|
||||
}
|
||||
|
||||
with pm.Model(coords=coords) as hierarchical_model:
|
||||
# Data containers (for later predictions)
|
||||
X_data = pm.Data('X_data', X)
|
||||
groups_data = pm.Data('groups_data', groups)
|
||||
|
||||
# Hyperpriors (population-level parameters)
|
||||
# TODO: Adjust hyperpriors based on your domain knowledge
|
||||
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
|
||||
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=5)
|
||||
|
||||
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
|
||||
sigma_beta = pm.HalfNormal('sigma_beta', sigma=5)
|
||||
|
||||
# Group-level parameters (non-centered parameterization)
|
||||
# Non-centered parameterization improves sampling efficiency
|
||||
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
|
||||
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
|
||||
|
||||
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='groups')
|
||||
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='groups')
|
||||
|
||||
# Observation-level model
|
||||
mu = alpha[groups_data] + beta[groups_data] * X_data
|
||||
|
||||
# Observation noise
|
||||
sigma = pm.HalfNormal('sigma', sigma=5)
|
||||
|
||||
# Likelihood
|
||||
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs')
|
||||
|
||||
print("Model built successfully!")
|
||||
print(f"Groups: {n_groups}")
|
||||
print(f"Observations: {n_obs}")
|
||||
|
||||
# =============================================================================
|
||||
# 3. PRIOR PREDICTIVE CHECK
|
||||
# =============================================================================
|
||||
|
||||
print("\nRunning prior predictive check...")
|
||||
with hierarchical_model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
|
||||
|
||||
# Visualize prior predictions
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
|
||||
ax.set_title('Prior Predictive Check')
|
||||
plt.tight_layout()
|
||||
plt.savefig('hierarchical_prior_check.png', dpi=300, bbox_inches='tight')
|
||||
print("Prior predictive check saved to 'hierarchical_prior_check.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 4. FIT MODEL
|
||||
# =============================================================================
|
||||
|
||||
print("\nFitting hierarchical model...")
|
||||
print("(This may take a few minutes due to model complexity)")
|
||||
|
||||
with hierarchical_model:
|
||||
# MCMC sampling with higher target_accept for hierarchical models
|
||||
idata = pm.sample(
|
||||
draws=2000,
|
||||
tune=2000, # More tuning for hierarchical models
|
||||
chains=4,
|
||||
target_accept=0.95, # Higher for better convergence
|
||||
random_seed=42,
|
||||
idata_kwargs={'log_likelihood': True}
|
||||
)
|
||||
|
||||
print("Sampling complete!")
|
||||
|
||||
# =============================================================================
|
||||
# 5. CHECK DIAGNOSTICS
|
||||
# =============================================================================
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("DIAGNOSTICS")
|
||||
print("="*60)
|
||||
|
||||
# Summary for key parameters
|
||||
summary = az.summary(
|
||||
idata,
|
||||
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma', 'alpha', 'beta']
|
||||
)
|
||||
print("\nParameter Summary:")
|
||||
print(summary)
|
||||
|
||||
# Check convergence
|
||||
bad_rhat = summary[summary['r_hat'] > 1.01]
|
||||
if len(bad_rhat) > 0:
|
||||
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
|
||||
print(bad_rhat[['r_hat']])
|
||||
else:
|
||||
print("\n✓ All R-hat values < 1.01 (good convergence)")
|
||||
|
||||
# Check effective sample size
|
||||
low_ess = summary[summary['ess_bulk'] < 400]
|
||||
if len(low_ess) > 0:
|
||||
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
|
||||
print(low_ess[['ess_bulk']].head(10))
|
||||
else:
|
||||
print("\n✓ All ESS values > 400 (sufficient samples)")
|
||||
|
||||
# Check divergences
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
if divergences > 0:
|
||||
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
|
||||
print(" This is common in hierarchical models - non-centered parameterization already applied")
|
||||
print(" Consider even higher target_accept or stronger hyperpriors")
|
||||
else:
|
||||
print("\n✓ No divergences")
|
||||
|
||||
# Trace plots for hyperparameters
|
||||
fig, axes = plt.subplots(5, 2, figsize=(12, 12))
|
||||
az.plot_trace(
|
||||
idata,
|
||||
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma'],
|
||||
axes=axes
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig('hierarchical_trace_plots.png', dpi=300, bbox_inches='tight')
|
||||
print("\nTrace plots saved to 'hierarchical_trace_plots.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 6. POSTERIOR PREDICTIVE CHECK
|
||||
# =============================================================================
|
||||
|
||||
print("\nRunning posterior predictive check...")
|
||||
with hierarchical_model:
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
|
||||
|
||||
# Visualize fit
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
|
||||
ax.set_title('Posterior Predictive Check')
|
||||
plt.tight_layout()
|
||||
plt.savefig('hierarchical_posterior_check.png', dpi=300, bbox_inches='tight')
|
||||
print("Posterior predictive check saved to 'hierarchical_posterior_check.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 7. ANALYZE HIERARCHICAL STRUCTURE
|
||||
# =============================================================================
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("POPULATION-LEVEL (HYPERPARAMETER) ESTIMATES")
|
||||
print("="*60)
|
||||
|
||||
# Population-level estimates
|
||||
hyper_summary = summary.loc[['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma']]
|
||||
print(hyper_summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']])
|
||||
|
||||
# Forest plot for group-level parameters
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 8))
|
||||
|
||||
# Group intercepts
|
||||
az.plot_forest(idata, var_names=['alpha'], combined=True, ax=axes[0])
|
||||
axes[0].set_title('Group-Level Intercepts (α)')
|
||||
axes[0].set_yticklabels(group_names)
|
||||
axes[0].axvline(idata.posterior['mu_alpha'].mean().item(), color='red', linestyle='--', label='Population mean')
|
||||
axes[0].legend()
|
||||
|
||||
# Group slopes
|
||||
az.plot_forest(idata, var_names=['beta'], combined=True, ax=axes[1])
|
||||
axes[1].set_title('Group-Level Slopes (β)')
|
||||
axes[1].set_yticklabels(group_names)
|
||||
axes[1].axvline(idata.posterior['mu_beta'].mean().item(), color='red', linestyle='--', label='Population mean')
|
||||
axes[1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('group_level_estimates.png', dpi=300, bbox_inches='tight')
|
||||
print("\nGroup-level estimates saved to 'group_level_estimates.png'")
|
||||
|
||||
# Shrinkage visualization
|
||||
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
||||
|
||||
# Intercepts
|
||||
alpha_samples = idata.posterior['alpha'].values.reshape(-1, n_groups)
|
||||
alpha_means = alpha_samples.mean(axis=0)
|
||||
mu_alpha_mean = idata.posterior['mu_alpha'].mean().item()
|
||||
|
||||
axes[0].scatter(range(n_groups), alpha_means, alpha=0.6)
|
||||
axes[0].axhline(mu_alpha_mean, color='red', linestyle='--', label='Population mean')
|
||||
axes[0].set_xlabel('Group')
|
||||
axes[0].set_ylabel('Intercept')
|
||||
axes[0].set_title('Group Intercepts (showing shrinkage to population mean)')
|
||||
axes[0].legend()
|
||||
|
||||
# Slopes
|
||||
beta_samples = idata.posterior['beta'].values.reshape(-1, n_groups)
|
||||
beta_means = beta_samples.mean(axis=0)
|
||||
mu_beta_mean = idata.posterior['mu_beta'].mean().item()
|
||||
|
||||
axes[1].scatter(range(n_groups), beta_means, alpha=0.6)
|
||||
axes[1].axhline(mu_beta_mean, color='red', linestyle='--', label='Population mean')
|
||||
axes[1].set_xlabel('Group')
|
||||
axes[1].set_ylabel('Slope')
|
||||
axes[1].set_title('Group Slopes (showing shrinkage to population mean)')
|
||||
axes[1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('shrinkage_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("Shrinkage plot saved to 'shrinkage_plot.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 8. PREDICTIONS FOR NEW DATA
|
||||
# =============================================================================
|
||||
|
||||
# TODO: Specify new data
|
||||
# For existing groups:
|
||||
# new_X = np.array([...])
|
||||
# new_groups = np.array([0, 1, 2, ...]) # Existing group indices
|
||||
|
||||
# For a new group (predict using population-level parameters):
|
||||
# Just use mu_alpha and mu_beta
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("PREDICTIONS FOR NEW DATA")
|
||||
print("="*60)
|
||||
|
||||
# Example: Predict for existing groups
|
||||
new_X = np.array([-2, -1, 0, 1, 2])
|
||||
new_groups = np.array([0, 2, 4, 6, 8]) # Select some groups
|
||||
|
||||
with hierarchical_model:
|
||||
pm.set_data({'X_data': new_X, 'groups_data': new_groups, 'obs': np.arange(len(new_X))})
|
||||
|
||||
post_pred = pm.sample_posterior_predictive(
|
||||
idata.posterior,
|
||||
var_names=['y_obs'],
|
||||
random_seed=42
|
||||
)
|
||||
|
||||
y_pred_samples = post_pred.posterior_predictive['y_obs']
|
||||
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
|
||||
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
|
||||
|
||||
print(f"Predictions for existing groups:")
|
||||
print(f"{'Group':<10} {'X':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
|
||||
print("-"*65)
|
||||
for i, g in enumerate(new_groups):
|
||||
print(f"{group_names[g]:<10} {new_X[i]:<10.2f} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
|
||||
|
||||
# Predict for a new group (using population parameters)
|
||||
print(f"\nPrediction for a NEW group (using population-level parameters):")
|
||||
new_X_newgroup = np.array([0.0])
|
||||
|
||||
# Manually compute using population parameters
|
||||
mu_alpha_samples = idata.posterior['mu_alpha'].values.flatten()
|
||||
mu_beta_samples = idata.posterior['mu_beta'].values.flatten()
|
||||
sigma_samples = idata.posterior['sigma'].values.flatten()
|
||||
|
||||
# Predicted mean for new group
|
||||
y_pred_newgroup = mu_alpha_samples + mu_beta_samples * new_X_newgroup[0]
|
||||
y_pred_mean_newgroup = y_pred_newgroup.mean()
|
||||
y_pred_hdi_newgroup = az.hdi(y_pred_newgroup, hdi_prob=0.95)
|
||||
|
||||
print(f"X = {new_X_newgroup[0]:.2f}")
|
||||
print(f"Predicted mean: {y_pred_mean_newgroup:.3f}")
|
||||
print(f"95% HDI: [{y_pred_hdi_newgroup[0]:.3f}, {y_pred_hdi_newgroup[1]:.3f}]")
|
||||
|
||||
# =============================================================================
|
||||
# 9. SAVE RESULTS
|
||||
# =============================================================================
|
||||
|
||||
idata.to_netcdf('hierarchical_model_results.nc')
|
||||
print("\nResults saved to 'hierarchical_model_results.nc'")
|
||||
|
||||
summary.to_csv('hierarchical_model_summary.csv')
|
||||
print("Summary saved to 'hierarchical_model_summary.csv'")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("ANALYSIS COMPLETE")
|
||||
print("="*60)
|
||||
241
skills/pymc/assets/linear_regression_template.py
Normal file
241
skills/pymc/assets/linear_regression_template.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
PyMC Linear Regression Template
|
||||
|
||||
This template provides a complete workflow for Bayesian linear regression,
|
||||
including data preparation, model building, diagnostics, and predictions.
|
||||
|
||||
Customize the sections marked with # TODO
|
||||
"""
|
||||
|
||||
import pymc as pm
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# =============================================================================
|
||||
# 1. DATA PREPARATION
|
||||
# =============================================================================
|
||||
|
||||
# TODO: Load your data
|
||||
# Example:
|
||||
# df = pd.read_csv('data.csv')
|
||||
# X = df[['predictor1', 'predictor2', 'predictor3']].values
|
||||
# y = df['outcome'].values
|
||||
|
||||
# For demonstration:
|
||||
np.random.seed(42)
|
||||
n_samples = 100
|
||||
n_predictors = 3
|
||||
|
||||
X = np.random.randn(n_samples, n_predictors)
|
||||
true_beta = np.array([1.5, -0.8, 2.1])
|
||||
true_alpha = 0.5
|
||||
y = true_alpha + X @ true_beta + np.random.randn(n_samples) * 0.5
|
||||
|
||||
# Standardize predictors for better sampling
|
||||
X_mean = X.mean(axis=0)
|
||||
X_std = X.std(axis=0)
|
||||
X_scaled = (X - X_mean) / X_std
|
||||
|
||||
# =============================================================================
|
||||
# 2. BUILD MODEL
|
||||
# =============================================================================
|
||||
|
||||
# TODO: Customize predictor names
|
||||
predictor_names = ['predictor1', 'predictor2', 'predictor3']
|
||||
|
||||
coords = {
|
||||
'predictors': predictor_names,
|
||||
'obs_id': np.arange(len(y))
|
||||
}
|
||||
|
||||
with pm.Model(coords=coords) as linear_model:
|
||||
# Priors
|
||||
# TODO: Adjust prior parameters based on your domain knowledge
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=1)
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# Linear predictor
|
||||
mu = alpha + pm.math.dot(X_scaled, beta)
|
||||
|
||||
# Likelihood
|
||||
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
|
||||
|
||||
# =============================================================================
|
||||
# 3. PRIOR PREDICTIVE CHECK
|
||||
# =============================================================================
|
||||
|
||||
print("Running prior predictive check...")
|
||||
with linear_model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
|
||||
|
||||
# Visualize prior predictions
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
|
||||
ax.set_title('Prior Predictive Check')
|
||||
plt.tight_layout()
|
||||
plt.savefig('prior_predictive_check.png', dpi=300, bbox_inches='tight')
|
||||
print("Prior predictive check saved to 'prior_predictive_check.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 4. FIT MODEL
|
||||
# =============================================================================
|
||||
|
||||
print("\nFitting model...")
|
||||
with linear_model:
|
||||
# Optional: Quick ADVI exploration
|
||||
# approx = pm.fit(n=20000, random_seed=42)
|
||||
|
||||
# MCMC sampling
|
||||
idata = pm.sample(
|
||||
draws=2000,
|
||||
tune=1000,
|
||||
chains=4,
|
||||
target_accept=0.9,
|
||||
random_seed=42,
|
||||
idata_kwargs={'log_likelihood': True}
|
||||
)
|
||||
|
||||
print("Sampling complete!")
|
||||
|
||||
# =============================================================================
|
||||
# 5. CHECK DIAGNOSTICS
|
||||
# =============================================================================
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("DIAGNOSTICS")
|
||||
print("="*60)
|
||||
|
||||
# Summary statistics
|
||||
summary = az.summary(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
print("\nParameter Summary:")
|
||||
print(summary)
|
||||
|
||||
# Check convergence
|
||||
bad_rhat = summary[summary['r_hat'] > 1.01]
|
||||
if len(bad_rhat) > 0:
|
||||
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
|
||||
print(bad_rhat[['r_hat']])
|
||||
else:
|
||||
print("\n✓ All R-hat values < 1.01 (good convergence)")
|
||||
|
||||
# Check effective sample size
|
||||
low_ess = summary[summary['ess_bulk'] < 400]
|
||||
if len(low_ess) > 0:
|
||||
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
|
||||
print(low_ess[['ess_bulk', 'ess_tail']])
|
||||
else:
|
||||
print("\n✓ All ESS values > 400 (sufficient samples)")
|
||||
|
||||
# Check divergences
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
if divergences > 0:
|
||||
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
|
||||
print(" Consider increasing target_accept or reparameterizing")
|
||||
else:
|
||||
print("\n✓ No divergences")
|
||||
|
||||
# Trace plots
|
||||
fig, axes = plt.subplots(len(['alpha', 'beta', 'sigma']), 2, figsize=(12, 8))
|
||||
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'], axes=axes)
|
||||
plt.tight_layout()
|
||||
plt.savefig('trace_plots.png', dpi=300, bbox_inches='tight')
|
||||
print("\nTrace plots saved to 'trace_plots.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 6. POSTERIOR PREDICTIVE CHECK
|
||||
# =============================================================================
|
||||
|
||||
print("\nRunning posterior predictive check...")
|
||||
with linear_model:
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
|
||||
|
||||
# Visualize fit
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
|
||||
ax.set_title('Posterior Predictive Check')
|
||||
plt.tight_layout()
|
||||
plt.savefig('posterior_predictive_check.png', dpi=300, bbox_inches='tight')
|
||||
print("Posterior predictive check saved to 'posterior_predictive_check.png'")
|
||||
|
||||
# =============================================================================
|
||||
# 7. ANALYZE RESULTS
|
||||
# =============================================================================
|
||||
|
||||
# Posterior distributions
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
|
||||
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'], ax=axes)
|
||||
plt.tight_layout()
|
||||
plt.savefig('posterior_distributions.png', dpi=300, bbox_inches='tight')
|
||||
print("Posterior distributions saved to 'posterior_distributions.png'")
|
||||
|
||||
# Forest plot for coefficients
|
||||
fig, ax = plt.subplots(figsize=(8, 6))
|
||||
az.plot_forest(idata, var_names=['beta'], combined=True, ax=ax)
|
||||
ax.set_title('Coefficient Estimates (95% HDI)')
|
||||
ax.set_yticklabels(predictor_names)
|
||||
plt.tight_layout()
|
||||
plt.savefig('coefficient_forest_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("Forest plot saved to 'coefficient_forest_plot.png'")
|
||||
|
||||
# Print coefficient estimates
|
||||
print("\n" + "="*60)
|
||||
print("COEFFICIENT ESTIMATES")
|
||||
print("="*60)
|
||||
beta_samples = idata.posterior['beta']
|
||||
for i, name in enumerate(predictor_names):
|
||||
mean = beta_samples.sel(predictors=name).mean().item()
|
||||
hdi = az.hdi(beta_samples.sel(predictors=name), hdi_prob=0.95)
|
||||
print(f"{name:20s}: {mean:7.3f} [95% HDI: {hdi.values[0]:7.3f}, {hdi.values[1]:7.3f}]")
|
||||
|
||||
# =============================================================================
|
||||
# 8. PREDICTIONS FOR NEW DATA
|
||||
# =============================================================================
|
||||
|
||||
# TODO: Provide new data for predictions
|
||||
# X_new = np.array([[...], [...], ...]) # New predictor values
|
||||
|
||||
# For demonstration, use some test data
|
||||
X_new = np.random.randn(10, n_predictors)
|
||||
X_new_scaled = (X_new - X_mean) / X_std
|
||||
|
||||
# Update model data and predict
|
||||
with linear_model:
|
||||
pm.set_data({'X_scaled': X_new_scaled, 'obs_id': np.arange(len(X_new))})
|
||||
|
||||
post_pred = pm.sample_posterior_predictive(
|
||||
idata.posterior,
|
||||
var_names=['y_obs'],
|
||||
random_seed=42
|
||||
)
|
||||
|
||||
# Extract predictions
|
||||
y_pred_samples = post_pred.posterior_predictive['y_obs']
|
||||
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
|
||||
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("PREDICTIONS FOR NEW DATA")
|
||||
print("="*60)
|
||||
print(f"{'Index':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
|
||||
print("-"*60)
|
||||
for i in range(len(X_new)):
|
||||
print(f"{i:<10} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
|
||||
|
||||
# =============================================================================
|
||||
# 9. SAVE RESULTS
|
||||
# =============================================================================
|
||||
|
||||
# Save InferenceData
|
||||
idata.to_netcdf('linear_regression_results.nc')
|
||||
print("\nResults saved to 'linear_regression_results.nc'")
|
||||
|
||||
# Save summary to CSV
|
||||
summary.to_csv('model_summary.csv')
|
||||
print("Summary saved to 'model_summary.csv'")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("ANALYSIS COMPLETE")
|
||||
print("="*60)
|
||||
320
skills/pymc/references/distributions.md
Normal file
320
skills/pymc/references/distributions.md
Normal file
@@ -0,0 +1,320 @@
|
||||
# PyMC Distributions Reference
|
||||
|
||||
This reference provides a comprehensive catalog of probability distributions available in PyMC, organized by category. Use this to select appropriate distributions for priors and likelihoods when building Bayesian models.
|
||||
|
||||
## Continuous Distributions
|
||||
|
||||
Continuous distributions define probability densities over real-valued domains.
|
||||
|
||||
### Common Continuous Distributions
|
||||
|
||||
**`pm.Normal(name, mu, sigma)`**
|
||||
- Normal (Gaussian) distribution
|
||||
- Parameters: `mu` (mean), `sigma` (standard deviation)
|
||||
- Support: (-∞, ∞)
|
||||
- Common uses: Default prior for unbounded parameters, likelihood for continuous data with additive noise
|
||||
|
||||
**`pm.HalfNormal(name, sigma)`**
|
||||
- Half-normal distribution (positive half of normal)
|
||||
- Parameters: `sigma` (standard deviation)
|
||||
- Support: [0, ∞)
|
||||
- Common uses: Prior for scale/standard deviation parameters
|
||||
|
||||
**`pm.Uniform(name, lower, upper)`**
|
||||
- Uniform distribution
|
||||
- Parameters: `lower`, `upper` (bounds)
|
||||
- Support: [lower, upper]
|
||||
- Common uses: Weakly informative prior when parameter must be bounded
|
||||
|
||||
**`pm.Beta(name, alpha, beta)`**
|
||||
- Beta distribution
|
||||
- Parameters: `alpha`, `beta` (shape parameters)
|
||||
- Support: [0, 1]
|
||||
- Common uses: Prior for probabilities and proportions
|
||||
|
||||
**`pm.Gamma(name, alpha, beta)`**
|
||||
- Gamma distribution
|
||||
- Parameters: `alpha` (shape), `beta` (rate)
|
||||
- Support: (0, ∞)
|
||||
- Common uses: Prior for positive parameters, rate parameters
|
||||
|
||||
**`pm.Exponential(name, lam)`**
|
||||
- Exponential distribution
|
||||
- Parameters: `lam` (rate parameter)
|
||||
- Support: [0, ∞)
|
||||
- Common uses: Prior for scale parameters, waiting times
|
||||
|
||||
**`pm.LogNormal(name, mu, sigma)`**
|
||||
- Log-normal distribution
|
||||
- Parameters: `mu`, `sigma` (parameters of underlying normal)
|
||||
- Support: (0, ∞)
|
||||
- Common uses: Prior for positive parameters with multiplicative effects
|
||||
|
||||
**`pm.StudentT(name, nu, mu, sigma)`**
|
||||
- Student's t-distribution
|
||||
- Parameters: `nu` (degrees of freedom), `mu` (location), `sigma` (scale)
|
||||
- Support: (-∞, ∞)
|
||||
- Common uses: Robust alternative to normal for outlier-resistant models
|
||||
|
||||
**`pm.Cauchy(name, alpha, beta)`**
|
||||
- Cauchy distribution
|
||||
- Parameters: `alpha` (location), `beta` (scale)
|
||||
- Support: (-∞, ∞)
|
||||
- Common uses: Heavy-tailed alternative to normal
|
||||
|
||||
### Specialized Continuous Distributions
|
||||
|
||||
**`pm.Laplace(name, mu, b)`** - Laplace (double exponential) distribution
|
||||
|
||||
**`pm.AsymmetricLaplace(name, kappa, mu, b)`** - Asymmetric Laplace distribution
|
||||
|
||||
**`pm.InverseGamma(name, alpha, beta)`** - Inverse gamma distribution
|
||||
|
||||
**`pm.Weibull(name, alpha, beta)`** - Weibull distribution for reliability analysis
|
||||
|
||||
**`pm.Logistic(name, mu, s)`** - Logistic distribution
|
||||
|
||||
**`pm.LogitNormal(name, mu, sigma)`** - Logit-normal distribution for (0,1) support
|
||||
|
||||
**`pm.Pareto(name, alpha, m)`** - Pareto distribution for power-law phenomena
|
||||
|
||||
**`pm.ChiSquared(name, nu)`** - Chi-squared distribution
|
||||
|
||||
**`pm.ExGaussian(name, mu, sigma, nu)`** - Exponentially modified Gaussian
|
||||
|
||||
**`pm.VonMises(name, mu, kappa)`** - Von Mises (circular normal) distribution
|
||||
|
||||
**`pm.SkewNormal(name, mu, sigma, alpha)`** - Skew-normal distribution
|
||||
|
||||
**`pm.Triangular(name, lower, c, upper)`** - Triangular distribution
|
||||
|
||||
**`pm.Gumbel(name, mu, beta)`** - Gumbel distribution for extreme values
|
||||
|
||||
**`pm.Rice(name, nu, sigma)`** - Rice (Rician) distribution
|
||||
|
||||
**`pm.Moyal(name, mu, sigma)`** - Moyal distribution
|
||||
|
||||
**`pm.Kumaraswamy(name, a, b)`** - Kumaraswamy distribution (Beta alternative)
|
||||
|
||||
**`pm.Interpolated(name, x_points, pdf_points)`** - Custom distribution from interpolation
|
||||
|
||||
## Discrete Distributions
|
||||
|
||||
Discrete distributions define probabilities over integer-valued domains.
|
||||
|
||||
### Common Discrete Distributions
|
||||
|
||||
**`pm.Bernoulli(name, p)`**
|
||||
- Bernoulli distribution (binary outcome)
|
||||
- Parameters: `p` (success probability)
|
||||
- Support: {0, 1}
|
||||
- Common uses: Binary classification, coin flips
|
||||
|
||||
**`pm.Binomial(name, n, p)`**
|
||||
- Binomial distribution
|
||||
- Parameters: `n` (number of trials), `p` (success probability)
|
||||
- Support: {0, 1, ..., n}
|
||||
- Common uses: Number of successes in fixed trials
|
||||
|
||||
**`pm.Poisson(name, mu)`**
|
||||
- Poisson distribution
|
||||
- Parameters: `mu` (rate parameter)
|
||||
- Support: {0, 1, 2, ...}
|
||||
- Common uses: Count data, rates, occurrences
|
||||
|
||||
**`pm.Categorical(name, p)`**
|
||||
- Categorical distribution
|
||||
- Parameters: `p` (probability vector)
|
||||
- Support: {0, 1, ..., K-1}
|
||||
- Common uses: Multi-class classification
|
||||
|
||||
**`pm.DiscreteUniform(name, lower, upper)`**
|
||||
- Discrete uniform distribution
|
||||
- Parameters: `lower`, `upper` (bounds)
|
||||
- Support: {lower, ..., upper}
|
||||
- Common uses: Uniform prior over finite integers
|
||||
|
||||
**`pm.NegativeBinomial(name, mu, alpha)`**
|
||||
- Negative binomial distribution
|
||||
- Parameters: `mu` (mean), `alpha` (dispersion)
|
||||
- Support: {0, 1, 2, ...}
|
||||
- Common uses: Overdispersed count data
|
||||
|
||||
**`pm.Geometric(name, p)`**
|
||||
- Geometric distribution
|
||||
- Parameters: `p` (success probability)
|
||||
- Support: {0, 1, 2, ...}
|
||||
- Common uses: Number of failures before first success
|
||||
|
||||
### Specialized Discrete Distributions
|
||||
|
||||
**`pm.BetaBinomial(name, alpha, beta, n)`** - Beta-binomial (overdispersed binomial)
|
||||
|
||||
**`pm.HyperGeometric(name, N, k, n)`** - Hypergeometric distribution
|
||||
|
||||
**`pm.DiscreteWeibull(name, q, beta)`** - Discrete Weibull distribution
|
||||
|
||||
**`pm.OrderedLogistic(name, eta, cutpoints)`** - Ordered logistic for ordinal data
|
||||
|
||||
**`pm.OrderedProbit(name, eta, cutpoints)`** - Ordered probit for ordinal data
|
||||
|
||||
## Multivariate Distributions
|
||||
|
||||
Multivariate distributions define joint probability distributions over vector-valued random variables.
|
||||
|
||||
### Common Multivariate Distributions
|
||||
|
||||
**`pm.MvNormal(name, mu, cov)`**
|
||||
- Multivariate normal distribution
|
||||
- Parameters: `mu` (mean vector), `cov` (covariance matrix)
|
||||
- Common uses: Correlated continuous variables, Gaussian processes
|
||||
|
||||
**`pm.Dirichlet(name, a)`**
|
||||
- Dirichlet distribution
|
||||
- Parameters: `a` (concentration parameters)
|
||||
- Support: Simplex (sums to 1)
|
||||
- Common uses: Prior for probability vectors, topic modeling
|
||||
|
||||
**`pm.Multinomial(name, n, p)`**
|
||||
- Multinomial distribution
|
||||
- Parameters: `n` (number of trials), `p` (probability vector)
|
||||
- Common uses: Count data across multiple categories
|
||||
|
||||
**`pm.MvStudentT(name, nu, mu, cov)`**
|
||||
- Multivariate Student's t-distribution
|
||||
- Parameters: `nu` (degrees of freedom), `mu` (location), `cov` (scale matrix)
|
||||
- Common uses: Robust multivariate modeling
|
||||
|
||||
### Specialized Multivariate Distributions
|
||||
|
||||
**`pm.LKJCorr(name, n, eta)`** - LKJ correlation matrix prior (for correlation matrices)
|
||||
|
||||
**`pm.LKJCholeskyCov(name, n, eta, sd_dist)`** - LKJ prior with Cholesky decomposition
|
||||
|
||||
**`pm.Wishart(name, nu, V)`** - Wishart distribution (for covariance matrices)
|
||||
|
||||
**`pm.InverseWishart(name, nu, V)`** - Inverse Wishart distribution
|
||||
|
||||
**`pm.MatrixNormal(name, mu, rowcov, colcov)`** - Matrix normal distribution
|
||||
|
||||
**`pm.KroneckerNormal(name, mu, covs, sigma)`** - Kronecker-structured normal
|
||||
|
||||
**`pm.CAR(name, mu, W, alpha, tau)`** - Conditional autoregressive (spatial)
|
||||
|
||||
**`pm.ICAR(name, W, sigma)`** - Intrinsic conditional autoregressive (spatial)
|
||||
|
||||
## Mixture Distributions
|
||||
|
||||
Mixture distributions combine multiple component distributions.
|
||||
|
||||
**`pm.Mixture(name, w, comp_dists)`**
|
||||
- General mixture distribution
|
||||
- Parameters: `w` (weights), `comp_dists` (component distributions)
|
||||
- Common uses: Clustering, multi-modal data
|
||||
|
||||
**`pm.NormalMixture(name, w, mu, sigma)`**
|
||||
- Mixture of normal distributions
|
||||
- Common uses: Mixture of Gaussians clustering
|
||||
|
||||
### Zero-Inflated and Hurdle Models
|
||||
|
||||
**`pm.ZeroInflatedPoisson(name, psi, mu)`** - Excess zeros in count data
|
||||
|
||||
**`pm.ZeroInflatedBinomial(name, psi, n, p)`** - Zero-inflated binomial
|
||||
|
||||
**`pm.ZeroInflatedNegativeBinomial(name, psi, mu, alpha)`** - Zero-inflated negative binomial
|
||||
|
||||
**`pm.HurdlePoisson(name, psi, mu)`** - Hurdle Poisson (two-part model)
|
||||
|
||||
**`pm.HurdleGamma(name, psi, alpha, beta)`** - Hurdle gamma
|
||||
|
||||
**`pm.HurdleLogNormal(name, psi, mu, sigma)`** - Hurdle log-normal
|
||||
|
||||
## Time Series Distributions
|
||||
|
||||
Distributions designed for temporal data and sequential modeling.
|
||||
|
||||
**`pm.AR(name, rho, sigma, init_dist)`**
|
||||
- Autoregressive process
|
||||
- Parameters: `rho` (AR coefficients), `sigma` (innovation std), `init_dist` (initial distribution)
|
||||
- Common uses: Time series modeling, sequential data
|
||||
|
||||
**`pm.GaussianRandomWalk(name, mu, sigma, init_dist)`**
|
||||
- Gaussian random walk
|
||||
- Parameters: `mu` (drift), `sigma` (step size), `init_dist` (initial value)
|
||||
- Common uses: Cumulative processes, random walk priors
|
||||
|
||||
**`pm.MvGaussianRandomWalk(name, mu, cov, init_dist)`**
|
||||
- Multivariate Gaussian random walk
|
||||
|
||||
**`pm.GARCH11(name, omega, alpha_1, beta_1)`**
|
||||
- GARCH(1,1) volatility model
|
||||
- Common uses: Financial time series, volatility modeling
|
||||
|
||||
**`pm.EulerMaruyama(name, dt, sde_fn, sde_pars, init_dist)`**
|
||||
- Stochastic differential equation via Euler-Maruyama discretization
|
||||
- Common uses: Continuous-time processes
|
||||
|
||||
## Special Distributions
|
||||
|
||||
**`pm.Deterministic(name, var)`**
|
||||
- Deterministic transformation (not a random variable)
|
||||
- Use for computed quantities derived from other variables
|
||||
|
||||
**`pm.Potential(name, logp)`**
|
||||
- Add arbitrary log-probability contribution
|
||||
- Use for custom likelihood components or constraints
|
||||
|
||||
**`pm.Flat(name)`**
|
||||
- Improper flat prior (constant density)
|
||||
- Use sparingly; can cause sampling issues
|
||||
|
||||
**`pm.HalfFlat(name)`**
|
||||
- Improper flat prior on positive reals
|
||||
- Use sparingly; can cause sampling issues
|
||||
|
||||
## Distribution Modifiers
|
||||
|
||||
**`pm.Truncated(name, dist, lower, upper)`**
|
||||
- Truncate any distribution to specified bounds
|
||||
|
||||
**`pm.Censored(name, dist, lower, upper)`**
|
||||
- Handle censored observations (observed bounds, not exact values)
|
||||
|
||||
**`pm.CustomDist(name, ..., logp, random)`**
|
||||
- Define custom distributions with user-specified log-probability and random sampling functions
|
||||
|
||||
**`pm.Simulator(name, fn, params, ...)`**
|
||||
- Custom distributions via simulation (for likelihood-free inference)
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Choosing Priors
|
||||
|
||||
1. **Scale parameters** (σ, τ): Use `HalfNormal`, `HalfCauchy`, `Exponential`, or `Gamma`
|
||||
2. **Probabilities**: Use `Beta` or `Uniform(0, 1)`
|
||||
3. **Unbounded parameters**: Use `Normal` or `StudentT` (for robustness)
|
||||
4. **Positive parameters**: Use `LogNormal`, `Gamma`, or `Exponential`
|
||||
5. **Correlation matrices**: Use `LKJCorr`
|
||||
6. **Count data**: Use `Poisson` or `NegativeBinomial` (for overdispersion)
|
||||
|
||||
### Shape Broadcasting
|
||||
|
||||
PyMC distributions support NumPy-style broadcasting. Use the `shape` parameter to create vectors or arrays of random variables:
|
||||
|
||||
```python
|
||||
# Vector of 5 independent normals
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, shape=5)
|
||||
|
||||
# 3x4 matrix of independent gammas
|
||||
tau = pm.Gamma('tau', alpha=2, beta=1, shape=(3, 4))
|
||||
```
|
||||
|
||||
### Using dims for Named Dimensions
|
||||
|
||||
Instead of shape, use `dims` for more readable models:
|
||||
|
||||
```python
|
||||
with pm.Model(coords={'predictors': ['age', 'income', 'education']}) as model:
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
|
||||
```
|
||||
424
skills/pymc/references/sampling_inference.md
Normal file
424
skills/pymc/references/sampling_inference.md
Normal file
@@ -0,0 +1,424 @@
|
||||
# PyMC Sampling and Inference Methods
|
||||
|
||||
This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.
|
||||
|
||||
## MCMC Sampling Methods
|
||||
|
||||
### Primary Sampling Function
|
||||
|
||||
**`pm.sample(draws=1000, tune=1000, chains=4, **kwargs)`**
|
||||
|
||||
The main interface for MCMC sampling in PyMC.
|
||||
|
||||
**Key Parameters:**
|
||||
- `draws`: Number of samples to draw per chain (default: 1000)
|
||||
- `tune`: Number of tuning/warmup samples (default: 1000, discarded)
|
||||
- `chains`: Number of parallel chains (default: 4)
|
||||
- `cores`: Number of CPU cores to use (default: all available)
|
||||
- `target_accept`: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)
|
||||
- `random_seed`: Random seed for reproducibility
|
||||
- `return_inferencedata`: Return ArviZ InferenceData object (default: True)
|
||||
- `idata_kwargs`: Additional kwargs for InferenceData creation (e.g., `{"log_likelihood": True}` for model comparison)
|
||||
|
||||
**Returns:** InferenceData object containing posterior samples, sampling statistics, and diagnostics
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with pm.Model() as model:
|
||||
# ... define model ...
|
||||
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
|
||||
```
|
||||
|
||||
### Sampling Algorithms
|
||||
|
||||
PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.
|
||||
|
||||
#### NUTS (No-U-Turn Sampler)
|
||||
|
||||
**Default algorithm** for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.
|
||||
|
||||
- Automatically tunes step size and mass matrix
|
||||
- Adaptive: explores posterior geometry during tuning
|
||||
- Best for smooth, continuous posteriors
|
||||
- Can struggle with high correlation or multimodality
|
||||
|
||||
**Manual specification:**
|
||||
```python
|
||||
with model:
|
||||
idata = pm.sample(step=pm.NUTS(target_accept=0.95))
|
||||
```
|
||||
|
||||
**When to adjust:**
|
||||
- Increase `target_accept` (0.9-0.99) if seeing divergences
|
||||
- Use `init='adapt_diag'` for faster initialization (default)
|
||||
- Use `init='jitter+adapt_diag'` for difficult initializations
|
||||
|
||||
#### Metropolis
|
||||
|
||||
General-purpose Metropolis-Hastings sampler.
|
||||
|
||||
- Works for both continuous and discrete variables
|
||||
- Less efficient than NUTS for smooth continuous posteriors
|
||||
- Useful for discrete parameters or non-differentiable models
|
||||
- Requires manual tuning
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
idata = pm.sample(step=pm.Metropolis())
|
||||
```
|
||||
|
||||
#### Slice Sampler
|
||||
|
||||
Slice sampling for univariate distributions.
|
||||
|
||||
- No tuning required
|
||||
- Good for difficult univariate posteriors
|
||||
- Can be slow for high dimensions
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
idata = pm.sample(step=pm.Slice())
|
||||
```
|
||||
|
||||
#### CompoundStep
|
||||
|
||||
Combine different samplers for different parameters.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
# Use NUTS for continuous params, Metropolis for discrete
|
||||
step1 = pm.NUTS([continuous_var1, continuous_var2])
|
||||
step2 = pm.Metropolis([discrete_var])
|
||||
idata = pm.sample(step=[step1, step2])
|
||||
```
|
||||
|
||||
### Sampling Diagnostics
|
||||
|
||||
PyMC automatically computes diagnostics. Check these before trusting results:
|
||||
|
||||
#### Effective Sample Size (ESS)
|
||||
|
||||
Measures independent information in correlated samples.
|
||||
|
||||
- **Rule of thumb**: ESS > 400 per chain (1600 total for 4 chains)
|
||||
- Low ESS indicates high autocorrelation
|
||||
- Access via: `az.ess(idata)`
|
||||
|
||||
#### R-hat (Gelman-Rubin statistic)
|
||||
|
||||
Measures convergence across chains.
|
||||
|
||||
- **Rule of thumb**: R-hat < 1.01 for all parameters
|
||||
- R-hat > 1.01 indicates non-convergence
|
||||
- Access via: `az.rhat(idata)`
|
||||
|
||||
#### Divergences
|
||||
|
||||
Indicate regions where NUTS struggled.
|
||||
|
||||
- **Rule of thumb**: 0 divergences (or very few)
|
||||
- Divergences suggest biased samples
|
||||
- **Fix**: Increase `target_accept`, reparameterize, or use stronger priors
|
||||
- Access via: `idata.sample_stats.diverging.sum()`
|
||||
|
||||
#### Energy Plot
|
||||
|
||||
Visualizes Hamiltonian Monte Carlo energy transitions.
|
||||
|
||||
```python
|
||||
az.plot_energy(idata)
|
||||
```
|
||||
|
||||
Good separation between energy distributions indicates healthy sampling.
|
||||
|
||||
### Handling Sampling Issues
|
||||
|
||||
#### Divergences
|
||||
|
||||
```python
|
||||
# Increase target acceptance rate
|
||||
idata = pm.sample(target_accept=0.95)
|
||||
|
||||
# Or reparameterize using non-centered parameterization
|
||||
# Bad (centered):
|
||||
mu = pm.Normal('mu', 0, 1)
|
||||
sigma = pm.HalfNormal('sigma', 1)
|
||||
x = pm.Normal('x', mu, sigma, observed=data)
|
||||
|
||||
# Good (non-centered):
|
||||
mu = pm.Normal('mu', 0, 1)
|
||||
sigma = pm.HalfNormal('sigma', 1)
|
||||
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)
|
||||
```
|
||||
|
||||
#### Slow Sampling
|
||||
|
||||
```python
|
||||
# Use fewer tuning steps if model is simple
|
||||
idata = pm.sample(tune=500)
|
||||
|
||||
# Increase cores for parallelization
|
||||
idata = pm.sample(cores=8, chains=8)
|
||||
|
||||
# Use variational inference for initialization
|
||||
with model:
|
||||
approx = pm.fit() # Run ADVI
|
||||
idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])
|
||||
```
|
||||
|
||||
#### High Autocorrelation
|
||||
|
||||
```python
|
||||
# Increase draws
|
||||
idata = pm.sample(draws=5000)
|
||||
|
||||
# Reparameterize to reduce correlation
|
||||
# Consider using QR decomposition for regression models
|
||||
```
|
||||
|
||||
## Variational Inference
|
||||
|
||||
Faster approximate inference for large models or quick exploration.
|
||||
|
||||
### ADVI (Automatic Differentiation Variational Inference)
|
||||
|
||||
**`pm.fit(n=10000, method='advi', **kwargs)`**
|
||||
|
||||
Approximates posterior with simpler distribution (typically mean-field Gaussian).
|
||||
|
||||
**Key Parameters:**
|
||||
- `n`: Number of iterations (default: 10000)
|
||||
- `method`: VI algorithm ('advi', 'fullrank_advi', 'svgd')
|
||||
- `random_seed`: Random seed
|
||||
|
||||
**Returns:** Approximation object for sampling and analysis
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
approx = pm.fit(n=50000)
|
||||
# Draw samples from approximation
|
||||
idata = approx.sample(1000)
|
||||
# Or sample for MCMC initialization
|
||||
start = approx.sample(return_inferencedata=False)[0]
|
||||
```
|
||||
|
||||
**Trade-offs:**
|
||||
- **Pros**: Much faster than MCMC, scales to large data
|
||||
- **Cons**: Approximate, may miss posterior structure, underestimates uncertainty
|
||||
|
||||
### Full-Rank ADVI
|
||||
|
||||
Captures correlations between parameters.
|
||||
|
||||
```python
|
||||
with model:
|
||||
approx = pm.fit(method='fullrank_advi')
|
||||
```
|
||||
|
||||
More accurate than mean-field but slower.
|
||||
|
||||
### SVGD (Stein Variational Gradient Descent)
|
||||
|
||||
Non-parametric variational inference.
|
||||
|
||||
```python
|
||||
with model:
|
||||
approx = pm.fit(method='svgd', n=20000)
|
||||
```
|
||||
|
||||
Better captures multimodality but more computationally expensive.
|
||||
|
||||
## Prior and Posterior Predictive Sampling
|
||||
|
||||
### Prior Predictive Sampling
|
||||
|
||||
Sample from the prior distribution (before seeing data).
|
||||
|
||||
**`pm.sample_prior_predictive(samples=500, **kwargs)`**
|
||||
|
||||
**Purpose:**
|
||||
- Validate priors are reasonable
|
||||
- Check implied predictions before fitting
|
||||
- Ensure model generates plausible data
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=1000)
|
||||
|
||||
# Visualize prior predictions
|
||||
az.plot_ppc(prior_pred, group='prior')
|
||||
```
|
||||
|
||||
### Posterior Predictive Sampling
|
||||
|
||||
Sample from posterior predictive distribution (after fitting).
|
||||
|
||||
**`pm.sample_posterior_predictive(trace, **kwargs)`**
|
||||
|
||||
**Purpose:**
|
||||
- Model validation via posterior predictive checks
|
||||
- Generate predictions for new data
|
||||
- Assess goodness-of-fit
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
# After sampling
|
||||
idata = pm.sample()
|
||||
|
||||
# Add posterior predictive samples
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
|
||||
|
||||
# Posterior predictive check
|
||||
az.plot_ppc(idata)
|
||||
```
|
||||
|
||||
### Predictions for New Data
|
||||
|
||||
Update data and sample predictive distribution:
|
||||
|
||||
```python
|
||||
with model:
|
||||
# Original model fit
|
||||
idata = pm.sample()
|
||||
|
||||
# Update with new predictor values
|
||||
pm.set_data({'X': X_new})
|
||||
|
||||
# Sample predictions
|
||||
post_pred_new = pm.sample_posterior_predictive(
|
||||
idata.posterior,
|
||||
var_names=['y_pred']
|
||||
)
|
||||
```
|
||||
|
||||
## Maximum A Posteriori (MAP) Estimation
|
||||
|
||||
Find posterior mode (point estimate).
|
||||
|
||||
**`pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)`**
|
||||
|
||||
**When to use:**
|
||||
- Quick point estimates
|
||||
- Initialization for MCMC
|
||||
- When full posterior not needed
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
with model:
|
||||
map_estimate = pm.find_MAP()
|
||||
print(map_estimate)
|
||||
```
|
||||
|
||||
**Limitations:**
|
||||
- Doesn't quantify uncertainty
|
||||
- Can find local optima in multimodal posteriors
|
||||
- Sensitive to prior specification
|
||||
|
||||
## Inference Recommendations
|
||||
|
||||
### Standard Workflow
|
||||
|
||||
1. **Start with ADVI** for quick exploration:
|
||||
```python
|
||||
approx = pm.fit(n=20000)
|
||||
```
|
||||
|
||||
2. **Run MCMC** for full inference:
|
||||
```python
|
||||
idata = pm.sample(draws=2000, tune=1000)
|
||||
```
|
||||
|
||||
3. **Check diagnostics**:
|
||||
```python
|
||||
az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars
|
||||
```
|
||||
|
||||
4. **Sample posterior predictive**:
|
||||
```python
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
|
||||
```
|
||||
|
||||
### Choosing Inference Method
|
||||
|
||||
| Scenario | Recommended Method |
|
||||
|----------|-------------------|
|
||||
| Small-medium models, need full uncertainty | MCMC with NUTS |
|
||||
| Large models, initial exploration | ADVI |
|
||||
| Discrete parameters | Metropolis or marginalize |
|
||||
| Hierarchical models with divergences | Non-centered parameterization + NUTS |
|
||||
| Very large data | Minibatch ADVI |
|
||||
| Quick point estimates | MAP or ADVI |
|
||||
|
||||
### Reparameterization Tricks
|
||||
|
||||
**Non-centered parameterization** for hierarchical models:
|
||||
|
||||
```python
|
||||
# Centered (can cause divergences):
|
||||
mu = pm.Normal('mu', 0, 10)
|
||||
sigma = pm.HalfNormal('sigma', 1)
|
||||
theta = pm.Normal('theta', mu, sigma, shape=n_groups)
|
||||
|
||||
# Non-centered (better sampling):
|
||||
mu = pm.Normal('mu', 0, 10)
|
||||
sigma = pm.HalfNormal('sigma', 1)
|
||||
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
|
||||
theta = pm.Deterministic('theta', mu + sigma * theta_offset)
|
||||
```
|
||||
|
||||
**QR decomposition** for correlated predictors:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
# QR decomposition
|
||||
Q, R = np.linalg.qr(X)
|
||||
|
||||
with pm.Model():
|
||||
# Uncorrelated coefficients
|
||||
beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)
|
||||
|
||||
# Transform back to original scale
|
||||
beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))
|
||||
|
||||
mu = pm.math.dot(Q, beta_tilde)
|
||||
sigma = pm.HalfNormal('sigma', 1)
|
||||
y = pm.Normal('y', mu, sigma, observed=y_obs)
|
||||
```
|
||||
|
||||
## Advanced Sampling
|
||||
|
||||
### Sequential Monte Carlo (SMC)
|
||||
|
||||
For complex posteriors or model evidence estimation:
|
||||
|
||||
```python
|
||||
with model:
|
||||
idata = pm.sample_smc(draws=2000, chains=4)
|
||||
```
|
||||
|
||||
Good for multimodal posteriors or when NUTS struggles.
|
||||
|
||||
### Custom Initialization
|
||||
|
||||
Provide starting values:
|
||||
|
||||
```python
|
||||
start = {'mu': 0, 'sigma': 1}
|
||||
with model:
|
||||
idata = pm.sample(start=start)
|
||||
```
|
||||
|
||||
Or use MAP estimate:
|
||||
|
||||
```python
|
||||
with model:
|
||||
start = pm.find_MAP()
|
||||
idata = pm.sample(start=start)
|
||||
```
|
||||
526
skills/pymc/references/workflows.md
Normal file
526
skills/pymc/references/workflows.md
Normal file
@@ -0,0 +1,526 @@
|
||||
# PyMC Workflows and Common Patterns
|
||||
|
||||
This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC.
|
||||
|
||||
## Standard Bayesian Workflow
|
||||
|
||||
### Complete Workflow Template
|
||||
|
||||
```python
|
||||
import pymc as pm
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. PREPARE DATA
|
||||
# ===============
|
||||
X = ... # Predictor variables
|
||||
y = ... # Observed outcomes
|
||||
|
||||
# Standardize predictors for better sampling
|
||||
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
|
||||
|
||||
# 2. BUILD MODEL
|
||||
# ==============
|
||||
with pm.Model() as model:
|
||||
# Define coordinates for named dimensions
|
||||
coords = {
|
||||
'predictors': ['var1', 'var2', 'var3'],
|
||||
'obs_id': np.arange(len(y))
|
||||
}
|
||||
|
||||
# Priors
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=1)
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# Linear predictor
|
||||
mu = alpha + pm.math.dot(X_scaled, beta)
|
||||
|
||||
# Likelihood
|
||||
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
|
||||
|
||||
# 3. PRIOR PREDICTIVE CHECK
|
||||
# ==========================
|
||||
with model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
|
||||
|
||||
# Visualize prior predictions
|
||||
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100)
|
||||
plt.title('Prior Predictive Check')
|
||||
plt.show()
|
||||
|
||||
# 4. FIT MODEL
|
||||
# ============
|
||||
with model:
|
||||
# Quick VI exploration (optional)
|
||||
approx = pm.fit(n=20000, random_seed=42)
|
||||
|
||||
# Full MCMC inference
|
||||
idata = pm.sample(
|
||||
draws=2000,
|
||||
tune=1000,
|
||||
chains=4,
|
||||
target_accept=0.9,
|
||||
random_seed=42,
|
||||
idata_kwargs={'log_likelihood': True} # For model comparison
|
||||
)
|
||||
|
||||
# 5. CHECK DIAGNOSTICS
|
||||
# ====================
|
||||
# Summary statistics
|
||||
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
|
||||
|
||||
# R-hat and ESS
|
||||
summary = az.summary(idata)
|
||||
if (summary['r_hat'] > 1.01).any():
|
||||
print("WARNING: Some R-hat values > 1.01, chains may not have converged")
|
||||
|
||||
if (summary['ess_bulk'] < 400).any():
|
||||
print("WARNING: Some ESS values < 400, consider more samples")
|
||||
|
||||
# Check divergences
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
print(f"Number of divergences: {divergences}")
|
||||
|
||||
# Trace plots
|
||||
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 6. POSTERIOR PREDICTIVE CHECK
|
||||
# ==============================
|
||||
with model:
|
||||
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
|
||||
|
||||
# Visualize fit
|
||||
az.plot_ppc(idata, num_pp_samples=100)
|
||||
plt.title('Posterior Predictive Check')
|
||||
plt.show()
|
||||
|
||||
# 7. ANALYZE RESULTS
|
||||
# ==================
|
||||
# Posterior distributions
|
||||
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Forest plot for coefficients
|
||||
az.plot_forest(idata, var_names=['beta'], combined=True)
|
||||
plt.title('Coefficient Estimates')
|
||||
plt.show()
|
||||
|
||||
# 8. PREDICTIONS FOR NEW DATA
|
||||
# ============================
|
||||
X_new = ... # New predictor values
|
||||
X_new_scaled = (X_new - X.mean(axis=0)) / X.std(axis=0)
|
||||
|
||||
with model:
|
||||
# Update data
|
||||
pm.set_data({'X': X_new_scaled})
|
||||
|
||||
# Sample predictions
|
||||
post_pred = pm.sample_posterior_predictive(
|
||||
idata.posterior,
|
||||
var_names=['y_obs'],
|
||||
random_seed=42
|
||||
)
|
||||
|
||||
# Prediction intervals
|
||||
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
|
||||
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
|
||||
|
||||
# 9. SAVE RESULTS
|
||||
# ===============
|
||||
idata.to_netcdf('model_results.nc') # Save for later
|
||||
```
|
||||
|
||||
## Model Building Patterns
|
||||
|
||||
### Linear Regression
|
||||
|
||||
```python
|
||||
with pm.Model() as linear_model:
|
||||
# Priors
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# Linear predictor
|
||||
mu = alpha + pm.math.dot(X, beta)
|
||||
|
||||
# Likelihood
|
||||
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
|
||||
```
|
||||
|
||||
### Logistic Regression
|
||||
|
||||
```python
|
||||
with pm.Model() as logistic_model:
|
||||
# Priors
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
|
||||
# Linear predictor
|
||||
logit_p = alpha + pm.math.dot(X, beta)
|
||||
|
||||
# Likelihood
|
||||
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
|
||||
```
|
||||
|
||||
### Hierarchical/Multilevel Model
|
||||
|
||||
```python
|
||||
with pm.Model(coords={'group': group_names, 'obs': np.arange(n_obs)}) as hierarchical_model:
|
||||
# Hyperpriors
|
||||
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
|
||||
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
|
||||
|
||||
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
|
||||
sigma_beta = pm.HalfNormal('sigma_beta', sigma=1)
|
||||
|
||||
# Group-level parameters (non-centered)
|
||||
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='group')
|
||||
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='group')
|
||||
|
||||
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='group')
|
||||
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='group')
|
||||
|
||||
# Observation-level model
|
||||
mu = alpha[group_idx] + beta[group_idx] * X
|
||||
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs, dims='obs')
|
||||
```
|
||||
|
||||
### Poisson Regression (Count Data)
|
||||
|
||||
```python
|
||||
with pm.Model() as poisson_model:
|
||||
# Priors
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=10)
|
||||
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
|
||||
|
||||
# Linear predictor on log scale
|
||||
log_lambda = alpha + pm.math.dot(X, beta)
|
||||
|
||||
# Likelihood
|
||||
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
|
||||
```
|
||||
|
||||
### Time Series (Autoregressive)
|
||||
|
||||
```python
|
||||
with pm.Model() as ar_model:
|
||||
# Innovation standard deviation
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# AR coefficients
|
||||
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
|
||||
|
||||
# Initial distribution
|
||||
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
|
||||
|
||||
# AR process
|
||||
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
|
||||
```
|
||||
|
||||
### Mixture Model
|
||||
|
||||
```python
|
||||
with pm.Model() as mixture_model:
|
||||
# Component weights
|
||||
w = pm.Dirichlet('w', a=np.ones(n_components))
|
||||
|
||||
# Component parameters
|
||||
mu = pm.Normal('mu', mu=0, sigma=10, shape=n_components)
|
||||
sigma = pm.HalfNormal('sigma', sigma=1, shape=n_components)
|
||||
|
||||
# Mixture
|
||||
components = [pm.Normal.dist(mu=mu[i], sigma=sigma[i]) for i in range(n_components)]
|
||||
y = pm.Mixture('y', w=w, comp_dists=components, observed=y_obs)
|
||||
```
|
||||
|
||||
## Data Preparation Best Practices
|
||||
|
||||
### Standardization
|
||||
|
||||
Standardize continuous predictors for better sampling:
|
||||
|
||||
```python
|
||||
# Standardize
|
||||
X_mean = X.mean(axis=0)
|
||||
X_std = X.std(axis=0)
|
||||
X_scaled = (X - X_mean) / X_std
|
||||
|
||||
# Model with scaled data
|
||||
with pm.Model() as model:
|
||||
beta_scaled = pm.Normal('beta_scaled', 0, 1)
|
||||
# ... rest of model ...
|
||||
|
||||
# Transform back to original scale
|
||||
beta_original = beta_scaled / X_std
|
||||
alpha_original = alpha - (beta_scaled * X_mean / X_std).sum()
|
||||
```
|
||||
|
||||
### Handling Missing Data
|
||||
|
||||
Treat missing values as parameters:
|
||||
|
||||
```python
|
||||
# Identify missing values
|
||||
missing_idx = np.isnan(X)
|
||||
X_observed = np.where(missing_idx, 0, X) # Placeholder
|
||||
|
||||
with pm.Model() as model:
|
||||
# Prior for missing values
|
||||
X_missing = pm.Normal('X_missing', mu=0, sigma=1, shape=missing_idx.sum())
|
||||
|
||||
# Combine observed and imputed
|
||||
X_complete = pm.math.switch(missing_idx.flatten(), X_missing, X_observed.flatten())
|
||||
|
||||
# ... rest of model using X_complete ...
|
||||
```
|
||||
|
||||
### Centering and Scaling
|
||||
|
||||
For regression models, center predictors and outcome:
|
||||
|
||||
```python
|
||||
# Center
|
||||
X_centered = X - X.mean(axis=0)
|
||||
y_centered = y - y.mean()
|
||||
|
||||
with pm.Model() as model:
|
||||
# Simpler prior on intercept
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=1) # Intercept near 0 when centered
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, shape=n_predictors)
|
||||
|
||||
mu = alpha + pm.math.dot(X_centered, beta)
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_centered)
|
||||
```
|
||||
|
||||
## Prior Selection Guidelines
|
||||
|
||||
### Weakly Informative Priors
|
||||
|
||||
Use when you have limited prior knowledge:
|
||||
|
||||
```python
|
||||
# For standardized predictors
|
||||
beta = pm.Normal('beta', mu=0, sigma=1)
|
||||
|
||||
# For scale parameters
|
||||
sigma = pm.HalfNormal('sigma', sigma=1)
|
||||
|
||||
# For probabilities
|
||||
p = pm.Beta('p', alpha=2, beta=2) # Slight preference for middle values
|
||||
```
|
||||
|
||||
### Informative Priors
|
||||
|
||||
Use domain knowledge:
|
||||
|
||||
```python
|
||||
# Effect size from literature: Cohen's d ≈ 0.3
|
||||
beta = pm.Normal('beta', mu=0.3, sigma=0.1)
|
||||
|
||||
# Physical constraint: probability between 0.7-0.9
|
||||
p = pm.Beta('p', alpha=8, beta=2) # Check with prior predictive!
|
||||
```
|
||||
|
||||
### Prior Predictive Checks
|
||||
|
||||
Always validate priors:
|
||||
|
||||
```python
|
||||
with model:
|
||||
prior_pred = pm.sample_prior_predictive(samples=1000)
|
||||
|
||||
# Check if predictions are reasonable
|
||||
print(f"Prior predictive range: {prior_pred.prior_predictive['y'].min():.2f} to {prior_pred.prior_predictive['y'].max():.2f}")
|
||||
print(f"Observed range: {y_obs.min():.2f} to {y_obs.max():.2f}")
|
||||
|
||||
# Visualize
|
||||
az.plot_ppc(prior_pred, group='prior')
|
||||
```
|
||||
|
||||
## Model Comparison Workflow
|
||||
|
||||
### Comparing Multiple Models
|
||||
|
||||
```python
|
||||
import arviz as az
|
||||
|
||||
# Fit multiple models
|
||||
models = {}
|
||||
idatas = {}
|
||||
|
||||
# Model 1: Simple linear
|
||||
with pm.Model() as models['linear']:
|
||||
# ... define model ...
|
||||
idatas['linear'] = pm.sample(idata_kwargs={'log_likelihood': True})
|
||||
|
||||
# Model 2: With interaction
|
||||
with pm.Model() as models['interaction']:
|
||||
# ... define model ...
|
||||
idatas['interaction'] = pm.sample(idata_kwargs={'log_likelihood': True})
|
||||
|
||||
# Model 3: Hierarchical
|
||||
with pm.Model() as models['hierarchical']:
|
||||
# ... define model ...
|
||||
idatas['hierarchical'] = pm.sample(idata_kwargs={'log_likelihood': True})
|
||||
|
||||
# Compare using LOO
|
||||
comparison = az.compare(idatas, ic='loo')
|
||||
print(comparison)
|
||||
|
||||
# Visualize comparison
|
||||
az.plot_compare(comparison)
|
||||
plt.show()
|
||||
|
||||
# Check LOO reliability
|
||||
for name, idata in idatas.items():
|
||||
loo = az.loo(idata, pointwise=True)
|
||||
high_pareto_k = (loo.pareto_k > 0.7).sum().item()
|
||||
if high_pareto_k > 0:
|
||||
print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k")
|
||||
```
|
||||
|
||||
### Model Weights
|
||||
|
||||
```python
|
||||
# Get model weights (pseudo-BMA)
|
||||
weights = comparison['weight'].values
|
||||
|
||||
print("Model probabilities:")
|
||||
for name, weight in zip(comparison.index, weights):
|
||||
print(f" {name}: {weight:.2%}")
|
||||
|
||||
# Model averaging (weighted predictions)
|
||||
def weighted_predictions(idatas, weights):
|
||||
preds = []
|
||||
for (name, idata), weight in zip(idatas.items(), weights):
|
||||
pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
|
||||
preds.append(weight * pred)
|
||||
return sum(preds)
|
||||
|
||||
averaged_pred = weighted_predictions(idatas, weights)
|
||||
```
|
||||
|
||||
## Diagnostics and Troubleshooting
|
||||
|
||||
### Diagnosing Sampling Problems
|
||||
|
||||
```python
|
||||
def diagnose_sampling(idata, var_names=None):
|
||||
"""Comprehensive sampling diagnostics"""
|
||||
|
||||
# Check convergence
|
||||
summary = az.summary(idata, var_names=var_names)
|
||||
|
||||
print("=== Convergence Diagnostics ===")
|
||||
bad_rhat = summary[summary['r_hat'] > 1.01]
|
||||
if len(bad_rhat) > 0:
|
||||
print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01")
|
||||
print(bad_rhat[['r_hat']])
|
||||
else:
|
||||
print("✓ All R-hat values < 1.01")
|
||||
|
||||
# Check effective sample size
|
||||
print("\n=== Effective Sample Size ===")
|
||||
low_ess = summary[summary['ess_bulk'] < 400]
|
||||
if len(low_ess) > 0:
|
||||
print(f"⚠️ {len(low_ess)} variables with ESS < 400")
|
||||
print(low_ess[['ess_bulk', 'ess_tail']])
|
||||
else:
|
||||
print("✓ All ESS values > 400")
|
||||
|
||||
# Check divergences
|
||||
print("\n=== Divergences ===")
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
if divergences > 0:
|
||||
print(f"⚠️ {divergences} divergent transitions")
|
||||
print(" Consider: increase target_accept, reparameterize, or stronger priors")
|
||||
else:
|
||||
print("✓ No divergences")
|
||||
|
||||
# Check tree depth
|
||||
print("\n=== NUTS Statistics ===")
|
||||
max_treedepth = idata.sample_stats.tree_depth.max().item()
|
||||
hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item()
|
||||
if hits_max > 0:
|
||||
print(f"⚠️ Hit max treedepth {hits_max} times")
|
||||
print(" Consider: reparameterize or increase max_treedepth")
|
||||
else:
|
||||
print(f"✓ No max treedepth issues (max: {max_treedepth})")
|
||||
|
||||
return summary
|
||||
|
||||
# Usage
|
||||
diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma'])
|
||||
```
|
||||
|
||||
### Common Fixes
|
||||
|
||||
| Problem | Solution |
|
||||
|---------|----------|
|
||||
| Divergences | Increase `target_accept=0.95`, use non-centered parameterization |
|
||||
| Low ESS | Sample more draws, reparameterize to reduce correlation |
|
||||
| High R-hat | Run longer chains, check for multimodality, improve initialization |
|
||||
| Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity |
|
||||
| Biased posterior | Check prior predictive, ensure likelihood is correct |
|
||||
|
||||
## Using Named Dimensions (dims)
|
||||
|
||||
### Benefits of dims
|
||||
|
||||
- More readable code
|
||||
- Easier subsetting and analysis
|
||||
- Better xarray integration
|
||||
|
||||
```python
|
||||
# Define coordinates
|
||||
coords = {
|
||||
'predictors': ['age', 'income', 'education'],
|
||||
'groups': ['A', 'B', 'C'],
|
||||
'time': pd.date_range('2020-01-01', periods=100, freq='D')
|
||||
}
|
||||
|
||||
with pm.Model(coords=coords) as model:
|
||||
# Use dims instead of shape
|
||||
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
|
||||
alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups')
|
||||
y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data)
|
||||
|
||||
# After sampling, dimensions are preserved
|
||||
idata = pm.sample()
|
||||
|
||||
# Easy subsetting
|
||||
beta_age = idata.posterior['beta'].sel(predictors='age')
|
||||
group_A = idata.posterior['alpha'].sel(groups='A')
|
||||
```
|
||||
|
||||
## Saving and Loading Results
|
||||
|
||||
```python
|
||||
# Save InferenceData
|
||||
idata.to_netcdf('results.nc')
|
||||
|
||||
# Load InferenceData
|
||||
loaded_idata = az.from_netcdf('results.nc')
|
||||
|
||||
# Save model for later predictions
|
||||
import pickle
|
||||
|
||||
with open('model.pkl', 'wb') as f:
|
||||
pickle.dump({'model': model, 'idata': idata}, f)
|
||||
|
||||
# Load model
|
||||
with open('model.pkl', 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
model = saved['model']
|
||||
idata = saved['idata']
|
||||
```
|
||||
387
skills/pymc/scripts/model_comparison.py
Normal file
387
skills/pymc/scripts/model_comparison.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""
|
||||
PyMC Model Comparison Script
|
||||
|
||||
Utilities for comparing multiple Bayesian models using information criteria
|
||||
and cross-validation metrics.
|
||||
|
||||
Usage:
|
||||
from scripts.model_comparison import compare_models, plot_model_comparison
|
||||
|
||||
# Compare multiple models
|
||||
comparison = compare_models(
|
||||
{'model1': idata1, 'model2': idata2, 'model3': idata3},
|
||||
ic='loo'
|
||||
)
|
||||
|
||||
# Visualize comparison
|
||||
plot_model_comparison(comparison, output_path='model_comparison.png')
|
||||
"""
|
||||
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def compare_models(models_dict: Dict[str, az.InferenceData],
|
||||
ic='loo',
|
||||
scale='deviance',
|
||||
verbose=True):
|
||||
"""
|
||||
Compare multiple models using information criteria.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_dict : dict
|
||||
Dictionary mapping model names to InferenceData objects.
|
||||
All models must have log_likelihood computed.
|
||||
ic : str
|
||||
Information criterion to use: 'loo' (default) or 'waic'
|
||||
scale : str
|
||||
Scale for IC: 'deviance' (default), 'log', or 'negative_log'
|
||||
verbose : bool
|
||||
Print detailed comparison results (default: True)
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
Comparison DataFrame with model rankings and statistics
|
||||
|
||||
Notes
|
||||
-----
|
||||
Models must be fit with idata_kwargs={'log_likelihood': True} or
|
||||
log-likelihood computed afterwards with pm.compute_log_likelihood().
|
||||
"""
|
||||
if verbose:
|
||||
print("="*70)
|
||||
print(f" " * 25 + f"MODEL COMPARISON ({ic.upper()})")
|
||||
print("="*70)
|
||||
|
||||
# Perform comparison
|
||||
comparison = az.compare(models_dict, ic=ic, scale=scale)
|
||||
|
||||
if verbose:
|
||||
print("\nModel Rankings:")
|
||||
print("-"*70)
|
||||
print(comparison.to_string())
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("INTERPRETATION GUIDE")
|
||||
print("="*70)
|
||||
print(f"• rank: Model ranking (0 = best)")
|
||||
print(f"• {ic}: {ic.upper()} estimate (lower is better)")
|
||||
print(f"• p_{ic}: Effective number of parameters")
|
||||
print(f"• d{ic}: Difference from best model")
|
||||
print(f"• weight: Model probability (pseudo-BMA)")
|
||||
print(f"• se: Standard error of {ic.upper()}")
|
||||
print(f"• dse: Standard error of the difference")
|
||||
print(f"• warning: True if model has reliability issues")
|
||||
print(f"• scale: {scale}")
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("MODEL SELECTION GUIDELINES")
|
||||
print("="*70)
|
||||
|
||||
best_model = comparison.index[0]
|
||||
print(f"\n✓ Best model: {best_model}")
|
||||
|
||||
# Check for clear winner
|
||||
if len(comparison) > 1:
|
||||
delta = comparison.iloc[1][f'd{ic}']
|
||||
delta_se = comparison.iloc[1]['dse']
|
||||
|
||||
if delta > 10:
|
||||
print(f" → STRONG evidence for {best_model} (Δ{ic} > 10)")
|
||||
elif delta > 4:
|
||||
print(f" → MODERATE evidence for {best_model} (4 < Δ{ic} < 10)")
|
||||
elif delta > 2:
|
||||
print(f" → WEAK evidence for {best_model} (2 < Δ{ic} < 4)")
|
||||
else:
|
||||
print(f" → Models are SIMILAR (Δ{ic} < 2)")
|
||||
print(f" Consider model averaging or choose based on simplicity")
|
||||
|
||||
# Check if difference is significant relative to SE
|
||||
if delta > 2 * delta_se:
|
||||
print(f" → Difference is > 2 SE, likely reliable")
|
||||
else:
|
||||
print(f" → Difference is < 2 SE, uncertain distinction")
|
||||
|
||||
# Check for warnings
|
||||
if comparison['warning'].any():
|
||||
print("\n⚠️ WARNING: Some models have reliability issues")
|
||||
warned_models = comparison[comparison['warning']].index.tolist()
|
||||
print(f" Models with warnings: {', '.join(warned_models)}")
|
||||
print(f" → Check Pareto-k diagnostics with check_loo_reliability()")
|
||||
|
||||
return comparison
|
||||
|
||||
|
||||
def check_loo_reliability(models_dict: Dict[str, az.InferenceData],
|
||||
threshold=0.7,
|
||||
verbose=True):
|
||||
"""
|
||||
Check LOO-CV reliability using Pareto-k diagnostics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_dict : dict
|
||||
Dictionary mapping model names to InferenceData objects
|
||||
threshold : float
|
||||
Pareto-k threshold for flagging observations (default: 0.7)
|
||||
verbose : bool
|
||||
Print detailed diagnostics (default: True)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary with Pareto-k diagnostics for each model
|
||||
"""
|
||||
if verbose:
|
||||
print("="*70)
|
||||
print(" " * 20 + "LOO RELIABILITY CHECK")
|
||||
print("="*70)
|
||||
|
||||
results = {}
|
||||
|
||||
for name, idata in models_dict.items():
|
||||
if verbose:
|
||||
print(f"\n{name}:")
|
||||
print("-"*70)
|
||||
|
||||
# Compute LOO with pointwise results
|
||||
loo_result = az.loo(idata, pointwise=True)
|
||||
pareto_k = loo_result.pareto_k.values
|
||||
|
||||
# Count problematic observations
|
||||
n_high = (pareto_k > threshold).sum()
|
||||
n_very_high = (pareto_k > 1.0).sum()
|
||||
|
||||
results[name] = {
|
||||
'pareto_k': pareto_k,
|
||||
'n_high': n_high,
|
||||
'n_very_high': n_very_high,
|
||||
'max_k': pareto_k.max(),
|
||||
'loo': loo_result
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"Pareto-k diagnostics:")
|
||||
print(f" • Good (k < 0.5): {(pareto_k < 0.5).sum()} observations")
|
||||
print(f" • OK (0.5 ≤ k < 0.7): {((pareto_k >= 0.5) & (pareto_k < 0.7)).sum()} observations")
|
||||
print(f" • Bad (0.7 ≤ k < 1.0): {((pareto_k >= 0.7) & (pareto_k < 1.0)).sum()} observations")
|
||||
print(f" • Very bad (k ≥ 1.0): {(pareto_k >= 1.0).sum()} observations")
|
||||
print(f" • Maximum k: {pareto_k.max():.3f}")
|
||||
|
||||
if n_high > 0:
|
||||
print(f"\n⚠️ {n_high} observations with k > {threshold}")
|
||||
print(" LOO approximation may be unreliable for these points")
|
||||
print(" Solutions:")
|
||||
print(" → Use WAIC instead (less sensitive to outliers)")
|
||||
print(" → Investigate influential observations")
|
||||
print(" → Consider more flexible model")
|
||||
|
||||
if n_very_high > 0:
|
||||
print(f"\n⚠️ {n_very_high} observations with k > 1.0")
|
||||
print(" These points have very high influence")
|
||||
print(" → Strongly consider K-fold CV or other validation")
|
||||
else:
|
||||
print(f"✓ All Pareto-k values < {threshold}")
|
||||
print(" LOO estimates are reliable")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def plot_model_comparison(comparison, output_path=None, show=True):
|
||||
"""
|
||||
Visualize model comparison results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
comparison : pd.DataFrame
|
||||
Comparison DataFrame from az.compare()
|
||||
output_path : str, optional
|
||||
If provided, save plot to this path
|
||||
show : bool
|
||||
Whether to display plot (default: True)
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.figure.Figure
|
||||
The comparison figure
|
||||
"""
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
az.plot_compare(comparison)
|
||||
plt.title('Model Comparison', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Comparison plot saved to {output_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def model_averaging(models_dict: Dict[str, az.InferenceData],
|
||||
weights=None,
|
||||
var_name='y_obs',
|
||||
ic='loo'):
|
||||
"""
|
||||
Perform Bayesian model averaging using model weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_dict : dict
|
||||
Dictionary mapping model names to InferenceData objects
|
||||
weights : array-like, optional
|
||||
Model weights. If None, computed from IC (pseudo-BMA weights)
|
||||
var_name : str
|
||||
Name of the predicted variable (default: 'y_obs')
|
||||
ic : str
|
||||
Information criterion for computing weights if not provided
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Averaged predictions across models
|
||||
np.ndarray
|
||||
Model weights used
|
||||
"""
|
||||
if weights is None:
|
||||
comparison = az.compare(models_dict, ic=ic)
|
||||
weights = comparison['weight'].values
|
||||
model_names = comparison.index.tolist()
|
||||
else:
|
||||
model_names = list(models_dict.keys())
|
||||
weights = np.array(weights)
|
||||
weights = weights / weights.sum() # Normalize
|
||||
|
||||
print("="*70)
|
||||
print(" " * 22 + "BAYESIAN MODEL AVERAGING")
|
||||
print("="*70)
|
||||
print("\nModel weights:")
|
||||
for name, weight in zip(model_names, weights):
|
||||
print(f" {name}: {weight:.4f} ({weight*100:.2f}%)")
|
||||
|
||||
# Extract predictions and average
|
||||
predictions = []
|
||||
for name in model_names:
|
||||
idata = models_dict[name]
|
||||
if 'posterior_predictive' in idata:
|
||||
pred = idata.posterior_predictive[var_name].values
|
||||
else:
|
||||
print(f"Warning: {name} missing posterior_predictive, skipping")
|
||||
continue
|
||||
predictions.append(pred)
|
||||
|
||||
# Weighted average
|
||||
averaged = sum(w * p for w, p in zip(weights, predictions))
|
||||
|
||||
print(f"\n✓ Model averaging complete")
|
||||
print(f" Combined predictions using {len(predictions)} models")
|
||||
|
||||
return averaged, weights
|
||||
|
||||
|
||||
def cross_validation_comparison(models_dict: Dict[str, az.InferenceData],
|
||||
k=10,
|
||||
verbose=True):
|
||||
"""
|
||||
Perform k-fold cross-validation comparison (conceptual guide).
|
||||
|
||||
Note: This function provides guidance. Full k-fold CV requires
|
||||
re-fitting models k times, which should be done in the main script.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_dict : dict
|
||||
Dictionary of model names to InferenceData
|
||||
k : int
|
||||
Number of folds (default: 10)
|
||||
verbose : bool
|
||||
Print guidance
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
if verbose:
|
||||
print("="*70)
|
||||
print(" " * 20 + "K-FOLD CROSS-VALIDATION GUIDE")
|
||||
print("="*70)
|
||||
print(f"\nTo perform {k}-fold CV:")
|
||||
print("""
|
||||
1. Split data into k folds
|
||||
2. For each fold:
|
||||
- Train all models on k-1 folds
|
||||
- Compute log-likelihood on held-out fold
|
||||
3. Sum log-likelihoods across folds for each model
|
||||
4. Compare models using total CV score
|
||||
|
||||
Example code:
|
||||
-------------
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
kf = KFold(n_splits=k, shuffle=True, random_seed=42)
|
||||
cv_scores = {name: [] for name in models_dict.keys()}
|
||||
|
||||
for train_idx, test_idx in kf.split(X):
|
||||
X_train, X_test = X[train_idx], X[test_idx]
|
||||
y_train, y_test = y[train_idx], y[test_idx]
|
||||
|
||||
for name in models_dict.keys():
|
||||
# Fit model on train set
|
||||
with create_model(name, X_train, y_train) as model:
|
||||
idata = pm.sample()
|
||||
|
||||
# Compute log-likelihood on test set
|
||||
with model:
|
||||
pm.set_data({'X': X_test, 'y': y_test})
|
||||
log_lik = pm.compute_log_likelihood(idata).sum()
|
||||
|
||||
cv_scores[name].append(log_lik)
|
||||
|
||||
# Compare total CV scores
|
||||
for name, scores in cv_scores.items():
|
||||
print(f"{name}: {np.sum(scores):.2f}")
|
||||
""")
|
||||
|
||||
print("\nNote: K-fold CV is expensive but most reliable for model comparison")
|
||||
print(" Use when LOO has reliability issues (high Pareto-k values)")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == '__main__':
|
||||
print("This script provides model comparison utilities for PyMC.")
|
||||
print("\nExample usage:")
|
||||
print("""
|
||||
import pymc as pm
|
||||
from scripts.model_comparison import compare_models, check_loo_reliability
|
||||
|
||||
# Fit multiple models (must include log_likelihood)
|
||||
with pm.Model() as model1:
|
||||
# ... define model 1 ...
|
||||
idata1 = pm.sample(idata_kwargs={'log_likelihood': True})
|
||||
|
||||
with pm.Model() as model2:
|
||||
# ... define model 2 ...
|
||||
idata2 = pm.sample(idata_kwargs={'log_likelihood': True})
|
||||
|
||||
# Compare models
|
||||
models = {'Simple': idata1, 'Complex': idata2}
|
||||
comparison = compare_models(models, ic='loo')
|
||||
|
||||
# Check reliability
|
||||
reliability = check_loo_reliability(models)
|
||||
|
||||
# Visualize
|
||||
plot_model_comparison(comparison, output_path='comparison.png')
|
||||
|
||||
# Model averaging
|
||||
averaged_pred, weights = model_averaging(models, var_name='y_obs')
|
||||
""")
|
||||
350
skills/pymc/scripts/model_diagnostics.py
Normal file
350
skills/pymc/scripts/model_diagnostics.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
PyMC Model Diagnostics Script
|
||||
|
||||
Comprehensive diagnostic checks for PyMC models.
|
||||
Run this after sampling to validate results before interpretation.
|
||||
|
||||
Usage:
|
||||
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
|
||||
|
||||
# Quick check
|
||||
check_diagnostics(idata)
|
||||
|
||||
# Full report with plots
|
||||
create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/')
|
||||
"""
|
||||
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01):
|
||||
"""
|
||||
Perform comprehensive diagnostic checks on MCMC samples.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData object from pm.sample()
|
||||
var_names : list, optional
|
||||
Variables to check. If None, checks all model parameters
|
||||
ess_threshold : int
|
||||
Minimum acceptable effective sample size (default: 400)
|
||||
rhat_threshold : float
|
||||
Maximum acceptable R-hat value (default: 1.01)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary with diagnostic results and flags
|
||||
"""
|
||||
print("="*70)
|
||||
print(" " * 20 + "MCMC DIAGNOSTICS REPORT")
|
||||
print("="*70)
|
||||
|
||||
# Get summary statistics
|
||||
summary = az.summary(idata, var_names=var_names)
|
||||
|
||||
results = {
|
||||
'summary': summary,
|
||||
'has_issues': False,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
# 1. Check R-hat (convergence)
|
||||
print("\n1. CONVERGENCE CHECK (R-hat)")
|
||||
print("-" * 70)
|
||||
bad_rhat = summary[summary['r_hat'] > rhat_threshold]
|
||||
|
||||
if len(bad_rhat) > 0:
|
||||
print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}")
|
||||
print("\nTop 10 worst R-hat values:")
|
||||
print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10))
|
||||
print("\n⚠️ Chains may not have converged!")
|
||||
print(" → Run longer chains or check for multimodality")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('convergence')
|
||||
else:
|
||||
print(f"✓ All R-hat values ≤ {rhat_threshold}")
|
||||
print(" Chains have converged successfully")
|
||||
|
||||
# 2. Check Effective Sample Size
|
||||
print("\n2. EFFECTIVE SAMPLE SIZE (ESS)")
|
||||
print("-" * 70)
|
||||
low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold]
|
||||
low_ess_tail = summary[summary['ess_tail'] < ess_threshold]
|
||||
|
||||
if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0:
|
||||
print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}")
|
||||
|
||||
if len(low_ess_bulk) > 0:
|
||||
print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):")
|
||||
print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10))
|
||||
|
||||
if len(low_ess_tail) > 0:
|
||||
print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):")
|
||||
print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10))
|
||||
|
||||
print("\n⚠️ High autocorrelation detected!")
|
||||
print(" → Sample more draws or reparameterize to reduce correlation")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('low_ess')
|
||||
else:
|
||||
print(f"✓ All ESS values ≥ {ess_threshold}")
|
||||
print(" Sufficient effective samples")
|
||||
|
||||
# 3. Check Divergences
|
||||
print("\n3. DIVERGENT TRANSITIONS")
|
||||
print("-" * 70)
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
|
||||
if divergences > 0:
|
||||
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
|
||||
divergence_rate = divergences / total_samples * 100
|
||||
|
||||
print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)")
|
||||
print("\n Divergences indicate biased sampling in difficult posterior regions")
|
||||
print(" Solutions:")
|
||||
print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)")
|
||||
print(" → Use non-centered parameterization for hierarchical models")
|
||||
print(" → Add stronger/more informative priors")
|
||||
print(" → Check for model misspecification")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('divergences')
|
||||
results['n_divergences'] = divergences
|
||||
else:
|
||||
print("✓ No divergences detected")
|
||||
print(" NUTS explored the posterior successfully")
|
||||
|
||||
# 4. Check Tree Depth
|
||||
print("\n4. TREE DEPTH")
|
||||
print("-" * 70)
|
||||
tree_depth = idata.sample_stats.tree_depth
|
||||
max_tree_depth = tree_depth.max().item()
|
||||
|
||||
# Typical max_treedepth is 10 (default in PyMC)
|
||||
hits_max = (tree_depth >= 10).sum().item()
|
||||
|
||||
if hits_max > 0:
|
||||
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
|
||||
hit_rate = hits_max / total_samples * 100
|
||||
|
||||
print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)")
|
||||
print("\n Model may be difficult to explore efficiently")
|
||||
print(" Solutions:")
|
||||
print(" → Reparameterize model to improve geometry")
|
||||
print(" → Increase max_treedepth (if necessary)")
|
||||
results['issues'].append('max_treedepth')
|
||||
else:
|
||||
print(f"✓ No maximum tree depth issues")
|
||||
print(f" Maximum tree depth reached: {max_tree_depth}")
|
||||
|
||||
# 5. Check Energy (if available)
|
||||
if hasattr(idata.sample_stats, 'energy'):
|
||||
print("\n5. ENERGY DIAGNOSTICS")
|
||||
print("-" * 70)
|
||||
print("✓ Energy statistics available")
|
||||
print(" Use az.plot_energy(idata) to visualize energy transitions")
|
||||
print(" Good separation indicates healthy HMC sampling")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("SUMMARY")
|
||||
print("="*70)
|
||||
|
||||
if not results['has_issues']:
|
||||
print("✓ All diagnostics passed!")
|
||||
print(" Your model has sampled successfully.")
|
||||
print(" Proceed with inference and interpretation.")
|
||||
else:
|
||||
print("⚠️ Some diagnostics failed!")
|
||||
print(f" Issues found: {', '.join(results['issues'])}")
|
||||
print(" Review warnings above and consider re-running with adjustments.")
|
||||
|
||||
print("="*70)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False):
|
||||
"""
|
||||
Create comprehensive diagnostic report with plots.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData object from pm.sample()
|
||||
var_names : list, optional
|
||||
Variables to plot. If None, uses all model parameters
|
||||
output_dir : str
|
||||
Directory to save diagnostic plots
|
||||
show : bool
|
||||
Whether to display plots (default: False, just save)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Diagnostic results from check_diagnostics
|
||||
"""
|
||||
# Create output directory
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run diagnostic checks
|
||||
results = check_diagnostics(idata, var_names=var_names)
|
||||
|
||||
print(f"\nGenerating diagnostic plots in '{output_dir}'...")
|
||||
|
||||
# 1. Trace plots
|
||||
fig, axes = plt.subplots(
|
||||
len(var_names) if var_names else 5,
|
||||
2,
|
||||
figsize=(12, 10)
|
||||
)
|
||||
az.plot_trace(idata, var_names=var_names, axes=axes)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved trace plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 2. Rank plots (check mixing)
|
||||
fig = plt.figure(figsize=(12, 8))
|
||||
az.plot_rank(idata, var_names=var_names)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved rank plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 3. Autocorrelation plots
|
||||
fig = plt.figure(figsize=(12, 8))
|
||||
az.plot_autocorr(idata, var_names=var_names, combined=True)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved autocorrelation plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 4. Energy plot (if available)
|
||||
if hasattr(idata.sample_stats, 'energy'):
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
az.plot_energy(idata)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved energy plot")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 5. ESS plot
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
az.plot_ess(idata, var_names=var_names, kind='evolution')
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved ESS evolution plot")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# Save summary to CSV
|
||||
results['summary'].to_csv(output_path / 'summary_statistics.csv')
|
||||
print(f" ✓ Saved summary statistics")
|
||||
|
||||
print(f"\nDiagnostic report complete! Files saved in '{output_dir}'")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None):
|
||||
"""
|
||||
Compare prior and posterior distributions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData with posterior samples
|
||||
prior_idata : arviz.InferenceData
|
||||
InferenceData with prior samples
|
||||
var_names : list, optional
|
||||
Variables to compare
|
||||
output_path : str, optional
|
||||
If provided, save plot to this path
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
fig, axes = plt.subplots(
|
||||
len(var_names) if var_names else 3,
|
||||
1,
|
||||
figsize=(10, 8)
|
||||
)
|
||||
|
||||
if not isinstance(axes, np.ndarray):
|
||||
axes = [axes]
|
||||
|
||||
for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]):
|
||||
# Plot prior
|
||||
az.plot_dist(
|
||||
prior_idata.prior[var].values.flatten(),
|
||||
label='Prior',
|
||||
ax=axes[idx],
|
||||
color='blue',
|
||||
alpha=0.3
|
||||
)
|
||||
|
||||
# Plot posterior
|
||||
az.plot_dist(
|
||||
idata.posterior[var].values.flatten(),
|
||||
label='Posterior',
|
||||
ax=axes[idx],
|
||||
color='green',
|
||||
alpha=0.3
|
||||
)
|
||||
|
||||
axes[idx].set_title(f'{var}: Prior vs Posterior')
|
||||
axes[idx].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Prior-posterior comparison saved to {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == '__main__':
|
||||
print("This script provides diagnostic functions for PyMC models.")
|
||||
print("\nExample usage:")
|
||||
print("""
|
||||
import pymc as pm
|
||||
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
|
||||
|
||||
# After sampling
|
||||
with pm.Model() as model:
|
||||
# ... define model ...
|
||||
idata = pm.sample()
|
||||
|
||||
# Quick diagnostic check
|
||||
results = check_diagnostics(idata)
|
||||
|
||||
# Full diagnostic report with plots
|
||||
create_diagnostic_report(
|
||||
idata,
|
||||
var_names=['alpha', 'beta', 'sigma'],
|
||||
output_dir='my_diagnostics/'
|
||||
)
|
||||
""")
|
||||
Reference in New Issue
Block a user