823 lines
17 KiB
Markdown
823 lines
17 KiB
Markdown
# Seaborn Common Use Cases and Examples
|
|
|
|
This document provides practical examples for common data visualization scenarios using seaborn.
|
|
|
|
## Exploratory Data Analysis
|
|
|
|
### Quick Dataset Overview
|
|
|
|
```python
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
# Load data
|
|
df = pd.read_csv('data.csv')
|
|
|
|
# Pairwise relationships for all numeric variables
|
|
sns.pairplot(df, hue='target_variable', corner=True, diag_kind='kde')
|
|
plt.suptitle('Dataset Overview', y=1.01)
|
|
plt.savefig('overview.png', dpi=300, bbox_inches='tight')
|
|
```
|
|
|
|
### Distribution Exploration
|
|
|
|
```python
|
|
# Multiple distributions across categories
|
|
g = sns.displot(
|
|
data=df,
|
|
x='measurement',
|
|
hue='condition',
|
|
col='timepoint',
|
|
kind='kde',
|
|
fill=True,
|
|
height=3,
|
|
aspect=1.5,
|
|
col_wrap=3,
|
|
common_norm=False
|
|
)
|
|
g.set_axis_labels('Measurement Value', 'Density')
|
|
g.set_titles('{col_name}')
|
|
```
|
|
|
|
### Correlation Analysis
|
|
|
|
```python
|
|
# Compute correlation matrix
|
|
corr = df.select_dtypes(include='number').corr()
|
|
|
|
# Create mask for upper triangle
|
|
mask = np.triu(np.ones_like(corr, dtype=bool))
|
|
|
|
# Plot heatmap
|
|
fig, ax = plt.subplots(figsize=(10, 8))
|
|
sns.heatmap(
|
|
corr,
|
|
mask=mask,
|
|
annot=True,
|
|
fmt='.2f',
|
|
cmap='coolwarm',
|
|
center=0,
|
|
square=True,
|
|
linewidths=1,
|
|
cbar_kws={'shrink': 0.8}
|
|
)
|
|
plt.title('Correlation Matrix')
|
|
plt.tight_layout()
|
|
```
|
|
|
|
## Scientific Publications
|
|
|
|
### Multi-Panel Figure with Different Plot Types
|
|
|
|
```python
|
|
# Set publication style
|
|
sns.set_theme(style='ticks', context='paper', font_scale=1.1)
|
|
sns.set_palette('colorblind')
|
|
|
|
# Create figure with custom layout
|
|
fig = plt.figure(figsize=(12, 8))
|
|
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
|
|
|
|
# Panel A: Time series
|
|
ax1 = fig.add_subplot(gs[0, :2])
|
|
sns.lineplot(
|
|
data=timeseries_df,
|
|
x='time',
|
|
y='expression',
|
|
hue='gene',
|
|
style='treatment',
|
|
markers=True,
|
|
dashes=False,
|
|
ax=ax1
|
|
)
|
|
ax1.set_title('A. Gene Expression Over Time', loc='left', fontweight='bold')
|
|
ax1.set_xlabel('Time (hours)')
|
|
ax1.set_ylabel('Expression Level (AU)')
|
|
|
|
# Panel B: Distribution comparison
|
|
ax2 = fig.add_subplot(gs[0, 2])
|
|
sns.violinplot(
|
|
data=expression_df,
|
|
x='treatment',
|
|
y='expression',
|
|
inner='box',
|
|
ax=ax2
|
|
)
|
|
ax2.set_title('B. Expression Distribution', loc='left', fontweight='bold')
|
|
ax2.set_xlabel('Treatment')
|
|
ax2.set_ylabel('')
|
|
|
|
# Panel C: Correlation
|
|
ax3 = fig.add_subplot(gs[1, 0])
|
|
sns.scatterplot(
|
|
data=correlation_df,
|
|
x='gene1',
|
|
y='gene2',
|
|
hue='cell_type',
|
|
alpha=0.6,
|
|
ax=ax3
|
|
)
|
|
sns.regplot(
|
|
data=correlation_df,
|
|
x='gene1',
|
|
y='gene2',
|
|
scatter=False,
|
|
color='black',
|
|
ax=ax3
|
|
)
|
|
ax3.set_title('C. Gene Correlation', loc='left', fontweight='bold')
|
|
ax3.set_xlabel('Gene 1 Expression')
|
|
ax3.set_ylabel('Gene 2 Expression')
|
|
|
|
# Panel D: Heatmap
|
|
ax4 = fig.add_subplot(gs[1, 1:])
|
|
sns.heatmap(
|
|
sample_matrix,
|
|
cmap='RdBu_r',
|
|
center=0,
|
|
annot=True,
|
|
fmt='.1f',
|
|
cbar_kws={'label': 'Log2 Fold Change'},
|
|
ax=ax4
|
|
)
|
|
ax4.set_title('D. Treatment Effects', loc='left', fontweight='bold')
|
|
ax4.set_xlabel('Sample')
|
|
ax4.set_ylabel('Gene')
|
|
|
|
# Clean up
|
|
sns.despine()
|
|
plt.savefig('figure.pdf', dpi=300, bbox_inches='tight')
|
|
plt.savefig('figure.png', dpi=300, bbox_inches='tight')
|
|
```
|
|
|
|
### Box Plot with Significance Annotations
|
|
|
|
```python
|
|
import numpy as np
|
|
from scipy import stats
|
|
|
|
# Create plot
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
sns.boxplot(
|
|
data=df,
|
|
x='treatment',
|
|
y='response',
|
|
order=['Control', 'Low', 'Medium', 'High'],
|
|
palette='Set2',
|
|
ax=ax
|
|
)
|
|
|
|
# Add individual points
|
|
sns.stripplot(
|
|
data=df,
|
|
x='treatment',
|
|
y='response',
|
|
order=['Control', 'Low', 'Medium', 'High'],
|
|
color='black',
|
|
alpha=0.3,
|
|
size=3,
|
|
ax=ax
|
|
)
|
|
|
|
# Add significance bars
|
|
def add_significance_bar(ax, x1, x2, y, h, text):
|
|
ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], 'k-', lw=1.5)
|
|
ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom')
|
|
|
|
y_max = df['response'].max()
|
|
add_significance_bar(ax, 0, 3, y_max + 1, 0.5, '***')
|
|
add_significance_bar(ax, 0, 1, y_max + 3, 0.5, 'ns')
|
|
|
|
ax.set_ylabel('Response (μM)')
|
|
ax.set_xlabel('Treatment Condition')
|
|
ax.set_title('Treatment Response Analysis')
|
|
sns.despine()
|
|
```
|
|
|
|
## Time Series Analysis
|
|
|
|
### Multiple Time Series with Confidence Bands
|
|
|
|
```python
|
|
# Plot with automatic aggregation
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
sns.lineplot(
|
|
data=timeseries_df,
|
|
x='timestamp',
|
|
y='value',
|
|
hue='sensor',
|
|
style='location',
|
|
markers=True,
|
|
dashes=False,
|
|
errorbar=('ci', 95),
|
|
ax=ax
|
|
)
|
|
|
|
# Customize
|
|
ax.set_xlabel('Date')
|
|
ax.set_ylabel('Measurement (units)')
|
|
ax.set_title('Sensor Measurements Over Time')
|
|
ax.legend(title='Sensor & Location', bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
|
|
# Format x-axis for dates
|
|
import matplotlib.dates as mdates
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
|
|
plt.xticks(rotation=45, ha='right')
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
### Faceted Time Series
|
|
|
|
```python
|
|
# Create faceted time series
|
|
g = sns.relplot(
|
|
data=long_timeseries,
|
|
x='date',
|
|
y='measurement',
|
|
hue='device',
|
|
col='location',
|
|
row='metric',
|
|
kind='line',
|
|
height=3,
|
|
aspect=2,
|
|
errorbar='sd',
|
|
facet_kws={'sharex': True, 'sharey': False}
|
|
)
|
|
|
|
# Customize facet titles
|
|
g.set_titles('{row_name} - {col_name}')
|
|
g.set_axis_labels('Date', 'Value')
|
|
|
|
# Rotate x-axis labels
|
|
for ax in g.axes.flat:
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
|
g.tight_layout()
|
|
```
|
|
|
|
## Categorical Comparisons
|
|
|
|
### Nested Categorical Variables
|
|
|
|
```python
|
|
# Create figure
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
|
|
|
# Left panel: Grouped bar plot
|
|
sns.barplot(
|
|
data=df,
|
|
x='category',
|
|
y='value',
|
|
hue='subcategory',
|
|
errorbar=('ci', 95),
|
|
capsize=0.1,
|
|
ax=axes[0]
|
|
)
|
|
axes[0].set_title('Mean Values with 95% CI')
|
|
axes[0].set_ylabel('Value (units)')
|
|
axes[0].legend(title='Subcategory')
|
|
|
|
# Right panel: Strip + violin plot
|
|
sns.violinplot(
|
|
data=df,
|
|
x='category',
|
|
y='value',
|
|
hue='subcategory',
|
|
inner=None,
|
|
alpha=0.3,
|
|
ax=axes[1]
|
|
)
|
|
sns.stripplot(
|
|
data=df,
|
|
x='category',
|
|
y='value',
|
|
hue='subcategory',
|
|
dodge=True,
|
|
size=3,
|
|
alpha=0.6,
|
|
ax=axes[1]
|
|
)
|
|
axes[1].set_title('Distribution of Individual Values')
|
|
axes[1].set_ylabel('')
|
|
axes[1].get_legend().remove()
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
### Point Plot for Trends
|
|
|
|
```python
|
|
# Show how values change across categories
|
|
sns.pointplot(
|
|
data=df,
|
|
x='timepoint',
|
|
y='score',
|
|
hue='treatment',
|
|
markers=['o', 's', '^'],
|
|
linestyles=['-', '--', '-.'],
|
|
dodge=0.3,
|
|
capsize=0.1,
|
|
errorbar=('ci', 95)
|
|
)
|
|
|
|
plt.xlabel('Timepoint')
|
|
plt.ylabel('Performance Score')
|
|
plt.title('Treatment Effects Over Time')
|
|
plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
sns.despine()
|
|
plt.tight_layout()
|
|
```
|
|
|
|
## Regression and Relationships
|
|
|
|
### Linear Regression with Facets
|
|
|
|
```python
|
|
# Fit separate regressions for each category
|
|
g = sns.lmplot(
|
|
data=df,
|
|
x='predictor',
|
|
y='response',
|
|
hue='treatment',
|
|
col='cell_line',
|
|
height=4,
|
|
aspect=1.2,
|
|
scatter_kws={'alpha': 0.5, 's': 50},
|
|
ci=95,
|
|
palette='Set2'
|
|
)
|
|
|
|
g.set_axis_labels('Predictor Variable', 'Response Variable')
|
|
g.set_titles('{col_name}')
|
|
g.tight_layout()
|
|
```
|
|
|
|
### Polynomial Regression
|
|
|
|
```python
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
|
|
for idx, order in enumerate([1, 2, 3]):
|
|
sns.regplot(
|
|
data=df,
|
|
x='x',
|
|
y='y',
|
|
order=order,
|
|
scatter_kws={'alpha': 0.5},
|
|
line_kws={'color': 'red'},
|
|
ci=95,
|
|
ax=axes[idx]
|
|
)
|
|
axes[idx].set_title(f'Order {order} Polynomial Fit')
|
|
axes[idx].set_xlabel('X Variable')
|
|
axes[idx].set_ylabel('Y Variable')
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
### Residual Analysis
|
|
|
|
```python
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
|
|
|
# Main regression
|
|
sns.regplot(data=df, x='x', y='y', ax=axes[0, 0])
|
|
axes[0, 0].set_title('Regression Fit')
|
|
|
|
# Residuals vs fitted
|
|
sns.residplot(data=df, x='x', y='y', lowess=True,
|
|
scatter_kws={'alpha': 0.5},
|
|
line_kws={'color': 'red', 'lw': 2},
|
|
ax=axes[0, 1])
|
|
axes[0, 1].set_title('Residuals vs Fitted')
|
|
axes[0, 1].axhline(0, ls='--', color='gray')
|
|
|
|
# Q-Q plot (using scipy)
|
|
from scipy import stats as sp_stats
|
|
residuals = df['y'] - np.poly1d(np.polyfit(df['x'], df['y'], 1))(df['x'])
|
|
sp_stats.probplot(residuals, dist="norm", plot=axes[1, 0])
|
|
axes[1, 0].set_title('Q-Q Plot')
|
|
|
|
# Histogram of residuals
|
|
sns.histplot(residuals, kde=True, ax=axes[1, 1])
|
|
axes[1, 1].set_title('Residual Distribution')
|
|
axes[1, 1].set_xlabel('Residuals')
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
## Bivariate and Joint Distributions
|
|
|
|
### Joint Plot with Multiple Representations
|
|
|
|
```python
|
|
# Scatter with marginals
|
|
g = sns.jointplot(
|
|
data=df,
|
|
x='var1',
|
|
y='var2',
|
|
hue='category',
|
|
kind='scatter',
|
|
height=8,
|
|
ratio=4,
|
|
space=0.1,
|
|
joint_kws={'alpha': 0.5, 's': 50},
|
|
marginal_kws={'kde': True, 'bins': 30}
|
|
)
|
|
|
|
# Add reference lines
|
|
g.ax_joint.axline((0, 0), slope=1, color='r', ls='--', alpha=0.5, label='y=x')
|
|
g.ax_joint.legend()
|
|
|
|
g.set_axis_labels('Variable 1', 'Variable 2', fontsize=12)
|
|
```
|
|
|
|
### KDE Contour Plot
|
|
|
|
```python
|
|
fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
# Bivariate KDE with filled contours
|
|
sns.kdeplot(
|
|
data=df,
|
|
x='x',
|
|
y='y',
|
|
fill=True,
|
|
levels=10,
|
|
cmap='viridis',
|
|
thresh=0.05,
|
|
ax=ax
|
|
)
|
|
|
|
# Overlay scatter
|
|
sns.scatterplot(
|
|
data=df,
|
|
x='x',
|
|
y='y',
|
|
color='white',
|
|
edgecolor='black',
|
|
s=50,
|
|
alpha=0.6,
|
|
ax=ax
|
|
)
|
|
|
|
ax.set_xlabel('X Variable')
|
|
ax.set_ylabel('Y Variable')
|
|
ax.set_title('Bivariate Distribution')
|
|
```
|
|
|
|
### Hexbin with Marginals
|
|
|
|
```python
|
|
# For large datasets
|
|
g = sns.jointplot(
|
|
data=large_df,
|
|
x='x',
|
|
y='y',
|
|
kind='hex',
|
|
height=8,
|
|
ratio=5,
|
|
space=0.1,
|
|
joint_kws={'gridsize': 30, 'cmap': 'viridis'},
|
|
marginal_kws={'bins': 50, 'color': 'skyblue'}
|
|
)
|
|
|
|
g.set_axis_labels('X Variable', 'Y Variable')
|
|
```
|
|
|
|
## Matrix and Heatmap Visualizations
|
|
|
|
### Hierarchical Clustering Heatmap
|
|
|
|
```python
|
|
# Prepare data (samples x features)
|
|
data_matrix = df.set_index('sample_id')[feature_columns]
|
|
|
|
# Create color annotations
|
|
row_colors = df.set_index('sample_id')['condition'].map({
|
|
'control': '#1f77b4',
|
|
'treatment': '#ff7f0e'
|
|
})
|
|
|
|
col_colors = pd.Series(['#2ca02c' if 'gene' in col else '#d62728'
|
|
for col in data_matrix.columns])
|
|
|
|
# Plot
|
|
g = sns.clustermap(
|
|
data_matrix,
|
|
method='ward',
|
|
metric='euclidean',
|
|
z_score=0, # Normalize rows
|
|
cmap='RdBu_r',
|
|
center=0,
|
|
row_colors=row_colors,
|
|
col_colors=col_colors,
|
|
figsize=(12, 10),
|
|
dendrogram_ratio=(0.1, 0.1),
|
|
cbar_pos=(0.02, 0.8, 0.03, 0.15),
|
|
linewidths=0.5
|
|
)
|
|
|
|
g.ax_heatmap.set_xlabel('Features')
|
|
g.ax_heatmap.set_ylabel('Samples')
|
|
plt.savefig('clustermap.png', dpi=300, bbox_inches='tight')
|
|
```
|
|
|
|
### Annotated Heatmap with Custom Colorbar
|
|
|
|
```python
|
|
# Pivot data for heatmap
|
|
pivot_data = df.pivot(index='row_var', columns='col_var', values='value')
|
|
|
|
# Create heatmap
|
|
fig, ax = plt.subplots(figsize=(10, 8))
|
|
sns.heatmap(
|
|
pivot_data,
|
|
annot=True,
|
|
fmt='.1f',
|
|
cmap='RdYlGn',
|
|
center=pivot_data.mean().mean(),
|
|
vmin=pivot_data.min().min(),
|
|
vmax=pivot_data.max().max(),
|
|
linewidths=0.5,
|
|
linecolor='gray',
|
|
cbar_kws={
|
|
'label': 'Value (units)',
|
|
'orientation': 'vertical',
|
|
'shrink': 0.8,
|
|
'aspect': 20
|
|
},
|
|
ax=ax
|
|
)
|
|
|
|
ax.set_title('Variable Relationships', fontsize=14, pad=20)
|
|
ax.set_xlabel('Column Variable', fontsize=12)
|
|
ax.set_ylabel('Row Variable', fontsize=12)
|
|
|
|
plt.xticks(rotation=45, ha='right')
|
|
plt.yticks(rotation=0)
|
|
plt.tight_layout()
|
|
```
|
|
|
|
## Statistical Comparisons
|
|
|
|
### Before/After Comparison
|
|
|
|
```python
|
|
# Reshape data for paired comparison
|
|
df_paired = df.melt(
|
|
id_vars='subject',
|
|
value_vars=['before', 'after'],
|
|
var_name='timepoint',
|
|
value_name='measurement'
|
|
)
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
|
|
|
# Left: Individual trajectories
|
|
for subject in df_paired['subject'].unique():
|
|
subject_data = df_paired[df_paired['subject'] == subject]
|
|
axes[0].plot(subject_data['timepoint'], subject_data['measurement'],
|
|
'o-', alpha=0.3, color='gray')
|
|
|
|
sns.pointplot(
|
|
data=df_paired,
|
|
x='timepoint',
|
|
y='measurement',
|
|
color='red',
|
|
markers='D',
|
|
scale=1.5,
|
|
errorbar=('ci', 95),
|
|
capsize=0.2,
|
|
ax=axes[0]
|
|
)
|
|
axes[0].set_title('Individual Changes')
|
|
axes[0].set_ylabel('Measurement')
|
|
|
|
# Right: Distribution comparison
|
|
sns.violinplot(
|
|
data=df_paired,
|
|
x='timepoint',
|
|
y='measurement',
|
|
inner='box',
|
|
ax=axes[1]
|
|
)
|
|
sns.swarmplot(
|
|
data=df_paired,
|
|
x='timepoint',
|
|
y='measurement',
|
|
color='black',
|
|
alpha=0.5,
|
|
size=3,
|
|
ax=axes[1]
|
|
)
|
|
axes[1].set_title('Distribution Comparison')
|
|
axes[1].set_ylabel('')
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
### Dose-Response Curve
|
|
|
|
```python
|
|
# Create dose-response plot
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
|
# Plot individual points
|
|
sns.stripplot(
|
|
data=dose_df,
|
|
x='dose',
|
|
y='response',
|
|
order=sorted(dose_df['dose'].unique()),
|
|
color='gray',
|
|
alpha=0.3,
|
|
jitter=0.2,
|
|
ax=ax
|
|
)
|
|
|
|
# Overlay mean with CI
|
|
sns.pointplot(
|
|
data=dose_df,
|
|
x='dose',
|
|
y='response',
|
|
order=sorted(dose_df['dose'].unique()),
|
|
color='blue',
|
|
markers='o',
|
|
scale=1.2,
|
|
errorbar=('ci', 95),
|
|
capsize=0.1,
|
|
ax=ax
|
|
)
|
|
|
|
# Fit sigmoid curve
|
|
from scipy.optimize import curve_fit
|
|
|
|
def sigmoid(x, bottom, top, ec50, hill):
|
|
return bottom + (top - bottom) / (1 + (ec50 / x) ** hill)
|
|
|
|
doses_numeric = dose_df['dose'].astype(float)
|
|
params, _ = curve_fit(sigmoid, doses_numeric, dose_df['response'])
|
|
|
|
x_smooth = np.logspace(np.log10(doses_numeric.min()),
|
|
np.log10(doses_numeric.max()), 100)
|
|
y_smooth = sigmoid(x_smooth, *params)
|
|
|
|
ax.plot(range(len(sorted(dose_df['dose'].unique()))),
|
|
sigmoid(sorted(doses_numeric.unique()), *params),
|
|
'r-', linewidth=2, label='Sigmoid Fit')
|
|
|
|
ax.set_xlabel('Dose')
|
|
ax.set_ylabel('Response')
|
|
ax.set_title('Dose-Response Analysis')
|
|
ax.legend()
|
|
sns.despine()
|
|
```
|
|
|
|
## Custom Styling
|
|
|
|
### Custom Color Palette from Hex Codes
|
|
|
|
```python
|
|
# Define custom palette
|
|
custom_palette = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']
|
|
sns.set_palette(custom_palette)
|
|
|
|
# Or use for specific plot
|
|
sns.scatterplot(
|
|
data=df,
|
|
x='x',
|
|
y='y',
|
|
hue='category',
|
|
palette=custom_palette
|
|
)
|
|
```
|
|
|
|
### Publication-Ready Theme
|
|
|
|
```python
|
|
# Set comprehensive theme
|
|
sns.set_theme(
|
|
context='paper',
|
|
style='ticks',
|
|
palette='colorblind',
|
|
font='Arial',
|
|
font_scale=1.1,
|
|
rc={
|
|
'figure.dpi': 300,
|
|
'savefig.dpi': 300,
|
|
'savefig.format': 'pdf',
|
|
'axes.linewidth': 1.0,
|
|
'axes.labelweight': 'bold',
|
|
'xtick.major.width': 1.0,
|
|
'ytick.major.width': 1.0,
|
|
'xtick.direction': 'out',
|
|
'ytick.direction': 'out',
|
|
'legend.frameon': False,
|
|
'pdf.fonttype': 42, # True Type fonts for PDFs
|
|
}
|
|
)
|
|
```
|
|
|
|
### Diverging Colormap Centered on Zero
|
|
|
|
```python
|
|
# For data with meaningful zero point (e.g., log fold change)
|
|
from matplotlib.colors import TwoSlopeNorm
|
|
|
|
# Find data range
|
|
vmin, vmax = df['value'].min(), df['value'].max()
|
|
vcenter = 0
|
|
|
|
# Create norm
|
|
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
|
|
|
|
# Plot
|
|
sns.heatmap(
|
|
pivot_data,
|
|
cmap='RdBu_r',
|
|
norm=norm,
|
|
center=0,
|
|
annot=True,
|
|
fmt='.2f'
|
|
)
|
|
```
|
|
|
|
## Large Datasets
|
|
|
|
### Downsampling Strategy
|
|
|
|
```python
|
|
# For very large datasets, sample intelligently
|
|
def smart_sample(df, target_size=10000, category_col=None):
|
|
if len(df) <= target_size:
|
|
return df
|
|
|
|
if category_col:
|
|
# Stratified sampling
|
|
return df.groupby(category_col, group_keys=False).apply(
|
|
lambda x: x.sample(min(len(x), target_size // df[category_col].nunique()))
|
|
)
|
|
else:
|
|
# Simple random sampling
|
|
return df.sample(target_size)
|
|
|
|
# Use sampled data for visualization
|
|
df_sampled = smart_sample(large_df, target_size=5000, category_col='category')
|
|
|
|
sns.scatterplot(data=df_sampled, x='x', y='y', hue='category', alpha=0.5)
|
|
```
|
|
|
|
### Hexbin for Dense Scatter Plots
|
|
|
|
```python
|
|
# For millions of points
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
|
|
|
# Regular scatter (slow)
|
|
axes[0].scatter(df['x'], df['y'], alpha=0.1, s=1)
|
|
axes[0].set_title('Scatter (all points)')
|
|
|
|
# Hexbin (fast)
|
|
hb = axes[1].hexbin(df['x'], df['y'], gridsize=50, cmap='viridis', mincnt=1)
|
|
axes[1].set_title('Hexbin Aggregation')
|
|
plt.colorbar(hb, ax=axes[1], label='Count')
|
|
|
|
plt.tight_layout()
|
|
```
|
|
|
|
## Interactive Elements for Notebooks
|
|
|
|
### Adjustable Parameters
|
|
|
|
```python
|
|
from ipywidgets import interact, FloatSlider
|
|
|
|
@interact(bandwidth=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0))
|
|
def plot_kde(bandwidth):
|
|
plt.figure(figsize=(10, 6))
|
|
sns.kdeplot(data=df, x='value', hue='category',
|
|
bw_adjust=bandwidth, fill=True)
|
|
plt.title(f'KDE with bandwidth adjustment = {bandwidth}')
|
|
plt.show()
|
|
```
|
|
|
|
### Dynamic Filtering
|
|
|
|
```python
|
|
from ipywidgets import interact, SelectMultiple
|
|
|
|
categories = df['category'].unique().tolist()
|
|
|
|
@interact(selected=SelectMultiple(options=categories, value=[categories[0]]))
|
|
def filtered_plot(selected):
|
|
filtered_df = df[df['category'].isin(selected)]
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
sns.violinplot(data=filtered_df, x='category', y='value', ax=ax)
|
|
ax.set_title(f'Showing {len(selected)} categories')
|
|
plt.show()
|
|
```
|