# PyMC Workflows and Common Patterns This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC. ## Standard Bayesian Workflow ### Complete Workflow Template ```python import pymc as pm import arviz as az import numpy as np import matplotlib.pyplot as plt # 1. PREPARE DATA # =============== X = ... # Predictor variables y = ... # Observed outcomes # Standardize predictors for better sampling X_scaled = (X - X.mean(axis=0)) / X.std(axis=0) # 2. BUILD MODEL # ============== with pm.Model() as model: # Define coordinates for named dimensions coords = { 'predictors': ['var1', 'var2', 'var3'], 'obs_id': np.arange(len(y)) } # Priors alpha = pm.Normal('alpha', mu=0, sigma=1) beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') sigma = pm.HalfNormal('sigma', sigma=1) # Linear predictor mu = alpha + pm.math.dot(X_scaled, beta) # Likelihood y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id') # 3. PRIOR PREDICTIVE CHECK # ========================== with model: prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42) # Visualize prior predictions az.plot_ppc(prior_pred, group='prior', num_pp_samples=100) plt.title('Prior Predictive Check') plt.show() # 4. FIT MODEL # ============ with model: # Quick VI exploration (optional) approx = pm.fit(n=20000, random_seed=42) # Full MCMC inference idata = pm.sample( draws=2000, tune=1000, chains=4, target_accept=0.9, random_seed=42, idata_kwargs={'log_likelihood': True} # For model comparison ) # 5. CHECK DIAGNOSTICS # ==================== # Summary statistics print(az.summary(idata, var_names=['alpha', 'beta', 'sigma'])) # R-hat and ESS summary = az.summary(idata) if (summary['r_hat'] > 1.01).any(): print("WARNING: Some R-hat values > 1.01, chains may not have converged") if (summary['ess_bulk'] < 400).any(): print("WARNING: Some ESS values < 400, consider more samples") # Check divergences divergences = idata.sample_stats.diverging.sum().item() print(f"Number of divergences: {divergences}") # Trace plots az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma']) plt.tight_layout() plt.show() # 6. POSTERIOR PREDICTIVE CHECK # ============================== with model: pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42) # Visualize fit az.plot_ppc(idata, num_pp_samples=100) plt.title('Posterior Predictive Check') plt.show() # 7. ANALYZE RESULTS # ================== # Posterior distributions az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma']) plt.tight_layout() plt.show() # Forest plot for coefficients az.plot_forest(idata, var_names=['beta'], combined=True) plt.title('Coefficient Estimates') plt.show() # 8. PREDICTIONS FOR NEW DATA # ============================ X_new = ... # New predictor values X_new_scaled = (X_new - X.mean(axis=0)) / X.std(axis=0) with model: # Update data pm.set_data({'X': X_new_scaled}) # Sample predictions post_pred = pm.sample_posterior_predictive( idata.posterior, var_names=['y_obs'], random_seed=42 ) # Prediction intervals y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']) y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs']) # 9. SAVE RESULTS # =============== idata.to_netcdf('model_results.nc') # Save for later ``` ## Model Building Patterns ### Linear Regression ```python with pm.Model() as linear_model: # Priors alpha = pm.Normal('alpha', mu=0, sigma=10) beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) sigma = pm.HalfNormal('sigma', sigma=1) # Linear predictor mu = alpha + pm.math.dot(X, beta) # Likelihood y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs) ``` ### Logistic Regression ```python with pm.Model() as logistic_model: # Priors alpha = pm.Normal('alpha', mu=0, sigma=10) beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) # Linear predictor logit_p = alpha + pm.math.dot(X, beta) # Likelihood y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs) ``` ### Hierarchical/Multilevel Model ```python with pm.Model(coords={'group': group_names, 'obs': np.arange(n_obs)}) as hierarchical_model: # Hyperpriors mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10) sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1) mu_beta = pm.Normal('mu_beta', mu=0, sigma=10) sigma_beta = pm.HalfNormal('sigma_beta', sigma=1) # Group-level parameters (non-centered) alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='group') alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='group') beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='group') beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='group') # Observation-level model mu = alpha[group_idx] + beta[group_idx] * X sigma = pm.HalfNormal('sigma', sigma=1) y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs, dims='obs') ``` ### Poisson Regression (Count Data) ```python with pm.Model() as poisson_model: # Priors alpha = pm.Normal('alpha', mu=0, sigma=10) beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) # Linear predictor on log scale log_lambda = alpha + pm.math.dot(X, beta) # Likelihood y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs) ``` ### Time Series (Autoregressive) ```python with pm.Model() as ar_model: # Innovation standard deviation sigma = pm.HalfNormal('sigma', sigma=1) # AR coefficients rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order) # Initial distribution init_dist = pm.Normal.dist(mu=0, sigma=sigma) # AR process y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs) ``` ### Mixture Model ```python with pm.Model() as mixture_model: # Component weights w = pm.Dirichlet('w', a=np.ones(n_components)) # Component parameters mu = pm.Normal('mu', mu=0, sigma=10, shape=n_components) sigma = pm.HalfNormal('sigma', sigma=1, shape=n_components) # Mixture components = [pm.Normal.dist(mu=mu[i], sigma=sigma[i]) for i in range(n_components)] y = pm.Mixture('y', w=w, comp_dists=components, observed=y_obs) ``` ## Data Preparation Best Practices ### Standardization Standardize continuous predictors for better sampling: ```python # Standardize X_mean = X.mean(axis=0) X_std = X.std(axis=0) X_scaled = (X - X_mean) / X_std # Model with scaled data with pm.Model() as model: beta_scaled = pm.Normal('beta_scaled', 0, 1) # ... rest of model ... # Transform back to original scale beta_original = beta_scaled / X_std alpha_original = alpha - (beta_scaled * X_mean / X_std).sum() ``` ### Handling Missing Data Treat missing values as parameters: ```python # Identify missing values missing_idx = np.isnan(X) X_observed = np.where(missing_idx, 0, X) # Placeholder with pm.Model() as model: # Prior for missing values X_missing = pm.Normal('X_missing', mu=0, sigma=1, shape=missing_idx.sum()) # Combine observed and imputed X_complete = pm.math.switch(missing_idx.flatten(), X_missing, X_observed.flatten()) # ... rest of model using X_complete ... ``` ### Centering and Scaling For regression models, center predictors and outcome: ```python # Center X_centered = X - X.mean(axis=0) y_centered = y - y.mean() with pm.Model() as model: # Simpler prior on intercept alpha = pm.Normal('alpha', mu=0, sigma=1) # Intercept near 0 when centered beta = pm.Normal('beta', mu=0, sigma=1, shape=n_predictors) mu = alpha + pm.math.dot(X_centered, beta) sigma = pm.HalfNormal('sigma', sigma=1) y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_centered) ``` ## Prior Selection Guidelines ### Weakly Informative Priors Use when you have limited prior knowledge: ```python # For standardized predictors beta = pm.Normal('beta', mu=0, sigma=1) # For scale parameters sigma = pm.HalfNormal('sigma', sigma=1) # For probabilities p = pm.Beta('p', alpha=2, beta=2) # Slight preference for middle values ``` ### Informative Priors Use domain knowledge: ```python # Effect size from literature: Cohen's d ≈ 0.3 beta = pm.Normal('beta', mu=0.3, sigma=0.1) # Physical constraint: probability between 0.7-0.9 p = pm.Beta('p', alpha=8, beta=2) # Check with prior predictive! ``` ### Prior Predictive Checks Always validate priors: ```python with model: prior_pred = pm.sample_prior_predictive(samples=1000) # Check if predictions are reasonable print(f"Prior predictive range: {prior_pred.prior_predictive['y'].min():.2f} to {prior_pred.prior_predictive['y'].max():.2f}") print(f"Observed range: {y_obs.min():.2f} to {y_obs.max():.2f}") # Visualize az.plot_ppc(prior_pred, group='prior') ``` ## Model Comparison Workflow ### Comparing Multiple Models ```python import arviz as az # Fit multiple models models = {} idatas = {} # Model 1: Simple linear with pm.Model() as models['linear']: # ... define model ... idatas['linear'] = pm.sample(idata_kwargs={'log_likelihood': True}) # Model 2: With interaction with pm.Model() as models['interaction']: # ... define model ... idatas['interaction'] = pm.sample(idata_kwargs={'log_likelihood': True}) # Model 3: Hierarchical with pm.Model() as models['hierarchical']: # ... define model ... idatas['hierarchical'] = pm.sample(idata_kwargs={'log_likelihood': True}) # Compare using LOO comparison = az.compare(idatas, ic='loo') print(comparison) # Visualize comparison az.plot_compare(comparison) plt.show() # Check LOO reliability for name, idata in idatas.items(): loo = az.loo(idata, pointwise=True) high_pareto_k = (loo.pareto_k > 0.7).sum().item() if high_pareto_k > 0: print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k") ``` ### Model Weights ```python # Get model weights (pseudo-BMA) weights = comparison['weight'].values print("Model probabilities:") for name, weight in zip(comparison.index, weights): print(f" {name}: {weight:.2%}") # Model averaging (weighted predictions) def weighted_predictions(idatas, weights): preds = [] for (name, idata), weight in zip(idatas.items(), weights): pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']) preds.append(weight * pred) return sum(preds) averaged_pred = weighted_predictions(idatas, weights) ``` ## Diagnostics and Troubleshooting ### Diagnosing Sampling Problems ```python def diagnose_sampling(idata, var_names=None): """Comprehensive sampling diagnostics""" # Check convergence summary = az.summary(idata, var_names=var_names) print("=== Convergence Diagnostics ===") bad_rhat = summary[summary['r_hat'] > 1.01] if len(bad_rhat) > 0: print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01") print(bad_rhat[['r_hat']]) else: print("✓ All R-hat values < 1.01") # Check effective sample size print("\n=== Effective Sample Size ===") low_ess = summary[summary['ess_bulk'] < 400] if len(low_ess) > 0: print(f"⚠️ {len(low_ess)} variables with ESS < 400") print(low_ess[['ess_bulk', 'ess_tail']]) else: print("✓ All ESS values > 400") # Check divergences print("\n=== Divergences ===") divergences = idata.sample_stats.diverging.sum().item() if divergences > 0: print(f"⚠️ {divergences} divergent transitions") print(" Consider: increase target_accept, reparameterize, or stronger priors") else: print("✓ No divergences") # Check tree depth print("\n=== NUTS Statistics ===") max_treedepth = idata.sample_stats.tree_depth.max().item() hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item() if hits_max > 0: print(f"⚠️ Hit max treedepth {hits_max} times") print(" Consider: reparameterize or increase max_treedepth") else: print(f"✓ No max treedepth issues (max: {max_treedepth})") return summary # Usage diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma']) ``` ### Common Fixes | Problem | Solution | |---------|----------| | Divergences | Increase `target_accept=0.95`, use non-centered parameterization | | Low ESS | Sample more draws, reparameterize to reduce correlation | | High R-hat | Run longer chains, check for multimodality, improve initialization | | Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity | | Biased posterior | Check prior predictive, ensure likelihood is correct | ## Using Named Dimensions (dims) ### Benefits of dims - More readable code - Easier subsetting and analysis - Better xarray integration ```python # Define coordinates coords = { 'predictors': ['age', 'income', 'education'], 'groups': ['A', 'B', 'C'], 'time': pd.date_range('2020-01-01', periods=100, freq='D') } with pm.Model(coords=coords) as model: # Use dims instead of shape beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups') y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data) # After sampling, dimensions are preserved idata = pm.sample() # Easy subsetting beta_age = idata.posterior['beta'].sel(predictors='age') group_A = idata.posterior['alpha'].sel(groups='A') ``` ## Saving and Loading Results ```python # Save InferenceData idata.to_netcdf('results.nc') # Load InferenceData loaded_idata = az.from_netcdf('results.nc') # Save model for later predictions import pickle with open('model.pkl', 'wb') as f: pickle.dump({'model': model, 'idata': idata}, f) # Load model with open('model.pkl', 'rb') as f: saved = pickle.load(f) model = saved['model'] idata = saved['idata'] ```