508 lines
16 KiB
Markdown
508 lines
16 KiB
Markdown
# SHAP Visualization Reference
|
|
|
|
This document provides comprehensive information about all SHAP plotting functions, their parameters, use cases, and best practices for visualizing model explanations.
|
|
|
|
## Overview
|
|
|
|
SHAP provides diverse visualization tools for explaining model predictions at both individual and global levels. Each plot type serves specific purposes in understanding feature importance, interactions, and prediction mechanisms.
|
|
|
|
## Plot Types
|
|
|
|
### Waterfall Plots
|
|
|
|
**Purpose**: Display explanations for individual predictions, showing how each feature moves the prediction from the baseline (expected value) toward the final prediction.
|
|
|
|
**Function**: `shap.plots.waterfall(explanation, max_display=10, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `explanation`: Single row from an Explanation object (not multiple samples)
|
|
- `max_display`: Number of features to show (default: 10); less impactful features collapse into a single "other features" term
|
|
- `show`: Whether to display the plot immediately
|
|
|
|
**Visual Elements**:
|
|
- **X-axis**: Shows SHAP values (contribution to prediction)
|
|
- **Starting point**: Model's expected value (baseline)
|
|
- **Feature contributions**: Red bars (positive) or blue bars (negative) showing how each feature moves the prediction
|
|
- **Feature values**: Displayed in gray to the left of feature names
|
|
- **Ending point**: Final model prediction
|
|
|
|
**When to Use**:
|
|
- Explaining individual predictions in detail
|
|
- Understanding which features drove a specific decision
|
|
- Communicating model behavior for single instances (e.g., loan denial, diagnosis)
|
|
- Debugging unexpected predictions
|
|
|
|
**Important Notes**:
|
|
- For XGBoost classifiers, predictions are explained in log-odds units (margin output before logistic transformation)
|
|
- SHAP values sum to the difference between baseline and final prediction (additivity property)
|
|
- Use scatter plots alongside waterfall plots to explore patterns across multiple samples
|
|
|
|
**Example**:
|
|
```python
|
|
import shap
|
|
|
|
# Compute SHAP values
|
|
explainer = shap.TreeExplainer(model)
|
|
shap_values = explainer(X_test)
|
|
|
|
# Plot waterfall for first prediction
|
|
shap.plots.waterfall(shap_values[0])
|
|
|
|
# Show more features
|
|
shap.plots.waterfall(shap_values[0], max_display=20)
|
|
```
|
|
|
|
### Beeswarm Plots
|
|
|
|
**Purpose**: Information-dense summary of how top features impact model output across the entire dataset, combining feature importance with value distributions.
|
|
|
|
**Function**: `shap.plots.beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), color=None, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `shap_values`: Explanation object containing multiple samples
|
|
- `max_display`: Number of features to display (default: 10)
|
|
- `order`: How to rank features
|
|
- `Explanation.abs.mean(0)`: Mean absolute SHAP values (default)
|
|
- `Explanation.abs.max(0)`: Maximum absolute values (highlights outlier impacts)
|
|
- `color`: matplotlib colormap; defaults to red-blue scheme
|
|
- `show`: Whether to display the plot immediately
|
|
|
|
**Visual Elements**:
|
|
- **Y-axis**: Features ranked by importance
|
|
- **X-axis**: SHAP value (impact on model output)
|
|
- **Each dot**: Single instance from dataset
|
|
- **Dot position (X)**: SHAP value magnitude
|
|
- **Dot color**: Original feature value (red = high, blue = low)
|
|
- **Dot clustering**: Shows density/distribution of impacts
|
|
|
|
**When to Use**:
|
|
- Summarizing feature importance across entire datasets
|
|
- Understanding both average and individual feature impacts
|
|
- Identifying feature value patterns and their effects
|
|
- Comparing global model behavior across features
|
|
- Detecting nonlinear relationships (e.g., higher age → lower income likelihood)
|
|
|
|
**Practical Variations**:
|
|
```python
|
|
# Standard beeswarm plot
|
|
shap.plots.beeswarm(shap_values)
|
|
|
|
# Show more features
|
|
shap.plots.beeswarm(shap_values, max_display=20)
|
|
|
|
# Order by maximum absolute values (highlight outliers)
|
|
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))
|
|
|
|
# Plot absolute SHAP values with fixed coloring
|
|
shap.plots.beeswarm(shap_values.abs, color="shap_red")
|
|
|
|
# Custom matplotlib colormap
|
|
shap.plots.beeswarm(shap_values, color=plt.cm.viridis)
|
|
```
|
|
|
|
### Bar Plots
|
|
|
|
**Purpose**: Display feature importance as mean absolute SHAP values, providing clean, simple visualizations of global feature impact.
|
|
|
|
**Function**: `shap.plots.bar(shap_values, max_display=10, clustering=None, clustering_cutoff=0.5, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `shap_values`: Explanation object (can be single instance, global, or cohorts)
|
|
- `max_display`: Maximum number of features/bars to show
|
|
- `clustering`: Optional hierarchical clustering object from `shap.utils.hclust`
|
|
- `clustering_cutoff`: Threshold for displaying clustering structure (0-1, default: 0.5)
|
|
|
|
**Plot Types**:
|
|
|
|
#### Global Bar Plot
|
|
Shows overall feature importance across all samples. Importance calculated as mean absolute SHAP value.
|
|
|
|
```python
|
|
# Global feature importance
|
|
explainer = shap.TreeExplainer(model)
|
|
shap_values = explainer(X_test)
|
|
shap.plots.bar(shap_values)
|
|
```
|
|
|
|
#### Local Bar Plot
|
|
Displays SHAP values for a single instance with feature values shown in gray.
|
|
|
|
```python
|
|
# Single prediction explanation
|
|
shap.plots.bar(shap_values[0])
|
|
```
|
|
|
|
#### Cohort Bar Plot
|
|
Compares feature importance across subgroups by passing a dictionary of Explanation objects.
|
|
|
|
```python
|
|
# Compare cohorts
|
|
cohorts = {
|
|
"Group A": shap_values[mask_A],
|
|
"Group B": shap_values[mask_B]
|
|
}
|
|
shap.plots.bar(cohorts)
|
|
```
|
|
|
|
**Feature Clustering**:
|
|
Identifies redundant features using model-based clustering (more accurate than correlation-based methods).
|
|
|
|
```python
|
|
# Add feature clustering
|
|
clustering = shap.utils.hclust(X_train, y_train)
|
|
shap.plots.bar(shap_values, clustering=clustering)
|
|
|
|
# Adjust clustering display threshold
|
|
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.3)
|
|
```
|
|
|
|
**When to Use**:
|
|
- Quick overview of global feature importance
|
|
- Comparing feature importance across cohorts or models
|
|
- Identifying redundant or correlated features
|
|
- Clean, simple visualizations for presentations
|
|
|
|
### Force Plots
|
|
|
|
**Purpose**: Additive force visualization showing how features push prediction higher (red) or lower (blue) from baseline.
|
|
|
|
**Function**: `shap.plots.force(base_value, shap_values, features, feature_names=None, out_names=None, link="identity", matplotlib=False, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `base_value`: Expected value (baseline prediction)
|
|
- `shap_values`: SHAP values for sample(s)
|
|
- `features`: Feature values for sample(s)
|
|
- `feature_names`: Optional feature names
|
|
- `link`: Transform function ("identity" or "logit")
|
|
- `matplotlib`: Use matplotlib backend (default: interactive JavaScript)
|
|
|
|
**Visual Elements**:
|
|
- **Baseline**: Starting prediction (expected value)
|
|
- **Red arrows**: Features pushing prediction higher
|
|
- **Blue arrows**: Features pushing prediction lower
|
|
- **Final value**: Resulting prediction
|
|
|
|
**Interactive Features** (JavaScript mode):
|
|
- Hover for detailed feature information
|
|
- Multiple samples create stacked visualization
|
|
- Can rotate for different perspectives
|
|
|
|
**When to Use**:
|
|
- Interactive exploration of predictions
|
|
- Visualizing multiple predictions simultaneously
|
|
- Presentations requiring interactive elements
|
|
- Understanding prediction composition at a glance
|
|
|
|
**Example**:
|
|
```python
|
|
# Single prediction force plot
|
|
shap.plots.force(
|
|
shap_values.base_values[0],
|
|
shap_values.values[0],
|
|
X_test.iloc[0],
|
|
matplotlib=True
|
|
)
|
|
|
|
# Multiple predictions (interactive)
|
|
shap.plots.force(
|
|
shap_values.base_values,
|
|
shap_values.values,
|
|
X_test
|
|
)
|
|
```
|
|
|
|
### Scatter Plots (Dependence Plots)
|
|
|
|
**Purpose**: Show relationship between feature values and their SHAP values, revealing how feature values impact predictions.
|
|
|
|
**Function**: `shap.plots.scatter(shap_values, color=None, hist=True, alpha=1, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `shap_values`: Explanation object, can specify feature with subscript (e.g., `shap_values[:, "Age"]`)
|
|
- `color`: Feature to use for coloring points (string name or Explanation object)
|
|
- `hist`: Show histogram of feature values on y-axis
|
|
- `alpha`: Point transparency (useful for dense plots)
|
|
|
|
**Visual Elements**:
|
|
- **X-axis**: Feature value
|
|
- **Y-axis**: SHAP value (impact on prediction)
|
|
- **Point color**: Another feature's value (for interaction detection)
|
|
- **Histogram**: Distribution of feature values
|
|
|
|
**When to Use**:
|
|
- Understanding feature-prediction relationships
|
|
- Detecting nonlinear effects
|
|
- Identifying feature interactions
|
|
- Validating or discovering patterns in model behavior
|
|
- Exploring counterintuitive predictions from waterfall plots
|
|
|
|
**Interaction Detection**:
|
|
Color points by another feature to reveal interactions.
|
|
|
|
```python
|
|
# Basic dependence plot
|
|
shap.plots.scatter(shap_values[:, "Age"])
|
|
|
|
# Color by another feature to show interactions
|
|
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education"])
|
|
|
|
# Multiple features in one plot
|
|
shap.plots.scatter(shap_values[:, ["Age", "Education", "Hours-per-week"]])
|
|
|
|
# Increase transparency for dense data
|
|
shap.plots.scatter(shap_values[:, "Age"], alpha=0.5)
|
|
```
|
|
|
|
### Heatmap Plots
|
|
|
|
**Purpose**: Visualize SHAP values for multiple samples simultaneously, showing feature impacts across instances.
|
|
|
|
**Function**: `shap.plots.heatmap(shap_values, instance_order=None, feature_values=None, max_display=10, show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `shap_values`: Explanation object
|
|
- `instance_order`: How to order instances (can be Explanation object for custom ordering)
|
|
- `feature_values`: Display feature values on hover
|
|
- `max_display`: Maximum features to display
|
|
|
|
**Visual Elements**:
|
|
- **Rows**: Individual instances/samples
|
|
- **Columns**: Features
|
|
- **Cell color**: SHAP value (red = positive, blue = negative)
|
|
- **Intensity**: Magnitude of impact
|
|
|
|
**When to Use**:
|
|
- Comparing explanations across multiple instances
|
|
- Identifying patterns in feature impacts
|
|
- Understanding which features vary most across predictions
|
|
- Detecting subgroups or clusters with similar explanation patterns
|
|
|
|
**Example**:
|
|
```python
|
|
# Basic heatmap
|
|
shap.plots.heatmap(shap_values)
|
|
|
|
# Order instances by model output
|
|
shap.plots.heatmap(shap_values, instance_order=shap_values.sum(1))
|
|
|
|
# Show specific subset
|
|
shap.plots.heatmap(shap_values[:100])
|
|
```
|
|
|
|
### Violin Plots
|
|
|
|
**Purpose**: Similar to beeswarm plots but uses violin (kernel density) visualization instead of individual dots.
|
|
|
|
**Function**: `shap.plots.violin(shap_values, features=None, feature_names=None, max_display=10, show=True)`
|
|
|
|
**When to Use**:
|
|
- Alternative to beeswarm when dataset is very large
|
|
- Emphasizing distribution density over individual points
|
|
- Cleaner visualization for presentations
|
|
|
|
**Example**:
|
|
```python
|
|
shap.plots.violin(shap_values)
|
|
```
|
|
|
|
### Decision Plots
|
|
|
|
**Purpose**: Show prediction paths through cumulative SHAP values, particularly useful for multiclass classification.
|
|
|
|
**Function**: `shap.plots.decision(base_value, shap_values, features, feature_names=None, feature_order="importance", highlight=None, link="identity", show=True)`
|
|
|
|
**Key Parameters**:
|
|
- `base_value`: Expected value
|
|
- `shap_values`: SHAP values for samples
|
|
- `features`: Feature values
|
|
- `feature_order`: How to order features ("importance" or list)
|
|
- `highlight`: Indices of samples to highlight
|
|
- `link`: Transform function
|
|
|
|
**When to Use**:
|
|
- Multiclass classification explanations
|
|
- Understanding cumulative feature effects
|
|
- Comparing prediction paths across samples
|
|
- Identifying where predictions diverge
|
|
|
|
**Example**:
|
|
```python
|
|
# Decision plot for multiple predictions
|
|
shap.plots.decision(
|
|
shap_values.base_values,
|
|
shap_values.values,
|
|
X_test,
|
|
feature_names=X_test.columns.tolist()
|
|
)
|
|
|
|
# Highlight specific instances
|
|
shap.plots.decision(
|
|
shap_values.base_values,
|
|
shap_values.values,
|
|
X_test,
|
|
highlight=[0, 5, 10]
|
|
)
|
|
```
|
|
|
|
## Plot Selection Guide
|
|
|
|
**For Individual Predictions**:
|
|
- **Waterfall**: Best for detailed, sequential explanation
|
|
- **Force**: Good for interactive exploration
|
|
- **Bar (local)**: Simple, clean single-prediction importance
|
|
|
|
**For Global Understanding**:
|
|
- **Beeswarm**: Information-dense summary with value distributions
|
|
- **Bar (global)**: Clean, simple importance ranking
|
|
- **Violin**: Distribution-focused alternative to beeswarm
|
|
|
|
**For Feature Relationships**:
|
|
- **Scatter**: Understand feature-prediction relationships and interactions
|
|
- **Heatmap**: Compare patterns across multiple instances
|
|
|
|
**For Multiple Samples**:
|
|
- **Heatmap**: Grid view of SHAP values
|
|
- **Force (stacked)**: Interactive multi-sample visualization
|
|
- **Decision**: Prediction paths for multiclass problems
|
|
|
|
**For Cohort Comparison**:
|
|
- **Bar (cohort)**: Clean comparison of feature importance
|
|
- **Multiple beeswarms**: Side-by-side distribution comparisons
|
|
|
|
## Visualization Best Practices
|
|
|
|
**1. Start Global, Then Go Local**:
|
|
- Begin with beeswarm or bar plot to understand global patterns
|
|
- Dive into waterfall or scatter plots for specific instances or features
|
|
|
|
**2. Use Multiple Plot Types**:
|
|
- Different plots reveal different insights
|
|
- Combine waterfall (individual) + scatter (relationship) + beeswarm (global)
|
|
|
|
**3. Adjust max_display**:
|
|
- Default (10) is good for presentations
|
|
- Increase (20-30) for detailed analysis
|
|
- Consider clustering for redundant features
|
|
|
|
**4. Color Meaningfully**:
|
|
- Use default red-blue for SHAP values (red = positive, blue = negative)
|
|
- Color scatter plots by interacting features
|
|
- Custom colormaps for specific domains
|
|
|
|
**5. Consider Audience**:
|
|
- Technical audience: Beeswarm, scatter, heatmap
|
|
- Non-technical audience: Waterfall, bar, force plots
|
|
- Interactive presentations: Force plots with JavaScript
|
|
|
|
**6. Save High-Quality Figures**:
|
|
```python
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Create plot
|
|
shap.plots.beeswarm(shap_values, show=False)
|
|
|
|
# Save with high DPI
|
|
plt.savefig('shap_plot.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
```
|
|
|
|
**7. Handle Large Datasets**:
|
|
- Sample subset for visualization (e.g., `shap_values[:1000]`)
|
|
- Use violin instead of beeswarm for very large datasets
|
|
- Adjust alpha for scatter plots with many points
|
|
|
|
## Common Patterns and Workflows
|
|
|
|
**Pattern 1: Complete Model Explanation**
|
|
```python
|
|
# 1. Global importance
|
|
shap.plots.beeswarm(shap_values)
|
|
|
|
# 2. Top feature relationships
|
|
for feature in top_features:
|
|
shap.plots.scatter(shap_values[:, feature])
|
|
|
|
# 3. Example predictions
|
|
for i in interesting_indices:
|
|
shap.plots.waterfall(shap_values[i])
|
|
```
|
|
|
|
**Pattern 2: Model Comparison**
|
|
```python
|
|
# Compute SHAP for multiple models
|
|
shap_model1 = explainer1(X_test)
|
|
shap_model2 = explainer2(X_test)
|
|
|
|
# Compare feature importance
|
|
shap.plots.bar({
|
|
"Model 1": shap_model1,
|
|
"Model 2": shap_model2
|
|
})
|
|
```
|
|
|
|
**Pattern 3: Subgroup Analysis**
|
|
```python
|
|
# Define cohorts
|
|
male_mask = X_test['Sex'] == 'Male'
|
|
female_mask = X_test['Sex'] == 'Female'
|
|
|
|
# Compare cohorts
|
|
shap.plots.bar({
|
|
"Male": shap_values[male_mask],
|
|
"Female": shap_values[female_mask]
|
|
})
|
|
|
|
# Separate beeswarm plots
|
|
shap.plots.beeswarm(shap_values[male_mask])
|
|
shap.plots.beeswarm(shap_values[female_mask])
|
|
```
|
|
|
|
**Pattern 4: Debugging Predictions**
|
|
```python
|
|
# Identify outliers or errors
|
|
errors = (model.predict(X_test) != y_test)
|
|
error_indices = np.where(errors)[0]
|
|
|
|
# Explain errors
|
|
for idx in error_indices[:5]:
|
|
print(f"Sample {idx}:")
|
|
shap.plots.waterfall(shap_values[idx])
|
|
|
|
# Explore key features
|
|
shap.plots.scatter(shap_values[:, "Key_Feature"])
|
|
```
|
|
|
|
## Integration with Notebooks and Reports
|
|
|
|
**Jupyter Notebooks**:
|
|
- Interactive force plots work seamlessly
|
|
- Use `show=True` (default) for inline display
|
|
- Combine with markdown explanations
|
|
|
|
**Static Reports**:
|
|
- Use matplotlib backend for force plots
|
|
- Save figures programmatically
|
|
- Prefer waterfall and bar plots for clarity
|
|
|
|
**Web Applications**:
|
|
- Export force plots as HTML
|
|
- Use shap.save_html() for interactive visualizations
|
|
- Consider generating plots on-demand
|
|
|
|
## Troubleshooting Visualizations
|
|
|
|
**Issue**: Plots don't display
|
|
- **Solution**: Ensure matplotlib backend is set correctly; use `plt.show()` if needed
|
|
|
|
**Issue**: Too many features cluttering plot
|
|
- **Solution**: Reduce `max_display` parameter or use feature clustering
|
|
|
|
**Issue**: Colors reversed or confusing
|
|
- **Solution**: Check model output type (probability vs. log-odds) and use appropriate link function
|
|
|
|
**Issue**: Slow plotting with large datasets
|
|
- **Solution**: Sample subset of data; use `shap_values[:1000]` for visualization
|
|
|
|
**Issue**: Feature names missing
|
|
- **Solution**: Ensure feature_names are in Explanation object or pass explicitly to plot functions
|