Files
2025-11-30 08:30:10 +08:00

354 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
PyDESeq2 Analysis Script
This script performs a complete differential expression analysis using PyDESeq2.
It can be used as a template for standard RNA-seq DEA workflows.
Usage:
python run_deseq2_analysis.py --counts counts.csv --metadata metadata.csv \
--design "~condition" --contrast condition treated control \
--output results/
Requirements:
- pydeseq2
- pandas
- matplotlib (optional, for plots)
"""
import argparse
import os
import pickle
import sys
from pathlib import Path
import pandas as pd
try:
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
except ImportError:
print("Error: pydeseq2 not installed. Install with: pip install pydeseq2")
sys.exit(1)
def load_and_validate_data(counts_path, metadata_path, transpose_counts=True):
"""Load count matrix and metadata, perform basic validation."""
print(f"Loading count data from {counts_path}...")
counts_df = pd.read_csv(counts_path, index_col=0)
if transpose_counts:
print("Transposing count matrix to samples × genes format...")
counts_df = counts_df.T
print(f"Loading metadata from {metadata_path}...")
metadata = pd.read_csv(metadata_path, index_col=0)
print(f"\nData loaded:")
print(f" Counts shape: {counts_df.shape} (samples × genes)")
print(f" Metadata shape: {metadata.shape} (samples × variables)")
# Validate
if not all(counts_df.index == metadata.index):
print("\nWarning: Sample indices don't match perfectly. Taking intersection...")
common_samples = counts_df.index.intersection(metadata.index)
counts_df = counts_df.loc[common_samples]
metadata = metadata.loc[common_samples]
print(f" Using {len(common_samples)} common samples")
# Check for negative or non-integer values
if (counts_df < 0).any().any():
raise ValueError("Count matrix contains negative values")
return counts_df, metadata
def filter_data(counts_df, metadata, min_counts=10, condition_col=None):
"""Filter low-count genes and samples with missing data."""
print(f"\nFiltering data...")
initial_genes = counts_df.shape[1]
initial_samples = counts_df.shape[0]
# Filter genes
genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= min_counts]
counts_df = counts_df[genes_to_keep]
genes_removed = initial_genes - counts_df.shape[1]
print(f" Removed {genes_removed} genes with < {min_counts} total counts")
# Filter samples with missing condition data
if condition_col and condition_col in metadata.columns:
samples_to_keep = ~metadata[condition_col].isna()
counts_df = counts_df.loc[samples_to_keep]
metadata = metadata.loc[samples_to_keep]
samples_removed = initial_samples - counts_df.shape[0]
if samples_removed > 0:
print(f" Removed {samples_removed} samples with missing '{condition_col}' data")
print(f" Final data shape: {counts_df.shape[0]} samples × {counts_df.shape[1]} genes")
return counts_df, metadata
def run_deseq2(counts_df, metadata, design, n_cpus=1):
"""Run DESeq2 normalization and fitting."""
print(f"\nInitializing DeseqDataSet with design: {design}")
dds = DeseqDataSet(
counts=counts_df,
metadata=metadata,
design=design,
refit_cooks=True,
n_cpus=n_cpus,
quiet=False
)
print("\nRunning DESeq2 pipeline...")
print(" Step 1/7: Computing size factors...")
print(" Step 2/7: Fitting genewise dispersions...")
print(" Step 3/7: Fitting dispersion trend curve...")
print(" Step 4/7: Computing dispersion priors...")
print(" Step 5/7: Fitting MAP dispersions...")
print(" Step 6/7: Fitting log fold changes...")
print(" Step 7/7: Calculating Cook's distances...")
dds.deseq2()
print("\n✓ DESeq2 fitting complete")
return dds
def run_statistical_tests(dds, contrast, alpha=0.05, shrink_lfc=True):
"""Perform Wald tests and compute p-values."""
print(f"\nPerforming statistical tests...")
print(f" Contrast: {contrast}")
print(f" Significance threshold: {alpha}")
ds = DeseqStats(
dds,
contrast=contrast,
alpha=alpha,
cooks_filter=True,
independent_filter=True,
quiet=False
)
print("\n Running Wald tests...")
print(" Filtering outliers based on Cook's distance...")
print(" Applying independent filtering...")
print(" Adjusting p-values (Benjamini-Hochberg)...")
ds.summary()
print("\n✓ Statistical testing complete")
# Optional LFC shrinkage
if shrink_lfc:
print("\nApplying LFC shrinkage for visualization...")
ds.lfc_shrink()
print("✓ LFC shrinkage complete")
return ds
def save_results(ds, dds, output_dir, shrink_lfc=True):
"""Save results and intermediate objects."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nSaving results to {output_dir}/")
# Save statistical results
results_path = output_dir / "deseq2_results.csv"
ds.results_df.to_csv(results_path)
print(f" Saved: {results_path}")
# Save significant genes
significant = ds.results_df[ds.results_df.padj < 0.05]
sig_path = output_dir / "significant_genes.csv"
significant.to_csv(sig_path)
print(f" Saved: {sig_path} ({len(significant)} significant genes)")
# Save sorted results
sorted_results = ds.results_df.sort_values("padj")
sorted_path = output_dir / "results_sorted_by_padj.csv"
sorted_results.to_csv(sorted_path)
print(f" Saved: {sorted_path}")
# Save DeseqDataSet as pickle
dds_path = output_dir / "deseq_dataset.pkl"
with open(dds_path, "wb") as f:
pickle.dump(dds.to_picklable_anndata(), f)
print(f" Saved: {dds_path}")
# Print summary
print(f"\n{'='*60}")
print("ANALYSIS SUMMARY")
print(f"{'='*60}")
print(f"Total genes tested: {len(ds.results_df)}")
print(f"Significant genes (padj < 0.05): {len(significant)}")
print(f"Upregulated: {len(significant[significant.log2FoldChange > 0])}")
print(f"Downregulated: {len(significant[significant.log2FoldChange < 0])}")
print(f"{'='*60}")
# Show top genes
print("\nTop 10 most significant genes:")
print(sorted_results.head(10)[["baseMean", "log2FoldChange", "pvalue", "padj"]])
return results_path
def create_plots(ds, output_dir):
"""Create basic visualization plots."""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
print("\nNote: matplotlib not installed. Skipping plot generation.")
return
output_dir = Path(output_dir)
results = ds.results_df.copy()
print("\nGenerating plots...")
# Volcano plot
results["-log10(padj)"] = -np.log10(results.padj.fillna(1))
plt.figure(figsize=(10, 6))
significant = results.padj < 0.05
plt.scatter(
results.loc[~significant, "log2FoldChange"],
results.loc[~significant, "-log10(padj)"],
alpha=0.3, s=10, c='gray', label='Not significant'
)
plt.scatter(
results.loc[significant, "log2FoldChange"],
results.loc[significant, "-log10(padj)"],
alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)'
)
plt.axhline(-np.log10(0.05), color='blue', linestyle='--', linewidth=1, alpha=0.5)
plt.axvline(1, color='gray', linestyle='--', linewidth=1, alpha=0.5)
plt.axvline(-1, color='gray', linestyle='--', linewidth=1, alpha=0.5)
plt.xlabel("Log2 Fold Change", fontsize=12)
plt.ylabel("-Log10(Adjusted P-value)", fontsize=12)
plt.title("Volcano Plot", fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
volcano_path = output_dir / "volcano_plot.png"
plt.savefig(volcano_path, dpi=300)
plt.close()
print(f" Saved: {volcano_path}")
# MA plot
plt.figure(figsize=(10, 6))
plt.scatter(
np.log10(results.loc[~significant, "baseMean"] + 1),
results.loc[~significant, "log2FoldChange"],
alpha=0.3, s=10, c='gray', label='Not significant'
)
plt.scatter(
np.log10(results.loc[significant, "baseMean"] + 1),
results.loc[significant, "log2FoldChange"],
alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)'
)
plt.axhline(0, color='blue', linestyle='--', linewidth=1, alpha=0.5)
plt.xlabel("Log10(Base Mean + 1)", fontsize=12)
plt.ylabel("Log2 Fold Change", fontsize=12)
plt.title("MA Plot", fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
ma_path = output_dir / "ma_plot.png"
plt.savefig(ma_path, dpi=300)
plt.close()
print(f" Saved: {ma_path}")
def main():
parser = argparse.ArgumentParser(
description="Run PyDESeq2 differential expression analysis",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic analysis
python run_deseq2_analysis.py \\
--counts counts.csv \\
--metadata metadata.csv \\
--design "~condition" \\
--contrast condition treated control \\
--output results/
# Multi-factor analysis
python run_deseq2_analysis.py \\
--counts counts.csv \\
--metadata metadata.csv \\
--design "~batch + condition" \\
--contrast condition treated control \\
--output results/ \\
--n-cpus 4
"""
)
parser.add_argument("--counts", required=True, help="Path to count matrix CSV file")
parser.add_argument("--metadata", required=True, help="Path to metadata CSV file")
parser.add_argument("--design", required=True, help="Design formula (e.g., '~condition')")
parser.add_argument("--contrast", nargs=3, required=True,
metavar=("VARIABLE", "TEST", "REFERENCE"),
help="Contrast specification: variable test_level reference_level")
parser.add_argument("--output", default="results", help="Output directory (default: results)")
parser.add_argument("--min-counts", type=int, default=10,
help="Minimum total counts for gene filtering (default: 10)")
parser.add_argument("--alpha", type=float, default=0.05,
help="Significance threshold (default: 0.05)")
parser.add_argument("--no-transpose", action="store_true",
help="Don't transpose count matrix (use if already samples × genes)")
parser.add_argument("--no-shrink", action="store_true",
help="Skip LFC shrinkage")
parser.add_argument("--n-cpus", type=int, default=1,
help="Number of CPUs for parallel processing (default: 1)")
parser.add_argument("--plots", action="store_true",
help="Generate volcano and MA plots")
args = parser.parse_args()
# Load data
counts_df, metadata = load_and_validate_data(
args.counts,
args.metadata,
transpose_counts=not args.no_transpose
)
# Filter data
condition_col = args.contrast[0]
counts_df, metadata = filter_data(
counts_df,
metadata,
min_counts=args.min_counts,
condition_col=condition_col
)
# Run DESeq2
dds = run_deseq2(counts_df, metadata, args.design, n_cpus=args.n_cpus)
# Statistical testing
ds = run_statistical_tests(
dds,
contrast=args.contrast,
alpha=args.alpha,
shrink_lfc=not args.no_shrink
)
# Save results
save_results(ds, dds, args.output, shrink_lfc=not args.no_shrink)
# Create plots if requested
if args.plots:
create_plots(ds, args.output)
print(f"\n✓ Analysis complete! Results saved to {args.output}/")
if __name__ == "__main__":
main()