Files
gh-k-dense-ai-claude-scient…/skills/shap/references/plots.md
2025-11-30 08:30:10 +08:00

508 lines
16 KiB
Markdown

# SHAP Visualization Reference
This document provides comprehensive information about all SHAP plotting functions, their parameters, use cases, and best practices for visualizing model explanations.
## Overview
SHAP provides diverse visualization tools for explaining model predictions at both individual and global levels. Each plot type serves specific purposes in understanding feature importance, interactions, and prediction mechanisms.
## Plot Types
### Waterfall Plots
**Purpose**: Display explanations for individual predictions, showing how each feature moves the prediction from the baseline (expected value) toward the final prediction.
**Function**: `shap.plots.waterfall(explanation, max_display=10, show=True)`
**Key Parameters**:
- `explanation`: Single row from an Explanation object (not multiple samples)
- `max_display`: Number of features to show (default: 10); less impactful features collapse into a single "other features" term
- `show`: Whether to display the plot immediately
**Visual Elements**:
- **X-axis**: Shows SHAP values (contribution to prediction)
- **Starting point**: Model's expected value (baseline)
- **Feature contributions**: Red bars (positive) or blue bars (negative) showing how each feature moves the prediction
- **Feature values**: Displayed in gray to the left of feature names
- **Ending point**: Final model prediction
**When to Use**:
- Explaining individual predictions in detail
- Understanding which features drove a specific decision
- Communicating model behavior for single instances (e.g., loan denial, diagnosis)
- Debugging unexpected predictions
**Important Notes**:
- For XGBoost classifiers, predictions are explained in log-odds units (margin output before logistic transformation)
- SHAP values sum to the difference between baseline and final prediction (additivity property)
- Use scatter plots alongside waterfall plots to explore patterns across multiple samples
**Example**:
```python
import shap
# Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Plot waterfall for first prediction
shap.plots.waterfall(shap_values[0])
# Show more features
shap.plots.waterfall(shap_values[0], max_display=20)
```
### Beeswarm Plots
**Purpose**: Information-dense summary of how top features impact model output across the entire dataset, combining feature importance with value distributions.
**Function**: `shap.plots.beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), color=None, show=True)`
**Key Parameters**:
- `shap_values`: Explanation object containing multiple samples
- `max_display`: Number of features to display (default: 10)
- `order`: How to rank features
- `Explanation.abs.mean(0)`: Mean absolute SHAP values (default)
- `Explanation.abs.max(0)`: Maximum absolute values (highlights outlier impacts)
- `color`: matplotlib colormap; defaults to red-blue scheme
- `show`: Whether to display the plot immediately
**Visual Elements**:
- **Y-axis**: Features ranked by importance
- **X-axis**: SHAP value (impact on model output)
- **Each dot**: Single instance from dataset
- **Dot position (X)**: SHAP value magnitude
- **Dot color**: Original feature value (red = high, blue = low)
- **Dot clustering**: Shows density/distribution of impacts
**When to Use**:
- Summarizing feature importance across entire datasets
- Understanding both average and individual feature impacts
- Identifying feature value patterns and their effects
- Comparing global model behavior across features
- Detecting nonlinear relationships (e.g., higher age → lower income likelihood)
**Practical Variations**:
```python
# Standard beeswarm plot
shap.plots.beeswarm(shap_values)
# Show more features
shap.plots.beeswarm(shap_values, max_display=20)
# Order by maximum absolute values (highlight outliers)
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))
# Plot absolute SHAP values with fixed coloring
shap.plots.beeswarm(shap_values.abs, color="shap_red")
# Custom matplotlib colormap
shap.plots.beeswarm(shap_values, color=plt.cm.viridis)
```
### Bar Plots
**Purpose**: Display feature importance as mean absolute SHAP values, providing clean, simple visualizations of global feature impact.
**Function**: `shap.plots.bar(shap_values, max_display=10, clustering=None, clustering_cutoff=0.5, show=True)`
**Key Parameters**:
- `shap_values`: Explanation object (can be single instance, global, or cohorts)
- `max_display`: Maximum number of features/bars to show
- `clustering`: Optional hierarchical clustering object from `shap.utils.hclust`
- `clustering_cutoff`: Threshold for displaying clustering structure (0-1, default: 0.5)
**Plot Types**:
#### Global Bar Plot
Shows overall feature importance across all samples. Importance calculated as mean absolute SHAP value.
```python
# Global feature importance
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.bar(shap_values)
```
#### Local Bar Plot
Displays SHAP values for a single instance with feature values shown in gray.
```python
# Single prediction explanation
shap.plots.bar(shap_values[0])
```
#### Cohort Bar Plot
Compares feature importance across subgroups by passing a dictionary of Explanation objects.
```python
# Compare cohorts
cohorts = {
"Group A": shap_values[mask_A],
"Group B": shap_values[mask_B]
}
shap.plots.bar(cohorts)
```
**Feature Clustering**:
Identifies redundant features using model-based clustering (more accurate than correlation-based methods).
```python
# Add feature clustering
clustering = shap.utils.hclust(X_train, y_train)
shap.plots.bar(shap_values, clustering=clustering)
# Adjust clustering display threshold
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.3)
```
**When to Use**:
- Quick overview of global feature importance
- Comparing feature importance across cohorts or models
- Identifying redundant or correlated features
- Clean, simple visualizations for presentations
### Force Plots
**Purpose**: Additive force visualization showing how features push prediction higher (red) or lower (blue) from baseline.
**Function**: `shap.plots.force(base_value, shap_values, features, feature_names=None, out_names=None, link="identity", matplotlib=False, show=True)`
**Key Parameters**:
- `base_value`: Expected value (baseline prediction)
- `shap_values`: SHAP values for sample(s)
- `features`: Feature values for sample(s)
- `feature_names`: Optional feature names
- `link`: Transform function ("identity" or "logit")
- `matplotlib`: Use matplotlib backend (default: interactive JavaScript)
**Visual Elements**:
- **Baseline**: Starting prediction (expected value)
- **Red arrows**: Features pushing prediction higher
- **Blue arrows**: Features pushing prediction lower
- **Final value**: Resulting prediction
**Interactive Features** (JavaScript mode):
- Hover for detailed feature information
- Multiple samples create stacked visualization
- Can rotate for different perspectives
**When to Use**:
- Interactive exploration of predictions
- Visualizing multiple predictions simultaneously
- Presentations requiring interactive elements
- Understanding prediction composition at a glance
**Example**:
```python
# Single prediction force plot
shap.plots.force(
shap_values.base_values[0],
shap_values.values[0],
X_test.iloc[0],
matplotlib=True
)
# Multiple predictions (interactive)
shap.plots.force(
shap_values.base_values,
shap_values.values,
X_test
)
```
### Scatter Plots (Dependence Plots)
**Purpose**: Show relationship between feature values and their SHAP values, revealing how feature values impact predictions.
**Function**: `shap.plots.scatter(shap_values, color=None, hist=True, alpha=1, show=True)`
**Key Parameters**:
- `shap_values`: Explanation object, can specify feature with subscript (e.g., `shap_values[:, "Age"]`)
- `color`: Feature to use for coloring points (string name or Explanation object)
- `hist`: Show histogram of feature values on y-axis
- `alpha`: Point transparency (useful for dense plots)
**Visual Elements**:
- **X-axis**: Feature value
- **Y-axis**: SHAP value (impact on prediction)
- **Point color**: Another feature's value (for interaction detection)
- **Histogram**: Distribution of feature values
**When to Use**:
- Understanding feature-prediction relationships
- Detecting nonlinear effects
- Identifying feature interactions
- Validating or discovering patterns in model behavior
- Exploring counterintuitive predictions from waterfall plots
**Interaction Detection**:
Color points by another feature to reveal interactions.
```python
# Basic dependence plot
shap.plots.scatter(shap_values[:, "Age"])
# Color by another feature to show interactions
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education"])
# Multiple features in one plot
shap.plots.scatter(shap_values[:, ["Age", "Education", "Hours-per-week"]])
# Increase transparency for dense data
shap.plots.scatter(shap_values[:, "Age"], alpha=0.5)
```
### Heatmap Plots
**Purpose**: Visualize SHAP values for multiple samples simultaneously, showing feature impacts across instances.
**Function**: `shap.plots.heatmap(shap_values, instance_order=None, feature_values=None, max_display=10, show=True)`
**Key Parameters**:
- `shap_values`: Explanation object
- `instance_order`: How to order instances (can be Explanation object for custom ordering)
- `feature_values`: Display feature values on hover
- `max_display`: Maximum features to display
**Visual Elements**:
- **Rows**: Individual instances/samples
- **Columns**: Features
- **Cell color**: SHAP value (red = positive, blue = negative)
- **Intensity**: Magnitude of impact
**When to Use**:
- Comparing explanations across multiple instances
- Identifying patterns in feature impacts
- Understanding which features vary most across predictions
- Detecting subgroups or clusters with similar explanation patterns
**Example**:
```python
# Basic heatmap
shap.plots.heatmap(shap_values)
# Order instances by model output
shap.plots.heatmap(shap_values, instance_order=shap_values.sum(1))
# Show specific subset
shap.plots.heatmap(shap_values[:100])
```
### Violin Plots
**Purpose**: Similar to beeswarm plots but uses violin (kernel density) visualization instead of individual dots.
**Function**: `shap.plots.violin(shap_values, features=None, feature_names=None, max_display=10, show=True)`
**When to Use**:
- Alternative to beeswarm when dataset is very large
- Emphasizing distribution density over individual points
- Cleaner visualization for presentations
**Example**:
```python
shap.plots.violin(shap_values)
```
### Decision Plots
**Purpose**: Show prediction paths through cumulative SHAP values, particularly useful for multiclass classification.
**Function**: `shap.plots.decision(base_value, shap_values, features, feature_names=None, feature_order="importance", highlight=None, link="identity", show=True)`
**Key Parameters**:
- `base_value`: Expected value
- `shap_values`: SHAP values for samples
- `features`: Feature values
- `feature_order`: How to order features ("importance" or list)
- `highlight`: Indices of samples to highlight
- `link`: Transform function
**When to Use**:
- Multiclass classification explanations
- Understanding cumulative feature effects
- Comparing prediction paths across samples
- Identifying where predictions diverge
**Example**:
```python
# Decision plot for multiple predictions
shap.plots.decision(
shap_values.base_values,
shap_values.values,
X_test,
feature_names=X_test.columns.tolist()
)
# Highlight specific instances
shap.plots.decision(
shap_values.base_values,
shap_values.values,
X_test,
highlight=[0, 5, 10]
)
```
## Plot Selection Guide
**For Individual Predictions**:
- **Waterfall**: Best for detailed, sequential explanation
- **Force**: Good for interactive exploration
- **Bar (local)**: Simple, clean single-prediction importance
**For Global Understanding**:
- **Beeswarm**: Information-dense summary with value distributions
- **Bar (global)**: Clean, simple importance ranking
- **Violin**: Distribution-focused alternative to beeswarm
**For Feature Relationships**:
- **Scatter**: Understand feature-prediction relationships and interactions
- **Heatmap**: Compare patterns across multiple instances
**For Multiple Samples**:
- **Heatmap**: Grid view of SHAP values
- **Force (stacked)**: Interactive multi-sample visualization
- **Decision**: Prediction paths for multiclass problems
**For Cohort Comparison**:
- **Bar (cohort)**: Clean comparison of feature importance
- **Multiple beeswarms**: Side-by-side distribution comparisons
## Visualization Best Practices
**1. Start Global, Then Go Local**:
- Begin with beeswarm or bar plot to understand global patterns
- Dive into waterfall or scatter plots for specific instances or features
**2. Use Multiple Plot Types**:
- Different plots reveal different insights
- Combine waterfall (individual) + scatter (relationship) + beeswarm (global)
**3. Adjust max_display**:
- Default (10) is good for presentations
- Increase (20-30) for detailed analysis
- Consider clustering for redundant features
**4. Color Meaningfully**:
- Use default red-blue for SHAP values (red = positive, blue = negative)
- Color scatter plots by interacting features
- Custom colormaps for specific domains
**5. Consider Audience**:
- Technical audience: Beeswarm, scatter, heatmap
- Non-technical audience: Waterfall, bar, force plots
- Interactive presentations: Force plots with JavaScript
**6. Save High-Quality Figures**:
```python
import matplotlib.pyplot as plt
# Create plot
shap.plots.beeswarm(shap_values, show=False)
# Save with high DPI
plt.savefig('shap_plot.png', dpi=300, bbox_inches='tight')
plt.close()
```
**7. Handle Large Datasets**:
- Sample subset for visualization (e.g., `shap_values[:1000]`)
- Use violin instead of beeswarm for very large datasets
- Adjust alpha for scatter plots with many points
## Common Patterns and Workflows
**Pattern 1: Complete Model Explanation**
```python
# 1. Global importance
shap.plots.beeswarm(shap_values)
# 2. Top feature relationships
for feature in top_features:
shap.plots.scatter(shap_values[:, feature])
# 3. Example predictions
for i in interesting_indices:
shap.plots.waterfall(shap_values[i])
```
**Pattern 2: Model Comparison**
```python
# Compute SHAP for multiple models
shap_model1 = explainer1(X_test)
shap_model2 = explainer2(X_test)
# Compare feature importance
shap.plots.bar({
"Model 1": shap_model1,
"Model 2": shap_model2
})
```
**Pattern 3: Subgroup Analysis**
```python
# Define cohorts
male_mask = X_test['Sex'] == 'Male'
female_mask = X_test['Sex'] == 'Female'
# Compare cohorts
shap.plots.bar({
"Male": shap_values[male_mask],
"Female": shap_values[female_mask]
})
# Separate beeswarm plots
shap.plots.beeswarm(shap_values[male_mask])
shap.plots.beeswarm(shap_values[female_mask])
```
**Pattern 4: Debugging Predictions**
```python
# Identify outliers or errors
errors = (model.predict(X_test) != y_test)
error_indices = np.where(errors)[0]
# Explain errors
for idx in error_indices[:5]:
print(f"Sample {idx}:")
shap.plots.waterfall(shap_values[idx])
# Explore key features
shap.plots.scatter(shap_values[:, "Key_Feature"])
```
## Integration with Notebooks and Reports
**Jupyter Notebooks**:
- Interactive force plots work seamlessly
- Use `show=True` (default) for inline display
- Combine with markdown explanations
**Static Reports**:
- Use matplotlib backend for force plots
- Save figures programmatically
- Prefer waterfall and bar plots for clarity
**Web Applications**:
- Export force plots as HTML
- Use shap.save_html() for interactive visualizations
- Consider generating plots on-demand
## Troubleshooting Visualizations
**Issue**: Plots don't display
- **Solution**: Ensure matplotlib backend is set correctly; use `plt.show()` if needed
**Issue**: Too many features cluttering plot
- **Solution**: Reduce `max_display` parameter or use feature clustering
**Issue**: Colors reversed or confusing
- **Solution**: Check model output type (probability vs. log-odds) and use appropriate link function
**Issue**: Slow plotting with large datasets
- **Solution**: Sample subset of data; use `shap_values[:1000]` for visualization
**Issue**: Feature names missing
- **Solution**: Ensure feature_names are in Explanation object or pass explicitly to plot functions