Initial commit
This commit is contained in:
350
skills/pymc/scripts/model_diagnostics.py
Normal file
350
skills/pymc/scripts/model_diagnostics.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
PyMC Model Diagnostics Script
|
||||
|
||||
Comprehensive diagnostic checks for PyMC models.
|
||||
Run this after sampling to validate results before interpretation.
|
||||
|
||||
Usage:
|
||||
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
|
||||
|
||||
# Quick check
|
||||
check_diagnostics(idata)
|
||||
|
||||
# Full report with plots
|
||||
create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/')
|
||||
"""
|
||||
|
||||
import arviz as az
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01):
|
||||
"""
|
||||
Perform comprehensive diagnostic checks on MCMC samples.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData object from pm.sample()
|
||||
var_names : list, optional
|
||||
Variables to check. If None, checks all model parameters
|
||||
ess_threshold : int
|
||||
Minimum acceptable effective sample size (default: 400)
|
||||
rhat_threshold : float
|
||||
Maximum acceptable R-hat value (default: 1.01)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary with diagnostic results and flags
|
||||
"""
|
||||
print("="*70)
|
||||
print(" " * 20 + "MCMC DIAGNOSTICS REPORT")
|
||||
print("="*70)
|
||||
|
||||
# Get summary statistics
|
||||
summary = az.summary(idata, var_names=var_names)
|
||||
|
||||
results = {
|
||||
'summary': summary,
|
||||
'has_issues': False,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
# 1. Check R-hat (convergence)
|
||||
print("\n1. CONVERGENCE CHECK (R-hat)")
|
||||
print("-" * 70)
|
||||
bad_rhat = summary[summary['r_hat'] > rhat_threshold]
|
||||
|
||||
if len(bad_rhat) > 0:
|
||||
print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}")
|
||||
print("\nTop 10 worst R-hat values:")
|
||||
print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10))
|
||||
print("\n⚠️ Chains may not have converged!")
|
||||
print(" → Run longer chains or check for multimodality")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('convergence')
|
||||
else:
|
||||
print(f"✓ All R-hat values ≤ {rhat_threshold}")
|
||||
print(" Chains have converged successfully")
|
||||
|
||||
# 2. Check Effective Sample Size
|
||||
print("\n2. EFFECTIVE SAMPLE SIZE (ESS)")
|
||||
print("-" * 70)
|
||||
low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold]
|
||||
low_ess_tail = summary[summary['ess_tail'] < ess_threshold]
|
||||
|
||||
if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0:
|
||||
print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}")
|
||||
|
||||
if len(low_ess_bulk) > 0:
|
||||
print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):")
|
||||
print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10))
|
||||
|
||||
if len(low_ess_tail) > 0:
|
||||
print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):")
|
||||
print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10))
|
||||
|
||||
print("\n⚠️ High autocorrelation detected!")
|
||||
print(" → Sample more draws or reparameterize to reduce correlation")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('low_ess')
|
||||
else:
|
||||
print(f"✓ All ESS values ≥ {ess_threshold}")
|
||||
print(" Sufficient effective samples")
|
||||
|
||||
# 3. Check Divergences
|
||||
print("\n3. DIVERGENT TRANSITIONS")
|
||||
print("-" * 70)
|
||||
divergences = idata.sample_stats.diverging.sum().item()
|
||||
|
||||
if divergences > 0:
|
||||
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
|
||||
divergence_rate = divergences / total_samples * 100
|
||||
|
||||
print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)")
|
||||
print("\n Divergences indicate biased sampling in difficult posterior regions")
|
||||
print(" Solutions:")
|
||||
print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)")
|
||||
print(" → Use non-centered parameterization for hierarchical models")
|
||||
print(" → Add stronger/more informative priors")
|
||||
print(" → Check for model misspecification")
|
||||
results['has_issues'] = True
|
||||
results['issues'].append('divergences')
|
||||
results['n_divergences'] = divergences
|
||||
else:
|
||||
print("✓ No divergences detected")
|
||||
print(" NUTS explored the posterior successfully")
|
||||
|
||||
# 4. Check Tree Depth
|
||||
print("\n4. TREE DEPTH")
|
||||
print("-" * 70)
|
||||
tree_depth = idata.sample_stats.tree_depth
|
||||
max_tree_depth = tree_depth.max().item()
|
||||
|
||||
# Typical max_treedepth is 10 (default in PyMC)
|
||||
hits_max = (tree_depth >= 10).sum().item()
|
||||
|
||||
if hits_max > 0:
|
||||
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
|
||||
hit_rate = hits_max / total_samples * 100
|
||||
|
||||
print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)")
|
||||
print("\n Model may be difficult to explore efficiently")
|
||||
print(" Solutions:")
|
||||
print(" → Reparameterize model to improve geometry")
|
||||
print(" → Increase max_treedepth (if necessary)")
|
||||
results['issues'].append('max_treedepth')
|
||||
else:
|
||||
print(f"✓ No maximum tree depth issues")
|
||||
print(f" Maximum tree depth reached: {max_tree_depth}")
|
||||
|
||||
# 5. Check Energy (if available)
|
||||
if hasattr(idata.sample_stats, 'energy'):
|
||||
print("\n5. ENERGY DIAGNOSTICS")
|
||||
print("-" * 70)
|
||||
print("✓ Energy statistics available")
|
||||
print(" Use az.plot_energy(idata) to visualize energy transitions")
|
||||
print(" Good separation indicates healthy HMC sampling")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("SUMMARY")
|
||||
print("="*70)
|
||||
|
||||
if not results['has_issues']:
|
||||
print("✓ All diagnostics passed!")
|
||||
print(" Your model has sampled successfully.")
|
||||
print(" Proceed with inference and interpretation.")
|
||||
else:
|
||||
print("⚠️ Some diagnostics failed!")
|
||||
print(f" Issues found: {', '.join(results['issues'])}")
|
||||
print(" Review warnings above and consider re-running with adjustments.")
|
||||
|
||||
print("="*70)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False):
|
||||
"""
|
||||
Create comprehensive diagnostic report with plots.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData object from pm.sample()
|
||||
var_names : list, optional
|
||||
Variables to plot. If None, uses all model parameters
|
||||
output_dir : str
|
||||
Directory to save diagnostic plots
|
||||
show : bool
|
||||
Whether to display plots (default: False, just save)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Diagnostic results from check_diagnostics
|
||||
"""
|
||||
# Create output directory
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run diagnostic checks
|
||||
results = check_diagnostics(idata, var_names=var_names)
|
||||
|
||||
print(f"\nGenerating diagnostic plots in '{output_dir}'...")
|
||||
|
||||
# 1. Trace plots
|
||||
fig, axes = plt.subplots(
|
||||
len(var_names) if var_names else 5,
|
||||
2,
|
||||
figsize=(12, 10)
|
||||
)
|
||||
az.plot_trace(idata, var_names=var_names, axes=axes)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved trace plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 2. Rank plots (check mixing)
|
||||
fig = plt.figure(figsize=(12, 8))
|
||||
az.plot_rank(idata, var_names=var_names)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved rank plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 3. Autocorrelation plots
|
||||
fig = plt.figure(figsize=(12, 8))
|
||||
az.plot_autocorr(idata, var_names=var_names, combined=True)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved autocorrelation plots")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 4. Energy plot (if available)
|
||||
if hasattr(idata.sample_stats, 'energy'):
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
az.plot_energy(idata)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved energy plot")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# 5. ESS plot
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
az.plot_ess(idata, var_names=var_names, kind='evolution')
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight')
|
||||
print(f" ✓ Saved ESS evolution plot")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
# Save summary to CSV
|
||||
results['summary'].to_csv(output_path / 'summary_statistics.csv')
|
||||
print(f" ✓ Saved summary statistics")
|
||||
|
||||
print(f"\nDiagnostic report complete! Files saved in '{output_dir}'")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None):
|
||||
"""
|
||||
Compare prior and posterior distributions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idata : arviz.InferenceData
|
||||
InferenceData with posterior samples
|
||||
prior_idata : arviz.InferenceData
|
||||
InferenceData with prior samples
|
||||
var_names : list, optional
|
||||
Variables to compare
|
||||
output_path : str, optional
|
||||
If provided, save plot to this path
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
fig, axes = plt.subplots(
|
||||
len(var_names) if var_names else 3,
|
||||
1,
|
||||
figsize=(10, 8)
|
||||
)
|
||||
|
||||
if not isinstance(axes, np.ndarray):
|
||||
axes = [axes]
|
||||
|
||||
for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]):
|
||||
# Plot prior
|
||||
az.plot_dist(
|
||||
prior_idata.prior[var].values.flatten(),
|
||||
label='Prior',
|
||||
ax=axes[idx],
|
||||
color='blue',
|
||||
alpha=0.3
|
||||
)
|
||||
|
||||
# Plot posterior
|
||||
az.plot_dist(
|
||||
idata.posterior[var].values.flatten(),
|
||||
label='Posterior',
|
||||
ax=axes[idx],
|
||||
color='green',
|
||||
alpha=0.3
|
||||
)
|
||||
|
||||
axes[idx].set_title(f'{var}: Prior vs Posterior')
|
||||
axes[idx].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Prior-posterior comparison saved to {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == '__main__':
|
||||
print("This script provides diagnostic functions for PyMC models.")
|
||||
print("\nExample usage:")
|
||||
print("""
|
||||
import pymc as pm
|
||||
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
|
||||
|
||||
# After sampling
|
||||
with pm.Model() as model:
|
||||
# ... define model ...
|
||||
idata = pm.sample()
|
||||
|
||||
# Quick diagnostic check
|
||||
results = check_diagnostics(idata)
|
||||
|
||||
# Full diagnostic report with plots
|
||||
create_diagnostic_report(
|
||||
idata,
|
||||
var_names=['alpha', 'beta', 'sigma'],
|
||||
output_dir='my_diagnostics/'
|
||||
)
|
||||
""")
|
||||
Reference in New Issue
Block a user