#!/usr/bin/env python3 """ PyDESeq2 Analysis Script This script performs a complete differential expression analysis using PyDESeq2. It can be used as a template for standard RNA-seq DEA workflows. Usage: python run_deseq2_analysis.py --counts counts.csv --metadata metadata.csv \ --design "~condition" --contrast condition treated control \ --output results/ Requirements: - pydeseq2 - pandas - matplotlib (optional, for plots) """ import argparse import os import pickle import sys from pathlib import Path import pandas as pd try: from pydeseq2.dds import DeseqDataSet from pydeseq2.ds import DeseqStats except ImportError: print("Error: pydeseq2 not installed. Install with: pip install pydeseq2") sys.exit(1) def load_and_validate_data(counts_path, metadata_path, transpose_counts=True): """Load count matrix and metadata, perform basic validation.""" print(f"Loading count data from {counts_path}...") counts_df = pd.read_csv(counts_path, index_col=0) if transpose_counts: print("Transposing count matrix to samples × genes format...") counts_df = counts_df.T print(f"Loading metadata from {metadata_path}...") metadata = pd.read_csv(metadata_path, index_col=0) print(f"\nData loaded:") print(f" Counts shape: {counts_df.shape} (samples × genes)") print(f" Metadata shape: {metadata.shape} (samples × variables)") # Validate if not all(counts_df.index == metadata.index): print("\nWarning: Sample indices don't match perfectly. Taking intersection...") common_samples = counts_df.index.intersection(metadata.index) counts_df = counts_df.loc[common_samples] metadata = metadata.loc[common_samples] print(f" Using {len(common_samples)} common samples") # Check for negative or non-integer values if (counts_df < 0).any().any(): raise ValueError("Count matrix contains negative values") return counts_df, metadata def filter_data(counts_df, metadata, min_counts=10, condition_col=None): """Filter low-count genes and samples with missing data.""" print(f"\nFiltering data...") initial_genes = counts_df.shape[1] initial_samples = counts_df.shape[0] # Filter genes genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= min_counts] counts_df = counts_df[genes_to_keep] genes_removed = initial_genes - counts_df.shape[1] print(f" Removed {genes_removed} genes with < {min_counts} total counts") # Filter samples with missing condition data if condition_col and condition_col in metadata.columns: samples_to_keep = ~metadata[condition_col].isna() counts_df = counts_df.loc[samples_to_keep] metadata = metadata.loc[samples_to_keep] samples_removed = initial_samples - counts_df.shape[0] if samples_removed > 0: print(f" Removed {samples_removed} samples with missing '{condition_col}' data") print(f" Final data shape: {counts_df.shape[0]} samples × {counts_df.shape[1]} genes") return counts_df, metadata def run_deseq2(counts_df, metadata, design, n_cpus=1): """Run DESeq2 normalization and fitting.""" print(f"\nInitializing DeseqDataSet with design: {design}") dds = DeseqDataSet( counts=counts_df, metadata=metadata, design=design, refit_cooks=True, n_cpus=n_cpus, quiet=False ) print("\nRunning DESeq2 pipeline...") print(" Step 1/7: Computing size factors...") print(" Step 2/7: Fitting genewise dispersions...") print(" Step 3/7: Fitting dispersion trend curve...") print(" Step 4/7: Computing dispersion priors...") print(" Step 5/7: Fitting MAP dispersions...") print(" Step 6/7: Fitting log fold changes...") print(" Step 7/7: Calculating Cook's distances...") dds.deseq2() print("\n✓ DESeq2 fitting complete") return dds def run_statistical_tests(dds, contrast, alpha=0.05, shrink_lfc=True): """Perform Wald tests and compute p-values.""" print(f"\nPerforming statistical tests...") print(f" Contrast: {contrast}") print(f" Significance threshold: {alpha}") ds = DeseqStats( dds, contrast=contrast, alpha=alpha, cooks_filter=True, independent_filter=True, quiet=False ) print("\n Running Wald tests...") print(" Filtering outliers based on Cook's distance...") print(" Applying independent filtering...") print(" Adjusting p-values (Benjamini-Hochberg)...") ds.summary() print("\n✓ Statistical testing complete") # Optional LFC shrinkage if shrink_lfc: print("\nApplying LFC shrinkage for visualization...") ds.lfc_shrink() print("✓ LFC shrinkage complete") return ds def save_results(ds, dds, output_dir, shrink_lfc=True): """Save results and intermediate objects.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"\nSaving results to {output_dir}/") # Save statistical results results_path = output_dir / "deseq2_results.csv" ds.results_df.to_csv(results_path) print(f" Saved: {results_path}") # Save significant genes significant = ds.results_df[ds.results_df.padj < 0.05] sig_path = output_dir / "significant_genes.csv" significant.to_csv(sig_path) print(f" Saved: {sig_path} ({len(significant)} significant genes)") # Save sorted results sorted_results = ds.results_df.sort_values("padj") sorted_path = output_dir / "results_sorted_by_padj.csv" sorted_results.to_csv(sorted_path) print(f" Saved: {sorted_path}") # Save DeseqDataSet as pickle dds_path = output_dir / "deseq_dataset.pkl" with open(dds_path, "wb") as f: pickle.dump(dds.to_picklable_anndata(), f) print(f" Saved: {dds_path}") # Print summary print(f"\n{'='*60}") print("ANALYSIS SUMMARY") print(f"{'='*60}") print(f"Total genes tested: {len(ds.results_df)}") print(f"Significant genes (padj < 0.05): {len(significant)}") print(f"Upregulated: {len(significant[significant.log2FoldChange > 0])}") print(f"Downregulated: {len(significant[significant.log2FoldChange < 0])}") print(f"{'='*60}") # Show top genes print("\nTop 10 most significant genes:") print(sorted_results.head(10)[["baseMean", "log2FoldChange", "pvalue", "padj"]]) return results_path def create_plots(ds, output_dir): """Create basic visualization plots.""" try: import matplotlib.pyplot as plt import numpy as np except ImportError: print("\nNote: matplotlib not installed. Skipping plot generation.") return output_dir = Path(output_dir) results = ds.results_df.copy() print("\nGenerating plots...") # Volcano plot results["-log10(padj)"] = -np.log10(results.padj.fillna(1)) plt.figure(figsize=(10, 6)) significant = results.padj < 0.05 plt.scatter( results.loc[~significant, "log2FoldChange"], results.loc[~significant, "-log10(padj)"], alpha=0.3, s=10, c='gray', label='Not significant' ) plt.scatter( results.loc[significant, "log2FoldChange"], results.loc[significant, "-log10(padj)"], alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)' ) plt.axhline(-np.log10(0.05), color='blue', linestyle='--', linewidth=1, alpha=0.5) plt.axvline(1, color='gray', linestyle='--', linewidth=1, alpha=0.5) plt.axvline(-1, color='gray', linestyle='--', linewidth=1, alpha=0.5) plt.xlabel("Log2 Fold Change", fontsize=12) plt.ylabel("-Log10(Adjusted P-value)", fontsize=12) plt.title("Volcano Plot", fontsize=14, fontweight='bold') plt.legend() plt.tight_layout() volcano_path = output_dir / "volcano_plot.png" plt.savefig(volcano_path, dpi=300) plt.close() print(f" Saved: {volcano_path}") # MA plot plt.figure(figsize=(10, 6)) plt.scatter( np.log10(results.loc[~significant, "baseMean"] + 1), results.loc[~significant, "log2FoldChange"], alpha=0.3, s=10, c='gray', label='Not significant' ) plt.scatter( np.log10(results.loc[significant, "baseMean"] + 1), results.loc[significant, "log2FoldChange"], alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)' ) plt.axhline(0, color='blue', linestyle='--', linewidth=1, alpha=0.5) plt.xlabel("Log10(Base Mean + 1)", fontsize=12) plt.ylabel("Log2 Fold Change", fontsize=12) plt.title("MA Plot", fontsize=14, fontweight='bold') plt.legend() plt.tight_layout() ma_path = output_dir / "ma_plot.png" plt.savefig(ma_path, dpi=300) plt.close() print(f" Saved: {ma_path}") def main(): parser = argparse.ArgumentParser( description="Run PyDESeq2 differential expression analysis", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Basic analysis python run_deseq2_analysis.py \\ --counts counts.csv \\ --metadata metadata.csv \\ --design "~condition" \\ --contrast condition treated control \\ --output results/ # Multi-factor analysis python run_deseq2_analysis.py \\ --counts counts.csv \\ --metadata metadata.csv \\ --design "~batch + condition" \\ --contrast condition treated control \\ --output results/ \\ --n-cpus 4 """ ) parser.add_argument("--counts", required=True, help="Path to count matrix CSV file") parser.add_argument("--metadata", required=True, help="Path to metadata CSV file") parser.add_argument("--design", required=True, help="Design formula (e.g., '~condition')") parser.add_argument("--contrast", nargs=3, required=True, metavar=("VARIABLE", "TEST", "REFERENCE"), help="Contrast specification: variable test_level reference_level") parser.add_argument("--output", default="results", help="Output directory (default: results)") parser.add_argument("--min-counts", type=int, default=10, help="Minimum total counts for gene filtering (default: 10)") parser.add_argument("--alpha", type=float, default=0.05, help="Significance threshold (default: 0.05)") parser.add_argument("--no-transpose", action="store_true", help="Don't transpose count matrix (use if already samples × genes)") parser.add_argument("--no-shrink", action="store_true", help="Skip LFC shrinkage") parser.add_argument("--n-cpus", type=int, default=1, help="Number of CPUs for parallel processing (default: 1)") parser.add_argument("--plots", action="store_true", help="Generate volcano and MA plots") args = parser.parse_args() # Load data counts_df, metadata = load_and_validate_data( args.counts, args.metadata, transpose_counts=not args.no_transpose ) # Filter data condition_col = args.contrast[0] counts_df, metadata = filter_data( counts_df, metadata, min_counts=args.min_counts, condition_col=condition_col ) # Run DESeq2 dds = run_deseq2(counts_df, metadata, args.design, n_cpus=args.n_cpus) # Statistical testing ds = run_statistical_tests( dds, contrast=args.contrast, alpha=args.alpha, shrink_lfc=not args.no_shrink ) # Save results save_results(ds, dds, args.output, shrink_lfc=not args.no_shrink) # Create plots if requested if args.plots: create_plots(ds, args.output) print(f"\n✓ Analysis complete! Results saved to {args.output}/") if __name__ == "__main__": main()