Initial commit
This commit is contained in:
339
skills/shap/references/explainers.md
Normal file
339
skills/shap/references/explainers.md
Normal file
@@ -0,0 +1,339 @@
|
||||
# SHAP Explainers Reference
|
||||
|
||||
This document provides comprehensive information about all SHAP explainer classes, their parameters, methods, and when to use each type.
|
||||
|
||||
## Overview
|
||||
|
||||
SHAP provides specialized explainers for different model types, each optimized for specific architectures. The general `shap.Explainer` class automatically selects the appropriate algorithm based on the model type.
|
||||
|
||||
## Core Explainer Classes
|
||||
|
||||
### shap.Explainer (Auto-selector)
|
||||
|
||||
**Purpose**: Automatically uses Shapley values to explain any machine learning model or Python function by selecting the most appropriate explainer algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: The model to explain (function or model object)
|
||||
- `masker`: Background data or masker object for feature manipulation
|
||||
- `algorithm`: Optional override to force specific explainer type
|
||||
- `output_names`: Names for model outputs
|
||||
- `feature_names`: Names for input features
|
||||
|
||||
**When to Use**: Default choice when unsure which explainer to use; automatically selects the best algorithm based on model type.
|
||||
|
||||
### TreeExplainer
|
||||
|
||||
**Purpose**: Fast and exact SHAP value computation for tree-based ensemble models using the Tree SHAP algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Tree-based model (XGBoost, LightGBM, CatBoost, PySpark, or scikit-learn trees)
|
||||
- `data`: Background dataset for feature integration (optional with tree_path_dependent)
|
||||
- `feature_perturbation`: How to handle dependent features
|
||||
- `"interventional"`: Requires background data; follows causal inference rules
|
||||
- `"tree_path_dependent"`: No background data needed; uses training examples per leaf
|
||||
- `"auto"`: Defaults to interventional if data provided, otherwise tree_path_dependent
|
||||
- `model_output`: What model output to explain
|
||||
- `"raw"`: Standard model output (default)
|
||||
- `"probability"`: Probability-transformed output
|
||||
- `"log_loss"`: Natural log of loss function
|
||||
- Custom method names like `"predict_proba"`
|
||||
- `feature_names`: Optional feature naming
|
||||
|
||||
**Supported Models**:
|
||||
- XGBoost (xgboost.XGBClassifier, xgboost.XGBRegressor, xgboost.Booster)
|
||||
- LightGBM (lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, lightgbm.Booster)
|
||||
- CatBoost (catboost.CatBoostClassifier, catboost.CatBoostRegressor)
|
||||
- PySpark MLlib tree models
|
||||
- scikit-learn (DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor, GradientBoostingClassifier, GradientBoostingRegressor)
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X)`: Computes SHAP values for samples; returns arrays where each row represents feature attribution
|
||||
- `shap_interaction_values(X)`: Estimates interaction effects between feature pairs; provides matrices with main effects and pairwise interactions
|
||||
- `explain_row(row)`: Explains individual rows with detailed attribution information
|
||||
|
||||
**When to Use**:
|
||||
- Primary choice for all tree-based models
|
||||
- When exact SHAP values are needed (not approximations)
|
||||
- When computational speed is important for large datasets
|
||||
- For models like random forests, gradient boosting, or XGBoost
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import xgboost
|
||||
|
||||
# Train model
|
||||
model = xgboost.XGBClassifier().fit(X_train, y_train)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.TreeExplainer(model)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test)
|
||||
|
||||
# Compute interaction values
|
||||
shap_interaction = explainer.shap_interaction_values(X_test)
|
||||
```
|
||||
|
||||
### DeepExplainer
|
||||
|
||||
**Purpose**: Approximates SHAP values for deep learning models using an enhanced version of the DeepLIFT algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Framework-dependent specification
|
||||
- **TensorFlow**: Tuple of (input_tensor, output_tensor) where output is single-dimensional
|
||||
- **PyTorch**: `nn.Module` object or tuple of `(model, layer)` for layer-specific explanations
|
||||
- `data`: Background dataset for feature integration
|
||||
- **TensorFlow**: numpy arrays or pandas DataFrames
|
||||
- **PyTorch**: torch tensors
|
||||
- **Recommended size**: 100-1000 samples (not full training set) to balance accuracy and computational cost
|
||||
- `session` (TensorFlow only): Optional session object; auto-detected if None
|
||||
- `learning_phase_flags`: Custom learning phase tensors for handling batch norm/dropout during inference
|
||||
|
||||
**Supported Frameworks**:
|
||||
- **TensorFlow**: Full support including Keras models
|
||||
- **PyTorch**: Complete integration with nn.Module architecture
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X)`: Returns approximate SHAP values for the model applied to data X
|
||||
- `explain_row(row)`: Explains single rows with attribution values and expected outputs
|
||||
- `save(file)` / `load(file)`: Serialization support for explainer objects
|
||||
- `supports_model_with_masker(model, masker)`: Compatibility checker for model types
|
||||
|
||||
**When to Use**:
|
||||
- For deep neural networks in TensorFlow or PyTorch
|
||||
- When working with convolutional neural networks (CNNs)
|
||||
- For recurrent neural networks (RNNs) and transformers
|
||||
- When model-specific explanation is needed for deep learning architectures
|
||||
|
||||
**Key Design Feature**:
|
||||
Variance of expectation estimates scales approximately as 1/√N, where N is the number of background samples, enabling accuracy-efficiency trade-offs.
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import tensorflow as tf
|
||||
|
||||
# Assume model is a Keras model
|
||||
model = tf.keras.models.load_model('my_model.h5')
|
||||
|
||||
# Select background samples (subset of training data)
|
||||
background = X_train[:100]
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.DeepExplainer(model, background)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### KernelExplainer
|
||||
|
||||
**Purpose**: Model-agnostic SHAP value computation using the Kernel SHAP method with weighted linear regression.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Function or model object that takes a matrix of samples and returns model outputs
|
||||
- `data`: Background dataset (numpy array, pandas DataFrame, or sparse matrix) used to simulate missing features
|
||||
- `feature_names`: Optional list of feature names; automatically derived from DataFrame column names if available
|
||||
- `link`: Connection function between feature importance and model output
|
||||
- `"identity"`: Direct relationship (default)
|
||||
- `"logit"`: For probability outputs
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X, **kwargs)`: Calculates SHAP values for sample predictions
|
||||
- `nsamples`: Evaluation count per prediction ("auto" or integer); higher values reduce variance
|
||||
- `l1_reg`: Feature selection regularization ("num_features(int)", "aic", "bic", or float)
|
||||
- Returns arrays where each row sums to the difference between model output and expected value
|
||||
- `explain_row(row)`: Explains individual predictions with attribution values and expected values
|
||||
- `save(file)` / `load(file)`: Persist and restore explainer objects
|
||||
|
||||
**When to Use**:
|
||||
- For black-box models where specialized explainers aren't available
|
||||
- When working with custom prediction functions
|
||||
- For any model type (neural networks, SVMs, ensemble methods, etc.)
|
||||
- When model-agnostic explanations are needed
|
||||
- **Note**: Slower than specialized explainers; use only when no specialized option exists
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
from sklearn.svm import SVC
|
||||
|
||||
# Train model
|
||||
model = SVC(probability=True).fit(X_train, y_train)
|
||||
|
||||
# Create prediction function
|
||||
predict_fn = lambda x: model.predict_proba(x)[:, 1]
|
||||
|
||||
# Select background samples
|
||||
background = shap.sample(X_train, 100)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.KernelExplainer(predict_fn, background)
|
||||
|
||||
# Compute SHAP values (may be slow)
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### LinearExplainer
|
||||
|
||||
**Purpose**: Specialized explainer for linear models that accounts for feature correlations.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Linear model or tuple of (coefficients, intercept)
|
||||
- `masker`: Background data for feature correlation
|
||||
- `feature_perturbation`: How to handle feature correlations
|
||||
- `"interventional"`: Assumes feature independence
|
||||
- `"correlation_dependent"`: Accounts for feature correlations
|
||||
|
||||
**Supported Models**:
|
||||
- scikit-learn linear models (LinearRegression, LogisticRegression, Ridge, Lasso, ElasticNet)
|
||||
- Custom linear models with coefficients and intercept
|
||||
|
||||
**When to Use**:
|
||||
- For linear regression and logistic regression models
|
||||
- When feature correlations are important to explanation accuracy
|
||||
- When extremely fast explanations are needed
|
||||
- For GLMs and other linear model types
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
# Train model
|
||||
model = LogisticRegression().fit(X_train, y_train)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.LinearExplainer(model, X_train)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test)
|
||||
```
|
||||
|
||||
### GradientExplainer
|
||||
|
||||
**Purpose**: Uses expected gradients to approximate SHAP values for neural networks.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Deep learning model (TensorFlow or PyTorch)
|
||||
- `data`: Background samples for integration
|
||||
- `batch_size`: Batch size for gradient computation
|
||||
- `local_smoothing`: Amount of noise to add for smoothing (default 0)
|
||||
|
||||
**When to Use**:
|
||||
- As an alternative to DeepExplainer for neural networks
|
||||
- When gradient-based explanations are preferred
|
||||
- For differentiable models where gradient information is available
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import torch
|
||||
|
||||
# Assume model is a PyTorch model
|
||||
model = torch.load('model.pt')
|
||||
|
||||
# Select background samples
|
||||
background = X_train[:100]
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.GradientExplainer(model, background)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### PermutationExplainer
|
||||
|
||||
**Purpose**: Approximates Shapley values by iterating through permutations of inputs.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Prediction function
|
||||
- `masker`: Background data or masker object
|
||||
- `max_evals`: Maximum number of model evaluations per sample
|
||||
|
||||
**When to Use**:
|
||||
- When exact Shapley values are needed but specialized explainers aren't available
|
||||
- For small feature sets where permutation is tractable
|
||||
- As a more accurate alternative to KernelExplainer (but slower)
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.PermutationExplainer(model.predict, X_train)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
## Explainer Selection Guide
|
||||
|
||||
**Decision Tree for Choosing an Explainer**:
|
||||
|
||||
1. **Is your model tree-based?** (XGBoost, LightGBM, CatBoost, Random Forest, etc.)
|
||||
- Yes → Use `TreeExplainer` (fast and exact)
|
||||
- No → Continue to step 2
|
||||
|
||||
2. **Is your model a deep neural network?** (TensorFlow, PyTorch, Keras)
|
||||
- Yes → Use `DeepExplainer` or `GradientExplainer`
|
||||
- No → Continue to step 3
|
||||
|
||||
3. **Is your model linear?** (Linear/Logistic Regression, GLMs)
|
||||
- Yes → Use `LinearExplainer` (extremely fast)
|
||||
- No → Continue to step 4
|
||||
|
||||
4. **Do you need model-agnostic explanations?**
|
||||
- Yes → Use `KernelExplainer` (slower but works with any model)
|
||||
- If computational budget allows and high accuracy is needed → Use `PermutationExplainer`
|
||||
|
||||
5. **Unsure or want automatic selection?**
|
||||
- Use `shap.Explainer` (auto-selects best algorithm)
|
||||
|
||||
## Common Parameters Across Explainers
|
||||
|
||||
**Background Data / Masker**:
|
||||
- Purpose: Represents the "typical" input to establish baseline expectations
|
||||
- Size recommendations: 50-1000 samples (more for complex models)
|
||||
- Selection: Random sample from training data or kmeans-selected representatives
|
||||
|
||||
**Feature Names**:
|
||||
- Automatically extracted from pandas DataFrames
|
||||
- Can be manually specified for numpy arrays
|
||||
- Important for plot interpretability
|
||||
|
||||
**Model Output Specification**:
|
||||
- Raw model output vs. transformed output (probabilities, log-odds)
|
||||
- Critical for correct interpretation of SHAP values
|
||||
- Example: For XGBoost classifiers, SHAP explains margin output (log-odds) before logistic transformation
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
**Speed Ranking** (fastest to slowest):
|
||||
1. `LinearExplainer` - Nearly instantaneous
|
||||
2. `TreeExplainer` - Very fast, scales well
|
||||
3. `DeepExplainer` - Fast for neural networks
|
||||
4. `GradientExplainer` - Fast for neural networks
|
||||
5. `KernelExplainer` - Slow, use only when necessary
|
||||
6. `PermutationExplainer` - Very slow but most accurate for small feature sets
|
||||
|
||||
**Memory Considerations**:
|
||||
- `TreeExplainer`: Low memory overhead
|
||||
- `DeepExplainer`: Memory proportional to background sample size
|
||||
- `KernelExplainer`: Can be memory-intensive for large background datasets
|
||||
- For large datasets: Use batching or sample subsets
|
||||
|
||||
## Explainer Output: The Explanation Object
|
||||
|
||||
All explainers return `shap.Explanation` objects containing:
|
||||
- `values`: SHAP values (numpy array)
|
||||
- `base_values`: Expected model output (baseline)
|
||||
- `data`: Original feature values
|
||||
- `feature_names`: Names of features
|
||||
|
||||
The Explanation object supports:
|
||||
- Slicing: `explanation[0]` for first sample
|
||||
- Array operations: Compatible with numpy operations
|
||||
- Direct plotting: Can be passed to plot functions
|
||||
507
skills/shap/references/plots.md
Normal file
507
skills/shap/references/plots.md
Normal file
@@ -0,0 +1,507 @@
|
||||
# SHAP Visualization Reference
|
||||
|
||||
This document provides comprehensive information about all SHAP plotting functions, their parameters, use cases, and best practices for visualizing model explanations.
|
||||
|
||||
## Overview
|
||||
|
||||
SHAP provides diverse visualization tools for explaining model predictions at both individual and global levels. Each plot type serves specific purposes in understanding feature importance, interactions, and prediction mechanisms.
|
||||
|
||||
## Plot Types
|
||||
|
||||
### Waterfall Plots
|
||||
|
||||
**Purpose**: Display explanations for individual predictions, showing how each feature moves the prediction from the baseline (expected value) toward the final prediction.
|
||||
|
||||
**Function**: `shap.plots.waterfall(explanation, max_display=10, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `explanation`: Single row from an Explanation object (not multiple samples)
|
||||
- `max_display`: Number of features to show (default: 10); less impactful features collapse into a single "other features" term
|
||||
- `show`: Whether to display the plot immediately
|
||||
|
||||
**Visual Elements**:
|
||||
- **X-axis**: Shows SHAP values (contribution to prediction)
|
||||
- **Starting point**: Model's expected value (baseline)
|
||||
- **Feature contributions**: Red bars (positive) or blue bars (negative) showing how each feature moves the prediction
|
||||
- **Feature values**: Displayed in gray to the left of feature names
|
||||
- **Ending point**: Final model prediction
|
||||
|
||||
**When to Use**:
|
||||
- Explaining individual predictions in detail
|
||||
- Understanding which features drove a specific decision
|
||||
- Communicating model behavior for single instances (e.g., loan denial, diagnosis)
|
||||
- Debugging unexpected predictions
|
||||
|
||||
**Important Notes**:
|
||||
- For XGBoost classifiers, predictions are explained in log-odds units (margin output before logistic transformation)
|
||||
- SHAP values sum to the difference between baseline and final prediction (additivity property)
|
||||
- Use scatter plots alongside waterfall plots to explore patterns across multiple samples
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
|
||||
# Compute SHAP values
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Plot waterfall for first prediction
|
||||
shap.plots.waterfall(shap_values[0])
|
||||
|
||||
# Show more features
|
||||
shap.plots.waterfall(shap_values[0], max_display=20)
|
||||
```
|
||||
|
||||
### Beeswarm Plots
|
||||
|
||||
**Purpose**: Information-dense summary of how top features impact model output across the entire dataset, combining feature importance with value distributions.
|
||||
|
||||
**Function**: `shap.plots.beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), color=None, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `shap_values`: Explanation object containing multiple samples
|
||||
- `max_display`: Number of features to display (default: 10)
|
||||
- `order`: How to rank features
|
||||
- `Explanation.abs.mean(0)`: Mean absolute SHAP values (default)
|
||||
- `Explanation.abs.max(0)`: Maximum absolute values (highlights outlier impacts)
|
||||
- `color`: matplotlib colormap; defaults to red-blue scheme
|
||||
- `show`: Whether to display the plot immediately
|
||||
|
||||
**Visual Elements**:
|
||||
- **Y-axis**: Features ranked by importance
|
||||
- **X-axis**: SHAP value (impact on model output)
|
||||
- **Each dot**: Single instance from dataset
|
||||
- **Dot position (X)**: SHAP value magnitude
|
||||
- **Dot color**: Original feature value (red = high, blue = low)
|
||||
- **Dot clustering**: Shows density/distribution of impacts
|
||||
|
||||
**When to Use**:
|
||||
- Summarizing feature importance across entire datasets
|
||||
- Understanding both average and individual feature impacts
|
||||
- Identifying feature value patterns and their effects
|
||||
- Comparing global model behavior across features
|
||||
- Detecting nonlinear relationships (e.g., higher age → lower income likelihood)
|
||||
|
||||
**Practical Variations**:
|
||||
```python
|
||||
# Standard beeswarm plot
|
||||
shap.plots.beeswarm(shap_values)
|
||||
|
||||
# Show more features
|
||||
shap.plots.beeswarm(shap_values, max_display=20)
|
||||
|
||||
# Order by maximum absolute values (highlight outliers)
|
||||
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))
|
||||
|
||||
# Plot absolute SHAP values with fixed coloring
|
||||
shap.plots.beeswarm(shap_values.abs, color="shap_red")
|
||||
|
||||
# Custom matplotlib colormap
|
||||
shap.plots.beeswarm(shap_values, color=plt.cm.viridis)
|
||||
```
|
||||
|
||||
### Bar Plots
|
||||
|
||||
**Purpose**: Display feature importance as mean absolute SHAP values, providing clean, simple visualizations of global feature impact.
|
||||
|
||||
**Function**: `shap.plots.bar(shap_values, max_display=10, clustering=None, clustering_cutoff=0.5, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `shap_values`: Explanation object (can be single instance, global, or cohorts)
|
||||
- `max_display`: Maximum number of features/bars to show
|
||||
- `clustering`: Optional hierarchical clustering object from `shap.utils.hclust`
|
||||
- `clustering_cutoff`: Threshold for displaying clustering structure (0-1, default: 0.5)
|
||||
|
||||
**Plot Types**:
|
||||
|
||||
#### Global Bar Plot
|
||||
Shows overall feature importance across all samples. Importance calculated as mean absolute SHAP value.
|
||||
|
||||
```python
|
||||
# Global feature importance
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
shap.plots.bar(shap_values)
|
||||
```
|
||||
|
||||
#### Local Bar Plot
|
||||
Displays SHAP values for a single instance with feature values shown in gray.
|
||||
|
||||
```python
|
||||
# Single prediction explanation
|
||||
shap.plots.bar(shap_values[0])
|
||||
```
|
||||
|
||||
#### Cohort Bar Plot
|
||||
Compares feature importance across subgroups by passing a dictionary of Explanation objects.
|
||||
|
||||
```python
|
||||
# Compare cohorts
|
||||
cohorts = {
|
||||
"Group A": shap_values[mask_A],
|
||||
"Group B": shap_values[mask_B]
|
||||
}
|
||||
shap.plots.bar(cohorts)
|
||||
```
|
||||
|
||||
**Feature Clustering**:
|
||||
Identifies redundant features using model-based clustering (more accurate than correlation-based methods).
|
||||
|
||||
```python
|
||||
# Add feature clustering
|
||||
clustering = shap.utils.hclust(X_train, y_train)
|
||||
shap.plots.bar(shap_values, clustering=clustering)
|
||||
|
||||
# Adjust clustering display threshold
|
||||
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.3)
|
||||
```
|
||||
|
||||
**When to Use**:
|
||||
- Quick overview of global feature importance
|
||||
- Comparing feature importance across cohorts or models
|
||||
- Identifying redundant or correlated features
|
||||
- Clean, simple visualizations for presentations
|
||||
|
||||
### Force Plots
|
||||
|
||||
**Purpose**: Additive force visualization showing how features push prediction higher (red) or lower (blue) from baseline.
|
||||
|
||||
**Function**: `shap.plots.force(base_value, shap_values, features, feature_names=None, out_names=None, link="identity", matplotlib=False, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `base_value`: Expected value (baseline prediction)
|
||||
- `shap_values`: SHAP values for sample(s)
|
||||
- `features`: Feature values for sample(s)
|
||||
- `feature_names`: Optional feature names
|
||||
- `link`: Transform function ("identity" or "logit")
|
||||
- `matplotlib`: Use matplotlib backend (default: interactive JavaScript)
|
||||
|
||||
**Visual Elements**:
|
||||
- **Baseline**: Starting prediction (expected value)
|
||||
- **Red arrows**: Features pushing prediction higher
|
||||
- **Blue arrows**: Features pushing prediction lower
|
||||
- **Final value**: Resulting prediction
|
||||
|
||||
**Interactive Features** (JavaScript mode):
|
||||
- Hover for detailed feature information
|
||||
- Multiple samples create stacked visualization
|
||||
- Can rotate for different perspectives
|
||||
|
||||
**When to Use**:
|
||||
- Interactive exploration of predictions
|
||||
- Visualizing multiple predictions simultaneously
|
||||
- Presentations requiring interactive elements
|
||||
- Understanding prediction composition at a glance
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
# Single prediction force plot
|
||||
shap.plots.force(
|
||||
shap_values.base_values[0],
|
||||
shap_values.values[0],
|
||||
X_test.iloc[0],
|
||||
matplotlib=True
|
||||
)
|
||||
|
||||
# Multiple predictions (interactive)
|
||||
shap.plots.force(
|
||||
shap_values.base_values,
|
||||
shap_values.values,
|
||||
X_test
|
||||
)
|
||||
```
|
||||
|
||||
### Scatter Plots (Dependence Plots)
|
||||
|
||||
**Purpose**: Show relationship between feature values and their SHAP values, revealing how feature values impact predictions.
|
||||
|
||||
**Function**: `shap.plots.scatter(shap_values, color=None, hist=True, alpha=1, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `shap_values`: Explanation object, can specify feature with subscript (e.g., `shap_values[:, "Age"]`)
|
||||
- `color`: Feature to use for coloring points (string name or Explanation object)
|
||||
- `hist`: Show histogram of feature values on y-axis
|
||||
- `alpha`: Point transparency (useful for dense plots)
|
||||
|
||||
**Visual Elements**:
|
||||
- **X-axis**: Feature value
|
||||
- **Y-axis**: SHAP value (impact on prediction)
|
||||
- **Point color**: Another feature's value (for interaction detection)
|
||||
- **Histogram**: Distribution of feature values
|
||||
|
||||
**When to Use**:
|
||||
- Understanding feature-prediction relationships
|
||||
- Detecting nonlinear effects
|
||||
- Identifying feature interactions
|
||||
- Validating or discovering patterns in model behavior
|
||||
- Exploring counterintuitive predictions from waterfall plots
|
||||
|
||||
**Interaction Detection**:
|
||||
Color points by another feature to reveal interactions.
|
||||
|
||||
```python
|
||||
# Basic dependence plot
|
||||
shap.plots.scatter(shap_values[:, "Age"])
|
||||
|
||||
# Color by another feature to show interactions
|
||||
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education"])
|
||||
|
||||
# Multiple features in one plot
|
||||
shap.plots.scatter(shap_values[:, ["Age", "Education", "Hours-per-week"]])
|
||||
|
||||
# Increase transparency for dense data
|
||||
shap.plots.scatter(shap_values[:, "Age"], alpha=0.5)
|
||||
```
|
||||
|
||||
### Heatmap Plots
|
||||
|
||||
**Purpose**: Visualize SHAP values for multiple samples simultaneously, showing feature impacts across instances.
|
||||
|
||||
**Function**: `shap.plots.heatmap(shap_values, instance_order=None, feature_values=None, max_display=10, show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `shap_values`: Explanation object
|
||||
- `instance_order`: How to order instances (can be Explanation object for custom ordering)
|
||||
- `feature_values`: Display feature values on hover
|
||||
- `max_display`: Maximum features to display
|
||||
|
||||
**Visual Elements**:
|
||||
- **Rows**: Individual instances/samples
|
||||
- **Columns**: Features
|
||||
- **Cell color**: SHAP value (red = positive, blue = negative)
|
||||
- **Intensity**: Magnitude of impact
|
||||
|
||||
**When to Use**:
|
||||
- Comparing explanations across multiple instances
|
||||
- Identifying patterns in feature impacts
|
||||
- Understanding which features vary most across predictions
|
||||
- Detecting subgroups or clusters with similar explanation patterns
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
# Basic heatmap
|
||||
shap.plots.heatmap(shap_values)
|
||||
|
||||
# Order instances by model output
|
||||
shap.plots.heatmap(shap_values, instance_order=shap_values.sum(1))
|
||||
|
||||
# Show specific subset
|
||||
shap.plots.heatmap(shap_values[:100])
|
||||
```
|
||||
|
||||
### Violin Plots
|
||||
|
||||
**Purpose**: Similar to beeswarm plots but uses violin (kernel density) visualization instead of individual dots.
|
||||
|
||||
**Function**: `shap.plots.violin(shap_values, features=None, feature_names=None, max_display=10, show=True)`
|
||||
|
||||
**When to Use**:
|
||||
- Alternative to beeswarm when dataset is very large
|
||||
- Emphasizing distribution density over individual points
|
||||
- Cleaner visualization for presentations
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
shap.plots.violin(shap_values)
|
||||
```
|
||||
|
||||
### Decision Plots
|
||||
|
||||
**Purpose**: Show prediction paths through cumulative SHAP values, particularly useful for multiclass classification.
|
||||
|
||||
**Function**: `shap.plots.decision(base_value, shap_values, features, feature_names=None, feature_order="importance", highlight=None, link="identity", show=True)`
|
||||
|
||||
**Key Parameters**:
|
||||
- `base_value`: Expected value
|
||||
- `shap_values`: SHAP values for samples
|
||||
- `features`: Feature values
|
||||
- `feature_order`: How to order features ("importance" or list)
|
||||
- `highlight`: Indices of samples to highlight
|
||||
- `link`: Transform function
|
||||
|
||||
**When to Use**:
|
||||
- Multiclass classification explanations
|
||||
- Understanding cumulative feature effects
|
||||
- Comparing prediction paths across samples
|
||||
- Identifying where predictions diverge
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
# Decision plot for multiple predictions
|
||||
shap.plots.decision(
|
||||
shap_values.base_values,
|
||||
shap_values.values,
|
||||
X_test,
|
||||
feature_names=X_test.columns.tolist()
|
||||
)
|
||||
|
||||
# Highlight specific instances
|
||||
shap.plots.decision(
|
||||
shap_values.base_values,
|
||||
shap_values.values,
|
||||
X_test,
|
||||
highlight=[0, 5, 10]
|
||||
)
|
||||
```
|
||||
|
||||
## Plot Selection Guide
|
||||
|
||||
**For Individual Predictions**:
|
||||
- **Waterfall**: Best for detailed, sequential explanation
|
||||
- **Force**: Good for interactive exploration
|
||||
- **Bar (local)**: Simple, clean single-prediction importance
|
||||
|
||||
**For Global Understanding**:
|
||||
- **Beeswarm**: Information-dense summary with value distributions
|
||||
- **Bar (global)**: Clean, simple importance ranking
|
||||
- **Violin**: Distribution-focused alternative to beeswarm
|
||||
|
||||
**For Feature Relationships**:
|
||||
- **Scatter**: Understand feature-prediction relationships and interactions
|
||||
- **Heatmap**: Compare patterns across multiple instances
|
||||
|
||||
**For Multiple Samples**:
|
||||
- **Heatmap**: Grid view of SHAP values
|
||||
- **Force (stacked)**: Interactive multi-sample visualization
|
||||
- **Decision**: Prediction paths for multiclass problems
|
||||
|
||||
**For Cohort Comparison**:
|
||||
- **Bar (cohort)**: Clean comparison of feature importance
|
||||
- **Multiple beeswarms**: Side-by-side distribution comparisons
|
||||
|
||||
## Visualization Best Practices
|
||||
|
||||
**1. Start Global, Then Go Local**:
|
||||
- Begin with beeswarm or bar plot to understand global patterns
|
||||
- Dive into waterfall or scatter plots for specific instances or features
|
||||
|
||||
**2. Use Multiple Plot Types**:
|
||||
- Different plots reveal different insights
|
||||
- Combine waterfall (individual) + scatter (relationship) + beeswarm (global)
|
||||
|
||||
**3. Adjust max_display**:
|
||||
- Default (10) is good for presentations
|
||||
- Increase (20-30) for detailed analysis
|
||||
- Consider clustering for redundant features
|
||||
|
||||
**4. Color Meaningfully**:
|
||||
- Use default red-blue for SHAP values (red = positive, blue = negative)
|
||||
- Color scatter plots by interacting features
|
||||
- Custom colormaps for specific domains
|
||||
|
||||
**5. Consider Audience**:
|
||||
- Technical audience: Beeswarm, scatter, heatmap
|
||||
- Non-technical audience: Waterfall, bar, force plots
|
||||
- Interactive presentations: Force plots with JavaScript
|
||||
|
||||
**6. Save High-Quality Figures**:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Create plot
|
||||
shap.plots.beeswarm(shap_values, show=False)
|
||||
|
||||
# Save with high DPI
|
||||
plt.savefig('shap_plot.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
```
|
||||
|
||||
**7. Handle Large Datasets**:
|
||||
- Sample subset for visualization (e.g., `shap_values[:1000]`)
|
||||
- Use violin instead of beeswarm for very large datasets
|
||||
- Adjust alpha for scatter plots with many points
|
||||
|
||||
## Common Patterns and Workflows
|
||||
|
||||
**Pattern 1: Complete Model Explanation**
|
||||
```python
|
||||
# 1. Global importance
|
||||
shap.plots.beeswarm(shap_values)
|
||||
|
||||
# 2. Top feature relationships
|
||||
for feature in top_features:
|
||||
shap.plots.scatter(shap_values[:, feature])
|
||||
|
||||
# 3. Example predictions
|
||||
for i in interesting_indices:
|
||||
shap.plots.waterfall(shap_values[i])
|
||||
```
|
||||
|
||||
**Pattern 2: Model Comparison**
|
||||
```python
|
||||
# Compute SHAP for multiple models
|
||||
shap_model1 = explainer1(X_test)
|
||||
shap_model2 = explainer2(X_test)
|
||||
|
||||
# Compare feature importance
|
||||
shap.plots.bar({
|
||||
"Model 1": shap_model1,
|
||||
"Model 2": shap_model2
|
||||
})
|
||||
```
|
||||
|
||||
**Pattern 3: Subgroup Analysis**
|
||||
```python
|
||||
# Define cohorts
|
||||
male_mask = X_test['Sex'] == 'Male'
|
||||
female_mask = X_test['Sex'] == 'Female'
|
||||
|
||||
# Compare cohorts
|
||||
shap.plots.bar({
|
||||
"Male": shap_values[male_mask],
|
||||
"Female": shap_values[female_mask]
|
||||
})
|
||||
|
||||
# Separate beeswarm plots
|
||||
shap.plots.beeswarm(shap_values[male_mask])
|
||||
shap.plots.beeswarm(shap_values[female_mask])
|
||||
```
|
||||
|
||||
**Pattern 4: Debugging Predictions**
|
||||
```python
|
||||
# Identify outliers or errors
|
||||
errors = (model.predict(X_test) != y_test)
|
||||
error_indices = np.where(errors)[0]
|
||||
|
||||
# Explain errors
|
||||
for idx in error_indices[:5]:
|
||||
print(f"Sample {idx}:")
|
||||
shap.plots.waterfall(shap_values[idx])
|
||||
|
||||
# Explore key features
|
||||
shap.plots.scatter(shap_values[:, "Key_Feature"])
|
||||
```
|
||||
|
||||
## Integration with Notebooks and Reports
|
||||
|
||||
**Jupyter Notebooks**:
|
||||
- Interactive force plots work seamlessly
|
||||
- Use `show=True` (default) for inline display
|
||||
- Combine with markdown explanations
|
||||
|
||||
**Static Reports**:
|
||||
- Use matplotlib backend for force plots
|
||||
- Save figures programmatically
|
||||
- Prefer waterfall and bar plots for clarity
|
||||
|
||||
**Web Applications**:
|
||||
- Export force plots as HTML
|
||||
- Use shap.save_html() for interactive visualizations
|
||||
- Consider generating plots on-demand
|
||||
|
||||
## Troubleshooting Visualizations
|
||||
|
||||
**Issue**: Plots don't display
|
||||
- **Solution**: Ensure matplotlib backend is set correctly; use `plt.show()` if needed
|
||||
|
||||
**Issue**: Too many features cluttering plot
|
||||
- **Solution**: Reduce `max_display` parameter or use feature clustering
|
||||
|
||||
**Issue**: Colors reversed or confusing
|
||||
- **Solution**: Check model output type (probability vs. log-odds) and use appropriate link function
|
||||
|
||||
**Issue**: Slow plotting with large datasets
|
||||
- **Solution**: Sample subset of data; use `shap_values[:1000]` for visualization
|
||||
|
||||
**Issue**: Feature names missing
|
||||
- **Solution**: Ensure feature_names are in Explanation object or pass explicitly to plot functions
|
||||
449
skills/shap/references/theory.md
Normal file
449
skills/shap/references/theory.md
Normal file
@@ -0,0 +1,449 @@
|
||||
# SHAP Theoretical Foundation
|
||||
|
||||
This document explains the theoretical foundations of SHAP (SHapley Additive exPlanations), including Shapley values from game theory, the principles that make SHAP unique, and connections to other explanation methods.
|
||||
|
||||
## Game Theory Origins
|
||||
|
||||
### Shapley Values
|
||||
|
||||
SHAP is grounded in **Shapley values**, a solution concept from cooperative game theory developed by Lloyd Shapley in 1951.
|
||||
|
||||
**Core Concept**:
|
||||
In cooperative game theory, players collaborate to achieve a total payoff, and the question is: how should this payoff be fairly distributed among players?
|
||||
|
||||
**Mapping to Machine Learning**:
|
||||
- **Players** → Input features
|
||||
- **Game** → Model prediction task
|
||||
- **Payoff** → Model output (prediction value)
|
||||
- **Coalition** → Subset of features with known values
|
||||
- **Fair Distribution** → Attributing prediction to features
|
||||
|
||||
### The Shapley Value Formula
|
||||
|
||||
For a feature $i$, its Shapley value $\phi_i$ is:
|
||||
|
||||
$$\phi_i = \sum_{S \subseteq F \setminus \{i\}} \frac{|S|!(|F|-|S|-1)!}{|F|!} [f(S \cup \{i\}) - f(S)]$$
|
||||
|
||||
Where:
|
||||
- $F$ is the set of all features
|
||||
- $S$ is a subset of features not including $i$
|
||||
- $f(S)$ is the model's expected output given only features in $S$
|
||||
- $|S|$ is the size of subset $S$
|
||||
|
||||
**Interpretation**:
|
||||
The Shapley value averages the marginal contribution of feature $i$ across all possible feature coalitions (subsets). The contribution is weighted by how likely each coalition is to occur.
|
||||
|
||||
### Key Properties of Shapley Values
|
||||
|
||||
**1. Efficiency (Additivity)**:
|
||||
$$\sum_{i=1}^{n} \phi_i = f(x) - f(\emptyset)$$
|
||||
|
||||
The sum of all SHAP values equals the difference between the model's prediction for the instance and the expected value (baseline).
|
||||
|
||||
This is why SHAP waterfall plots always sum to the total prediction change.
|
||||
|
||||
**2. Symmetry**:
|
||||
If two features $i$ and $j$ contribute equally to all coalitions, then $\phi_i = \phi_j$.
|
||||
|
||||
Features with identical effects receive identical attribution.
|
||||
|
||||
**3. Dummy**:
|
||||
If a feature $i$ doesn't change the model output for any coalition, then $\phi_i = 0$.
|
||||
|
||||
Irrelevant features receive zero attribution.
|
||||
|
||||
**4. Monotonicity**:
|
||||
If a feature's marginal contribution increases across coalitions, its Shapley value increases.
|
||||
|
||||
## From Game Theory to Machine Learning
|
||||
|
||||
### The Challenge
|
||||
|
||||
Computing exact Shapley values requires evaluating the model on all possible feature coalitions:
|
||||
- For $n$ features, there are $2^n$ possible coalitions
|
||||
- For 50 features, this is over 1 quadrillion evaluations
|
||||
|
||||
This exponential complexity makes exact computation intractable for most real-world models.
|
||||
|
||||
### SHAP's Solution: Additive Feature Attribution
|
||||
|
||||
SHAP connects Shapley values to **additive feature attribution methods**, enabling efficient computation.
|
||||
|
||||
**Additive Feature Attribution Model**:
|
||||
$$g(z') = \phi_0 + \sum_{i=1}^{M} \phi_i z'_i$$
|
||||
|
||||
Where:
|
||||
- $g$ is the explanation model
|
||||
- $z' \in \{0,1\}^M$ indicates feature presence/absence
|
||||
- $\phi_i$ is the attribution to feature $i$
|
||||
- $\phi_0$ is the baseline (expected value)
|
||||
|
||||
SHAP proves that **Shapley values are the only attribution values satisfying three desirable properties**: local accuracy, missingness, and consistency.
|
||||
|
||||
## SHAP Properties and Guarantees
|
||||
|
||||
### Local Accuracy
|
||||
|
||||
**Property**: The explanation matches the model's output:
|
||||
$$f(x) = g(x') = \phi_0 + \sum_{i=1}^{M} \phi_i x'_i$$
|
||||
|
||||
**Interpretation**: SHAP values exactly account for the model's prediction. This enables waterfall plots to precisely decompose predictions.
|
||||
|
||||
### Missingness
|
||||
|
||||
**Property**: If a feature is missing (not observed), its attribution is zero:
|
||||
$$x'_i = 0 \Rightarrow \phi_i = 0$$
|
||||
|
||||
**Interpretation**: Only features that are present contribute to explanations.
|
||||
|
||||
### Consistency
|
||||
|
||||
**Property**: If a model changes so a feature's marginal contribution increases (or stays the same) for all inputs, that feature's attribution should not decrease.
|
||||
|
||||
**Interpretation**: If a feature becomes more important to the model, its SHAP value reflects this. This enables meaningful model comparisons.
|
||||
|
||||
## SHAP as a Unified Framework
|
||||
|
||||
SHAP unifies several existing explanation methods by showing they're special cases of Shapley values under specific assumptions.
|
||||
|
||||
### LIME (Local Interpretable Model-agnostic Explanations)
|
||||
|
||||
**LIME's Approach**: Fit a local linear model around a prediction using perturbed samples.
|
||||
|
||||
**Connection to SHAP**: LIME approximates Shapley values but with suboptimal sample weighting. SHAP uses theoretically optimal weights derived from Shapley value formula.
|
||||
|
||||
**Key Difference**: LIME's loss function and sampling don't guarantee consistency or exact additivity; SHAP does.
|
||||
|
||||
### DeepLIFT
|
||||
|
||||
**DeepLIFT's Approach**: Backpropagate contributions through neural networks by comparing to reference activations.
|
||||
|
||||
**Connection to SHAP**: DeepExplainer uses DeepLIFT but averages over multiple reference samples to approximate conditional expectations, yielding Shapley values.
|
||||
|
||||
### Layer-Wise Relevance Propagation (LRP)
|
||||
|
||||
**LRP's Approach**: Decompose neural network predictions by propagating relevance scores backward through layers.
|
||||
|
||||
**Connection to SHAP**: LRP is a special case of SHAP with specific propagation rules. SHAP generalizes these rules with Shapley value theory.
|
||||
|
||||
### Integrated Gradients
|
||||
|
||||
**Integrated Gradients' Approach**: Integrate gradients along path from baseline to input.
|
||||
|
||||
**Connection to SHAP**: When using a single reference point, Integrated Gradients approximates SHAP values for smooth models.
|
||||
|
||||
## SHAP Computation Methods
|
||||
|
||||
Different SHAP explainers use specialized algorithms to compute Shapley values efficiently for specific model types.
|
||||
|
||||
### Tree SHAP (TreeExplainer)
|
||||
|
||||
**Innovation**: Exploits tree structure to compute exact Shapley values in polynomial time instead of exponential.
|
||||
|
||||
**Algorithm**:
|
||||
- Traverses each tree path from root to leaf
|
||||
- Computes feature contributions using tree splits and weights
|
||||
- Aggregates across all trees in ensemble
|
||||
|
||||
**Complexity**: $O(TLD^2)$ where $T$ = number of trees, $L$ = max leaves, $D$ = max depth
|
||||
|
||||
**Key Advantage**: Exact Shapley values computed efficiently for tree-based models (XGBoost, LightGBM, Random Forest, etc.)
|
||||
|
||||
### Kernel SHAP (KernelExplainer)
|
||||
|
||||
**Innovation**: Uses weighted linear regression to estimate Shapley values for any model.
|
||||
|
||||
**Algorithm**:
|
||||
- Samples coalitions (feature subsets) according to Shapley kernel weights
|
||||
- Evaluates model on each coalition (missing features replaced by background values)
|
||||
- Fits weighted linear model to estimate feature attributions
|
||||
|
||||
**Complexity**: $O(n \cdot 2^M)$ but approximates with fewer samples
|
||||
|
||||
**Key Advantage**: Model-agnostic; works with any prediction function
|
||||
|
||||
**Trade-off**: Slower than specialized explainers; approximate rather than exact
|
||||
|
||||
### Deep SHAP (DeepExplainer)
|
||||
|
||||
**Innovation**: Combines DeepLIFT with Shapley value sampling.
|
||||
|
||||
**Algorithm**:
|
||||
- Computes DeepLIFT attributions for each reference sample
|
||||
- Averages attributions across multiple reference samples
|
||||
- Approximates conditional expectations: $E[f(x) | x_S]$
|
||||
|
||||
**Complexity**: $O(n \cdot m)$ where $m$ = number of reference samples
|
||||
|
||||
**Key Advantage**: Efficiently approximates Shapley values for deep neural networks
|
||||
|
||||
### Linear SHAP (LinearExplainer)
|
||||
|
||||
**Innovation**: Closed-form Shapley values for linear models.
|
||||
|
||||
**Algorithm**:
|
||||
- For independent features: $\phi_i = w_i \cdot (x_i - E[x_i])$
|
||||
- For correlated features: Adjusts for feature covariance
|
||||
|
||||
**Complexity**: $O(n)$ - nearly instantaneous
|
||||
|
||||
**Key Advantage**: Exact Shapley values with minimal computation
|
||||
|
||||
## Understanding Conditional Expectations
|
||||
|
||||
### The Core Challenge
|
||||
|
||||
Computing $f(S)$ (model output given only features in $S$) requires handling missing features.
|
||||
|
||||
**Question**: How should we represent "missing" features when the model requires all features as input?
|
||||
|
||||
### Two Approaches
|
||||
|
||||
**1. Interventional (Marginal) Approach**:
|
||||
- Replace missing features with values from background dataset
|
||||
- Estimates: $E[f(x) | x_S]$ by marginalizing over $x_{\bar{S}}$
|
||||
- Interpretation: "What would the model predict if we didn't know features $\bar{S}$?"
|
||||
|
||||
**2. Observational (Conditional) Approach**:
|
||||
- Use conditional distribution: $E[f(x) | x_S = x_S^*]$
|
||||
- Accounts for feature dependencies
|
||||
- Interpretation: "What would the model predict for similar instances with features $S = x_S^*$?"
|
||||
|
||||
**Trade-offs**:
|
||||
- **Interventional**: Simpler, assumes feature independence, matches causal interpretation
|
||||
- **Observational**: More accurate for correlated features, requires conditional distribution estimation
|
||||
|
||||
**TreeExplainer** supports both via `feature_perturbation` parameter.
|
||||
|
||||
## Baseline (Expected Value) Selection
|
||||
|
||||
The **baseline** $\phi_0 = E[f(x)]$ represents the model's average prediction.
|
||||
|
||||
### Computing the Baseline
|
||||
|
||||
**For TreeExplainer**:
|
||||
- With background data: Average prediction on background dataset
|
||||
- With tree_path_dependent: Weighted average using tree leaf distributions
|
||||
|
||||
**For DeepExplainer / KernelExplainer**:
|
||||
- Average prediction on background samples
|
||||
|
||||
### Importance of Baseline
|
||||
|
||||
- SHAP values measure deviation from baseline
|
||||
- Different baselines → different SHAP values (but still sum correctly)
|
||||
- Choose baseline representative of "typical" or "neutral" input
|
||||
- Common choices: Training set mean, median, or mode
|
||||
|
||||
## Interpreting SHAP Values
|
||||
|
||||
### Units and Scale
|
||||
|
||||
**SHAP values have the same units as the model output**:
|
||||
- Regression: Same units as target variable (dollars, temperature, etc.)
|
||||
- Classification (log-odds): Log-odds units
|
||||
- Classification (probability): Probability units (if model output transformed)
|
||||
|
||||
**Magnitude**: Higher absolute SHAP value = stronger feature impact
|
||||
|
||||
**Sign**:
|
||||
- Positive SHAP value = Feature pushes prediction higher
|
||||
- Negative SHAP value = Feature pushes prediction lower
|
||||
|
||||
### Additive Decomposition
|
||||
|
||||
For a prediction $f(x)$:
|
||||
$$f(x) = E[f(X)] + \sum_{i=1}^{n} \phi_i(x)$$
|
||||
|
||||
**Example**:
|
||||
- Expected value (baseline): 0.3
|
||||
- SHAP values: {Age: +0.15, Income: +0.10, Education: -0.05}
|
||||
- Prediction: $0.3 + 0.15 + 0.10 - 0.05 = 0.50$
|
||||
|
||||
### Global vs. Local Importance
|
||||
|
||||
**Local (Instance-level)**:
|
||||
- SHAP values for single prediction: $\phi_i(x)$
|
||||
- Explains: "Why did the model predict $f(x)$ for this instance?"
|
||||
- Visualization: Waterfall, force plots
|
||||
|
||||
**Global (Dataset-level)**:
|
||||
- Average absolute SHAP values: $E[|\phi_i(x)|]$
|
||||
- Explains: "Which features are most important overall?"
|
||||
- Visualization: Beeswarm, bar plots
|
||||
|
||||
**Key Insight**: Global importance is the aggregation of local importances, maintaining consistency between instance and dataset explanations.
|
||||
|
||||
## SHAP vs. Other Feature Importance Methods
|
||||
|
||||
### Comparison with Permutation Importance
|
||||
|
||||
**Permutation Importance**:
|
||||
- Shuffles a feature and measures accuracy drop
|
||||
- Global metric only (no instance-level explanations)
|
||||
- Can be misleading with correlated features
|
||||
|
||||
**SHAP**:
|
||||
- Provides both local and global importance
|
||||
- Handles feature correlations through coalitional averaging
|
||||
- Consistent: Additive property guarantees sum to prediction
|
||||
|
||||
### Comparison with Feature Coefficients (Linear Models)
|
||||
|
||||
**Feature Coefficients** ($w_i$):
|
||||
- Measure impact per unit change in feature
|
||||
- Don't account for feature scale or distribution
|
||||
|
||||
**SHAP for Linear Models**:
|
||||
- $\phi_i = w_i \cdot (x_i - E[x_i])$
|
||||
- Accounts for feature value relative to average
|
||||
- More interpretable for comparing features with different units/scales
|
||||
|
||||
### Comparison with Tree Feature Importance (Gini/Split-based)
|
||||
|
||||
**Gini/Split Importance**:
|
||||
- Based on training process (purity gain or frequency of splits)
|
||||
- Biased toward high-cardinality features
|
||||
- No instance-level explanations
|
||||
- Can be misleading (importance ≠ predictive power)
|
||||
|
||||
**SHAP (Tree SHAP)**:
|
||||
- Based on model output (prediction behavior)
|
||||
- Fair attribution through Shapley values
|
||||
- Provides instance-level explanations
|
||||
- Consistent and theoretically grounded
|
||||
|
||||
## Interactions and Higher-Order Effects
|
||||
|
||||
### SHAP Interaction Values
|
||||
|
||||
Standard SHAP captures main effects. **SHAP interaction values** capture pairwise interactions.
|
||||
|
||||
**Formula for Interaction**:
|
||||
$$\phi_{i,j} = \sum_{S \subseteq F \setminus \{i,j\}} \frac{|S|!(|F|-|S|-2)!}{2(|F|-1)!} \Delta_{ij}(S)$$
|
||||
|
||||
Where $\Delta_{ij}(S)$ is the interaction effect of features $i$ and $j$ given coalition $S$.
|
||||
|
||||
**Interpretation**:
|
||||
- $\phi_{i,i}$: Main effect of feature $i$
|
||||
- $\phi_{i,j}$ ($i \neq j$): Interaction effect between features $i$ and $j$
|
||||
|
||||
**Property**:
|
||||
$$\phi_i = \phi_{i,i} + \sum_{j \neq i} \phi_{i,j}$$
|
||||
|
||||
Main SHAP value equals main effect plus half of all pairwise interactions involving feature $i$.
|
||||
|
||||
### Computing Interactions
|
||||
|
||||
**TreeExplainer** supports exact interaction computation:
|
||||
```python
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_interaction_values = explainer.shap_interaction_values(X)
|
||||
```
|
||||
|
||||
**Limitation**: Exponentially complex for other explainers (only practical for tree models)
|
||||
|
||||
## Theoretical Limitations and Considerations
|
||||
|
||||
### Computational Complexity
|
||||
|
||||
**Exact Computation**: $O(2^n)$ - intractable for large $n$
|
||||
|
||||
**Specialized Algorithms**:
|
||||
- Tree SHAP: $O(TLD^2)$ - efficient for trees
|
||||
- Deep SHAP, Kernel SHAP: Approximations required
|
||||
|
||||
**Implication**: For non-tree models with many features, explanations may be approximate.
|
||||
|
||||
### Feature Independence Assumption
|
||||
|
||||
**Kernel SHAP and Basic Implementation**: Assume features can be independently manipulated
|
||||
|
||||
**Challenge**: Real features are often correlated (e.g., height and weight)
|
||||
|
||||
**Solutions**:
|
||||
- Use observational approach (conditional expectations)
|
||||
- TreeExplainer with correlation-aware perturbation
|
||||
- Feature grouping for highly correlated features
|
||||
|
||||
### Out-of-Distribution Samples
|
||||
|
||||
**Issue**: Creating coalitions by replacing features may create unrealistic samples (outside training distribution)
|
||||
|
||||
**Example**: Setting "Age=5" and "Has PhD=Yes" simultaneously
|
||||
|
||||
**Implication**: SHAP values reflect model behavior on potentially unrealistic inputs
|
||||
|
||||
**Mitigation**: Use observational approach or carefully selected background data
|
||||
|
||||
### Causality
|
||||
|
||||
**SHAP measures association, not causation**
|
||||
|
||||
SHAP answers: "How does the model's prediction change with this feature?"
|
||||
SHAP does NOT answer: "What would happen if we changed this feature in reality?"
|
||||
|
||||
**Example**:
|
||||
- SHAP: "Hospital stay length increases prediction of mortality" (association)
|
||||
- Causality: "Longer hospital stays cause higher mortality" (incorrect!)
|
||||
|
||||
**Implication**: Use domain knowledge to interpret SHAP causally; SHAP alone doesn't establish causation.
|
||||
|
||||
## Advanced Theoretical Topics
|
||||
|
||||
### SHAP as Optimal Credit Allocation
|
||||
|
||||
SHAP is the unique attribution method satisfying:
|
||||
1. **Local accuracy**: Explanation matches model
|
||||
2. **Missingness**: Absent features have zero attribution
|
||||
3. **Consistency**: Attribution reflects feature importance changes
|
||||
|
||||
**Proof**: Lundberg & Lee (2017) showed Shapley values are the only solution satisfying these axioms.
|
||||
|
||||
### Connection to Functional ANOVA
|
||||
|
||||
SHAP values correspond to first-order terms in functional ANOVA decomposition:
|
||||
$$f(x) = f_0 + \sum_i f_i(x_i) + \sum_{i,j} f_{ij}(x_i, x_j) + ...$$
|
||||
|
||||
Where $f_i(x_i)$ captures main effect of feature $i$, and $\phi_i \approx f_i(x_i)$.
|
||||
|
||||
### Relationship to Sensitivity Analysis
|
||||
|
||||
SHAP generalizes sensitivity analysis:
|
||||
- **Sensitivity Analysis**: $\frac{\partial f}{\partial x_i}$ (local gradient)
|
||||
- **SHAP**: Integrated sensitivity over feature coalition space
|
||||
|
||||
Gradient-based methods (GradientExplainer, Integrated Gradients) approximate SHAP using derivatives.
|
||||
|
||||
## Practical Implications of Theory
|
||||
|
||||
### Why Use SHAP?
|
||||
|
||||
1. **Theoretical Guarantees**: Only method with consistency, local accuracy, and missingness
|
||||
2. **Unified Framework**: Connects and generalizes multiple explanation methods
|
||||
3. **Additive Decomposition**: Predictions precisely decompose into feature contributions
|
||||
4. **Model Comparison**: Consistency enables comparing feature importance across models
|
||||
5. **Versatility**: Works with any model type (with appropriate explainer)
|
||||
|
||||
### When to Be Cautious
|
||||
|
||||
1. **Computational Cost**: May be slow for complex models without specialized explainers
|
||||
2. **Feature Correlation**: Standard approaches may create unrealistic samples
|
||||
3. **Interpretation**: Requires understanding baseline, units, and assumptions
|
||||
4. **Causality**: SHAP doesn't imply causation; use domain knowledge
|
||||
5. **Approximations**: Non-tree methods use approximations; understand accuracy trade-offs
|
||||
|
||||
## References and Further Reading
|
||||
|
||||
**Foundational Papers**:
|
||||
- Shapley, L. S. (1951). "A value for n-person games"
|
||||
- Lundberg, S. M., & Lee, S. I. (2017). "A Unified Approach to Interpreting Model Predictions" (NeurIPS)
|
||||
- Lundberg, S. M., et al. (2020). "From local explanations to global understanding with explainable AI for trees" (Nature Machine Intelligence)
|
||||
|
||||
**Key Concepts**:
|
||||
- Cooperative game theory and Shapley values
|
||||
- Additive feature attribution methods
|
||||
- Conditional expectation estimation
|
||||
- Tree SHAP algorithm and polynomial-time computation
|
||||
|
||||
This theoretical foundation explains why SHAP is a principled, versatile, and powerful tool for model interpretation.
|
||||
605
skills/shap/references/workflows.md
Normal file
605
skills/shap/references/workflows.md
Normal file
@@ -0,0 +1,605 @@
|
||||
# SHAP Workflows and Best Practices
|
||||
|
||||
This document provides comprehensive workflows, best practices, and common use cases for using SHAP in various model interpretation scenarios.
|
||||
|
||||
## Basic Workflow Structure
|
||||
|
||||
Every SHAP analysis follows a general workflow:
|
||||
|
||||
1. **Train Model**: Build and train the machine learning model
|
||||
2. **Select Explainer**: Choose appropriate explainer based on model type
|
||||
3. **Compute SHAP Values**: Generate explanations for test samples
|
||||
4. **Visualize Results**: Use plots to understand feature impacts
|
||||
5. **Interpret and Act**: Draw conclusions and make decisions
|
||||
|
||||
## Workflow 1: Basic Model Explanation
|
||||
|
||||
**Use Case**: Understanding feature importance and prediction behavior for a trained model
|
||||
|
||||
```python
|
||||
import shap
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Step 1: Load and split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
|
||||
|
||||
# Step 2: Train model (example with XGBoost)
|
||||
import xgboost as xgb
|
||||
model = xgb.XGBClassifier(n_estimators=100, max_depth=5)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# Step 3: Create explainer
|
||||
explainer = shap.TreeExplainer(model)
|
||||
|
||||
# Step 4: Compute SHAP values
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Step 5: Visualize global importance
|
||||
shap.plots.beeswarm(shap_values, max_display=15)
|
||||
|
||||
# Step 6: Examine top features in detail
|
||||
shap.plots.scatter(shap_values[:, "Feature1"])
|
||||
shap.plots.scatter(shap_values[:, "Feature2"], color=shap_values[:, "Feature1"])
|
||||
|
||||
# Step 7: Explain individual predictions
|
||||
shap.plots.waterfall(shap_values[0])
|
||||
```
|
||||
|
||||
**Key Decisions**:
|
||||
- Explainer type based on model architecture
|
||||
- Background dataset size (for DeepExplainer, KernelExplainer)
|
||||
- Number of samples to explain (all test set vs. subset)
|
||||
|
||||
## Workflow 2: Model Debugging and Validation
|
||||
|
||||
**Use Case**: Identifying and fixing model issues, validating expected behavior
|
||||
|
||||
```python
|
||||
# Step 1: Compute SHAP values
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Step 2: Identify prediction errors
|
||||
predictions = model.predict(X_test)
|
||||
errors = predictions != y_test
|
||||
error_indices = np.where(errors)[0]
|
||||
|
||||
# Step 3: Analyze errors
|
||||
print(f"Total errors: {len(error_indices)}")
|
||||
print(f"Error rate: {len(error_indices) / len(y_test):.2%}")
|
||||
|
||||
# Step 4: Explain misclassified samples
|
||||
for idx in error_indices[:10]: # First 10 errors
|
||||
print(f"\n=== Error {idx} ===")
|
||||
print(f"Prediction: {predictions[idx]}, Actual: {y_test.iloc[idx]}")
|
||||
shap.plots.waterfall(shap_values[idx])
|
||||
|
||||
# Step 5: Check if model learned correct patterns
|
||||
# Look for unexpected feature importance
|
||||
shap.plots.beeswarm(shap_values)
|
||||
|
||||
# Step 6: Investigate specific feature relationships
|
||||
# Verify nonlinear relationships make sense
|
||||
for feature in model.feature_importances_.argsort()[-5:]: # Top 5 features
|
||||
feature_name = X_test.columns[feature]
|
||||
shap.plots.scatter(shap_values[:, feature_name])
|
||||
|
||||
# Step 7: Validate feature interactions
|
||||
# Check if interactions align with domain knowledge
|
||||
shap.plots.scatter(shap_values[:, "Feature1"], color=shap_values[:, "Feature2"])
|
||||
```
|
||||
|
||||
**Common Issues to Check**:
|
||||
- Data leakage (feature with suspiciously high importance)
|
||||
- Spurious correlations (unexpected feature relationships)
|
||||
- Target leakage (features that shouldn't be predictive)
|
||||
- Biases (disproportionate impact on certain groups)
|
||||
|
||||
## Workflow 3: Feature Engineering Guidance
|
||||
|
||||
**Use Case**: Using SHAP insights to improve feature engineering
|
||||
|
||||
```python
|
||||
# Step 1: Initial model with baseline features
|
||||
model_v1 = train_model(X_train_v1, y_train)
|
||||
explainer_v1 = shap.TreeExplainer(model_v1)
|
||||
shap_values_v1 = explainer_v1(X_test_v1)
|
||||
|
||||
# Step 2: Identify feature engineering opportunities
|
||||
shap.plots.beeswarm(shap_values_v1)
|
||||
|
||||
# Check for:
|
||||
# - Nonlinear relationships (candidates for transformation)
|
||||
shap.plots.scatter(shap_values_v1[:, "Age"]) # Maybe age^2 or age bins?
|
||||
|
||||
# - Feature interactions (candidates for interaction terms)
|
||||
shap.plots.scatter(shap_values_v1[:, "Income"], color=shap_values_v1[:, "Education"])
|
||||
# Maybe create Income * Education interaction?
|
||||
|
||||
# Step 3: Engineer new features based on insights
|
||||
X_train_v2 = X_train_v1.copy()
|
||||
X_train_v2['Age_squared'] = X_train_v2['Age'] ** 2
|
||||
X_train_v2['Income_Education'] = X_train_v2['Income'] * X_train_v2['Education']
|
||||
|
||||
# Step 4: Retrain with engineered features
|
||||
model_v2 = train_model(X_train_v2, y_train)
|
||||
explainer_v2 = shap.TreeExplainer(model_v2)
|
||||
shap_values_v2 = explainer_v2(X_test_v2)
|
||||
|
||||
# Step 5: Compare feature importance
|
||||
shap.plots.bar({
|
||||
"Baseline": shap_values_v1,
|
||||
"With Engineered Features": shap_values_v2
|
||||
})
|
||||
|
||||
# Step 6: Validate improvement
|
||||
print(f"V1 Score: {model_v1.score(X_test_v1, y_test):.4f}")
|
||||
print(f"V2 Score: {model_v2.score(X_test_v2, y_test):.4f}")
|
||||
```
|
||||
|
||||
**Feature Engineering Insights from SHAP**:
|
||||
- Strong nonlinear patterns → Try transformations (log, sqrt, polynomial)
|
||||
- Color-coded interactions in scatter → Create interaction terms
|
||||
- Redundant features in clustering → Remove or combine
|
||||
- Unexpected importance → Investigate for data quality issues
|
||||
|
||||
## Workflow 4: Model Comparison and Selection
|
||||
|
||||
**Use Case**: Comparing multiple models to select the best interpretable model
|
||||
|
||||
```python
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
import xgboost as xgb
|
||||
|
||||
# Step 1: Train multiple models
|
||||
models = {
|
||||
'Logistic Regression': LogisticRegression(max_iter=1000).fit(X_train, y_train),
|
||||
'Random Forest': RandomForestClassifier(n_estimators=100).fit(X_train, y_train),
|
||||
'XGBoost': xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train)
|
||||
}
|
||||
|
||||
# Step 2: Compute SHAP values for each model
|
||||
shap_values_dict = {}
|
||||
for name, model in models.items():
|
||||
if name == 'Logistic Regression':
|
||||
explainer = shap.LinearExplainer(model, X_train)
|
||||
else:
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values_dict[name] = explainer(X_test)
|
||||
|
||||
# Step 3: Compare global feature importance
|
||||
shap.plots.bar(shap_values_dict)
|
||||
|
||||
# Step 4: Compare model scores
|
||||
for name, model in models.items():
|
||||
score = model.score(X_test, y_test)
|
||||
print(f"{name}: {score:.4f}")
|
||||
|
||||
# Step 5: Check consistency of feature importance
|
||||
for feature in X_test.columns[:5]: # Top 5 features
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
|
||||
for idx, (name, shap_vals) in enumerate(shap_values_dict.items()):
|
||||
plt.sca(axes[idx])
|
||||
shap.plots.scatter(shap_vals[:, feature], show=False)
|
||||
plt.title(f"{name} - {feature}")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Step 6: Analyze specific predictions across models
|
||||
sample_idx = 0
|
||||
for name, shap_vals in shap_values_dict.items():
|
||||
print(f"\n=== {name} ===")
|
||||
shap.plots.waterfall(shap_vals[sample_idx])
|
||||
|
||||
# Step 7: Decision based on:
|
||||
# - Accuracy/Performance
|
||||
# - Interpretability (consistent feature importance)
|
||||
# - Deployment constraints
|
||||
# - Stakeholder requirements
|
||||
```
|
||||
|
||||
**Model Selection Criteria**:
|
||||
- **Accuracy vs. Interpretability**: Sometimes simpler models with SHAP are preferable
|
||||
- **Feature Consistency**: Models agreeing on feature importance are more trustworthy
|
||||
- **Explanation Quality**: Clear, actionable explanations
|
||||
- **Computational Cost**: TreeExplainer is faster than KernelExplainer
|
||||
|
||||
## Workflow 5: Fairness and Bias Analysis
|
||||
|
||||
**Use Case**: Detecting and analyzing model bias across demographic groups
|
||||
|
||||
```python
|
||||
# Step 1: Identify protected attributes
|
||||
protected_attr = 'Gender' # or 'Race', 'Age_Group', etc.
|
||||
|
||||
# Step 2: Compute SHAP values
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Step 3: Compare feature importance across groups
|
||||
groups = X_test[protected_attr].unique()
|
||||
cohorts = {
|
||||
f"{protected_attr}={group}": shap_values[X_test[protected_attr] == group]
|
||||
for group in groups
|
||||
}
|
||||
shap.plots.bar(cohorts)
|
||||
|
||||
# Step 4: Check if protected attribute has high SHAP importance
|
||||
# (should be low/zero for fair models)
|
||||
protected_importance = np.abs(shap_values[:, protected_attr].values).mean()
|
||||
print(f"{protected_attr} mean |SHAP|: {protected_importance:.4f}")
|
||||
|
||||
# Step 5: Analyze predictions for each group
|
||||
for group in groups:
|
||||
mask = X_test[protected_attr] == group
|
||||
group_shap = shap_values[mask]
|
||||
|
||||
print(f"\n=== {protected_attr} = {group} ===")
|
||||
print(f"Sample size: {mask.sum()}")
|
||||
print(f"Positive prediction rate: {(model.predict(X_test[mask]) == 1).mean():.2%}")
|
||||
|
||||
# Visualize
|
||||
shap.plots.beeswarm(group_shap, max_display=10)
|
||||
|
||||
# Step 6: Check for proxy features
|
||||
# Features correlated with protected attribute that shouldn't have high importance
|
||||
# Example: 'Zip_Code' might be proxy for race
|
||||
proxy_features = ['Zip_Code', 'Last_Name_Prefix'] # Domain-specific
|
||||
for feature in proxy_features:
|
||||
if feature in X_test.columns:
|
||||
importance = np.abs(shap_values[:, feature].values).mean()
|
||||
print(f"Potential proxy '{feature}' importance: {importance:.4f}")
|
||||
|
||||
# Step 7: Mitigation strategies if bias found
|
||||
# - Remove protected attribute and proxies
|
||||
# - Add fairness constraints during training
|
||||
# - Post-process predictions to equalize outcomes
|
||||
# - Use different model architecture
|
||||
```
|
||||
|
||||
**Fairness Metrics to Check**:
|
||||
- **Demographic Parity**: Similar positive prediction rates across groups
|
||||
- **Equal Opportunity**: Similar true positive rates across groups
|
||||
- **Feature Importance Parity**: Similar feature rankings across groups
|
||||
- **Protected Attribute Importance**: Should be minimal
|
||||
|
||||
## Workflow 6: Deep Learning Model Explanation
|
||||
|
||||
**Use Case**: Explaining neural network predictions with DeepExplainer
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import shap
|
||||
|
||||
# Step 1: Load or build neural network
|
||||
model = tf.keras.models.load_model('my_model.h5')
|
||||
|
||||
# Step 2: Select background dataset
|
||||
# Use subset (100-1000 samples) from training data
|
||||
background = X_train[:100]
|
||||
|
||||
# Step 3: Create DeepExplainer
|
||||
explainer = shap.DeepExplainer(model, background)
|
||||
|
||||
# Step 4: Compute SHAP values (may take time)
|
||||
# Explain subset of test data
|
||||
test_subset = X_test[:50]
|
||||
shap_values = explainer.shap_values(test_subset)
|
||||
|
||||
# Step 5: Handle multi-output models
|
||||
# For binary classification, shap_values is a list [class_0_values, class_1_values]
|
||||
# For regression, it's a single array
|
||||
if isinstance(shap_values, list):
|
||||
# Focus on positive class
|
||||
shap_values_positive = shap_values[1]
|
||||
shap_exp = shap.Explanation(
|
||||
values=shap_values_positive,
|
||||
base_values=explainer.expected_value[1],
|
||||
data=test_subset
|
||||
)
|
||||
else:
|
||||
shap_exp = shap.Explanation(
|
||||
values=shap_values,
|
||||
base_values=explainer.expected_value,
|
||||
data=test_subset
|
||||
)
|
||||
|
||||
# Step 6: Visualize
|
||||
shap.plots.beeswarm(shap_exp)
|
||||
shap.plots.waterfall(shap_exp[0])
|
||||
|
||||
# Step 7: For image/text data, use specialized plots
|
||||
# Image: shap.image_plot
|
||||
# Text: shap.plots.text (for transformers)
|
||||
```
|
||||
|
||||
**Deep Learning Considerations**:
|
||||
- Background dataset size affects accuracy and speed
|
||||
- Multi-output handling (classification vs. regression)
|
||||
- Specialized plots for image/text data
|
||||
- Computational cost (consider GPU acceleration)
|
||||
|
||||
## Workflow 7: Production Deployment
|
||||
|
||||
**Use Case**: Integrating SHAP explanations into production systems
|
||||
|
||||
```python
|
||||
import joblib
|
||||
import shap
|
||||
|
||||
# Step 1: Train and save model
|
||||
model = train_model(X_train, y_train)
|
||||
joblib.dump(model, 'model.pkl')
|
||||
|
||||
# Step 2: Create and save explainer
|
||||
explainer = shap.TreeExplainer(model)
|
||||
joblib.dump(explainer, 'explainer.pkl')
|
||||
|
||||
# Step 3: Create explanation service
|
||||
class ExplanationService:
|
||||
def __init__(self, model_path, explainer_path):
|
||||
self.model = joblib.load(model_path)
|
||||
self.explainer = joblib.load(explainer_path)
|
||||
|
||||
def predict_with_explanation(self, X):
|
||||
"""
|
||||
Returns prediction and explanation
|
||||
"""
|
||||
# Prediction
|
||||
prediction = self.model.predict(X)
|
||||
|
||||
# SHAP values
|
||||
shap_values = self.explainer(X)
|
||||
|
||||
# Format explanation
|
||||
explanations = []
|
||||
for i in range(len(X)):
|
||||
exp = {
|
||||
'prediction': prediction[i],
|
||||
'base_value': shap_values.base_values[i],
|
||||
'shap_values': dict(zip(X.columns, shap_values.values[i])),
|
||||
'feature_values': X.iloc[i].to_dict()
|
||||
}
|
||||
explanations.append(exp)
|
||||
|
||||
return explanations
|
||||
|
||||
def get_top_features(self, X, n=5):
|
||||
"""
|
||||
Returns top N features for each prediction
|
||||
"""
|
||||
shap_values = self.explainer(X)
|
||||
|
||||
top_features = []
|
||||
for i in range(len(X)):
|
||||
# Get absolute SHAP values
|
||||
abs_shap = np.abs(shap_values.values[i])
|
||||
|
||||
# Sort and get top N
|
||||
top_indices = abs_shap.argsort()[-n:][::-1]
|
||||
top_feature_names = X.columns[top_indices].tolist()
|
||||
top_shap_values = shap_values.values[i][top_indices].tolist()
|
||||
|
||||
top_features.append({
|
||||
'features': top_feature_names,
|
||||
'shap_values': top_shap_values
|
||||
})
|
||||
|
||||
return top_features
|
||||
|
||||
# Step 4: Usage in API
|
||||
service = ExplanationService('model.pkl', 'explainer.pkl')
|
||||
|
||||
# Example API endpoint
|
||||
def predict_endpoint(input_data):
|
||||
X = pd.DataFrame([input_data])
|
||||
explanations = service.predict_with_explanation(X)
|
||||
return {
|
||||
'prediction': explanations[0]['prediction'],
|
||||
'explanation': explanations[0]
|
||||
}
|
||||
|
||||
# Step 5: Generate static explanations for batch predictions
|
||||
def batch_explain_and_save(X_batch, output_dir):
|
||||
shap_values = explainer(X_batch)
|
||||
|
||||
# Save global plot
|
||||
shap.plots.beeswarm(shap_values, show=False)
|
||||
plt.savefig(f'{output_dir}/global_importance.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
# Save individual explanations
|
||||
for i in range(min(100, len(X_batch))): # First 100
|
||||
shap.plots.waterfall(shap_values[i], show=False)
|
||||
plt.savefig(f'{output_dir}/explanation_{i}.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
```
|
||||
|
||||
**Production Best Practices**:
|
||||
- Cache explainers to avoid recomputation
|
||||
- Batch explanations when possible
|
||||
- Limit explanation complexity (top N features)
|
||||
- Monitor explanation latency
|
||||
- Version explainers alongside models
|
||||
- Consider pre-computing explanations for common inputs
|
||||
|
||||
## Workflow 8: Time Series Model Explanation
|
||||
|
||||
**Use Case**: Explaining time series forecasting models
|
||||
|
||||
```python
|
||||
# Step 1: Prepare data with time-based features
|
||||
# Example: Predicting next day's sales
|
||||
df['DayOfWeek'] = df['Date'].dt.dayofweek
|
||||
df['Month'] = df['Date'].dt.month
|
||||
df['Lag_1'] = df['Sales'].shift(1)
|
||||
df['Lag_7'] = df['Sales'].shift(7)
|
||||
df['Rolling_Mean_7'] = df['Sales'].rolling(7).mean()
|
||||
|
||||
# Step 2: Train model
|
||||
features = ['DayOfWeek', 'Month', 'Lag_1', 'Lag_7', 'Rolling_Mean_7']
|
||||
X_train, X_test, y_train, y_test = train_test_split(df[features], df['Sales'])
|
||||
model = xgb.XGBRegressor().fit(X_train, y_train)
|
||||
|
||||
# Step 3: Compute SHAP values
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Step 4: Analyze temporal patterns
|
||||
# Which features drive predictions at different times?
|
||||
shap.plots.beeswarm(shap_values)
|
||||
|
||||
# Step 5: Check lagged feature importance
|
||||
# Lag features should have high importance for time series
|
||||
lag_features = ['Lag_1', 'Lag_7', 'Rolling_Mean_7']
|
||||
for feature in lag_features:
|
||||
shap.plots.scatter(shap_values[:, feature])
|
||||
|
||||
# Step 6: Explain specific predictions
|
||||
# E.g., why was Monday's forecast so different?
|
||||
monday_mask = X_test['DayOfWeek'] == 0
|
||||
shap.plots.waterfall(shap_values[monday_mask][0])
|
||||
|
||||
# Step 7: Validate seasonality understanding
|
||||
shap.plots.scatter(shap_values[:, 'Month'])
|
||||
```
|
||||
|
||||
**Time Series Considerations**:
|
||||
- Lagged features and their importance
|
||||
- Rolling statistics interpretation
|
||||
- Seasonal patterns in SHAP values
|
||||
- Avoiding data leakage in feature engineering
|
||||
|
||||
## Common Pitfalls and Solutions
|
||||
|
||||
### Pitfall 1: Wrong Explainer Choice
|
||||
**Problem**: Using KernelExplainer for tree models (slow and unnecessary)
|
||||
**Solution**: Always use TreeExplainer for tree-based models
|
||||
|
||||
### Pitfall 2: Insufficient Background Data
|
||||
**Problem**: DeepExplainer/KernelExplainer with too few background samples
|
||||
**Solution**: Use 100-1000 representative samples
|
||||
|
||||
### Pitfall 3: Misinterpreting Log-Odds
|
||||
**Problem**: Confusion about units (probability vs. log-odds)
|
||||
**Solution**: Check model output type; use link="logit" when needed
|
||||
|
||||
### Pitfall 4: Ignoring Feature Correlations
|
||||
**Problem**: Interpreting features as independent when they're correlated
|
||||
**Solution**: Use feature clustering; understand domain relationships
|
||||
|
||||
### Pitfall 5: Overfitting to Explanations
|
||||
**Problem**: Feature engineering based solely on SHAP without validation
|
||||
**Solution**: Always validate improvements with cross-validation
|
||||
|
||||
### Pitfall 6: Data Leakage Undetected
|
||||
**Problem**: Not noticing unexpected feature importance indicating leakage
|
||||
**Solution**: Validate SHAP results against domain knowledge
|
||||
|
||||
### Pitfall 7: Computational Constraints Ignored
|
||||
**Problem**: Computing SHAP for entire large dataset
|
||||
**Solution**: Use sampling, batching, or subset analysis
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Technique 1: SHAP Interaction Values
|
||||
Capture pairwise feature interactions:
|
||||
```python
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_interaction_values = explainer.shap_interaction_values(X_test)
|
||||
|
||||
# Analyze specific interaction
|
||||
feature1_idx = 0
|
||||
feature2_idx = 3
|
||||
interaction = shap_interaction_values[:, feature1_idx, feature2_idx]
|
||||
print(f"Interaction strength: {np.abs(interaction).mean():.4f}")
|
||||
```
|
||||
|
||||
### Technique 2: Partial Dependence with SHAP
|
||||
Combine partial dependence plots with SHAP:
|
||||
```python
|
||||
from sklearn.inspection import partial_dependence
|
||||
|
||||
# SHAP dependence
|
||||
shap.plots.scatter(shap_values[:, "Feature1"])
|
||||
|
||||
# Partial dependence (model-agnostic)
|
||||
pd_result = partial_dependence(model, X_test, features=["Feature1"])
|
||||
plt.plot(pd_result['grid_values'][0], pd_result['average'][0])
|
||||
```
|
||||
|
||||
### Technique 3: Conditional Expectations
|
||||
Analyze SHAP values conditioned on other features:
|
||||
```python
|
||||
# High Income group
|
||||
high_income = X_test['Income'] > X_test['Income'].median()
|
||||
shap.plots.beeswarm(shap_values[high_income])
|
||||
|
||||
# Low Income group
|
||||
low_income = X_test['Income'] <= X_test['Income'].median()
|
||||
shap.plots.beeswarm(shap_values[low_income])
|
||||
```
|
||||
|
||||
### Technique 4: Feature Clustering for Redundancy
|
||||
```python
|
||||
# Create hierarchical clustering
|
||||
clustering = shap.utils.hclust(X_train, y_train)
|
||||
|
||||
# Visualize with clustering
|
||||
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.5)
|
||||
|
||||
# Identify redundant features to remove
|
||||
# Features with distance < 0.1 are highly redundant
|
||||
```
|
||||
|
||||
## Integration with MLOps
|
||||
|
||||
**Experiment Tracking**:
|
||||
```python
|
||||
import mlflow
|
||||
|
||||
# Log SHAP values
|
||||
with mlflow.start_run():
|
||||
# Train model
|
||||
model = train_model(X_train, y_train)
|
||||
|
||||
# Compute SHAP
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer(X_test)
|
||||
|
||||
# Log plots
|
||||
shap.plots.beeswarm(shap_values, show=False)
|
||||
mlflow.log_figure(plt.gcf(), "shap_beeswarm.png")
|
||||
plt.close()
|
||||
|
||||
# Log feature importance as metrics
|
||||
mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
|
||||
for feature, importance in zip(X_test.columns, mean_abs_shap):
|
||||
mlflow.log_metric(f"shap_{feature}", importance)
|
||||
```
|
||||
|
||||
**Model Monitoring**:
|
||||
```python
|
||||
# Track SHAP distribution drift over time
|
||||
def compute_shap_summary(shap_values):
|
||||
return {
|
||||
'mean': shap_values.values.mean(axis=0),
|
||||
'std': shap_values.values.std(axis=0),
|
||||
'percentiles': np.percentile(shap_values.values, [25, 50, 75], axis=0)
|
||||
}
|
||||
|
||||
# Compute baseline
|
||||
baseline_summary = compute_shap_summary(shap_values_train)
|
||||
|
||||
# Monitor production data
|
||||
production_summary = compute_shap_summary(shap_values_production)
|
||||
|
||||
# Detect drift
|
||||
drift_detected = np.abs(
|
||||
production_summary['mean'] - baseline_summary['mean']
|
||||
) > threshold
|
||||
```
|
||||
|
||||
This comprehensive workflows document covers the most common and advanced use cases for SHAP in practice.
|
||||
Reference in New Issue
Block a user