354 lines
12 KiB
Python
354 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
PyDESeq2 Analysis Script
|
||
|
||
This script performs a complete differential expression analysis using PyDESeq2.
|
||
It can be used as a template for standard RNA-seq DEA workflows.
|
||
|
||
Usage:
|
||
python run_deseq2_analysis.py --counts counts.csv --metadata metadata.csv \
|
||
--design "~condition" --contrast condition treated control \
|
||
--output results/
|
||
|
||
Requirements:
|
||
- pydeseq2
|
||
- pandas
|
||
- matplotlib (optional, for plots)
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import pickle
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import pandas as pd
|
||
|
||
try:
|
||
from pydeseq2.dds import DeseqDataSet
|
||
from pydeseq2.ds import DeseqStats
|
||
except ImportError:
|
||
print("Error: pydeseq2 not installed. Install with: pip install pydeseq2")
|
||
sys.exit(1)
|
||
|
||
|
||
def load_and_validate_data(counts_path, metadata_path, transpose_counts=True):
|
||
"""Load count matrix and metadata, perform basic validation."""
|
||
print(f"Loading count data from {counts_path}...")
|
||
counts_df = pd.read_csv(counts_path, index_col=0)
|
||
|
||
if transpose_counts:
|
||
print("Transposing count matrix to samples × genes format...")
|
||
counts_df = counts_df.T
|
||
|
||
print(f"Loading metadata from {metadata_path}...")
|
||
metadata = pd.read_csv(metadata_path, index_col=0)
|
||
|
||
print(f"\nData loaded:")
|
||
print(f" Counts shape: {counts_df.shape} (samples × genes)")
|
||
print(f" Metadata shape: {metadata.shape} (samples × variables)")
|
||
|
||
# Validate
|
||
if not all(counts_df.index == metadata.index):
|
||
print("\nWarning: Sample indices don't match perfectly. Taking intersection...")
|
||
common_samples = counts_df.index.intersection(metadata.index)
|
||
counts_df = counts_df.loc[common_samples]
|
||
metadata = metadata.loc[common_samples]
|
||
print(f" Using {len(common_samples)} common samples")
|
||
|
||
# Check for negative or non-integer values
|
||
if (counts_df < 0).any().any():
|
||
raise ValueError("Count matrix contains negative values")
|
||
|
||
return counts_df, metadata
|
||
|
||
|
||
def filter_data(counts_df, metadata, min_counts=10, condition_col=None):
|
||
"""Filter low-count genes and samples with missing data."""
|
||
print(f"\nFiltering data...")
|
||
|
||
initial_genes = counts_df.shape[1]
|
||
initial_samples = counts_df.shape[0]
|
||
|
||
# Filter genes
|
||
genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= min_counts]
|
||
counts_df = counts_df[genes_to_keep]
|
||
genes_removed = initial_genes - counts_df.shape[1]
|
||
print(f" Removed {genes_removed} genes with < {min_counts} total counts")
|
||
|
||
# Filter samples with missing condition data
|
||
if condition_col and condition_col in metadata.columns:
|
||
samples_to_keep = ~metadata[condition_col].isna()
|
||
counts_df = counts_df.loc[samples_to_keep]
|
||
metadata = metadata.loc[samples_to_keep]
|
||
samples_removed = initial_samples - counts_df.shape[0]
|
||
if samples_removed > 0:
|
||
print(f" Removed {samples_removed} samples with missing '{condition_col}' data")
|
||
|
||
print(f" Final data shape: {counts_df.shape[0]} samples × {counts_df.shape[1]} genes")
|
||
|
||
return counts_df, metadata
|
||
|
||
|
||
def run_deseq2(counts_df, metadata, design, n_cpus=1):
|
||
"""Run DESeq2 normalization and fitting."""
|
||
print(f"\nInitializing DeseqDataSet with design: {design}")
|
||
|
||
dds = DeseqDataSet(
|
||
counts=counts_df,
|
||
metadata=metadata,
|
||
design=design,
|
||
refit_cooks=True,
|
||
n_cpus=n_cpus,
|
||
quiet=False
|
||
)
|
||
|
||
print("\nRunning DESeq2 pipeline...")
|
||
print(" Step 1/7: Computing size factors...")
|
||
print(" Step 2/7: Fitting genewise dispersions...")
|
||
print(" Step 3/7: Fitting dispersion trend curve...")
|
||
print(" Step 4/7: Computing dispersion priors...")
|
||
print(" Step 5/7: Fitting MAP dispersions...")
|
||
print(" Step 6/7: Fitting log fold changes...")
|
||
print(" Step 7/7: Calculating Cook's distances...")
|
||
|
||
dds.deseq2()
|
||
|
||
print("\n✓ DESeq2 fitting complete")
|
||
|
||
return dds
|
||
|
||
|
||
def run_statistical_tests(dds, contrast, alpha=0.05, shrink_lfc=True):
|
||
"""Perform Wald tests and compute p-values."""
|
||
print(f"\nPerforming statistical tests...")
|
||
print(f" Contrast: {contrast}")
|
||
print(f" Significance threshold: {alpha}")
|
||
|
||
ds = DeseqStats(
|
||
dds,
|
||
contrast=contrast,
|
||
alpha=alpha,
|
||
cooks_filter=True,
|
||
independent_filter=True,
|
||
quiet=False
|
||
)
|
||
|
||
print("\n Running Wald tests...")
|
||
print(" Filtering outliers based on Cook's distance...")
|
||
print(" Applying independent filtering...")
|
||
print(" Adjusting p-values (Benjamini-Hochberg)...")
|
||
|
||
ds.summary()
|
||
|
||
print("\n✓ Statistical testing complete")
|
||
|
||
# Optional LFC shrinkage
|
||
if shrink_lfc:
|
||
print("\nApplying LFC shrinkage for visualization...")
|
||
ds.lfc_shrink()
|
||
print("✓ LFC shrinkage complete")
|
||
|
||
return ds
|
||
|
||
|
||
def save_results(ds, dds, output_dir, shrink_lfc=True):
|
||
"""Save results and intermediate objects."""
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"\nSaving results to {output_dir}/")
|
||
|
||
# Save statistical results
|
||
results_path = output_dir / "deseq2_results.csv"
|
||
ds.results_df.to_csv(results_path)
|
||
print(f" Saved: {results_path}")
|
||
|
||
# Save significant genes
|
||
significant = ds.results_df[ds.results_df.padj < 0.05]
|
||
sig_path = output_dir / "significant_genes.csv"
|
||
significant.to_csv(sig_path)
|
||
print(f" Saved: {sig_path} ({len(significant)} significant genes)")
|
||
|
||
# Save sorted results
|
||
sorted_results = ds.results_df.sort_values("padj")
|
||
sorted_path = output_dir / "results_sorted_by_padj.csv"
|
||
sorted_results.to_csv(sorted_path)
|
||
print(f" Saved: {sorted_path}")
|
||
|
||
# Save DeseqDataSet as pickle
|
||
dds_path = output_dir / "deseq_dataset.pkl"
|
||
with open(dds_path, "wb") as f:
|
||
pickle.dump(dds.to_picklable_anndata(), f)
|
||
print(f" Saved: {dds_path}")
|
||
|
||
# Print summary
|
||
print(f"\n{'='*60}")
|
||
print("ANALYSIS SUMMARY")
|
||
print(f"{'='*60}")
|
||
print(f"Total genes tested: {len(ds.results_df)}")
|
||
print(f"Significant genes (padj < 0.05): {len(significant)}")
|
||
print(f"Upregulated: {len(significant[significant.log2FoldChange > 0])}")
|
||
print(f"Downregulated: {len(significant[significant.log2FoldChange < 0])}")
|
||
print(f"{'='*60}")
|
||
|
||
# Show top genes
|
||
print("\nTop 10 most significant genes:")
|
||
print(sorted_results.head(10)[["baseMean", "log2FoldChange", "pvalue", "padj"]])
|
||
|
||
return results_path
|
||
|
||
|
||
def create_plots(ds, output_dir):
|
||
"""Create basic visualization plots."""
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
except ImportError:
|
||
print("\nNote: matplotlib not installed. Skipping plot generation.")
|
||
return
|
||
|
||
output_dir = Path(output_dir)
|
||
results = ds.results_df.copy()
|
||
|
||
print("\nGenerating plots...")
|
||
|
||
# Volcano plot
|
||
results["-log10(padj)"] = -np.log10(results.padj.fillna(1))
|
||
|
||
plt.figure(figsize=(10, 6))
|
||
significant = results.padj < 0.05
|
||
plt.scatter(
|
||
results.loc[~significant, "log2FoldChange"],
|
||
results.loc[~significant, "-log10(padj)"],
|
||
alpha=0.3, s=10, c='gray', label='Not significant'
|
||
)
|
||
plt.scatter(
|
||
results.loc[significant, "log2FoldChange"],
|
||
results.loc[significant, "-log10(padj)"],
|
||
alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)'
|
||
)
|
||
plt.axhline(-np.log10(0.05), color='blue', linestyle='--', linewidth=1, alpha=0.5)
|
||
plt.axvline(1, color='gray', linestyle='--', linewidth=1, alpha=0.5)
|
||
plt.axvline(-1, color='gray', linestyle='--', linewidth=1, alpha=0.5)
|
||
plt.xlabel("Log2 Fold Change", fontsize=12)
|
||
plt.ylabel("-Log10(Adjusted P-value)", fontsize=12)
|
||
plt.title("Volcano Plot", fontsize=14, fontweight='bold')
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
volcano_path = output_dir / "volcano_plot.png"
|
||
plt.savefig(volcano_path, dpi=300)
|
||
plt.close()
|
||
print(f" Saved: {volcano_path}")
|
||
|
||
# MA plot
|
||
plt.figure(figsize=(10, 6))
|
||
plt.scatter(
|
||
np.log10(results.loc[~significant, "baseMean"] + 1),
|
||
results.loc[~significant, "log2FoldChange"],
|
||
alpha=0.3, s=10, c='gray', label='Not significant'
|
||
)
|
||
plt.scatter(
|
||
np.log10(results.loc[significant, "baseMean"] + 1),
|
||
results.loc[significant, "log2FoldChange"],
|
||
alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)'
|
||
)
|
||
plt.axhline(0, color='blue', linestyle='--', linewidth=1, alpha=0.5)
|
||
plt.xlabel("Log10(Base Mean + 1)", fontsize=12)
|
||
plt.ylabel("Log2 Fold Change", fontsize=12)
|
||
plt.title("MA Plot", fontsize=14, fontweight='bold')
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
ma_path = output_dir / "ma_plot.png"
|
||
plt.savefig(ma_path, dpi=300)
|
||
plt.close()
|
||
print(f" Saved: {ma_path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Run PyDESeq2 differential expression analysis",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
Examples:
|
||
# Basic analysis
|
||
python run_deseq2_analysis.py \\
|
||
--counts counts.csv \\
|
||
--metadata metadata.csv \\
|
||
--design "~condition" \\
|
||
--contrast condition treated control \\
|
||
--output results/
|
||
|
||
# Multi-factor analysis
|
||
python run_deseq2_analysis.py \\
|
||
--counts counts.csv \\
|
||
--metadata metadata.csv \\
|
||
--design "~batch + condition" \\
|
||
--contrast condition treated control \\
|
||
--output results/ \\
|
||
--n-cpus 4
|
||
"""
|
||
)
|
||
|
||
parser.add_argument("--counts", required=True, help="Path to count matrix CSV file")
|
||
parser.add_argument("--metadata", required=True, help="Path to metadata CSV file")
|
||
parser.add_argument("--design", required=True, help="Design formula (e.g., '~condition')")
|
||
parser.add_argument("--contrast", nargs=3, required=True,
|
||
metavar=("VARIABLE", "TEST", "REFERENCE"),
|
||
help="Contrast specification: variable test_level reference_level")
|
||
parser.add_argument("--output", default="results", help="Output directory (default: results)")
|
||
parser.add_argument("--min-counts", type=int, default=10,
|
||
help="Minimum total counts for gene filtering (default: 10)")
|
||
parser.add_argument("--alpha", type=float, default=0.05,
|
||
help="Significance threshold (default: 0.05)")
|
||
parser.add_argument("--no-transpose", action="store_true",
|
||
help="Don't transpose count matrix (use if already samples × genes)")
|
||
parser.add_argument("--no-shrink", action="store_true",
|
||
help="Skip LFC shrinkage")
|
||
parser.add_argument("--n-cpus", type=int, default=1,
|
||
help="Number of CPUs for parallel processing (default: 1)")
|
||
parser.add_argument("--plots", action="store_true",
|
||
help="Generate volcano and MA plots")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Load data
|
||
counts_df, metadata = load_and_validate_data(
|
||
args.counts,
|
||
args.metadata,
|
||
transpose_counts=not args.no_transpose
|
||
)
|
||
|
||
# Filter data
|
||
condition_col = args.contrast[0]
|
||
counts_df, metadata = filter_data(
|
||
counts_df,
|
||
metadata,
|
||
min_counts=args.min_counts,
|
||
condition_col=condition_col
|
||
)
|
||
|
||
# Run DESeq2
|
||
dds = run_deseq2(counts_df, metadata, args.design, n_cpus=args.n_cpus)
|
||
|
||
# Statistical testing
|
||
ds = run_statistical_tests(
|
||
dds,
|
||
contrast=args.contrast,
|
||
alpha=args.alpha,
|
||
shrink_lfc=not args.no_shrink
|
||
)
|
||
|
||
# Save results
|
||
save_results(ds, dds, args.output, shrink_lfc=not args.no_shrink)
|
||
|
||
# Create plots if requested
|
||
if args.plots:
|
||
create_plots(ds, args.output)
|
||
|
||
print(f"\n✓ Analysis complete! Results saved to {args.output}/")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|