Files
gh-k-dense-ai-claude-scient…/skills/shap/references/plots.md
2025-11-30 08:30:10 +08:00

16 KiB

SHAP Visualization Reference

This document provides comprehensive information about all SHAP plotting functions, their parameters, use cases, and best practices for visualizing model explanations.

Overview

SHAP provides diverse visualization tools for explaining model predictions at both individual and global levels. Each plot type serves specific purposes in understanding feature importance, interactions, and prediction mechanisms.

Plot Types

Waterfall Plots

Purpose: Display explanations for individual predictions, showing how each feature moves the prediction from the baseline (expected value) toward the final prediction.

Function: shap.plots.waterfall(explanation, max_display=10, show=True)

Key Parameters:

  • explanation: Single row from an Explanation object (not multiple samples)
  • max_display: Number of features to show (default: 10); less impactful features collapse into a single "other features" term
  • show: Whether to display the plot immediately

Visual Elements:

  • X-axis: Shows SHAP values (contribution to prediction)
  • Starting point: Model's expected value (baseline)
  • Feature contributions: Red bars (positive) or blue bars (negative) showing how each feature moves the prediction
  • Feature values: Displayed in gray to the left of feature names
  • Ending point: Final model prediction

When to Use:

  • Explaining individual predictions in detail
  • Understanding which features drove a specific decision
  • Communicating model behavior for single instances (e.g., loan denial, diagnosis)
  • Debugging unexpected predictions

Important Notes:

  • For XGBoost classifiers, predictions are explained in log-odds units (margin output before logistic transformation)
  • SHAP values sum to the difference between baseline and final prediction (additivity property)
  • Use scatter plots alongside waterfall plots to explore patterns across multiple samples

Example:

import shap

# Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)

# Plot waterfall for first prediction
shap.plots.waterfall(shap_values[0])

# Show more features
shap.plots.waterfall(shap_values[0], max_display=20)

Beeswarm Plots

Purpose: Information-dense summary of how top features impact model output across the entire dataset, combining feature importance with value distributions.

Function: shap.plots.beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), color=None, show=True)

Key Parameters:

  • shap_values: Explanation object containing multiple samples
  • max_display: Number of features to display (default: 10)
  • order: How to rank features
    • Explanation.abs.mean(0): Mean absolute SHAP values (default)
    • Explanation.abs.max(0): Maximum absolute values (highlights outlier impacts)
  • color: matplotlib colormap; defaults to red-blue scheme
  • show: Whether to display the plot immediately

Visual Elements:

  • Y-axis: Features ranked by importance
  • X-axis: SHAP value (impact on model output)
  • Each dot: Single instance from dataset
  • Dot position (X): SHAP value magnitude
  • Dot color: Original feature value (red = high, blue = low)
  • Dot clustering: Shows density/distribution of impacts

When to Use:

  • Summarizing feature importance across entire datasets
  • Understanding both average and individual feature impacts
  • Identifying feature value patterns and their effects
  • Comparing global model behavior across features
  • Detecting nonlinear relationships (e.g., higher age → lower income likelihood)

Practical Variations:

# Standard beeswarm plot
shap.plots.beeswarm(shap_values)

# Show more features
shap.plots.beeswarm(shap_values, max_display=20)

# Order by maximum absolute values (highlight outliers)
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))

# Plot absolute SHAP values with fixed coloring
shap.plots.beeswarm(shap_values.abs, color="shap_red")

# Custom matplotlib colormap
shap.plots.beeswarm(shap_values, color=plt.cm.viridis)

Bar Plots

Purpose: Display feature importance as mean absolute SHAP values, providing clean, simple visualizations of global feature impact.

Function: shap.plots.bar(shap_values, max_display=10, clustering=None, clustering_cutoff=0.5, show=True)

Key Parameters:

  • shap_values: Explanation object (can be single instance, global, or cohorts)
  • max_display: Maximum number of features/bars to show
  • clustering: Optional hierarchical clustering object from shap.utils.hclust
  • clustering_cutoff: Threshold for displaying clustering structure (0-1, default: 0.5)

Plot Types:

Global Bar Plot

Shows overall feature importance across all samples. Importance calculated as mean absolute SHAP value.

# Global feature importance
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.bar(shap_values)

Local Bar Plot

Displays SHAP values for a single instance with feature values shown in gray.

# Single prediction explanation
shap.plots.bar(shap_values[0])

Cohort Bar Plot

Compares feature importance across subgroups by passing a dictionary of Explanation objects.

# Compare cohorts
cohorts = {
    "Group A": shap_values[mask_A],
    "Group B": shap_values[mask_B]
}
shap.plots.bar(cohorts)

Feature Clustering: Identifies redundant features using model-based clustering (more accurate than correlation-based methods).

# Add feature clustering
clustering = shap.utils.hclust(X_train, y_train)
shap.plots.bar(shap_values, clustering=clustering)

# Adjust clustering display threshold
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.3)

When to Use:

  • Quick overview of global feature importance
  • Comparing feature importance across cohorts or models
  • Identifying redundant or correlated features
  • Clean, simple visualizations for presentations

Force Plots

Purpose: Additive force visualization showing how features push prediction higher (red) or lower (blue) from baseline.

Function: shap.plots.force(base_value, shap_values, features, feature_names=None, out_names=None, link="identity", matplotlib=False, show=True)

Key Parameters:

  • base_value: Expected value (baseline prediction)
  • shap_values: SHAP values for sample(s)
  • features: Feature values for sample(s)
  • feature_names: Optional feature names
  • link: Transform function ("identity" or "logit")
  • matplotlib: Use matplotlib backend (default: interactive JavaScript)

Visual Elements:

  • Baseline: Starting prediction (expected value)
  • Red arrows: Features pushing prediction higher
  • Blue arrows: Features pushing prediction lower
  • Final value: Resulting prediction

Interactive Features (JavaScript mode):

  • Hover for detailed feature information
  • Multiple samples create stacked visualization
  • Can rotate for different perspectives

When to Use:

  • Interactive exploration of predictions
  • Visualizing multiple predictions simultaneously
  • Presentations requiring interactive elements
  • Understanding prediction composition at a glance

Example:

# Single prediction force plot
shap.plots.force(
    shap_values.base_values[0],
    shap_values.values[0],
    X_test.iloc[0],
    matplotlib=True
)

# Multiple predictions (interactive)
shap.plots.force(
    shap_values.base_values,
    shap_values.values,
    X_test
)

Scatter Plots (Dependence Plots)

Purpose: Show relationship between feature values and their SHAP values, revealing how feature values impact predictions.

Function: shap.plots.scatter(shap_values, color=None, hist=True, alpha=1, show=True)

Key Parameters:

  • shap_values: Explanation object, can specify feature with subscript (e.g., shap_values[:, "Age"])
  • color: Feature to use for coloring points (string name or Explanation object)
  • hist: Show histogram of feature values on y-axis
  • alpha: Point transparency (useful for dense plots)

Visual Elements:

  • X-axis: Feature value
  • Y-axis: SHAP value (impact on prediction)
  • Point color: Another feature's value (for interaction detection)
  • Histogram: Distribution of feature values

When to Use:

  • Understanding feature-prediction relationships
  • Detecting nonlinear effects
  • Identifying feature interactions
  • Validating or discovering patterns in model behavior
  • Exploring counterintuitive predictions from waterfall plots

Interaction Detection: Color points by another feature to reveal interactions.

# Basic dependence plot
shap.plots.scatter(shap_values[:, "Age"])

# Color by another feature to show interactions
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education"])

# Multiple features in one plot
shap.plots.scatter(shap_values[:, ["Age", "Education", "Hours-per-week"]])

# Increase transparency for dense data
shap.plots.scatter(shap_values[:, "Age"], alpha=0.5)

Heatmap Plots

Purpose: Visualize SHAP values for multiple samples simultaneously, showing feature impacts across instances.

Function: shap.plots.heatmap(shap_values, instance_order=None, feature_values=None, max_display=10, show=True)

Key Parameters:

  • shap_values: Explanation object
  • instance_order: How to order instances (can be Explanation object for custom ordering)
  • feature_values: Display feature values on hover
  • max_display: Maximum features to display

Visual Elements:

  • Rows: Individual instances/samples
  • Columns: Features
  • Cell color: SHAP value (red = positive, blue = negative)
  • Intensity: Magnitude of impact

When to Use:

  • Comparing explanations across multiple instances
  • Identifying patterns in feature impacts
  • Understanding which features vary most across predictions
  • Detecting subgroups or clusters with similar explanation patterns

Example:

# Basic heatmap
shap.plots.heatmap(shap_values)

# Order instances by model output
shap.plots.heatmap(shap_values, instance_order=shap_values.sum(1))

# Show specific subset
shap.plots.heatmap(shap_values[:100])

Violin Plots

Purpose: Similar to beeswarm plots but uses violin (kernel density) visualization instead of individual dots.

Function: shap.plots.violin(shap_values, features=None, feature_names=None, max_display=10, show=True)

When to Use:

  • Alternative to beeswarm when dataset is very large
  • Emphasizing distribution density over individual points
  • Cleaner visualization for presentations

Example:

shap.plots.violin(shap_values)

Decision Plots

Purpose: Show prediction paths through cumulative SHAP values, particularly useful for multiclass classification.

Function: shap.plots.decision(base_value, shap_values, features, feature_names=None, feature_order="importance", highlight=None, link="identity", show=True)

Key Parameters:

  • base_value: Expected value
  • shap_values: SHAP values for samples
  • features: Feature values
  • feature_order: How to order features ("importance" or list)
  • highlight: Indices of samples to highlight
  • link: Transform function

When to Use:

  • Multiclass classification explanations
  • Understanding cumulative feature effects
  • Comparing prediction paths across samples
  • Identifying where predictions diverge

Example:

# Decision plot for multiple predictions
shap.plots.decision(
    shap_values.base_values,
    shap_values.values,
    X_test,
    feature_names=X_test.columns.tolist()
)

# Highlight specific instances
shap.plots.decision(
    shap_values.base_values,
    shap_values.values,
    X_test,
    highlight=[0, 5, 10]
)

Plot Selection Guide

For Individual Predictions:

  • Waterfall: Best for detailed, sequential explanation
  • Force: Good for interactive exploration
  • Bar (local): Simple, clean single-prediction importance

For Global Understanding:

  • Beeswarm: Information-dense summary with value distributions
  • Bar (global): Clean, simple importance ranking
  • Violin: Distribution-focused alternative to beeswarm

For Feature Relationships:

  • Scatter: Understand feature-prediction relationships and interactions
  • Heatmap: Compare patterns across multiple instances

For Multiple Samples:

  • Heatmap: Grid view of SHAP values
  • Force (stacked): Interactive multi-sample visualization
  • Decision: Prediction paths for multiclass problems

For Cohort Comparison:

  • Bar (cohort): Clean comparison of feature importance
  • Multiple beeswarms: Side-by-side distribution comparisons

Visualization Best Practices

1. Start Global, Then Go Local:

  • Begin with beeswarm or bar plot to understand global patterns
  • Dive into waterfall or scatter plots for specific instances or features

2. Use Multiple Plot Types:

  • Different plots reveal different insights
  • Combine waterfall (individual) + scatter (relationship) + beeswarm (global)

3. Adjust max_display:

  • Default (10) is good for presentations
  • Increase (20-30) for detailed analysis
  • Consider clustering for redundant features

4. Color Meaningfully:

  • Use default red-blue for SHAP values (red = positive, blue = negative)
  • Color scatter plots by interacting features
  • Custom colormaps for specific domains

5. Consider Audience:

  • Technical audience: Beeswarm, scatter, heatmap
  • Non-technical audience: Waterfall, bar, force plots
  • Interactive presentations: Force plots with JavaScript

6. Save High-Quality Figures:

import matplotlib.pyplot as plt

# Create plot
shap.plots.beeswarm(shap_values, show=False)

# Save with high DPI
plt.savefig('shap_plot.png', dpi=300, bbox_inches='tight')
plt.close()

7. Handle Large Datasets:

  • Sample subset for visualization (e.g., shap_values[:1000])
  • Use violin instead of beeswarm for very large datasets
  • Adjust alpha for scatter plots with many points

Common Patterns and Workflows

Pattern 1: Complete Model Explanation

# 1. Global importance
shap.plots.beeswarm(shap_values)

# 2. Top feature relationships
for feature in top_features:
    shap.plots.scatter(shap_values[:, feature])

# 3. Example predictions
for i in interesting_indices:
    shap.plots.waterfall(shap_values[i])

Pattern 2: Model Comparison

# Compute SHAP for multiple models
shap_model1 = explainer1(X_test)
shap_model2 = explainer2(X_test)

# Compare feature importance
shap.plots.bar({
    "Model 1": shap_model1,
    "Model 2": shap_model2
})

Pattern 3: Subgroup Analysis

# Define cohorts
male_mask = X_test['Sex'] == 'Male'
female_mask = X_test['Sex'] == 'Female'

# Compare cohorts
shap.plots.bar({
    "Male": shap_values[male_mask],
    "Female": shap_values[female_mask]
})

# Separate beeswarm plots
shap.plots.beeswarm(shap_values[male_mask])
shap.plots.beeswarm(shap_values[female_mask])

Pattern 4: Debugging Predictions

# Identify outliers or errors
errors = (model.predict(X_test) != y_test)
error_indices = np.where(errors)[0]

# Explain errors
for idx in error_indices[:5]:
    print(f"Sample {idx}:")
    shap.plots.waterfall(shap_values[idx])

    # Explore key features
    shap.plots.scatter(shap_values[:, "Key_Feature"])

Integration with Notebooks and Reports

Jupyter Notebooks:

  • Interactive force plots work seamlessly
  • Use show=True (default) for inline display
  • Combine with markdown explanations

Static Reports:

  • Use matplotlib backend for force plots
  • Save figures programmatically
  • Prefer waterfall and bar plots for clarity

Web Applications:

  • Export force plots as HTML
  • Use shap.save_html() for interactive visualizations
  • Consider generating plots on-demand

Troubleshooting Visualizations

Issue: Plots don't display

  • Solution: Ensure matplotlib backend is set correctly; use plt.show() if needed

Issue: Too many features cluttering plot

  • Solution: Reduce max_display parameter or use feature clustering

Issue: Colors reversed or confusing

  • Solution: Check model output type (probability vs. log-odds) and use appropriate link function

Issue: Slow plotting with large datasets

  • Solution: Sample subset of data; use shap_values[:1000] for visualization

Issue: Feature names missing

  • Solution: Ensure feature_names are in Explanation object or pass explicitly to plot functions