Files
2025-11-30 08:30:10 +08:00

606 lines
19 KiB
Markdown

# SHAP Workflows and Best Practices
This document provides comprehensive workflows, best practices, and common use cases for using SHAP in various model interpretation scenarios.
## Basic Workflow Structure
Every SHAP analysis follows a general workflow:
1. **Train Model**: Build and train the machine learning model
2. **Select Explainer**: Choose appropriate explainer based on model type
3. **Compute SHAP Values**: Generate explanations for test samples
4. **Visualize Results**: Use plots to understand feature impacts
5. **Interpret and Act**: Draw conclusions and make decisions
## Workflow 1: Basic Model Explanation
**Use Case**: Understanding feature importance and prediction behavior for a trained model
```python
import shap
import pandas as pd
from sklearn.model_selection import train_test_split
# Step 1: Load and split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Step 2: Train model (example with XGBoost)
import xgboost as xgb
model = xgb.XGBClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
# Step 3: Create explainer
explainer = shap.TreeExplainer(model)
# Step 4: Compute SHAP values
shap_values = explainer(X_test)
# Step 5: Visualize global importance
shap.plots.beeswarm(shap_values, max_display=15)
# Step 6: Examine top features in detail
shap.plots.scatter(shap_values[:, "Feature1"])
shap.plots.scatter(shap_values[:, "Feature2"], color=shap_values[:, "Feature1"])
# Step 7: Explain individual predictions
shap.plots.waterfall(shap_values[0])
```
**Key Decisions**:
- Explainer type based on model architecture
- Background dataset size (for DeepExplainer, KernelExplainer)
- Number of samples to explain (all test set vs. subset)
## Workflow 2: Model Debugging and Validation
**Use Case**: Identifying and fixing model issues, validating expected behavior
```python
# Step 1: Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Step 2: Identify prediction errors
predictions = model.predict(X_test)
errors = predictions != y_test
error_indices = np.where(errors)[0]
# Step 3: Analyze errors
print(f"Total errors: {len(error_indices)}")
print(f"Error rate: {len(error_indices) / len(y_test):.2%}")
# Step 4: Explain misclassified samples
for idx in error_indices[:10]: # First 10 errors
print(f"\n=== Error {idx} ===")
print(f"Prediction: {predictions[idx]}, Actual: {y_test.iloc[idx]}")
shap.plots.waterfall(shap_values[idx])
# Step 5: Check if model learned correct patterns
# Look for unexpected feature importance
shap.plots.beeswarm(shap_values)
# Step 6: Investigate specific feature relationships
# Verify nonlinear relationships make sense
for feature in model.feature_importances_.argsort()[-5:]: # Top 5 features
feature_name = X_test.columns[feature]
shap.plots.scatter(shap_values[:, feature_name])
# Step 7: Validate feature interactions
# Check if interactions align with domain knowledge
shap.plots.scatter(shap_values[:, "Feature1"], color=shap_values[:, "Feature2"])
```
**Common Issues to Check**:
- Data leakage (feature with suspiciously high importance)
- Spurious correlations (unexpected feature relationships)
- Target leakage (features that shouldn't be predictive)
- Biases (disproportionate impact on certain groups)
## Workflow 3: Feature Engineering Guidance
**Use Case**: Using SHAP insights to improve feature engineering
```python
# Step 1: Initial model with baseline features
model_v1 = train_model(X_train_v1, y_train)
explainer_v1 = shap.TreeExplainer(model_v1)
shap_values_v1 = explainer_v1(X_test_v1)
# Step 2: Identify feature engineering opportunities
shap.plots.beeswarm(shap_values_v1)
# Check for:
# - Nonlinear relationships (candidates for transformation)
shap.plots.scatter(shap_values_v1[:, "Age"]) # Maybe age^2 or age bins?
# - Feature interactions (candidates for interaction terms)
shap.plots.scatter(shap_values_v1[:, "Income"], color=shap_values_v1[:, "Education"])
# Maybe create Income * Education interaction?
# Step 3: Engineer new features based on insights
X_train_v2 = X_train_v1.copy()
X_train_v2['Age_squared'] = X_train_v2['Age'] ** 2
X_train_v2['Income_Education'] = X_train_v2['Income'] * X_train_v2['Education']
# Step 4: Retrain with engineered features
model_v2 = train_model(X_train_v2, y_train)
explainer_v2 = shap.TreeExplainer(model_v2)
shap_values_v2 = explainer_v2(X_test_v2)
# Step 5: Compare feature importance
shap.plots.bar({
"Baseline": shap_values_v1,
"With Engineered Features": shap_values_v2
})
# Step 6: Validate improvement
print(f"V1 Score: {model_v1.score(X_test_v1, y_test):.4f}")
print(f"V2 Score: {model_v2.score(X_test_v2, y_test):.4f}")
```
**Feature Engineering Insights from SHAP**:
- Strong nonlinear patterns → Try transformations (log, sqrt, polynomial)
- Color-coded interactions in scatter → Create interaction terms
- Redundant features in clustering → Remove or combine
- Unexpected importance → Investigate for data quality issues
## Workflow 4: Model Comparison and Selection
**Use Case**: Comparing multiple models to select the best interpretable model
```python
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
import xgboost as xgb
# Step 1: Train multiple models
models = {
'Logistic Regression': LogisticRegression(max_iter=1000).fit(X_train, y_train),
'Random Forest': RandomForestClassifier(n_estimators=100).fit(X_train, y_train),
'XGBoost': xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train)
}
# Step 2: Compute SHAP values for each model
shap_values_dict = {}
for name, model in models.items():
if name == 'Logistic Regression':
explainer = shap.LinearExplainer(model, X_train)
else:
explainer = shap.TreeExplainer(model)
shap_values_dict[name] = explainer(X_test)
# Step 3: Compare global feature importance
shap.plots.bar(shap_values_dict)
# Step 4: Compare model scores
for name, model in models.items():
score = model.score(X_test, y_test)
print(f"{name}: {score:.4f}")
# Step 5: Check consistency of feature importance
for feature in X_test.columns[:5]: # Top 5 features
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (name, shap_vals) in enumerate(shap_values_dict.items()):
plt.sca(axes[idx])
shap.plots.scatter(shap_vals[:, feature], show=False)
plt.title(f"{name} - {feature}")
plt.tight_layout()
plt.show()
# Step 6: Analyze specific predictions across models
sample_idx = 0
for name, shap_vals in shap_values_dict.items():
print(f"\n=== {name} ===")
shap.plots.waterfall(shap_vals[sample_idx])
# Step 7: Decision based on:
# - Accuracy/Performance
# - Interpretability (consistent feature importance)
# - Deployment constraints
# - Stakeholder requirements
```
**Model Selection Criteria**:
- **Accuracy vs. Interpretability**: Sometimes simpler models with SHAP are preferable
- **Feature Consistency**: Models agreeing on feature importance are more trustworthy
- **Explanation Quality**: Clear, actionable explanations
- **Computational Cost**: TreeExplainer is faster than KernelExplainer
## Workflow 5: Fairness and Bias Analysis
**Use Case**: Detecting and analyzing model bias across demographic groups
```python
# Step 1: Identify protected attributes
protected_attr = 'Gender' # or 'Race', 'Age_Group', etc.
# Step 2: Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Step 3: Compare feature importance across groups
groups = X_test[protected_attr].unique()
cohorts = {
f"{protected_attr}={group}": shap_values[X_test[protected_attr] == group]
for group in groups
}
shap.plots.bar(cohorts)
# Step 4: Check if protected attribute has high SHAP importance
# (should be low/zero for fair models)
protected_importance = np.abs(shap_values[:, protected_attr].values).mean()
print(f"{protected_attr} mean |SHAP|: {protected_importance:.4f}")
# Step 5: Analyze predictions for each group
for group in groups:
mask = X_test[protected_attr] == group
group_shap = shap_values[mask]
print(f"\n=== {protected_attr} = {group} ===")
print(f"Sample size: {mask.sum()}")
print(f"Positive prediction rate: {(model.predict(X_test[mask]) == 1).mean():.2%}")
# Visualize
shap.plots.beeswarm(group_shap, max_display=10)
# Step 6: Check for proxy features
# Features correlated with protected attribute that shouldn't have high importance
# Example: 'Zip_Code' might be proxy for race
proxy_features = ['Zip_Code', 'Last_Name_Prefix'] # Domain-specific
for feature in proxy_features:
if feature in X_test.columns:
importance = np.abs(shap_values[:, feature].values).mean()
print(f"Potential proxy '{feature}' importance: {importance:.4f}")
# Step 7: Mitigation strategies if bias found
# - Remove protected attribute and proxies
# - Add fairness constraints during training
# - Post-process predictions to equalize outcomes
# - Use different model architecture
```
**Fairness Metrics to Check**:
- **Demographic Parity**: Similar positive prediction rates across groups
- **Equal Opportunity**: Similar true positive rates across groups
- **Feature Importance Parity**: Similar feature rankings across groups
- **Protected Attribute Importance**: Should be minimal
## Workflow 6: Deep Learning Model Explanation
**Use Case**: Explaining neural network predictions with DeepExplainer
```python
import tensorflow as tf
import shap
# Step 1: Load or build neural network
model = tf.keras.models.load_model('my_model.h5')
# Step 2: Select background dataset
# Use subset (100-1000 samples) from training data
background = X_train[:100]
# Step 3: Create DeepExplainer
explainer = shap.DeepExplainer(model, background)
# Step 4: Compute SHAP values (may take time)
# Explain subset of test data
test_subset = X_test[:50]
shap_values = explainer.shap_values(test_subset)
# Step 5: Handle multi-output models
# For binary classification, shap_values is a list [class_0_values, class_1_values]
# For regression, it's a single array
if isinstance(shap_values, list):
# Focus on positive class
shap_values_positive = shap_values[1]
shap_exp = shap.Explanation(
values=shap_values_positive,
base_values=explainer.expected_value[1],
data=test_subset
)
else:
shap_exp = shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=test_subset
)
# Step 6: Visualize
shap.plots.beeswarm(shap_exp)
shap.plots.waterfall(shap_exp[0])
# Step 7: For image/text data, use specialized plots
# Image: shap.image_plot
# Text: shap.plots.text (for transformers)
```
**Deep Learning Considerations**:
- Background dataset size affects accuracy and speed
- Multi-output handling (classification vs. regression)
- Specialized plots for image/text data
- Computational cost (consider GPU acceleration)
## Workflow 7: Production Deployment
**Use Case**: Integrating SHAP explanations into production systems
```python
import joblib
import shap
# Step 1: Train and save model
model = train_model(X_train, y_train)
joblib.dump(model, 'model.pkl')
# Step 2: Create and save explainer
explainer = shap.TreeExplainer(model)
joblib.dump(explainer, 'explainer.pkl')
# Step 3: Create explanation service
class ExplanationService:
def __init__(self, model_path, explainer_path):
self.model = joblib.load(model_path)
self.explainer = joblib.load(explainer_path)
def predict_with_explanation(self, X):
"""
Returns prediction and explanation
"""
# Prediction
prediction = self.model.predict(X)
# SHAP values
shap_values = self.explainer(X)
# Format explanation
explanations = []
for i in range(len(X)):
exp = {
'prediction': prediction[i],
'base_value': shap_values.base_values[i],
'shap_values': dict(zip(X.columns, shap_values.values[i])),
'feature_values': X.iloc[i].to_dict()
}
explanations.append(exp)
return explanations
def get_top_features(self, X, n=5):
"""
Returns top N features for each prediction
"""
shap_values = self.explainer(X)
top_features = []
for i in range(len(X)):
# Get absolute SHAP values
abs_shap = np.abs(shap_values.values[i])
# Sort and get top N
top_indices = abs_shap.argsort()[-n:][::-1]
top_feature_names = X.columns[top_indices].tolist()
top_shap_values = shap_values.values[i][top_indices].tolist()
top_features.append({
'features': top_feature_names,
'shap_values': top_shap_values
})
return top_features
# Step 4: Usage in API
service = ExplanationService('model.pkl', 'explainer.pkl')
# Example API endpoint
def predict_endpoint(input_data):
X = pd.DataFrame([input_data])
explanations = service.predict_with_explanation(X)
return {
'prediction': explanations[0]['prediction'],
'explanation': explanations[0]
}
# Step 5: Generate static explanations for batch predictions
def batch_explain_and_save(X_batch, output_dir):
shap_values = explainer(X_batch)
# Save global plot
shap.plots.beeswarm(shap_values, show=False)
plt.savefig(f'{output_dir}/global_importance.png', dpi=300, bbox_inches='tight')
plt.close()
# Save individual explanations
for i in range(min(100, len(X_batch))): # First 100
shap.plots.waterfall(shap_values[i], show=False)
plt.savefig(f'{output_dir}/explanation_{i}.png', dpi=300, bbox_inches='tight')
plt.close()
```
**Production Best Practices**:
- Cache explainers to avoid recomputation
- Batch explanations when possible
- Limit explanation complexity (top N features)
- Monitor explanation latency
- Version explainers alongside models
- Consider pre-computing explanations for common inputs
## Workflow 8: Time Series Model Explanation
**Use Case**: Explaining time series forecasting models
```python
# Step 1: Prepare data with time-based features
# Example: Predicting next day's sales
df['DayOfWeek'] = df['Date'].dt.dayofweek
df['Month'] = df['Date'].dt.month
df['Lag_1'] = df['Sales'].shift(1)
df['Lag_7'] = df['Sales'].shift(7)
df['Rolling_Mean_7'] = df['Sales'].rolling(7).mean()
# Step 2: Train model
features = ['DayOfWeek', 'Month', 'Lag_1', 'Lag_7', 'Rolling_Mean_7']
X_train, X_test, y_train, y_test = train_test_split(df[features], df['Sales'])
model = xgb.XGBRegressor().fit(X_train, y_train)
# Step 3: Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Step 4: Analyze temporal patterns
# Which features drive predictions at different times?
shap.plots.beeswarm(shap_values)
# Step 5: Check lagged feature importance
# Lag features should have high importance for time series
lag_features = ['Lag_1', 'Lag_7', 'Rolling_Mean_7']
for feature in lag_features:
shap.plots.scatter(shap_values[:, feature])
# Step 6: Explain specific predictions
# E.g., why was Monday's forecast so different?
monday_mask = X_test['DayOfWeek'] == 0
shap.plots.waterfall(shap_values[monday_mask][0])
# Step 7: Validate seasonality understanding
shap.plots.scatter(shap_values[:, 'Month'])
```
**Time Series Considerations**:
- Lagged features and their importance
- Rolling statistics interpretation
- Seasonal patterns in SHAP values
- Avoiding data leakage in feature engineering
## Common Pitfalls and Solutions
### Pitfall 1: Wrong Explainer Choice
**Problem**: Using KernelExplainer for tree models (slow and unnecessary)
**Solution**: Always use TreeExplainer for tree-based models
### Pitfall 2: Insufficient Background Data
**Problem**: DeepExplainer/KernelExplainer with too few background samples
**Solution**: Use 100-1000 representative samples
### Pitfall 3: Misinterpreting Log-Odds
**Problem**: Confusion about units (probability vs. log-odds)
**Solution**: Check model output type; use link="logit" when needed
### Pitfall 4: Ignoring Feature Correlations
**Problem**: Interpreting features as independent when they're correlated
**Solution**: Use feature clustering; understand domain relationships
### Pitfall 5: Overfitting to Explanations
**Problem**: Feature engineering based solely on SHAP without validation
**Solution**: Always validate improvements with cross-validation
### Pitfall 6: Data Leakage Undetected
**Problem**: Not noticing unexpected feature importance indicating leakage
**Solution**: Validate SHAP results against domain knowledge
### Pitfall 7: Computational Constraints Ignored
**Problem**: Computing SHAP for entire large dataset
**Solution**: Use sampling, batching, or subset analysis
## Advanced Techniques
### Technique 1: SHAP Interaction Values
Capture pairwise feature interactions:
```python
explainer = shap.TreeExplainer(model)
shap_interaction_values = explainer.shap_interaction_values(X_test)
# Analyze specific interaction
feature1_idx = 0
feature2_idx = 3
interaction = shap_interaction_values[:, feature1_idx, feature2_idx]
print(f"Interaction strength: {np.abs(interaction).mean():.4f}")
```
### Technique 2: Partial Dependence with SHAP
Combine partial dependence plots with SHAP:
```python
from sklearn.inspection import partial_dependence
# SHAP dependence
shap.plots.scatter(shap_values[:, "Feature1"])
# Partial dependence (model-agnostic)
pd_result = partial_dependence(model, X_test, features=["Feature1"])
plt.plot(pd_result['grid_values'][0], pd_result['average'][0])
```
### Technique 3: Conditional Expectations
Analyze SHAP values conditioned on other features:
```python
# High Income group
high_income = X_test['Income'] > X_test['Income'].median()
shap.plots.beeswarm(shap_values[high_income])
# Low Income group
low_income = X_test['Income'] <= X_test['Income'].median()
shap.plots.beeswarm(shap_values[low_income])
```
### Technique 4: Feature Clustering for Redundancy
```python
# Create hierarchical clustering
clustering = shap.utils.hclust(X_train, y_train)
# Visualize with clustering
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.5)
# Identify redundant features to remove
# Features with distance < 0.1 are highly redundant
```
## Integration with MLOps
**Experiment Tracking**:
```python
import mlflow
# Log SHAP values
with mlflow.start_run():
# Train model
model = train_model(X_train, y_train)
# Compute SHAP
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Log plots
shap.plots.beeswarm(shap_values, show=False)
mlflow.log_figure(plt.gcf(), "shap_beeswarm.png")
plt.close()
# Log feature importance as metrics
mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
for feature, importance in zip(X_test.columns, mean_abs_shap):
mlflow.log_metric(f"shap_{feature}", importance)
```
**Model Monitoring**:
```python
# Track SHAP distribution drift over time
def compute_shap_summary(shap_values):
return {
'mean': shap_values.values.mean(axis=0),
'std': shap_values.values.std(axis=0),
'percentiles': np.percentile(shap_values.values, [25, 50, 75], axis=0)
}
# Compute baseline
baseline_summary = compute_shap_summary(shap_values_train)
# Monitor production data
production_summary = compute_shap_summary(shap_values_production)
# Detect drift
drift_detected = np.abs(
production_summary['mean'] - baseline_summary['mean']
) > threshold
```
This comprehensive workflows document covers the most common and advanced use cases for SHAP in practice.