Files
2025-11-30 08:30:10 +08:00

823 lines
17 KiB
Markdown

# Seaborn Common Use Cases and Examples
This document provides practical examples for common data visualization scenarios using seaborn.
## Exploratory Data Analysis
### Quick Dataset Overview
```python
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
# Load data
df = pd.read_csv('data.csv')
# Pairwise relationships for all numeric variables
sns.pairplot(df, hue='target_variable', corner=True, diag_kind='kde')
plt.suptitle('Dataset Overview', y=1.01)
plt.savefig('overview.png', dpi=300, bbox_inches='tight')
```
### Distribution Exploration
```python
# Multiple distributions across categories
g = sns.displot(
data=df,
x='measurement',
hue='condition',
col='timepoint',
kind='kde',
fill=True,
height=3,
aspect=1.5,
col_wrap=3,
common_norm=False
)
g.set_axis_labels('Measurement Value', 'Density')
g.set_titles('{col_name}')
```
### Correlation Analysis
```python
# Compute correlation matrix
corr = df.select_dtypes(include='number').corr()
# Create mask for upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
# Plot heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
corr,
mask=mask,
annot=True,
fmt='.2f',
cmap='coolwarm',
center=0,
square=True,
linewidths=1,
cbar_kws={'shrink': 0.8}
)
plt.title('Correlation Matrix')
plt.tight_layout()
```
## Scientific Publications
### Multi-Panel Figure with Different Plot Types
```python
# Set publication style
sns.set_theme(style='ticks', context='paper', font_scale=1.1)
sns.set_palette('colorblind')
# Create figure with custom layout
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
# Panel A: Time series
ax1 = fig.add_subplot(gs[0, :2])
sns.lineplot(
data=timeseries_df,
x='time',
y='expression',
hue='gene',
style='treatment',
markers=True,
dashes=False,
ax=ax1
)
ax1.set_title('A. Gene Expression Over Time', loc='left', fontweight='bold')
ax1.set_xlabel('Time (hours)')
ax1.set_ylabel('Expression Level (AU)')
# Panel B: Distribution comparison
ax2 = fig.add_subplot(gs[0, 2])
sns.violinplot(
data=expression_df,
x='treatment',
y='expression',
inner='box',
ax=ax2
)
ax2.set_title('B. Expression Distribution', loc='left', fontweight='bold')
ax2.set_xlabel('Treatment')
ax2.set_ylabel('')
# Panel C: Correlation
ax3 = fig.add_subplot(gs[1, 0])
sns.scatterplot(
data=correlation_df,
x='gene1',
y='gene2',
hue='cell_type',
alpha=0.6,
ax=ax3
)
sns.regplot(
data=correlation_df,
x='gene1',
y='gene2',
scatter=False,
color='black',
ax=ax3
)
ax3.set_title('C. Gene Correlation', loc='left', fontweight='bold')
ax3.set_xlabel('Gene 1 Expression')
ax3.set_ylabel('Gene 2 Expression')
# Panel D: Heatmap
ax4 = fig.add_subplot(gs[1, 1:])
sns.heatmap(
sample_matrix,
cmap='RdBu_r',
center=0,
annot=True,
fmt='.1f',
cbar_kws={'label': 'Log2 Fold Change'},
ax=ax4
)
ax4.set_title('D. Treatment Effects', loc='left', fontweight='bold')
ax4.set_xlabel('Sample')
ax4.set_ylabel('Gene')
# Clean up
sns.despine()
plt.savefig('figure.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figure.png', dpi=300, bbox_inches='tight')
```
### Box Plot with Significance Annotations
```python
import numpy as np
from scipy import stats
# Create plot
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(
data=df,
x='treatment',
y='response',
order=['Control', 'Low', 'Medium', 'High'],
palette='Set2',
ax=ax
)
# Add individual points
sns.stripplot(
data=df,
x='treatment',
y='response',
order=['Control', 'Low', 'Medium', 'High'],
color='black',
alpha=0.3,
size=3,
ax=ax
)
# Add significance bars
def add_significance_bar(ax, x1, x2, y, h, text):
ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], 'k-', lw=1.5)
ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom')
y_max = df['response'].max()
add_significance_bar(ax, 0, 3, y_max + 1, 0.5, '***')
add_significance_bar(ax, 0, 1, y_max + 3, 0.5, 'ns')
ax.set_ylabel('Response (μM)')
ax.set_xlabel('Treatment Condition')
ax.set_title('Treatment Response Analysis')
sns.despine()
```
## Time Series Analysis
### Multiple Time Series with Confidence Bands
```python
# Plot with automatic aggregation
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
data=timeseries_df,
x='timestamp',
y='value',
hue='sensor',
style='location',
markers=True,
dashes=False,
errorbar=('ci', 95),
ax=ax
)
# Customize
ax.set_xlabel('Date')
ax.set_ylabel('Measurement (units)')
ax.set_title('Sensor Measurements Over Time')
ax.legend(title='Sensor & Location', bbox_to_anchor=(1.05, 1), loc='upper left')
# Format x-axis for dates
import matplotlib.dates as mdates
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
```
### Faceted Time Series
```python
# Create faceted time series
g = sns.relplot(
data=long_timeseries,
x='date',
y='measurement',
hue='device',
col='location',
row='metric',
kind='line',
height=3,
aspect=2,
errorbar='sd',
facet_kws={'sharex': True, 'sharey': False}
)
# Customize facet titles
g.set_titles('{row_name} - {col_name}')
g.set_axis_labels('Date', 'Value')
# Rotate x-axis labels
for ax in g.axes.flat:
ax.tick_params(axis='x', rotation=45)
g.tight_layout()
```
## Categorical Comparisons
### Nested Categorical Variables
```python
# Create figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Left panel: Grouped bar plot
sns.barplot(
data=df,
x='category',
y='value',
hue='subcategory',
errorbar=('ci', 95),
capsize=0.1,
ax=axes[0]
)
axes[0].set_title('Mean Values with 95% CI')
axes[0].set_ylabel('Value (units)')
axes[0].legend(title='Subcategory')
# Right panel: Strip + violin plot
sns.violinplot(
data=df,
x='category',
y='value',
hue='subcategory',
inner=None,
alpha=0.3,
ax=axes[1]
)
sns.stripplot(
data=df,
x='category',
y='value',
hue='subcategory',
dodge=True,
size=3,
alpha=0.6,
ax=axes[1]
)
axes[1].set_title('Distribution of Individual Values')
axes[1].set_ylabel('')
axes[1].get_legend().remove()
plt.tight_layout()
```
### Point Plot for Trends
```python
# Show how values change across categories
sns.pointplot(
data=df,
x='timepoint',
y='score',
hue='treatment',
markers=['o', 's', '^'],
linestyles=['-', '--', '-.'],
dodge=0.3,
capsize=0.1,
errorbar=('ci', 95)
)
plt.xlabel('Timepoint')
plt.ylabel('Performance Score')
plt.title('Treatment Effects Over Time')
plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()
```
## Regression and Relationships
### Linear Regression with Facets
```python
# Fit separate regressions for each category
g = sns.lmplot(
data=df,
x='predictor',
y='response',
hue='treatment',
col='cell_line',
height=4,
aspect=1.2,
scatter_kws={'alpha': 0.5, 's': 50},
ci=95,
palette='Set2'
)
g.set_axis_labels('Predictor Variable', 'Response Variable')
g.set_titles('{col_name}')
g.tight_layout()
```
### Polynomial Regression
```python
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for idx, order in enumerate([1, 2, 3]):
sns.regplot(
data=df,
x='x',
y='y',
order=order,
scatter_kws={'alpha': 0.5},
line_kws={'color': 'red'},
ci=95,
ax=axes[idx]
)
axes[idx].set_title(f'Order {order} Polynomial Fit')
axes[idx].set_xlabel('X Variable')
axes[idx].set_ylabel('Y Variable')
plt.tight_layout()
```
### Residual Analysis
```python
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Main regression
sns.regplot(data=df, x='x', y='y', ax=axes[0, 0])
axes[0, 0].set_title('Regression Fit')
# Residuals vs fitted
sns.residplot(data=df, x='x', y='y', lowess=True,
scatter_kws={'alpha': 0.5},
line_kws={'color': 'red', 'lw': 2},
ax=axes[0, 1])
axes[0, 1].set_title('Residuals vs Fitted')
axes[0, 1].axhline(0, ls='--', color='gray')
# Q-Q plot (using scipy)
from scipy import stats as sp_stats
residuals = df['y'] - np.poly1d(np.polyfit(df['x'], df['y'], 1))(df['x'])
sp_stats.probplot(residuals, dist="norm", plot=axes[1, 0])
axes[1, 0].set_title('Q-Q Plot')
# Histogram of residuals
sns.histplot(residuals, kde=True, ax=axes[1, 1])
axes[1, 1].set_title('Residual Distribution')
axes[1, 1].set_xlabel('Residuals')
plt.tight_layout()
```
## Bivariate and Joint Distributions
### Joint Plot with Multiple Representations
```python
# Scatter with marginals
g = sns.jointplot(
data=df,
x='var1',
y='var2',
hue='category',
kind='scatter',
height=8,
ratio=4,
space=0.1,
joint_kws={'alpha': 0.5, 's': 50},
marginal_kws={'kde': True, 'bins': 30}
)
# Add reference lines
g.ax_joint.axline((0, 0), slope=1, color='r', ls='--', alpha=0.5, label='y=x')
g.ax_joint.legend()
g.set_axis_labels('Variable 1', 'Variable 2', fontsize=12)
```
### KDE Contour Plot
```python
fig, ax = plt.subplots(figsize=(8, 8))
# Bivariate KDE with filled contours
sns.kdeplot(
data=df,
x='x',
y='y',
fill=True,
levels=10,
cmap='viridis',
thresh=0.05,
ax=ax
)
# Overlay scatter
sns.scatterplot(
data=df,
x='x',
y='y',
color='white',
edgecolor='black',
s=50,
alpha=0.6,
ax=ax
)
ax.set_xlabel('X Variable')
ax.set_ylabel('Y Variable')
ax.set_title('Bivariate Distribution')
```
### Hexbin with Marginals
```python
# For large datasets
g = sns.jointplot(
data=large_df,
x='x',
y='y',
kind='hex',
height=8,
ratio=5,
space=0.1,
joint_kws={'gridsize': 30, 'cmap': 'viridis'},
marginal_kws={'bins': 50, 'color': 'skyblue'}
)
g.set_axis_labels('X Variable', 'Y Variable')
```
## Matrix and Heatmap Visualizations
### Hierarchical Clustering Heatmap
```python
# Prepare data (samples x features)
data_matrix = df.set_index('sample_id')[feature_columns]
# Create color annotations
row_colors = df.set_index('sample_id')['condition'].map({
'control': '#1f77b4',
'treatment': '#ff7f0e'
})
col_colors = pd.Series(['#2ca02c' if 'gene' in col else '#d62728'
for col in data_matrix.columns])
# Plot
g = sns.clustermap(
data_matrix,
method='ward',
metric='euclidean',
z_score=0, # Normalize rows
cmap='RdBu_r',
center=0,
row_colors=row_colors,
col_colors=col_colors,
figsize=(12, 10),
dendrogram_ratio=(0.1, 0.1),
cbar_pos=(0.02, 0.8, 0.03, 0.15),
linewidths=0.5
)
g.ax_heatmap.set_xlabel('Features')
g.ax_heatmap.set_ylabel('Samples')
plt.savefig('clustermap.png', dpi=300, bbox_inches='tight')
```
### Annotated Heatmap with Custom Colorbar
```python
# Pivot data for heatmap
pivot_data = df.pivot(index='row_var', columns='col_var', values='value')
# Create heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
pivot_data,
annot=True,
fmt='.1f',
cmap='RdYlGn',
center=pivot_data.mean().mean(),
vmin=pivot_data.min().min(),
vmax=pivot_data.max().max(),
linewidths=0.5,
linecolor='gray',
cbar_kws={
'label': 'Value (units)',
'orientation': 'vertical',
'shrink': 0.8,
'aspect': 20
},
ax=ax
)
ax.set_title('Variable Relationships', fontsize=14, pad=20)
ax.set_xlabel('Column Variable', fontsize=12)
ax.set_ylabel('Row Variable', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
```
## Statistical Comparisons
### Before/After Comparison
```python
# Reshape data for paired comparison
df_paired = df.melt(
id_vars='subject',
value_vars=['before', 'after'],
var_name='timepoint',
value_name='measurement'
)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Left: Individual trajectories
for subject in df_paired['subject'].unique():
subject_data = df_paired[df_paired['subject'] == subject]
axes[0].plot(subject_data['timepoint'], subject_data['measurement'],
'o-', alpha=0.3, color='gray')
sns.pointplot(
data=df_paired,
x='timepoint',
y='measurement',
color='red',
markers='D',
scale=1.5,
errorbar=('ci', 95),
capsize=0.2,
ax=axes[0]
)
axes[0].set_title('Individual Changes')
axes[0].set_ylabel('Measurement')
# Right: Distribution comparison
sns.violinplot(
data=df_paired,
x='timepoint',
y='measurement',
inner='box',
ax=axes[1]
)
sns.swarmplot(
data=df_paired,
x='timepoint',
y='measurement',
color='black',
alpha=0.5,
size=3,
ax=axes[1]
)
axes[1].set_title('Distribution Comparison')
axes[1].set_ylabel('')
plt.tight_layout()
```
### Dose-Response Curve
```python
# Create dose-response plot
fig, ax = plt.subplots(figsize=(8, 6))
# Plot individual points
sns.stripplot(
data=dose_df,
x='dose',
y='response',
order=sorted(dose_df['dose'].unique()),
color='gray',
alpha=0.3,
jitter=0.2,
ax=ax
)
# Overlay mean with CI
sns.pointplot(
data=dose_df,
x='dose',
y='response',
order=sorted(dose_df['dose'].unique()),
color='blue',
markers='o',
scale=1.2,
errorbar=('ci', 95),
capsize=0.1,
ax=ax
)
# Fit sigmoid curve
from scipy.optimize import curve_fit
def sigmoid(x, bottom, top, ec50, hill):
return bottom + (top - bottom) / (1 + (ec50 / x) ** hill)
doses_numeric = dose_df['dose'].astype(float)
params, _ = curve_fit(sigmoid, doses_numeric, dose_df['response'])
x_smooth = np.logspace(np.log10(doses_numeric.min()),
np.log10(doses_numeric.max()), 100)
y_smooth = sigmoid(x_smooth, *params)
ax.plot(range(len(sorted(dose_df['dose'].unique()))),
sigmoid(sorted(doses_numeric.unique()), *params),
'r-', linewidth=2, label='Sigmoid Fit')
ax.set_xlabel('Dose')
ax.set_ylabel('Response')
ax.set_title('Dose-Response Analysis')
ax.legend()
sns.despine()
```
## Custom Styling
### Custom Color Palette from Hex Codes
```python
# Define custom palette
custom_palette = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']
sns.set_palette(custom_palette)
# Or use for specific plot
sns.scatterplot(
data=df,
x='x',
y='y',
hue='category',
palette=custom_palette
)
```
### Publication-Ready Theme
```python
# Set comprehensive theme
sns.set_theme(
context='paper',
style='ticks',
palette='colorblind',
font='Arial',
font_scale=1.1,
rc={
'figure.dpi': 300,
'savefig.dpi': 300,
'savefig.format': 'pdf',
'axes.linewidth': 1.0,
'axes.labelweight': 'bold',
'xtick.major.width': 1.0,
'ytick.major.width': 1.0,
'xtick.direction': 'out',
'ytick.direction': 'out',
'legend.frameon': False,
'pdf.fonttype': 42, # True Type fonts for PDFs
}
)
```
### Diverging Colormap Centered on Zero
```python
# For data with meaningful zero point (e.g., log fold change)
from matplotlib.colors import TwoSlopeNorm
# Find data range
vmin, vmax = df['value'].min(), df['value'].max()
vcenter = 0
# Create norm
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
# Plot
sns.heatmap(
pivot_data,
cmap='RdBu_r',
norm=norm,
center=0,
annot=True,
fmt='.2f'
)
```
## Large Datasets
### Downsampling Strategy
```python
# For very large datasets, sample intelligently
def smart_sample(df, target_size=10000, category_col=None):
if len(df) <= target_size:
return df
if category_col:
# Stratified sampling
return df.groupby(category_col, group_keys=False).apply(
lambda x: x.sample(min(len(x), target_size // df[category_col].nunique()))
)
else:
# Simple random sampling
return df.sample(target_size)
# Use sampled data for visualization
df_sampled = smart_sample(large_df, target_size=5000, category_col='category')
sns.scatterplot(data=df_sampled, x='x', y='y', hue='category', alpha=0.5)
```
### Hexbin for Dense Scatter Plots
```python
# For millions of points
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Regular scatter (slow)
axes[0].scatter(df['x'], df['y'], alpha=0.1, s=1)
axes[0].set_title('Scatter (all points)')
# Hexbin (fast)
hb = axes[1].hexbin(df['x'], df['y'], gridsize=50, cmap='viridis', mincnt=1)
axes[1].set_title('Hexbin Aggregation')
plt.colorbar(hb, ax=axes[1], label='Count')
plt.tight_layout()
```
## Interactive Elements for Notebooks
### Adjustable Parameters
```python
from ipywidgets import interact, FloatSlider
@interact(bandwidth=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0))
def plot_kde(bandwidth):
plt.figure(figsize=(10, 6))
sns.kdeplot(data=df, x='value', hue='category',
bw_adjust=bandwidth, fill=True)
plt.title(f'KDE with bandwidth adjustment = {bandwidth}')
plt.show()
```
### Dynamic Filtering
```python
from ipywidgets import interact, SelectMultiple
categories = df['category'].unique().tolist()
@interact(selected=SelectMultiple(options=categories, value=[categories[0]]))
def filtered_plot(selected):
filtered_df = df[df['category'].isin(selected)]
fig, ax = plt.subplots(figsize=(10, 6))
sns.violinplot(data=filtered_df, x='category', y='value', ax=ax)
ax.set_title(f'Showing {len(selected)} categories')
plt.show()
```