""" Comprehensive statistical assumption checking utilities. This module provides functions to check common statistical assumptions: - Normality - Homogeneity of variance - Independence - Linearity - Outliers """ import numpy as np import pandas as pd from scipy import stats import matplotlib.pyplot as plt import seaborn as sns from typing import Dict, List, Tuple, Optional, Union def check_normality( data: Union[np.ndarray, pd.Series, List], name: str = "data", alpha: float = 0.05, plot: bool = True ) -> Dict: """ Check normality assumption using Shapiro-Wilk test and visualizations. Parameters ---------- data : array-like Data to check for normality name : str Name of the variable (for labeling) alpha : float Significance level for Shapiro-Wilk test plot : bool Whether to create Q-Q plot and histogram Returns ------- dict Results including test statistic, p-value, and interpretation """ data = np.asarray(data) data_clean = data[~np.isnan(data)] # Shapiro-Wilk test statistic, p_value = stats.shapiro(data_clean) # Interpretation is_normal = p_value > alpha interpretation = ( f"Data {'appear' if is_normal else 'do not appear'} normally distributed " f"(W = {statistic:.3f}, p = {p_value:.3f})" ) # Visual checks if plot: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Q-Q plot stats.probplot(data_clean, dist="norm", plot=ax1) ax1.set_title(f"Q-Q Plot: {name}") ax1.grid(alpha=0.3) # Histogram with normal curve ax2.hist(data_clean, bins='auto', density=True, alpha=0.7, color='steelblue', edgecolor='black') mu, sigma = data_clean.mean(), data_clean.std() x = np.linspace(data_clean.min(), data_clean.max(), 100) ax2.plot(x, stats.norm.pdf(x, mu, sigma), 'r-', linewidth=2, label='Normal curve') ax2.set_xlabel('Value') ax2.set_ylabel('Density') ax2.set_title(f'Histogram: {name}') ax2.legend() ax2.grid(alpha=0.3) plt.tight_layout() plt.show() return { 'test': 'Shapiro-Wilk', 'statistic': statistic, 'p_value': p_value, 'is_normal': is_normal, 'interpretation': interpretation, 'n': len(data_clean), 'recommendation': ( "Proceed with parametric test" if is_normal else "Consider non-parametric alternative or transformation" ) } def check_normality_per_group( data: pd.DataFrame, value_col: str, group_col: str, alpha: float = 0.05, plot: bool = True ) -> pd.DataFrame: """ Check normality assumption for each group separately. Parameters ---------- data : pd.DataFrame Data containing values and group labels value_col : str Column name for values to check group_col : str Column name for group labels alpha : float Significance level plot : bool Whether to create Q-Q plots for each group Returns ------- pd.DataFrame Results for each group """ groups = data[group_col].unique() results = [] if plot: n_groups = len(groups) fig, axes = plt.subplots(1, n_groups, figsize=(5 * n_groups, 4)) if n_groups == 1: axes = [axes] for idx, group in enumerate(groups): group_data = data[data[group_col] == group][value_col].dropna() stat, p = stats.shapiro(group_data) results.append({ 'Group': group, 'N': len(group_data), 'W': stat, 'p-value': p, 'Normal': 'Yes' if p > alpha else 'No' }) if plot: stats.probplot(group_data, dist="norm", plot=axes[idx]) axes[idx].set_title(f"Q-Q Plot: {group}") axes[idx].grid(alpha=0.3) if plot: plt.tight_layout() plt.show() return pd.DataFrame(results) def check_homogeneity_of_variance( data: pd.DataFrame, value_col: str, group_col: str, alpha: float = 0.05, plot: bool = True ) -> Dict: """ Check homogeneity of variance using Levene's test. Parameters ---------- data : pd.DataFrame Data containing values and group labels value_col : str Column name for values group_col : str Column name for group labels alpha : float Significance level plot : bool Whether to create box plots Returns ------- dict Results including test statistic, p-value, and interpretation """ groups = [group[value_col].values for name, group in data.groupby(group_col)] # Levene's test (robust to non-normality) statistic, p_value = stats.levene(*groups) # Variance ratio (max/min) variances = [np.var(g, ddof=1) for g in groups] var_ratio = max(variances) / min(variances) is_homogeneous = p_value > alpha interpretation = ( f"Variances {'appear' if is_homogeneous else 'do not appear'} homogeneous " f"(F = {statistic:.3f}, p = {p_value:.3f}, variance ratio = {var_ratio:.2f})" ) if plot: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Box plot data.boxplot(column=value_col, by=group_col, ax=ax1) ax1.set_title('Box Plots by Group') ax1.set_xlabel(group_col) ax1.set_ylabel(value_col) plt.sca(ax1) plt.xticks(rotation=45) # Variance plot group_names = data[group_col].unique() ax2.bar(range(len(variances)), variances, color='steelblue', edgecolor='black') ax2.set_xticks(range(len(variances))) ax2.set_xticklabels(group_names, rotation=45) ax2.set_ylabel('Variance') ax2.set_title('Variance by Group') ax2.grid(alpha=0.3, axis='y') plt.tight_layout() plt.show() return { 'test': 'Levene', 'statistic': statistic, 'p_value': p_value, 'is_homogeneous': is_homogeneous, 'variance_ratio': var_ratio, 'interpretation': interpretation, 'recommendation': ( "Proceed with standard test" if is_homogeneous else "Consider Welch's correction or transformation" ) } def check_linearity( x: Union[np.ndarray, pd.Series], y: Union[np.ndarray, pd.Series], x_name: str = "X", y_name: str = "Y" ) -> Dict: """ Check linearity assumption for regression. Parameters ---------- x : array-like Predictor variable y : array-like Outcome variable x_name : str Name of predictor y_name : str Name of outcome Returns ------- dict Visualization and recommendations """ x = np.asarray(x) y = np.asarray(y) # Fit linear regression slope, intercept, r_value, p_value, std_err = stats.linregress(x, y) y_pred = intercept + slope * x # Calculate residuals residuals = y - y_pred # Visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Scatter plot with regression line ax1.scatter(x, y, alpha=0.6, s=50, edgecolors='black', linewidths=0.5) ax1.plot(x, y_pred, 'r-', linewidth=2, label=f'y = {intercept:.2f} + {slope:.2f}x') ax1.set_xlabel(x_name) ax1.set_ylabel(y_name) ax1.set_title('Scatter Plot with Regression Line') ax1.legend() ax1.grid(alpha=0.3) # Residuals vs fitted ax2.scatter(y_pred, residuals, alpha=0.6, s=50, edgecolors='black', linewidths=0.5) ax2.axhline(y=0, color='r', linestyle='--', linewidth=2) ax2.set_xlabel('Fitted values') ax2.set_ylabel('Residuals') ax2.set_title('Residuals vs Fitted Values') ax2.grid(alpha=0.3) plt.tight_layout() plt.show() return { 'r': r_value, 'r_squared': r_value ** 2, 'interpretation': ( "Examine residual plot. Points should be randomly scattered around zero. " "Patterns (curves, funnels) suggest non-linearity or heteroscedasticity." ), 'recommendation': ( "If non-linear pattern detected: Consider polynomial terms, " "transformations, or non-linear models" ) } def detect_outliers( data: Union[np.ndarray, pd.Series, List], name: str = "data", method: str = "iqr", threshold: float = 1.5, plot: bool = True ) -> Dict: """ Detect outliers using IQR method or z-score method. Parameters ---------- data : array-like Data to check for outliers name : str Name of variable method : str Method to use: 'iqr' or 'zscore' threshold : float Threshold for outlier detection For IQR: typically 1.5 (mild) or 3 (extreme) For z-score: typically 3 plot : bool Whether to create visualizations Returns ------- dict Outlier indices, values, and visualizations """ data = np.asarray(data) data_clean = data[~np.isnan(data)] if method == "iqr": q1 = np.percentile(data_clean, 25) q3 = np.percentile(data_clean, 75) iqr = q3 - q1 lower_bound = q1 - threshold * iqr upper_bound = q3 + threshold * iqr outlier_mask = (data_clean < lower_bound) | (data_clean > upper_bound) elif method == "zscore": z_scores = np.abs(stats.zscore(data_clean)) outlier_mask = z_scores > threshold lower_bound = data_clean.mean() - threshold * data_clean.std() upper_bound = data_clean.mean() + threshold * data_clean.std() else: raise ValueError("method must be 'iqr' or 'zscore'") outlier_indices = np.where(outlier_mask)[0] outlier_values = data_clean[outlier_mask] n_outliers = len(outlier_indices) pct_outliers = (n_outliers / len(data_clean)) * 100 if plot: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Box plot bp = ax1.boxplot(data_clean, vert=True, patch_artist=True) bp['boxes'][0].set_facecolor('steelblue') ax1.set_ylabel('Value') ax1.set_title(f'Box Plot: {name}') ax1.grid(alpha=0.3, axis='y') # Scatter plot highlighting outliers x_coords = np.arange(len(data_clean)) ax2.scatter(x_coords[~outlier_mask], data_clean[~outlier_mask], alpha=0.6, s=50, color='steelblue', label='Normal', edgecolors='black', linewidths=0.5) if n_outliers > 0: ax2.scatter(x_coords[outlier_mask], data_clean[outlier_mask], alpha=0.8, s=100, color='red', label='Outliers', marker='D', edgecolors='black', linewidths=0.5) ax2.axhline(y=lower_bound, color='orange', linestyle='--', linewidth=1.5, label='Bounds') ax2.axhline(y=upper_bound, color='orange', linestyle='--', linewidth=1.5) ax2.set_xlabel('Index') ax2.set_ylabel('Value') ax2.set_title(f'Outlier Detection: {name}') ax2.legend() ax2.grid(alpha=0.3) plt.tight_layout() plt.show() return { 'method': method, 'threshold': threshold, 'n_outliers': n_outliers, 'pct_outliers': pct_outliers, 'outlier_indices': outlier_indices, 'outlier_values': outlier_values, 'lower_bound': lower_bound, 'upper_bound': upper_bound, 'interpretation': f"Found {n_outliers} outliers ({pct_outliers:.1f}% of data)", 'recommendation': ( "Investigate outliers for data entry errors. " "Consider: (1) removing if errors, (2) winsorizing, " "(3) keeping if legitimate, (4) using robust methods" ) } def comprehensive_assumption_check( data: pd.DataFrame, value_col: str, group_col: Optional[str] = None, alpha: float = 0.05 ) -> Dict: """ Perform comprehensive assumption checking for common statistical tests. Parameters ---------- data : pd.DataFrame Data to check value_col : str Column name for dependent variable group_col : str, optional Column name for grouping variable (if applicable) alpha : float Significance level Returns ------- dict Summary of all assumption checks """ print("=" * 70) print("COMPREHENSIVE ASSUMPTION CHECK") print("=" * 70) results = {} # Outlier detection print("\n1. OUTLIER DETECTION") print("-" * 70) outlier_results = detect_outliers( data[value_col].dropna(), name=value_col, method='iqr', plot=True ) results['outliers'] = outlier_results print(f" {outlier_results['interpretation']}") print(f" {outlier_results['recommendation']}") # Check if grouped data if group_col is not None: # Normality per group print(f"\n2. NORMALITY CHECK (by {group_col})") print("-" * 70) normality_results = check_normality_per_group( data, value_col, group_col, alpha=alpha, plot=True ) results['normality_per_group'] = normality_results print(normality_results.to_string(index=False)) all_normal = normality_results['Normal'].eq('Yes').all() print(f"\n All groups normal: {'Yes' if all_normal else 'No'}") if not all_normal: print(" → Consider non-parametric alternative (Mann-Whitney, Kruskal-Wallis)") # Homogeneity of variance print(f"\n3. HOMOGENEITY OF VARIANCE") print("-" * 70) homogeneity_results = check_homogeneity_of_variance( data, value_col, group_col, alpha=alpha, plot=True ) results['homogeneity'] = homogeneity_results print(f" {homogeneity_results['interpretation']}") print(f" {homogeneity_results['recommendation']}") else: # Overall normality print(f"\n2. NORMALITY CHECK") print("-" * 70) normality_results = check_normality( data[value_col].dropna(), name=value_col, alpha=alpha, plot=True ) results['normality'] = normality_results print(f" {normality_results['interpretation']}") print(f" {normality_results['recommendation']}") # Summary print("\n" + "=" * 70) print("SUMMARY") print("=" * 70) if group_col is not None: all_normal = results.get('normality_per_group', pd.DataFrame()).get('Normal', pd.Series()).eq('Yes').all() is_homogeneous = results.get('homogeneity', {}).get('is_homogeneous', False) if all_normal and is_homogeneous: print("✓ All assumptions met. Proceed with parametric test (t-test, ANOVA).") elif not all_normal: print("✗ Normality violated. Use non-parametric alternative.") elif not is_homogeneous: print("✗ Homogeneity violated. Use Welch's correction or transformation.") else: is_normal = results.get('normality', {}).get('is_normal', False) if is_normal: print("✓ Normality assumption met.") else: print("✗ Normality violated. Consider transformation or non-parametric method.") print("=" * 70) return results if __name__ == "__main__": # Example usage np.random.seed(42) # Simulate data group_a = np.random.normal(75, 8, 50) group_b = np.random.normal(68, 10, 50) df = pd.DataFrame({ 'score': np.concatenate([group_a, group_b]), 'group': ['A'] * 50 + ['B'] * 50 }) # Run comprehensive check results = comprehensive_assumption_check( df, value_col='score', group_col='group', alpha=0.05 )