14 KiB
Differential Expression Analysis in scvi-tools
This document provides detailed information about differential expression (DE) analysis using scvi-tools' probabilistic framework.
Overview
scvi-tools implements Bayesian differential expression testing that leverages the learned generative models to estimate expression differences between groups. This approach provides several advantages over traditional methods:
- Batch correction: DE testing on batch-corrected representations
- Uncertainty quantification: Probabilistic estimates of effect sizes
- Zero-inflation handling: Proper modeling of dropout and zeros
- Flexible comparisons: Between any groups or cell types
- Multiple modalities: Works for RNA, proteins (totalVI), and accessibility (PeakVI)
Core Statistical Framework
Problem Definition
The goal is to estimate the log fold-change in expression between two conditions:
log fold-change = log(μ_B) - log(μ_A)
Where μ_A and μ_B are the mean expression levels in conditions A and B.
Three-Stage Process
Stage 1: Estimating Expression Levels
- Sample from posterior distribution of cellular states
- Generate expression values from the learned generative model
- Aggregate across cells to get population-level estimates
Stage 2: Detecting Relevant Features (Hypothesis Testing)
- Test for differential expression using Bayesian framework
- Two testing modes available:
- "vanilla" mode: Point null hypothesis (β = 0)
- "change" mode: Composite hypothesis (|β| ≤ δ)
Stage 3: Controlling False Discovery
- Posterior expected False Discovery Proportion (FDP) control
- Selects maximum number of discoveries ensuring E[FDP] ≤ α
Basic Usage
Simple Two-Group Comparison
import scvi
# After training a model
model = scvi.model.SCVI(adata)
model.train()
# Compare two cell types
de_results = model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells"
)
# View top DE genes
top_genes = de_results.sort_values("lfc_mean", ascending=False).head(20)
print(top_genes[["lfc_mean", "lfc_std", "bayes_factor", "is_de_fdr_0.05"]])
One vs. Rest Comparison
# Compare one group against all others
de_results = model.differential_expression(
groupby="cell_type",
group1="T cells" # No group2 = compare to rest
)
All Pairwise Comparisons
# Compare all cell types pairwise
all_comparisons = {}
cell_types = adata.obs["cell_type"].unique()
for ct1 in cell_types:
for ct2 in cell_types:
if ct1 != ct2:
key = f"{ct1}_vs_{ct2}"
all_comparisons[key] = model.differential_expression(
groupby="cell_type",
group1=ct1,
group2=ct2
)
Key Parameters
groupby (required)
Column in adata.obs defining groups to compare.
# Must be a categorical variable
de_results = model.differential_expression(groupby="cell_type")
group1 and group2
Groups to compare. If group2 is None, compares group1 to all others.
# Specific comparison
de = model.differential_expression(groupby="condition", group1="treated", group2="control")
# One vs rest
de = model.differential_expression(groupby="cell_type", group1="T cells")
mode (Hypothesis Testing Mode)
"vanilla" mode (default): Point null hypothesis
- Tests if β = 0 exactly
- More sensitive, but may find trivially small effects
"change" mode: Composite null hypothesis
- Tests if |β| ≤ δ
- Requires biologically meaningful change
- Reduces false discoveries of tiny effects
# Change mode with minimum effect size
de = model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells",
mode="change",
delta=0.25 # Minimum log fold-change
)
delta
Minimum effect size threshold for "change" mode.
- Typical values: 0.25, 0.5, 0.7 (log scale)
- log2(1.5) ≈ 0.58 (1.5-fold change)
- log2(2) = 1.0 (2-fold change)
# Require at least 1.5-fold change
de = model.differential_expression(
groupby="condition",
group1="disease",
group2="healthy",
mode="change",
delta=0.58 # log2(1.5)
)
fdr_target
False discovery rate threshold (default: 0.05)
# More stringent FDR control
de = model.differential_expression(
groupby="cell_type",
group1="T cells",
fdr_target=0.01
)
batch_correction
Whether to perform batch correction during DE testing (default: True)
# Test within a specific batch
de = model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells",
batch_correction=False
)
n_samples
Number of posterior samples for estimation (default: 5000)
- More samples = more accurate but slower
- Reduce for speed, increase for precision
# High precision analysis
de = model.differential_expression(
groupby="cell_type",
group1="T cells",
n_samples=10000
)
Interpreting Results
Output Columns
The results DataFrame contains several important columns:
Effect Size Estimates:
lfc_mean: Mean log fold-changelfc_median: Median log fold-changelfc_std: Standard deviation of log fold-changelfc_min: Lower bound of effect sizelfc_max: Upper bound of effect size
Statistical Significance:
bayes_factor: Bayes factor for differential expression- Higher values = stronger evidence
-
3 often considered meaningful
is_de_fdr_0.05: Boolean indicating if gene is DE at FDR 0.05is_de_fdr_0.1: Boolean indicating if gene is DE at FDR 0.1
Expression Levels:
mean1: Mean expression in group 1mean2: Mean expression in group 2non_zeros_proportion1: Proportion of non-zero cells in group 1non_zeros_proportion2: Proportion of non-zero cells in group 2
Example Interpretation
de_results = model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells"
)
# Find significantly upregulated genes in T cells
upreg_tcells = de_results[
(de_results["is_de_fdr_0.05"]) &
(de_results["lfc_mean"] > 0)
].sort_values("lfc_mean", ascending=False)
print(f"Upregulated genes in T cells: {len(upreg_tcells)}")
print(upreg_tcells.head(10))
# Find genes with large effect sizes
large_effect = de_results[
(de_results["is_de_fdr_0.05"]) &
(abs(de_results["lfc_mean"]) > 1) # 2-fold change
]
Advanced Usage
DE Within Specific Cells
# Test DE only within a subset of cells
subset_indices = adata.obs["tissue"] == "lung"
de = model.differential_expression(
idx1=adata.obs["cell_type"] == "T cells" & subset_indices,
idx2=adata.obs["cell_type"] == "B cells" & subset_indices
)
Batch-Specific DE
# Test DE within each batch separately
batches = adata.obs["batch"].unique()
batch_de_results = {}
for batch in batches:
batch_idx = adata.obs["batch"] == batch
batch_de_results[batch] = model.differential_expression(
idx1=(adata.obs["condition"] == "treated") & batch_idx,
idx2=(adata.obs["condition"] == "control") & batch_idx
)
Pseudo-bulk DE
# Aggregate cells before DE testing
# Useful for low cell counts per group
de = model.differential_expression(
groupby="cell_type",
group1="rare_cell_type",
group2="common_cell_type",
n_samples=10000, # More samples for stability
batch_correction=True
)
Visualization
Volcano Plot
import matplotlib.pyplot as plt
import numpy as np
de = model.differential_expression(
groupby="condition",
group1="treated",
group2="control"
)
# Volcano plot
plt.figure(figsize=(10, 6))
plt.scatter(
de["lfc_mean"],
-np.log10(1 / (de["bayes_factor"] + 1)),
c=de["is_de_fdr_0.05"],
cmap="coolwarm",
alpha=0.5
)
plt.xlabel("Log Fold Change")
plt.ylabel("-log10(1/Bayes Factor)")
plt.title("Volcano Plot: Treated vs Control")
plt.axvline(x=0, color='k', linestyle='--', linewidth=0.5)
plt.show()
Heatmap of Top DE Genes
import seaborn as sns
# Get top DE genes
top_genes = de.sort_values("lfc_mean", ascending=False).head(50).index
# Get normalized expression
norm_expr = model.get_normalized_expression(
adata,
indices=adata.obs["condition"].isin(["treated", "control"]),
gene_list=top_genes
)
# Plot heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(
norm_expr.T,
cmap="viridis",
xticklabels=False,
yticklabels=top_genes
)
plt.title("Top 50 DE Genes")
plt.show()
Ranked Gene Plot
# Plot genes ranked by effect size
de_sorted = de.sort_values("lfc_mean", ascending=False)
plt.figure(figsize=(12, 6))
plt.plot(range(len(de_sorted)), de_sorted["lfc_mean"].values)
plt.axhline(y=0, color='r', linestyle='--')
plt.xlabel("Gene Rank")
plt.ylabel("Log Fold Change")
plt.title("Genes Ranked by Effect Size")
plt.show()
Comparison with Traditional Methods
scvi-tools vs. Wilcoxon Test
import scanpy as sc
# Traditional Wilcoxon test
sc.tl.rank_genes_groups(
adata,
groupby="cell_type",
method="wilcoxon",
key_added="wilcoxon"
)
# scvi-tools DE
de_scvi = model.differential_expression(
groupby="cell_type",
group1="T cells"
)
# Compare results
wilcox_results = sc.get.rank_genes_groups_df(adata, group="T cells", key="wilcoxon")
Advantages of scvi-tools:
- Accounts for batch effects automatically
- Handles zero-inflation properly
- Provides uncertainty quantification
- No arbitrary pseudocount needed
- Better statistical properties
When to use Wilcoxon:
- Very quick exploratory analysis
- Comparison with published results using Wilcoxon
Multi-Modal DE
Protein DE (totalVI)
# Train totalVI on CITE-seq data
totalvi_model = scvi.model.TOTALVI(adata)
totalvi_model.train()
# RNA differential expression
rna_de = totalvi_model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells",
protein_expression=False # Default
)
# Protein differential expression
protein_de = totalvi_model.differential_expression(
groupby="cell_type",
group1="T cells",
group2="B cells",
protein_expression=True
)
print(f"DE genes: {rna_de['is_de_fdr_0.05'].sum()}")
print(f"DE proteins: {protein_de['is_de_fdr_0.05'].sum()}")
Differential Accessibility (PeakVI)
# Train PeakVI on ATAC-seq data
peakvi_model = scvi.model.PEAKVI(atac_adata)
peakvi_model.train()
# Differential accessibility
da = peakvi_model.differential_accessibility(
groupby="cell_type",
group1="T cells",
group2="B cells"
)
# Same interpretation as DE
Handling Special Cases
Low Cell Count Groups
# Increase posterior samples for stability
de = model.differential_expression(
groupby="cell_type",
group1="rare_type", # e.g., 50 cells
group2="common_type", # e.g., 5000 cells
n_samples=10000
)
Imbalanced Comparisons
# When groups have very different sizes
# Use change mode to avoid tiny effects
de = model.differential_expression(
groupby="condition",
group1="rare_condition",
group2="common_condition",
mode="change",
delta=0.5
)
Multiple Testing Correction
# Already included via FDP control
# But can apply additional corrections
from statsmodels.stats.multitest import multipletests
# Bonferroni correction (very conservative)
_, pvals_corrected, _, _ = multipletests(
1 / (de["bayes_factor"] + 1),
method="bonferroni"
)
Performance Considerations
Speed Optimization
# Faster DE testing for large datasets
de = model.differential_expression(
groupby="cell_type",
group1="T cells",
n_samples=1000, # Reduce samples
batch_size=512 # Increase batch size
)
Memory Management
# For very large datasets
# Test one comparison at a time rather than all pairwise
cell_types = adata.obs["cell_type"].unique()
for ct in cell_types:
de = model.differential_expression(
groupby="cell_type",
group1=ct
)
# Save results
de.to_csv(f"de_results_{ct}.csv")
Best Practices
- Use "change" mode: For biologically interpretable results
- Set appropriate delta: Based on biological significance
- Check expression levels: Filter lowly expressed genes
- Validate findings: Check marker genes for sanity
- Visualize results: Always plot top DE genes
- Report parameters: Document mode, delta, FDR used
- Consider batch effects: Use batch_correction=True
- Multiple comparisons: Be aware of testing many groups
- Sample size: Ensure sufficient cells per group (>50 recommended)
- Biological validation: Follow up with functional experiments
Example: Complete DE Analysis Workflow
import scvi
import scanpy as sc
import matplotlib.pyplot as plt
# 1. Train model
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
model = scvi.model.SCVI(adata)
model.train()
# 2. Perform DE analysis
de_results = model.differential_expression(
groupby="cell_type",
group1="Disease_T_cells",
group2="Healthy_T_cells",
mode="change",
delta=0.5,
fdr_target=0.05
)
# 3. Filter and analyze
sig_genes = de_results[de_results["is_de_fdr_0.05"]]
upreg = sig_genes[sig_genes["lfc_mean"] > 0].sort_values("lfc_mean", ascending=False)
downreg = sig_genes[sig_genes["lfc_mean"] < 0].sort_values("lfc_mean")
print(f"Significant genes: {len(sig_genes)}")
print(f"Upregulated: {len(upreg)}")
print(f"Downregulated: {len(downreg)}")
# 4. Visualize top genes
top_genes = upreg.head(10).index.tolist() + downreg.head(10).index.tolist()
sc.pl.violin(
adata[adata.obs["cell_type"].isin(["Disease_T_cells", "Healthy_T_cells"])],
keys=top_genes,
groupby="cell_type",
rotation=90
)
# 5. Functional enrichment (using external tools)
# E.g., g:Profiler, DAVID, or gprofiler-official Python package
upreg_genes = upreg.head(100).index.tolist()
# Perform pathway analysis...
# 6. Save results
de_results.to_csv("de_results_disease_vs_healthy.csv")
upreg.to_csv("upregulated_genes.csv")
downreg.to_csv("downregulated_genes.csv")