Initial commit
This commit is contained in:
339
skills/shap/references/explainers.md
Normal file
339
skills/shap/references/explainers.md
Normal file
@@ -0,0 +1,339 @@
|
||||
# SHAP Explainers Reference
|
||||
|
||||
This document provides comprehensive information about all SHAP explainer classes, their parameters, methods, and when to use each type.
|
||||
|
||||
## Overview
|
||||
|
||||
SHAP provides specialized explainers for different model types, each optimized for specific architectures. The general `shap.Explainer` class automatically selects the appropriate algorithm based on the model type.
|
||||
|
||||
## Core Explainer Classes
|
||||
|
||||
### shap.Explainer (Auto-selector)
|
||||
|
||||
**Purpose**: Automatically uses Shapley values to explain any machine learning model or Python function by selecting the most appropriate explainer algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: The model to explain (function or model object)
|
||||
- `masker`: Background data or masker object for feature manipulation
|
||||
- `algorithm`: Optional override to force specific explainer type
|
||||
- `output_names`: Names for model outputs
|
||||
- `feature_names`: Names for input features
|
||||
|
||||
**When to Use**: Default choice when unsure which explainer to use; automatically selects the best algorithm based on model type.
|
||||
|
||||
### TreeExplainer
|
||||
|
||||
**Purpose**: Fast and exact SHAP value computation for tree-based ensemble models using the Tree SHAP algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Tree-based model (XGBoost, LightGBM, CatBoost, PySpark, or scikit-learn trees)
|
||||
- `data`: Background dataset for feature integration (optional with tree_path_dependent)
|
||||
- `feature_perturbation`: How to handle dependent features
|
||||
- `"interventional"`: Requires background data; follows causal inference rules
|
||||
- `"tree_path_dependent"`: No background data needed; uses training examples per leaf
|
||||
- `"auto"`: Defaults to interventional if data provided, otherwise tree_path_dependent
|
||||
- `model_output`: What model output to explain
|
||||
- `"raw"`: Standard model output (default)
|
||||
- `"probability"`: Probability-transformed output
|
||||
- `"log_loss"`: Natural log of loss function
|
||||
- Custom method names like `"predict_proba"`
|
||||
- `feature_names`: Optional feature naming
|
||||
|
||||
**Supported Models**:
|
||||
- XGBoost (xgboost.XGBClassifier, xgboost.XGBRegressor, xgboost.Booster)
|
||||
- LightGBM (lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, lightgbm.Booster)
|
||||
- CatBoost (catboost.CatBoostClassifier, catboost.CatBoostRegressor)
|
||||
- PySpark MLlib tree models
|
||||
- scikit-learn (DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor, GradientBoostingClassifier, GradientBoostingRegressor)
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X)`: Computes SHAP values for samples; returns arrays where each row represents feature attribution
|
||||
- `shap_interaction_values(X)`: Estimates interaction effects between feature pairs; provides matrices with main effects and pairwise interactions
|
||||
- `explain_row(row)`: Explains individual rows with detailed attribution information
|
||||
|
||||
**When to Use**:
|
||||
- Primary choice for all tree-based models
|
||||
- When exact SHAP values are needed (not approximations)
|
||||
- When computational speed is important for large datasets
|
||||
- For models like random forests, gradient boosting, or XGBoost
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import xgboost
|
||||
|
||||
# Train model
|
||||
model = xgboost.XGBClassifier().fit(X_train, y_train)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.TreeExplainer(model)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test)
|
||||
|
||||
# Compute interaction values
|
||||
shap_interaction = explainer.shap_interaction_values(X_test)
|
||||
```
|
||||
|
||||
### DeepExplainer
|
||||
|
||||
**Purpose**: Approximates SHAP values for deep learning models using an enhanced version of the DeepLIFT algorithm.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Framework-dependent specification
|
||||
- **TensorFlow**: Tuple of (input_tensor, output_tensor) where output is single-dimensional
|
||||
- **PyTorch**: `nn.Module` object or tuple of `(model, layer)` for layer-specific explanations
|
||||
- `data`: Background dataset for feature integration
|
||||
- **TensorFlow**: numpy arrays or pandas DataFrames
|
||||
- **PyTorch**: torch tensors
|
||||
- **Recommended size**: 100-1000 samples (not full training set) to balance accuracy and computational cost
|
||||
- `session` (TensorFlow only): Optional session object; auto-detected if None
|
||||
- `learning_phase_flags`: Custom learning phase tensors for handling batch norm/dropout during inference
|
||||
|
||||
**Supported Frameworks**:
|
||||
- **TensorFlow**: Full support including Keras models
|
||||
- **PyTorch**: Complete integration with nn.Module architecture
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X)`: Returns approximate SHAP values for the model applied to data X
|
||||
- `explain_row(row)`: Explains single rows with attribution values and expected outputs
|
||||
- `save(file)` / `load(file)`: Serialization support for explainer objects
|
||||
- `supports_model_with_masker(model, masker)`: Compatibility checker for model types
|
||||
|
||||
**When to Use**:
|
||||
- For deep neural networks in TensorFlow or PyTorch
|
||||
- When working with convolutional neural networks (CNNs)
|
||||
- For recurrent neural networks (RNNs) and transformers
|
||||
- When model-specific explanation is needed for deep learning architectures
|
||||
|
||||
**Key Design Feature**:
|
||||
Variance of expectation estimates scales approximately as 1/√N, where N is the number of background samples, enabling accuracy-efficiency trade-offs.
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import tensorflow as tf
|
||||
|
||||
# Assume model is a Keras model
|
||||
model = tf.keras.models.load_model('my_model.h5')
|
||||
|
||||
# Select background samples (subset of training data)
|
||||
background = X_train[:100]
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.DeepExplainer(model, background)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### KernelExplainer
|
||||
|
||||
**Purpose**: Model-agnostic SHAP value computation using the Kernel SHAP method with weighted linear regression.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Function or model object that takes a matrix of samples and returns model outputs
|
||||
- `data`: Background dataset (numpy array, pandas DataFrame, or sparse matrix) used to simulate missing features
|
||||
- `feature_names`: Optional list of feature names; automatically derived from DataFrame column names if available
|
||||
- `link`: Connection function between feature importance and model output
|
||||
- `"identity"`: Direct relationship (default)
|
||||
- `"logit"`: For probability outputs
|
||||
|
||||
**Key Methods**:
|
||||
- `shap_values(X, **kwargs)`: Calculates SHAP values for sample predictions
|
||||
- `nsamples`: Evaluation count per prediction ("auto" or integer); higher values reduce variance
|
||||
- `l1_reg`: Feature selection regularization ("num_features(int)", "aic", "bic", or float)
|
||||
- Returns arrays where each row sums to the difference between model output and expected value
|
||||
- `explain_row(row)`: Explains individual predictions with attribution values and expected values
|
||||
- `save(file)` / `load(file)`: Persist and restore explainer objects
|
||||
|
||||
**When to Use**:
|
||||
- For black-box models where specialized explainers aren't available
|
||||
- When working with custom prediction functions
|
||||
- For any model type (neural networks, SVMs, ensemble methods, etc.)
|
||||
- When model-agnostic explanations are needed
|
||||
- **Note**: Slower than specialized explainers; use only when no specialized option exists
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
from sklearn.svm import SVC
|
||||
|
||||
# Train model
|
||||
model = SVC(probability=True).fit(X_train, y_train)
|
||||
|
||||
# Create prediction function
|
||||
predict_fn = lambda x: model.predict_proba(x)[:, 1]
|
||||
|
||||
# Select background samples
|
||||
background = shap.sample(X_train, 100)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.KernelExplainer(predict_fn, background)
|
||||
|
||||
# Compute SHAP values (may be slow)
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### LinearExplainer
|
||||
|
||||
**Purpose**: Specialized explainer for linear models that accounts for feature correlations.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Linear model or tuple of (coefficients, intercept)
|
||||
- `masker`: Background data for feature correlation
|
||||
- `feature_perturbation`: How to handle feature correlations
|
||||
- `"interventional"`: Assumes feature independence
|
||||
- `"correlation_dependent"`: Accounts for feature correlations
|
||||
|
||||
**Supported Models**:
|
||||
- scikit-learn linear models (LinearRegression, LogisticRegression, Ridge, Lasso, ElasticNet)
|
||||
- Custom linear models with coefficients and intercept
|
||||
|
||||
**When to Use**:
|
||||
- For linear regression and logistic regression models
|
||||
- When feature correlations are important to explanation accuracy
|
||||
- When extremely fast explanations are needed
|
||||
- For GLMs and other linear model types
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
# Train model
|
||||
model = LogisticRegression().fit(X_train, y_train)
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.LinearExplainer(model, X_train)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test)
|
||||
```
|
||||
|
||||
### GradientExplainer
|
||||
|
||||
**Purpose**: Uses expected gradients to approximate SHAP values for neural networks.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Deep learning model (TensorFlow or PyTorch)
|
||||
- `data`: Background samples for integration
|
||||
- `batch_size`: Batch size for gradient computation
|
||||
- `local_smoothing`: Amount of noise to add for smoothing (default 0)
|
||||
|
||||
**When to Use**:
|
||||
- As an alternative to DeepExplainer for neural networks
|
||||
- When gradient-based explanations are preferred
|
||||
- For differentiable models where gradient information is available
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
import torch
|
||||
|
||||
# Assume model is a PyTorch model
|
||||
model = torch.load('model.pt')
|
||||
|
||||
# Select background samples
|
||||
background = X_train[:100]
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.GradientExplainer(model, background)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
### PermutationExplainer
|
||||
|
||||
**Purpose**: Approximates Shapley values by iterating through permutations of inputs.
|
||||
|
||||
**Constructor Parameters**:
|
||||
- `model`: Prediction function
|
||||
- `masker`: Background data or masker object
|
||||
- `max_evals`: Maximum number of model evaluations per sample
|
||||
|
||||
**When to Use**:
|
||||
- When exact Shapley values are needed but specialized explainers aren't available
|
||||
- For small feature sets where permutation is tractable
|
||||
- As a more accurate alternative to KernelExplainer (but slower)
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import shap
|
||||
|
||||
# Create explainer
|
||||
explainer = shap.PermutationExplainer(model.predict, X_train)
|
||||
|
||||
# Compute SHAP values
|
||||
shap_values = explainer.shap_values(X_test[:10])
|
||||
```
|
||||
|
||||
## Explainer Selection Guide
|
||||
|
||||
**Decision Tree for Choosing an Explainer**:
|
||||
|
||||
1. **Is your model tree-based?** (XGBoost, LightGBM, CatBoost, Random Forest, etc.)
|
||||
- Yes → Use `TreeExplainer` (fast and exact)
|
||||
- No → Continue to step 2
|
||||
|
||||
2. **Is your model a deep neural network?** (TensorFlow, PyTorch, Keras)
|
||||
- Yes → Use `DeepExplainer` or `GradientExplainer`
|
||||
- No → Continue to step 3
|
||||
|
||||
3. **Is your model linear?** (Linear/Logistic Regression, GLMs)
|
||||
- Yes → Use `LinearExplainer` (extremely fast)
|
||||
- No → Continue to step 4
|
||||
|
||||
4. **Do you need model-agnostic explanations?**
|
||||
- Yes → Use `KernelExplainer` (slower but works with any model)
|
||||
- If computational budget allows and high accuracy is needed → Use `PermutationExplainer`
|
||||
|
||||
5. **Unsure or want automatic selection?**
|
||||
- Use `shap.Explainer` (auto-selects best algorithm)
|
||||
|
||||
## Common Parameters Across Explainers
|
||||
|
||||
**Background Data / Masker**:
|
||||
- Purpose: Represents the "typical" input to establish baseline expectations
|
||||
- Size recommendations: 50-1000 samples (more for complex models)
|
||||
- Selection: Random sample from training data or kmeans-selected representatives
|
||||
|
||||
**Feature Names**:
|
||||
- Automatically extracted from pandas DataFrames
|
||||
- Can be manually specified for numpy arrays
|
||||
- Important for plot interpretability
|
||||
|
||||
**Model Output Specification**:
|
||||
- Raw model output vs. transformed output (probabilities, log-odds)
|
||||
- Critical for correct interpretation of SHAP values
|
||||
- Example: For XGBoost classifiers, SHAP explains margin output (log-odds) before logistic transformation
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
**Speed Ranking** (fastest to slowest):
|
||||
1. `LinearExplainer` - Nearly instantaneous
|
||||
2. `TreeExplainer` - Very fast, scales well
|
||||
3. `DeepExplainer` - Fast for neural networks
|
||||
4. `GradientExplainer` - Fast for neural networks
|
||||
5. `KernelExplainer` - Slow, use only when necessary
|
||||
6. `PermutationExplainer` - Very slow but most accurate for small feature sets
|
||||
|
||||
**Memory Considerations**:
|
||||
- `TreeExplainer`: Low memory overhead
|
||||
- `DeepExplainer`: Memory proportional to background sample size
|
||||
- `KernelExplainer`: Can be memory-intensive for large background datasets
|
||||
- For large datasets: Use batching or sample subsets
|
||||
|
||||
## Explainer Output: The Explanation Object
|
||||
|
||||
All explainers return `shap.Explanation` objects containing:
|
||||
- `values`: SHAP values (numpy array)
|
||||
- `base_values`: Expected model output (baseline)
|
||||
- `data`: Original feature values
|
||||
- `feature_names`: Names of features
|
||||
|
||||
The Explanation object supports:
|
||||
- Slicing: `explanation[0]` for first sample
|
||||
- Array operations: Compatible with numpy operations
|
||||
- Direct plotting: Can be passed to plot functions
|
||||
Reference in New Issue
Block a user