Files
2025-11-30 08:30:10 +08:00

17 KiB

Seaborn Common Use Cases and Examples

This document provides practical examples for common data visualization scenarios using seaborn.

Exploratory Data Analysis

Quick Dataset Overview

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load data
df = pd.read_csv('data.csv')

# Pairwise relationships for all numeric variables
sns.pairplot(df, hue='target_variable', corner=True, diag_kind='kde')
plt.suptitle('Dataset Overview', y=1.01)
plt.savefig('overview.png', dpi=300, bbox_inches='tight')

Distribution Exploration

# Multiple distributions across categories
g = sns.displot(
    data=df,
    x='measurement',
    hue='condition',
    col='timepoint',
    kind='kde',
    fill=True,
    height=3,
    aspect=1.5,
    col_wrap=3,
    common_norm=False
)
g.set_axis_labels('Measurement Value', 'Density')
g.set_titles('{col_name}')

Correlation Analysis

# Compute correlation matrix
corr = df.select_dtypes(include='number').corr()

# Create mask for upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))

# Plot heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    corr,
    mask=mask,
    annot=True,
    fmt='.2f',
    cmap='coolwarm',
    center=0,
    square=True,
    linewidths=1,
    cbar_kws={'shrink': 0.8}
)
plt.title('Correlation Matrix')
plt.tight_layout()

Scientific Publications

Multi-Panel Figure with Different Plot Types

# Set publication style
sns.set_theme(style='ticks', context='paper', font_scale=1.1)
sns.set_palette('colorblind')

# Create figure with custom layout
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

# Panel A: Time series
ax1 = fig.add_subplot(gs[0, :2])
sns.lineplot(
    data=timeseries_df,
    x='time',
    y='expression',
    hue='gene',
    style='treatment',
    markers=True,
    dashes=False,
    ax=ax1
)
ax1.set_title('A. Gene Expression Over Time', loc='left', fontweight='bold')
ax1.set_xlabel('Time (hours)')
ax1.set_ylabel('Expression Level (AU)')

# Panel B: Distribution comparison
ax2 = fig.add_subplot(gs[0, 2])
sns.violinplot(
    data=expression_df,
    x='treatment',
    y='expression',
    inner='box',
    ax=ax2
)
ax2.set_title('B. Expression Distribution', loc='left', fontweight='bold')
ax2.set_xlabel('Treatment')
ax2.set_ylabel('')

# Panel C: Correlation
ax3 = fig.add_subplot(gs[1, 0])
sns.scatterplot(
    data=correlation_df,
    x='gene1',
    y='gene2',
    hue='cell_type',
    alpha=0.6,
    ax=ax3
)
sns.regplot(
    data=correlation_df,
    x='gene1',
    y='gene2',
    scatter=False,
    color='black',
    ax=ax3
)
ax3.set_title('C. Gene Correlation', loc='left', fontweight='bold')
ax3.set_xlabel('Gene 1 Expression')
ax3.set_ylabel('Gene 2 Expression')

# Panel D: Heatmap
ax4 = fig.add_subplot(gs[1, 1:])
sns.heatmap(
    sample_matrix,
    cmap='RdBu_r',
    center=0,
    annot=True,
    fmt='.1f',
    cbar_kws={'label': 'Log2 Fold Change'},
    ax=ax4
)
ax4.set_title('D. Treatment Effects', loc='left', fontweight='bold')
ax4.set_xlabel('Sample')
ax4.set_ylabel('Gene')

# Clean up
sns.despine()
plt.savefig('figure.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figure.png', dpi=300, bbox_inches='tight')

Box Plot with Significance Annotations

import numpy as np
from scipy import stats

# Create plot
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(
    data=df,
    x='treatment',
    y='response',
    order=['Control', 'Low', 'Medium', 'High'],
    palette='Set2',
    ax=ax
)

# Add individual points
sns.stripplot(
    data=df,
    x='treatment',
    y='response',
    order=['Control', 'Low', 'Medium', 'High'],
    color='black',
    alpha=0.3,
    size=3,
    ax=ax
)

# Add significance bars
def add_significance_bar(ax, x1, x2, y, h, text):
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], 'k-', lw=1.5)
    ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom')

y_max = df['response'].max()
add_significance_bar(ax, 0, 3, y_max + 1, 0.5, '***')
add_significance_bar(ax, 0, 1, y_max + 3, 0.5, 'ns')

ax.set_ylabel('Response (μM)')
ax.set_xlabel('Treatment Condition')
ax.set_title('Treatment Response Analysis')
sns.despine()

Time Series Analysis

Multiple Time Series with Confidence Bands

# Plot with automatic aggregation
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    data=timeseries_df,
    x='timestamp',
    y='value',
    hue='sensor',
    style='location',
    markers=True,
    dashes=False,
    errorbar=('ci', 95),
    ax=ax
)

# Customize
ax.set_xlabel('Date')
ax.set_ylabel('Measurement (units)')
ax.set_title('Sensor Measurements Over Time')
ax.legend(title='Sensor & Location', bbox_to_anchor=(1.05, 1), loc='upper left')

# Format x-axis for dates
import matplotlib.dates as mdates
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
plt.xticks(rotation=45, ha='right')

plt.tight_layout()

Faceted Time Series

# Create faceted time series
g = sns.relplot(
    data=long_timeseries,
    x='date',
    y='measurement',
    hue='device',
    col='location',
    row='metric',
    kind='line',
    height=3,
    aspect=2,
    errorbar='sd',
    facet_kws={'sharex': True, 'sharey': False}
)

# Customize facet titles
g.set_titles('{row_name} - {col_name}')
g.set_axis_labels('Date', 'Value')

# Rotate x-axis labels
for ax in g.axes.flat:
    ax.tick_params(axis='x', rotation=45)

g.tight_layout()

Categorical Comparisons

Nested Categorical Variables

# Create figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left panel: Grouped bar plot
sns.barplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    errorbar=('ci', 95),
    capsize=0.1,
    ax=axes[0]
)
axes[0].set_title('Mean Values with 95% CI')
axes[0].set_ylabel('Value (units)')
axes[0].legend(title='Subcategory')

# Right panel: Strip + violin plot
sns.violinplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    inner=None,
    alpha=0.3,
    ax=axes[1]
)
sns.stripplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    dodge=True,
    size=3,
    alpha=0.6,
    ax=axes[1]
)
axes[1].set_title('Distribution of Individual Values')
axes[1].set_ylabel('')
axes[1].get_legend().remove()

plt.tight_layout()
# Show how values change across categories
sns.pointplot(
    data=df,
    x='timepoint',
    y='score',
    hue='treatment',
    markers=['o', 's', '^'],
    linestyles=['-', '--', '-.'],
    dodge=0.3,
    capsize=0.1,
    errorbar=('ci', 95)
)

plt.xlabel('Timepoint')
plt.ylabel('Performance Score')
plt.title('Treatment Effects Over Time')
plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()

Regression and Relationships

Linear Regression with Facets

# Fit separate regressions for each category
g = sns.lmplot(
    data=df,
    x='predictor',
    y='response',
    hue='treatment',
    col='cell_line',
    height=4,
    aspect=1.2,
    scatter_kws={'alpha': 0.5, 's': 50},
    ci=95,
    palette='Set2'
)

g.set_axis_labels('Predictor Variable', 'Response Variable')
g.set_titles('{col_name}')
g.tight_layout()

Polynomial Regression

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, order in enumerate([1, 2, 3]):
    sns.regplot(
        data=df,
        x='x',
        y='y',
        order=order,
        scatter_kws={'alpha': 0.5},
        line_kws={'color': 'red'},
        ci=95,
        ax=axes[idx]
    )
    axes[idx].set_title(f'Order {order} Polynomial Fit')
    axes[idx].set_xlabel('X Variable')
    axes[idx].set_ylabel('Y Variable')

plt.tight_layout()

Residual Analysis

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Main regression
sns.regplot(data=df, x='x', y='y', ax=axes[0, 0])
axes[0, 0].set_title('Regression Fit')

# Residuals vs fitted
sns.residplot(data=df, x='x', y='y', lowess=True,
              scatter_kws={'alpha': 0.5},
              line_kws={'color': 'red', 'lw': 2},
              ax=axes[0, 1])
axes[0, 1].set_title('Residuals vs Fitted')
axes[0, 1].axhline(0, ls='--', color='gray')

# Q-Q plot (using scipy)
from scipy import stats as sp_stats
residuals = df['y'] - np.poly1d(np.polyfit(df['x'], df['y'], 1))(df['x'])
sp_stats.probplot(residuals, dist="norm", plot=axes[1, 0])
axes[1, 0].set_title('Q-Q Plot')

# Histogram of residuals
sns.histplot(residuals, kde=True, ax=axes[1, 1])
axes[1, 1].set_title('Residual Distribution')
axes[1, 1].set_xlabel('Residuals')

plt.tight_layout()

Bivariate and Joint Distributions

Joint Plot with Multiple Representations

# Scatter with marginals
g = sns.jointplot(
    data=df,
    x='var1',
    y='var2',
    hue='category',
    kind='scatter',
    height=8,
    ratio=4,
    space=0.1,
    joint_kws={'alpha': 0.5, 's': 50},
    marginal_kws={'kde': True, 'bins': 30}
)

# Add reference lines
g.ax_joint.axline((0, 0), slope=1, color='r', ls='--', alpha=0.5, label='y=x')
g.ax_joint.legend()

g.set_axis_labels('Variable 1', 'Variable 2', fontsize=12)

KDE Contour Plot

fig, ax = plt.subplots(figsize=(8, 8))

# Bivariate KDE with filled contours
sns.kdeplot(
    data=df,
    x='x',
    y='y',
    fill=True,
    levels=10,
    cmap='viridis',
    thresh=0.05,
    ax=ax
)

# Overlay scatter
sns.scatterplot(
    data=df,
    x='x',
    y='y',
    color='white',
    edgecolor='black',
    s=50,
    alpha=0.6,
    ax=ax
)

ax.set_xlabel('X Variable')
ax.set_ylabel('Y Variable')
ax.set_title('Bivariate Distribution')

Hexbin with Marginals

# For large datasets
g = sns.jointplot(
    data=large_df,
    x='x',
    y='y',
    kind='hex',
    height=8,
    ratio=5,
    space=0.1,
    joint_kws={'gridsize': 30, 'cmap': 'viridis'},
    marginal_kws={'bins': 50, 'color': 'skyblue'}
)

g.set_axis_labels('X Variable', 'Y Variable')

Matrix and Heatmap Visualizations

Hierarchical Clustering Heatmap

# Prepare data (samples x features)
data_matrix = df.set_index('sample_id')[feature_columns]

# Create color annotations
row_colors = df.set_index('sample_id')['condition'].map({
    'control': '#1f77b4',
    'treatment': '#ff7f0e'
})

col_colors = pd.Series(['#2ca02c' if 'gene' in col else '#d62728'
                        for col in data_matrix.columns])

# Plot
g = sns.clustermap(
    data_matrix,
    method='ward',
    metric='euclidean',
    z_score=0,  # Normalize rows
    cmap='RdBu_r',
    center=0,
    row_colors=row_colors,
    col_colors=col_colors,
    figsize=(12, 10),
    dendrogram_ratio=(0.1, 0.1),
    cbar_pos=(0.02, 0.8, 0.03, 0.15),
    linewidths=0.5
)

g.ax_heatmap.set_xlabel('Features')
g.ax_heatmap.set_ylabel('Samples')
plt.savefig('clustermap.png', dpi=300, bbox_inches='tight')

Annotated Heatmap with Custom Colorbar

# Pivot data for heatmap
pivot_data = df.pivot(index='row_var', columns='col_var', values='value')

# Create heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    pivot_data,
    annot=True,
    fmt='.1f',
    cmap='RdYlGn',
    center=pivot_data.mean().mean(),
    vmin=pivot_data.min().min(),
    vmax=pivot_data.max().max(),
    linewidths=0.5,
    linecolor='gray',
    cbar_kws={
        'label': 'Value (units)',
        'orientation': 'vertical',
        'shrink': 0.8,
        'aspect': 20
    },
    ax=ax
)

ax.set_title('Variable Relationships', fontsize=14, pad=20)
ax.set_xlabel('Column Variable', fontsize=12)
ax.set_ylabel('Row Variable', fontsize=12)

plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

Statistical Comparisons

Before/After Comparison

# Reshape data for paired comparison
df_paired = df.melt(
    id_vars='subject',
    value_vars=['before', 'after'],
    var_name='timepoint',
    value_name='measurement'
)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Left: Individual trajectories
for subject in df_paired['subject'].unique():
    subject_data = df_paired[df_paired['subject'] == subject]
    axes[0].plot(subject_data['timepoint'], subject_data['measurement'],
                 'o-', alpha=0.3, color='gray')

sns.pointplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    color='red',
    markers='D',
    scale=1.5,
    errorbar=('ci', 95),
    capsize=0.2,
    ax=axes[0]
)
axes[0].set_title('Individual Changes')
axes[0].set_ylabel('Measurement')

# Right: Distribution comparison
sns.violinplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    inner='box',
    ax=axes[1]
)
sns.swarmplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    color='black',
    alpha=0.5,
    size=3,
    ax=axes[1]
)
axes[1].set_title('Distribution Comparison')
axes[1].set_ylabel('')

plt.tight_layout()

Dose-Response Curve

# Create dose-response plot
fig, ax = plt.subplots(figsize=(8, 6))

# Plot individual points
sns.stripplot(
    data=dose_df,
    x='dose',
    y='response',
    order=sorted(dose_df['dose'].unique()),
    color='gray',
    alpha=0.3,
    jitter=0.2,
    ax=ax
)

# Overlay mean with CI
sns.pointplot(
    data=dose_df,
    x='dose',
    y='response',
    order=sorted(dose_df['dose'].unique()),
    color='blue',
    markers='o',
    scale=1.2,
    errorbar=('ci', 95),
    capsize=0.1,
    ax=ax
)

# Fit sigmoid curve
from scipy.optimize import curve_fit

def sigmoid(x, bottom, top, ec50, hill):
    return bottom + (top - bottom) / (1 + (ec50 / x) ** hill)

doses_numeric = dose_df['dose'].astype(float)
params, _ = curve_fit(sigmoid, doses_numeric, dose_df['response'])

x_smooth = np.logspace(np.log10(doses_numeric.min()),
                       np.log10(doses_numeric.max()), 100)
y_smooth = sigmoid(x_smooth, *params)

ax.plot(range(len(sorted(dose_df['dose'].unique()))),
        sigmoid(sorted(doses_numeric.unique()), *params),
        'r-', linewidth=2, label='Sigmoid Fit')

ax.set_xlabel('Dose')
ax.set_ylabel('Response')
ax.set_title('Dose-Response Analysis')
ax.legend()
sns.despine()

Custom Styling

Custom Color Palette from Hex Codes

# Define custom palette
custom_palette = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']
sns.set_palette(custom_palette)

# Or use for specific plot
sns.scatterplot(
    data=df,
    x='x',
    y='y',
    hue='category',
    palette=custom_palette
)

Publication-Ready Theme

# Set comprehensive theme
sns.set_theme(
    context='paper',
    style='ticks',
    palette='colorblind',
    font='Arial',
    font_scale=1.1,
    rc={
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'savefig.format': 'pdf',
        'axes.linewidth': 1.0,
        'axes.labelweight': 'bold',
        'xtick.major.width': 1.0,
        'ytick.major.width': 1.0,
        'xtick.direction': 'out',
        'ytick.direction': 'out',
        'legend.frameon': False,
        'pdf.fonttype': 42,  # True Type fonts for PDFs
    }
)

Diverging Colormap Centered on Zero

# For data with meaningful zero point (e.g., log fold change)
from matplotlib.colors import TwoSlopeNorm

# Find data range
vmin, vmax = df['value'].min(), df['value'].max()
vcenter = 0

# Create norm
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

# Plot
sns.heatmap(
    pivot_data,
    cmap='RdBu_r',
    norm=norm,
    center=0,
    annot=True,
    fmt='.2f'
)

Large Datasets

Downsampling Strategy

# For very large datasets, sample intelligently
def smart_sample(df, target_size=10000, category_col=None):
    if len(df) <= target_size:
        return df

    if category_col:
        # Stratified sampling
        return df.groupby(category_col, group_keys=False).apply(
            lambda x: x.sample(min(len(x), target_size // df[category_col].nunique()))
        )
    else:
        # Simple random sampling
        return df.sample(target_size)

# Use sampled data for visualization
df_sampled = smart_sample(large_df, target_size=5000, category_col='category')

sns.scatterplot(data=df_sampled, x='x', y='y', hue='category', alpha=0.5)

Hexbin for Dense Scatter Plots

# For millions of points
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Regular scatter (slow)
axes[0].scatter(df['x'], df['y'], alpha=0.1, s=1)
axes[0].set_title('Scatter (all points)')

# Hexbin (fast)
hb = axes[1].hexbin(df['x'], df['y'], gridsize=50, cmap='viridis', mincnt=1)
axes[1].set_title('Hexbin Aggregation')
plt.colorbar(hb, ax=axes[1], label='Count')

plt.tight_layout()

Interactive Elements for Notebooks

Adjustable Parameters

from ipywidgets import interact, FloatSlider

@interact(bandwidth=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0))
def plot_kde(bandwidth):
    plt.figure(figsize=(10, 6))
    sns.kdeplot(data=df, x='value', hue='category',
                bw_adjust=bandwidth, fill=True)
    plt.title(f'KDE with bandwidth adjustment = {bandwidth}')
    plt.show()

Dynamic Filtering

from ipywidgets import interact, SelectMultiple

categories = df['category'].unique().tolist()

@interact(selected=SelectMultiple(options=categories, value=[categories[0]]))
def filtered_plot(selected):
    filtered_df = df[df['category'].isin(selected)]

    fig, ax = plt.subplots(figsize=(10, 6))
    sns.violinplot(data=filtered_df, x='category', y='value', ax=ax)
    ax.set_title(f'Showing {len(selected)} categories')
    plt.show()