""" PyMC Model Diagnostics Script Comprehensive diagnostic checks for PyMC models. Run this after sampling to validate results before interpretation. Usage: from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report # Quick check check_diagnostics(idata) # Full report with plots create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/') """ import arviz as az import numpy as np import matplotlib.pyplot as plt from pathlib import Path def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01): """ Perform comprehensive diagnostic checks on MCMC samples. Parameters ---------- idata : arviz.InferenceData InferenceData object from pm.sample() var_names : list, optional Variables to check. If None, checks all model parameters ess_threshold : int Minimum acceptable effective sample size (default: 400) rhat_threshold : float Maximum acceptable R-hat value (default: 1.01) Returns ------- dict Dictionary with diagnostic results and flags """ print("="*70) print(" " * 20 + "MCMC DIAGNOSTICS REPORT") print("="*70) # Get summary statistics summary = az.summary(idata, var_names=var_names) results = { 'summary': summary, 'has_issues': False, 'issues': [] } # 1. Check R-hat (convergence) print("\n1. CONVERGENCE CHECK (R-hat)") print("-" * 70) bad_rhat = summary[summary['r_hat'] > rhat_threshold] if len(bad_rhat) > 0: print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}") print("\nTop 10 worst R-hat values:") print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10)) print("\n⚠️ Chains may not have converged!") print(" → Run longer chains or check for multimodality") results['has_issues'] = True results['issues'].append('convergence') else: print(f"✓ All R-hat values ≤ {rhat_threshold}") print(" Chains have converged successfully") # 2. Check Effective Sample Size print("\n2. EFFECTIVE SAMPLE SIZE (ESS)") print("-" * 70) low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold] low_ess_tail = summary[summary['ess_tail'] < ess_threshold] if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0: print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}") if len(low_ess_bulk) > 0: print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):") print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10)) if len(low_ess_tail) > 0: print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):") print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10)) print("\n⚠️ High autocorrelation detected!") print(" → Sample more draws or reparameterize to reduce correlation") results['has_issues'] = True results['issues'].append('low_ess') else: print(f"✓ All ESS values ≥ {ess_threshold}") print(" Sufficient effective samples") # 3. Check Divergences print("\n3. DIVERGENT TRANSITIONS") print("-" * 70) divergences = idata.sample_stats.diverging.sum().item() if divergences > 0: total_samples = len(idata.posterior.draw) * len(idata.posterior.chain) divergence_rate = divergences / total_samples * 100 print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)") print("\n Divergences indicate biased sampling in difficult posterior regions") print(" Solutions:") print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)") print(" → Use non-centered parameterization for hierarchical models") print(" → Add stronger/more informative priors") print(" → Check for model misspecification") results['has_issues'] = True results['issues'].append('divergences') results['n_divergences'] = divergences else: print("✓ No divergences detected") print(" NUTS explored the posterior successfully") # 4. Check Tree Depth print("\n4. TREE DEPTH") print("-" * 70) tree_depth = idata.sample_stats.tree_depth max_tree_depth = tree_depth.max().item() # Typical max_treedepth is 10 (default in PyMC) hits_max = (tree_depth >= 10).sum().item() if hits_max > 0: total_samples = len(idata.posterior.draw) * len(idata.posterior.chain) hit_rate = hits_max / total_samples * 100 print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)") print("\n Model may be difficult to explore efficiently") print(" Solutions:") print(" → Reparameterize model to improve geometry") print(" → Increase max_treedepth (if necessary)") results['issues'].append('max_treedepth') else: print(f"✓ No maximum tree depth issues") print(f" Maximum tree depth reached: {max_tree_depth}") # 5. Check Energy (if available) if hasattr(idata.sample_stats, 'energy'): print("\n5. ENERGY DIAGNOSTICS") print("-" * 70) print("✓ Energy statistics available") print(" Use az.plot_energy(idata) to visualize energy transitions") print(" Good separation indicates healthy HMC sampling") # Summary print("\n" + "="*70) print("SUMMARY") print("="*70) if not results['has_issues']: print("✓ All diagnostics passed!") print(" Your model has sampled successfully.") print(" Proceed with inference and interpretation.") else: print("⚠️ Some diagnostics failed!") print(f" Issues found: {', '.join(results['issues'])}") print(" Review warnings above and consider re-running with adjustments.") print("="*70) return results def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False): """ Create comprehensive diagnostic report with plots. Parameters ---------- idata : arviz.InferenceData InferenceData object from pm.sample() var_names : list, optional Variables to plot. If None, uses all model parameters output_dir : str Directory to save diagnostic plots show : bool Whether to display plots (default: False, just save) Returns ------- dict Diagnostic results from check_diagnostics """ # Create output directory output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Run diagnostic checks results = check_diagnostics(idata, var_names=var_names) print(f"\nGenerating diagnostic plots in '{output_dir}'...") # 1. Trace plots fig, axes = plt.subplots( len(var_names) if var_names else 5, 2, figsize=(12, 10) ) az.plot_trace(idata, var_names=var_names, axes=axes) plt.tight_layout() plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight') print(f" ✓ Saved trace plots") if show: plt.show() else: plt.close() # 2. Rank plots (check mixing) fig = plt.figure(figsize=(12, 8)) az.plot_rank(idata, var_names=var_names) plt.tight_layout() plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight') print(f" ✓ Saved rank plots") if show: plt.show() else: plt.close() # 3. Autocorrelation plots fig = plt.figure(figsize=(12, 8)) az.plot_autocorr(idata, var_names=var_names, combined=True) plt.tight_layout() plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight') print(f" ✓ Saved autocorrelation plots") if show: plt.show() else: plt.close() # 4. Energy plot (if available) if hasattr(idata.sample_stats, 'energy'): fig = plt.figure(figsize=(10, 6)) az.plot_energy(idata) plt.tight_layout() plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight') print(f" ✓ Saved energy plot") if show: plt.show() else: plt.close() # 5. ESS plot fig = plt.figure(figsize=(10, 6)) az.plot_ess(idata, var_names=var_names, kind='evolution') plt.tight_layout() plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight') print(f" ✓ Saved ESS evolution plot") if show: plt.show() else: plt.close() # Save summary to CSV results['summary'].to_csv(output_path / 'summary_statistics.csv') print(f" ✓ Saved summary statistics") print(f"\nDiagnostic report complete! Files saved in '{output_dir}'") return results def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None): """ Compare prior and posterior distributions. Parameters ---------- idata : arviz.InferenceData InferenceData with posterior samples prior_idata : arviz.InferenceData InferenceData with prior samples var_names : list, optional Variables to compare output_path : str, optional If provided, save plot to this path Returns ------- None """ fig, axes = plt.subplots( len(var_names) if var_names else 3, 1, figsize=(10, 8) ) if not isinstance(axes, np.ndarray): axes = [axes] for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]): # Plot prior az.plot_dist( prior_idata.prior[var].values.flatten(), label='Prior', ax=axes[idx], color='blue', alpha=0.3 ) # Plot posterior az.plot_dist( idata.posterior[var].values.flatten(), label='Posterior', ax=axes[idx], color='green', alpha=0.3 ) axes[idx].set_title(f'{var}: Prior vs Posterior') axes[idx].legend() plt.tight_layout() if output_path: plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"Prior-posterior comparison saved to {output_path}") else: plt.show() # Example usage if __name__ == '__main__': print("This script provides diagnostic functions for PyMC models.") print("\nExample usage:") print(""" import pymc as pm from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report # After sampling with pm.Model() as model: # ... define model ... idata = pm.sample() # Quick diagnostic check results = check_diagnostics(idata) # Full diagnostic report with plots create_diagnostic_report( idata, var_names=['alpha', 'beta', 'sigma'], output_dir='my_diagnostics/' ) """)