Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:30:10 +08:00
commit f0bd18fb4e
824 changed files with 331919 additions and 0 deletions

View File

@@ -0,0 +1,387 @@
"""
PyMC Model Comparison Script
Utilities for comparing multiple Bayesian models using information criteria
and cross-validation metrics.
Usage:
from scripts.model_comparison import compare_models, plot_model_comparison
# Compare multiple models
comparison = compare_models(
{'model1': idata1, 'model2': idata2, 'model3': idata3},
ic='loo'
)
# Visualize comparison
plot_model_comparison(comparison, output_path='model_comparison.png')
"""
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict
def compare_models(models_dict: Dict[str, az.InferenceData],
ic='loo',
scale='deviance',
verbose=True):
"""
Compare multiple models using information criteria.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects.
All models must have log_likelihood computed.
ic : str
Information criterion to use: 'loo' (default) or 'waic'
scale : str
Scale for IC: 'deviance' (default), 'log', or 'negative_log'
verbose : bool
Print detailed comparison results (default: True)
Returns
-------
pd.DataFrame
Comparison DataFrame with model rankings and statistics
Notes
-----
Models must be fit with idata_kwargs={'log_likelihood': True} or
log-likelihood computed afterwards with pm.compute_log_likelihood().
"""
if verbose:
print("="*70)
print(f" " * 25 + f"MODEL COMPARISON ({ic.upper()})")
print("="*70)
# Perform comparison
comparison = az.compare(models_dict, ic=ic, scale=scale)
if verbose:
print("\nModel Rankings:")
print("-"*70)
print(comparison.to_string())
print("\n" + "="*70)
print("INTERPRETATION GUIDE")
print("="*70)
print(f"• rank: Model ranking (0 = best)")
print(f"{ic}: {ic.upper()} estimate (lower is better)")
print(f"• p_{ic}: Effective number of parameters")
print(f"• d{ic}: Difference from best model")
print(f"• weight: Model probability (pseudo-BMA)")
print(f"• se: Standard error of {ic.upper()}")
print(f"• dse: Standard error of the difference")
print(f"• warning: True if model has reliability issues")
print(f"• scale: {scale}")
print("\n" + "="*70)
print("MODEL SELECTION GUIDELINES")
print("="*70)
best_model = comparison.index[0]
print(f"\n✓ Best model: {best_model}")
# Check for clear winner
if len(comparison) > 1:
delta = comparison.iloc[1][f'd{ic}']
delta_se = comparison.iloc[1]['dse']
if delta > 10:
print(f" → STRONG evidence for {best_model}{ic} > 10)")
elif delta > 4:
print(f" → MODERATE evidence for {best_model} (4 < Δ{ic} < 10)")
elif delta > 2:
print(f" → WEAK evidence for {best_model} (2 < Δ{ic} < 4)")
else:
print(f" → Models are SIMILAR (Δ{ic} < 2)")
print(f" Consider model averaging or choose based on simplicity")
# Check if difference is significant relative to SE
if delta > 2 * delta_se:
print(f" → Difference is > 2 SE, likely reliable")
else:
print(f" → Difference is < 2 SE, uncertain distinction")
# Check for warnings
if comparison['warning'].any():
print("\n⚠️ WARNING: Some models have reliability issues")
warned_models = comparison[comparison['warning']].index.tolist()
print(f" Models with warnings: {', '.join(warned_models)}")
print(f" → Check Pareto-k diagnostics with check_loo_reliability()")
return comparison
def check_loo_reliability(models_dict: Dict[str, az.InferenceData],
threshold=0.7,
verbose=True):
"""
Check LOO-CV reliability using Pareto-k diagnostics.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects
threshold : float
Pareto-k threshold for flagging observations (default: 0.7)
verbose : bool
Print detailed diagnostics (default: True)
Returns
-------
dict
Dictionary with Pareto-k diagnostics for each model
"""
if verbose:
print("="*70)
print(" " * 20 + "LOO RELIABILITY CHECK")
print("="*70)
results = {}
for name, idata in models_dict.items():
if verbose:
print(f"\n{name}:")
print("-"*70)
# Compute LOO with pointwise results
loo_result = az.loo(idata, pointwise=True)
pareto_k = loo_result.pareto_k.values
# Count problematic observations
n_high = (pareto_k > threshold).sum()
n_very_high = (pareto_k > 1.0).sum()
results[name] = {
'pareto_k': pareto_k,
'n_high': n_high,
'n_very_high': n_very_high,
'max_k': pareto_k.max(),
'loo': loo_result
}
if verbose:
print(f"Pareto-k diagnostics:")
print(f" • Good (k < 0.5): {(pareto_k < 0.5).sum()} observations")
print(f" • OK (0.5 ≤ k < 0.7): {((pareto_k >= 0.5) & (pareto_k < 0.7)).sum()} observations")
print(f" • Bad (0.7 ≤ k < 1.0): {((pareto_k >= 0.7) & (pareto_k < 1.0)).sum()} observations")
print(f" • Very bad (k ≥ 1.0): {(pareto_k >= 1.0).sum()} observations")
print(f" • Maximum k: {pareto_k.max():.3f}")
if n_high > 0:
print(f"\n⚠️ {n_high} observations with k > {threshold}")
print(" LOO approximation may be unreliable for these points")
print(" Solutions:")
print(" → Use WAIC instead (less sensitive to outliers)")
print(" → Investigate influential observations")
print(" → Consider more flexible model")
if n_very_high > 0:
print(f"\n⚠️ {n_very_high} observations with k > 1.0")
print(" These points have very high influence")
print(" → Strongly consider K-fold CV or other validation")
else:
print(f"✓ All Pareto-k values < {threshold}")
print(" LOO estimates are reliable")
return results
def plot_model_comparison(comparison, output_path=None, show=True):
"""
Visualize model comparison results.
Parameters
----------
comparison : pd.DataFrame
Comparison DataFrame from az.compare()
output_path : str, optional
If provided, save plot to this path
show : bool
Whether to display plot (default: True)
Returns
-------
matplotlib.figure.Figure
The comparison figure
"""
fig = plt.figure(figsize=(10, 6))
az.plot_compare(comparison)
plt.title('Model Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Comparison plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return fig
def model_averaging(models_dict: Dict[str, az.InferenceData],
weights=None,
var_name='y_obs',
ic='loo'):
"""
Perform Bayesian model averaging using model weights.
Parameters
----------
models_dict : dict
Dictionary mapping model names to InferenceData objects
weights : array-like, optional
Model weights. If None, computed from IC (pseudo-BMA weights)
var_name : str
Name of the predicted variable (default: 'y_obs')
ic : str
Information criterion for computing weights if not provided
Returns
-------
np.ndarray
Averaged predictions across models
np.ndarray
Model weights used
"""
if weights is None:
comparison = az.compare(models_dict, ic=ic)
weights = comparison['weight'].values
model_names = comparison.index.tolist()
else:
model_names = list(models_dict.keys())
weights = np.array(weights)
weights = weights / weights.sum() # Normalize
print("="*70)
print(" " * 22 + "BAYESIAN MODEL AVERAGING")
print("="*70)
print("\nModel weights:")
for name, weight in zip(model_names, weights):
print(f" {name}: {weight:.4f} ({weight*100:.2f}%)")
# Extract predictions and average
predictions = []
for name in model_names:
idata = models_dict[name]
if 'posterior_predictive' in idata:
pred = idata.posterior_predictive[var_name].values
else:
print(f"Warning: {name} missing posterior_predictive, skipping")
continue
predictions.append(pred)
# Weighted average
averaged = sum(w * p for w, p in zip(weights, predictions))
print(f"\n✓ Model averaging complete")
print(f" Combined predictions using {len(predictions)} models")
return averaged, weights
def cross_validation_comparison(models_dict: Dict[str, az.InferenceData],
k=10,
verbose=True):
"""
Perform k-fold cross-validation comparison (conceptual guide).
Note: This function provides guidance. Full k-fold CV requires
re-fitting models k times, which should be done in the main script.
Parameters
----------
models_dict : dict
Dictionary of model names to InferenceData
k : int
Number of folds (default: 10)
verbose : bool
Print guidance
Returns
-------
None
"""
if verbose:
print("="*70)
print(" " * 20 + "K-FOLD CROSS-VALIDATION GUIDE")
print("="*70)
print(f"\nTo perform {k}-fold CV:")
print("""
1. Split data into k folds
2. For each fold:
- Train all models on k-1 folds
- Compute log-likelihood on held-out fold
3. Sum log-likelihoods across folds for each model
4. Compare models using total CV score
Example code:
-------------
from sklearn.model_selection import KFold
kf = KFold(n_splits=k, shuffle=True, random_seed=42)
cv_scores = {name: [] for name in models_dict.keys()}
for train_idx, test_idx in kf.split(X):
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]
for name in models_dict.keys():
# Fit model on train set
with create_model(name, X_train, y_train) as model:
idata = pm.sample()
# Compute log-likelihood on test set
with model:
pm.set_data({'X': X_test, 'y': y_test})
log_lik = pm.compute_log_likelihood(idata).sum()
cv_scores[name].append(log_lik)
# Compare total CV scores
for name, scores in cv_scores.items():
print(f"{name}: {np.sum(scores):.2f}")
""")
print("\nNote: K-fold CV is expensive but most reliable for model comparison")
print(" Use when LOO has reliability issues (high Pareto-k values)")
# Example usage
if __name__ == '__main__':
print("This script provides model comparison utilities for PyMC.")
print("\nExample usage:")
print("""
import pymc as pm
from scripts.model_comparison import compare_models, check_loo_reliability
# Fit multiple models (must include log_likelihood)
with pm.Model() as model1:
# ... define model 1 ...
idata1 = pm.sample(idata_kwargs={'log_likelihood': True})
with pm.Model() as model2:
# ... define model 2 ...
idata2 = pm.sample(idata_kwargs={'log_likelihood': True})
# Compare models
models = {'Simple': idata1, 'Complex': idata2}
comparison = compare_models(models, ic='loo')
# Check reliability
reliability = check_loo_reliability(models)
# Visualize
plot_model_comparison(comparison, output_path='comparison.png')
# Model averaging
averaged_pred, weights = model_averaging(models, var_name='y_obs')
""")

View File

@@ -0,0 +1,350 @@
"""
PyMC Model Diagnostics Script
Comprehensive diagnostic checks for PyMC models.
Run this after sampling to validate results before interpretation.
Usage:
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
# Quick check
check_diagnostics(idata)
# Full report with plots
create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/')
"""
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01):
"""
Perform comprehensive diagnostic checks on MCMC samples.
Parameters
----------
idata : arviz.InferenceData
InferenceData object from pm.sample()
var_names : list, optional
Variables to check. If None, checks all model parameters
ess_threshold : int
Minimum acceptable effective sample size (default: 400)
rhat_threshold : float
Maximum acceptable R-hat value (default: 1.01)
Returns
-------
dict
Dictionary with diagnostic results and flags
"""
print("="*70)
print(" " * 20 + "MCMC DIAGNOSTICS REPORT")
print("="*70)
# Get summary statistics
summary = az.summary(idata, var_names=var_names)
results = {
'summary': summary,
'has_issues': False,
'issues': []
}
# 1. Check R-hat (convergence)
print("\n1. CONVERGENCE CHECK (R-hat)")
print("-" * 70)
bad_rhat = summary[summary['r_hat'] > rhat_threshold]
if len(bad_rhat) > 0:
print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}")
print("\nTop 10 worst R-hat values:")
print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10))
print("\n⚠️ Chains may not have converged!")
print(" → Run longer chains or check for multimodality")
results['has_issues'] = True
results['issues'].append('convergence')
else:
print(f"✓ All R-hat values ≤ {rhat_threshold}")
print(" Chains have converged successfully")
# 2. Check Effective Sample Size
print("\n2. EFFECTIVE SAMPLE SIZE (ESS)")
print("-" * 70)
low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold]
low_ess_tail = summary[summary['ess_tail'] < ess_threshold]
if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0:
print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}")
if len(low_ess_bulk) > 0:
print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):")
print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10))
if len(low_ess_tail) > 0:
print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):")
print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10))
print("\n⚠️ High autocorrelation detected!")
print(" → Sample more draws or reparameterize to reduce correlation")
results['has_issues'] = True
results['issues'].append('low_ess')
else:
print(f"✓ All ESS values ≥ {ess_threshold}")
print(" Sufficient effective samples")
# 3. Check Divergences
print("\n3. DIVERGENT TRANSITIONS")
print("-" * 70)
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
divergence_rate = divergences / total_samples * 100
print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)")
print("\n Divergences indicate biased sampling in difficult posterior regions")
print(" Solutions:")
print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)")
print(" → Use non-centered parameterization for hierarchical models")
print(" → Add stronger/more informative priors")
print(" → Check for model misspecification")
results['has_issues'] = True
results['issues'].append('divergences')
results['n_divergences'] = divergences
else:
print("✓ No divergences detected")
print(" NUTS explored the posterior successfully")
# 4. Check Tree Depth
print("\n4. TREE DEPTH")
print("-" * 70)
tree_depth = idata.sample_stats.tree_depth
max_tree_depth = tree_depth.max().item()
# Typical max_treedepth is 10 (default in PyMC)
hits_max = (tree_depth >= 10).sum().item()
if hits_max > 0:
total_samples = len(idata.posterior.draw) * len(idata.posterior.chain)
hit_rate = hits_max / total_samples * 100
print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)")
print("\n Model may be difficult to explore efficiently")
print(" Solutions:")
print(" → Reparameterize model to improve geometry")
print(" → Increase max_treedepth (if necessary)")
results['issues'].append('max_treedepth')
else:
print(f"✓ No maximum tree depth issues")
print(f" Maximum tree depth reached: {max_tree_depth}")
# 5. Check Energy (if available)
if hasattr(idata.sample_stats, 'energy'):
print("\n5. ENERGY DIAGNOSTICS")
print("-" * 70)
print("✓ Energy statistics available")
print(" Use az.plot_energy(idata) to visualize energy transitions")
print(" Good separation indicates healthy HMC sampling")
# Summary
print("\n" + "="*70)
print("SUMMARY")
print("="*70)
if not results['has_issues']:
print("✓ All diagnostics passed!")
print(" Your model has sampled successfully.")
print(" Proceed with inference and interpretation.")
else:
print("⚠️ Some diagnostics failed!")
print(f" Issues found: {', '.join(results['issues'])}")
print(" Review warnings above and consider re-running with adjustments.")
print("="*70)
return results
def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False):
"""
Create comprehensive diagnostic report with plots.
Parameters
----------
idata : arviz.InferenceData
InferenceData object from pm.sample()
var_names : list, optional
Variables to plot. If None, uses all model parameters
output_dir : str
Directory to save diagnostic plots
show : bool
Whether to display plots (default: False, just save)
Returns
-------
dict
Diagnostic results from check_diagnostics
"""
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Run diagnostic checks
results = check_diagnostics(idata, var_names=var_names)
print(f"\nGenerating diagnostic plots in '{output_dir}'...")
# 1. Trace plots
fig, axes = plt.subplots(
len(var_names) if var_names else 5,
2,
figsize=(12, 10)
)
az.plot_trace(idata, var_names=var_names, axes=axes)
plt.tight_layout()
plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved trace plots")
if show:
plt.show()
else:
plt.close()
# 2. Rank plots (check mixing)
fig = plt.figure(figsize=(12, 8))
az.plot_rank(idata, var_names=var_names)
plt.tight_layout()
plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved rank plots")
if show:
plt.show()
else:
plt.close()
# 3. Autocorrelation plots
fig = plt.figure(figsize=(12, 8))
az.plot_autocorr(idata, var_names=var_names, combined=True)
plt.tight_layout()
plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved autocorrelation plots")
if show:
plt.show()
else:
plt.close()
# 4. Energy plot (if available)
if hasattr(idata.sample_stats, 'energy'):
fig = plt.figure(figsize=(10, 6))
az.plot_energy(idata)
plt.tight_layout()
plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved energy plot")
if show:
plt.show()
else:
plt.close()
# 5. ESS plot
fig = plt.figure(figsize=(10, 6))
az.plot_ess(idata, var_names=var_names, kind='evolution')
plt.tight_layout()
plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight')
print(f" ✓ Saved ESS evolution plot")
if show:
plt.show()
else:
plt.close()
# Save summary to CSV
results['summary'].to_csv(output_path / 'summary_statistics.csv')
print(f" ✓ Saved summary statistics")
print(f"\nDiagnostic report complete! Files saved in '{output_dir}'")
return results
def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None):
"""
Compare prior and posterior distributions.
Parameters
----------
idata : arviz.InferenceData
InferenceData with posterior samples
prior_idata : arviz.InferenceData
InferenceData with prior samples
var_names : list, optional
Variables to compare
output_path : str, optional
If provided, save plot to this path
Returns
-------
None
"""
fig, axes = plt.subplots(
len(var_names) if var_names else 3,
1,
figsize=(10, 8)
)
if not isinstance(axes, np.ndarray):
axes = [axes]
for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]):
# Plot prior
az.plot_dist(
prior_idata.prior[var].values.flatten(),
label='Prior',
ax=axes[idx],
color='blue',
alpha=0.3
)
# Plot posterior
az.plot_dist(
idata.posterior[var].values.flatten(),
label='Posterior',
ax=axes[idx],
color='green',
alpha=0.3
)
axes[idx].set_title(f'{var}: Prior vs Posterior')
axes[idx].legend()
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Prior-posterior comparison saved to {output_path}")
else:
plt.show()
# Example usage
if __name__ == '__main__':
print("This script provides diagnostic functions for PyMC models.")
print("\nExample usage:")
print("""
import pymc as pm
from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report
# After sampling
with pm.Model() as model:
# ... define model ...
idata = pm.sample()
# Quick diagnostic check
results = check_diagnostics(idata)
# Full diagnostic report with plots
create_diagnostic_report(
idata,
var_names=['alpha', 'beta', 'sigma'],
output_dir='my_diagnostics/'
)
""")