Files
gh-aojdevstudio-dev-utils-m…/hooks/scripts/api-standards-checker.py
2025-11-29 17:57:16 +08:00

611 lines
21 KiB
Python
Executable File

#!/usr/bin/env python3
# /// script
# requires-python = ">=3.8"
# dependencies = []
# ///
"""
API Standards Checker - UV Script Version
Validates API routes follow project conventions
Usage:
uv run api-standards-checker.py <file_path>
uv run api-standards-checker.py --check-dir <directory>
uv run api-standards-checker.py --hook-mode # For Claude Code hook compatibility
"""
import argparse
import json
import os
import re
import sys
import urllib.parse
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
@dataclass
class Violation:
rule: str
message: str
severity: str
file_path: str | None = None
line_number: int | None = None
class ApiStandardsChecker:
def __init__(self, file_path: str | None = None):
self.file_path = file_path
self.violations: list[Violation] = []
self.suggestions: list[Violation] = []
def validate_file(self, file_path: str, content: str) -> list[Violation]:
"""Validate a single file's content"""
self.file_path = file_path
self.violations = []
self.suggestions = []
# Only validate API route files
if not self.is_api_route(file_path):
return []
# Perform validations
self.validate_file_name(file_path)
self.validate_http_methods(content)
self.validate_response_format(content)
self.validate_error_handling(content)
self.validate_authentication(content, file_path)
self.validate_input_validation(content)
self.validate_multi_tenancy(content)
return self.violations + self.suggestions
def is_api_route(self, file_path: str) -> bool:
"""Check if file is an API route"""
return "/app/api/" in file_path and ".test." not in file_path
def validate_file_name(self, file_path: str):
"""Validate file naming convention"""
file_name = os.path.basename(file_path)
if file_name not in ["route.ts", "route.js"]:
self.violations.append(
Violation(
rule="File Naming",
message=f"API route files must be named 'route.ts', found: {file_name}",
severity="error",
file_path=file_path,
)
)
def validate_http_methods(self, content: str):
"""Validate HTTP method exports"""
valid_methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
exported_methods = []
# Find exported HTTP methods
for method in valid_methods:
patterns = [
rf"export\s+const\s+{method}\s*=",
rf"export\s+async\s+function\s+{method}",
rf"export\s+function\s+{method}",
]
if any(re.search(pattern, content) for pattern in patterns):
exported_methods.append(method)
if not exported_methods:
self.violations.append(
Violation(
rule="HTTP Methods",
message="API routes should export named HTTP method handlers (GET, POST, etc.)",
severity="error",
file_path=self.file_path,
)
)
# Check for consistent async usage
for method in exported_methods:
async_pattern = rf"export\s+const\s+{method}\s*=\s*async"
function_pattern = rf"export\s+async\s+function\s+{method}"
if not re.search(async_pattern, content) and not re.search(
function_pattern, content
):
self.suggestions.append(
Violation(
rule="Async Handlers",
message=f"Consider making {method} handler async for consistency",
severity="info",
file_path=self.file_path,
)
)
def validate_response_format(self, content: str):
"""Validate response format consistency"""
has_api_utils = any(
util in content for util in ["apiSuccess", "apiError", "apiPaginated"]
)
has_next_response = "NextResponse.json" in content
has_response_json = "Response.json" in content
if not has_api_utils and (has_next_response or has_response_json):
self.suggestions.append(
Violation(
rule="Response Format",
message="Consider using standardized API utilities (apiSuccess, apiError, apiPaginated) for consistent responses",
severity="warning",
file_path=self.file_path,
)
)
# Check for consistent status codes
status_matches = re.findall(r"status[:(]\s*(\d{3})", content)
valid_codes = ["200", "201", "204", "400", "401", "403", "404", "500"]
for code in status_matches:
if code not in valid_codes:
self.suggestions.append(
Violation(
rule="Status Codes",
message=f"Unusual status code {code} - ensure it's appropriate",
severity="info",
file_path=self.file_path,
)
)
def validate_error_handling(self, content: str):
"""Validate error handling patterns"""
has_try_catch = "try" in content and "catch" in content
has_error_handler = "handleApiError" in content
if not has_try_catch and not has_error_handler:
self.violations.append(
Violation(
rule="Error Handling",
message="API routes should include proper error handling (try-catch or handleApiError)",
severity="warning",
file_path=self.file_path,
)
)
# Check for proper error responses
if has_try_catch:
catch_blocks = re.findall(r"catch\s*\([^)]*\)\s*{[^}]*}", content)
for block in catch_blocks:
if not any(
term in block for term in ["apiError", "status", "Response"]
):
self.violations.append(
Violation(
rule="Error Response",
message="Catch blocks should return proper error responses",
severity="warning",
file_path=self.file_path,
)
)
def validate_authentication(self, content: str, file_path: str):
"""Validate authentication usage"""
has_with_auth = "withAuth" in content
is_public_route = "/public/" in file_path or "/webhook/" in file_path
if not has_with_auth and not is_public_route:
self.suggestions.append(
Violation(
rule="Authentication",
message="Consider using withAuth middleware for protected routes",
severity="warning",
file_path=file_path,
)
)
# Check for role-based access control
if has_with_auth and not any(
term in content for term in ["permissions", "role"]
):
self.suggestions.append(
Violation(
rule="Authorization",
message="Consider implementing role-based access control",
severity="info",
file_path=file_path,
)
)
def validate_input_validation(self, content: str):
"""Validate input validation"""
has_zod = "z." in content or "zod" in content
has_request_json = "request.json()" in content
has_form_data = "formData()" in content
if (has_request_json or has_form_data) and not has_zod:
self.suggestions.append(
Violation(
rule="Input Validation",
message="Consider using Zod schemas for request validation",
severity="warning",
file_path=self.file_path,
)
)
# Check for SQL injection prevention
if "prisma" in content and "$queryRaw" in content:
if "Prisma.sql" not in content:
self.suggestions.append(
Violation(
rule="SQL Safety",
message="Ensure raw queries are parameterized to prevent SQL injection",
severity="warning",
file_path=self.file_path,
)
)
def validate_multi_tenancy(self, content: str):
"""Validate multi-tenancy patterns"""
has_prisma_query = "prisma." in content
has_clinic_filter = "clinic_id" in content or "clinicId" in content
if has_prisma_query and not has_clinic_filter:
# Check if it's a query that should be filtered by clinic
data_models = ["provider", "patient", "appointment", "transaction"]
has_data_model = any(f"prisma.{model}" in content for model in data_models)
if has_data_model:
self.suggestions.append(
Violation(
rule="Multi-tenancy",
message="Ensure data queries are filtered by clinic_id for multi-tenant isolation",
severity="warning",
file_path=self.file_path,
)
)
def check_file(file_path: str) -> list[Violation]:
"""Check a single file"""
try:
with open(file_path, encoding="utf-8") as f:
content = f.read()
checker = ApiStandardsChecker()
return checker.validate_file(file_path, content)
except Exception as e:
return [
Violation(
rule="File Error",
message=f"Error reading file: {str(e)}",
severity="error",
file_path=file_path,
)
]
def check_directory(directory: str) -> list[Violation]:
"""Check all API route files in a directory"""
all_violations = []
for root, dirs, files in os.walk(directory):
for file in files:
if file in ["route.ts", "route.js"]:
file_path = os.path.join(root, file)
violations = check_file(file_path)
all_violations.extend(violations)
return all_violations
def is_safe_path(file_path: str, base_dir: str | None = None) -> bool:
"""
Robust path traversal security check that handles various attack vectors.
This function provides comprehensive protection against path traversal attacks by:
1. Checking for encoded dangerous patterns before URL decoding
2. Detecting Unicode normalization attacks (e.g., %c0%af for /)
3. Handling multiple layers of URL encoding
4. Normalizing paths to resolve relative segments
5. Ensuring normalized paths stay within allowed boundaries
6. Cross-platform compatibility (Windows/Unix)
Attack vectors detected:
- Standard traversal: ../../../etc/passwd
- URL encoded: ..%2f..%2f..%2fetc%2fpasswd
- Double encoded: %2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd
- Unicode overlong: ..%c0%af..%c0%af
- Null byte injection: ../../../etc/passwd%00.txt
- Mixed patterns: %2e%2e/../../etc/passwd
- Windows traversal: ..\\..\\..\\windows\\system32
Args:
file_path: The file path to validate
base_dir: Optional base directory to restrict access to (defaults to current working directory)
Returns:
True if the path is safe, False if it's potentially malicious
"""
if not file_path:
return True
try:
# Set base directory (default to current working directory)
if base_dir is None:
base_dir = os.getcwd()
base_dir = os.path.abspath(base_dir)
# First check for dangerous patterns in the original (encoded) path
original_lower = file_path.lower()
# Check for encoded dangerous patterns before decoding
encoded_dangerous_patterns = [
"%2e%2e%2f", # Double URL encoded ../
"%2e%2e%5c", # Double URL encoded ..\
"%2e%2e/", # Mixed encoded ../
"%2e%2e\\", # Mixed encoded ..\
"..%2f", # URL encoded forward slash
"..%2F", # URL encoded forward slash (uppercase)
"..%5c", # URL encoded backslash
"..%5C", # URL encoded backslash (uppercase)
"%00", # Null byte injection
"%c0%af", # Unicode overlong encoding for /
"%c1%9c", # Unicode overlong encoding for \
"%c0%ae", # Unicode overlong encoding for .
]
for pattern in encoded_dangerous_patterns:
if pattern in original_lower:
return False
# Check for Unicode normalization attacks with regex
import re
unicode_patterns = [
r"%c[0-1]%[a-f0-9][a-f0-9]", # Unicode overlong encoding patterns like %c0%af
]
for pattern in unicode_patterns:
if re.search(pattern, original_lower):
return False
# URL decode the path to handle encoded characters like %2e%2e%2f (../)
decoded_path = urllib.parse.unquote(file_path)
# Handle multiple URL encoding layers
prev_decoded = decoded_path
for _ in range(3): # Limit iterations to prevent infinite loops
decoded_path = urllib.parse.unquote(decoded_path)
if decoded_path == prev_decoded:
break
prev_decoded = decoded_path
# Check for obvious traversal patterns in the decoded path
dangerous_patterns = [
"../", # Standard traversal
"..\\", # Windows traversal
"....///", # Multiple dots and slashes
"....\\\\\\", # Multiple dots and backslashes
]
decoded_lower = decoded_path.lower()
for pattern in dangerous_patterns:
if pattern in decoded_lower:
return False
# Normalize the path to resolve any relative segments
if os.path.isabs(decoded_path):
# Absolute path - check if it's within allowed boundaries
normalized_path = os.path.abspath(decoded_path)
else:
# Relative path - resolve against base directory
normalized_path = os.path.abspath(os.path.join(base_dir, decoded_path))
# Ensure the normalized path is within the base directory
common_path = os.path.commonpath([normalized_path, base_dir])
if common_path != base_dir:
return False
# Additional check for Windows drive letter changes
if os.name == "nt":
base_drive = os.path.splitdrive(base_dir)[0].lower()
normalized_drive = os.path.splitdrive(normalized_path)[0].lower()
if base_drive != normalized_drive:
return False
return True
except (ValueError, OSError):
# If path operations fail, consider it unsafe
return False
def hook_mode() -> dict:
"""Claude Code hook compatibility mode"""
try:
input_data = json.loads(sys.stdin.read())
# Ensure log directory exists
log_dir = Path.cwd() / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_path = log_dir / "api_standards_checker.json"
# Read existing log data or initialize empty list
if log_path.exists():
with open(log_path) as f:
try:
log_data = json.load(f)
except (json.JSONDecodeError, ValueError):
log_data = []
else:
log_data = []
# Add timestamp to the log entry
timestamp = datetime.now().strftime("%b %d, %I:%M%p").lower()
input_data["timestamp"] = timestamp
tool_input = input_data.get("tool_input", {})
output = input_data.get("output", {})
# Get file path and content
file_path = tool_input.get("file_path", "")
content = output.get("content") or tool_input.get("content", "")
# Enhanced security check for path traversal
if file_path and not is_safe_path(file_path):
result = {
"approve": False,
"message": "🚨 Security Alert: Potentially unsafe file path detected. Path traversal attempt blocked.",
}
input_data["validation_result"] = result
log_data.append(input_data)
# Write back to file with formatting
with open(log_path, "w") as f:
json.dump(log_data, f, indent=2)
return result
if not content:
result = {"approve": True, "message": "✅ API standards check passed"}
input_data["validation_result"] = result
log_data.append(input_data)
# Write back to file with formatting
with open(log_path, "w") as f:
json.dump(log_data, f, indent=2)
return result
# Validate
checker = ApiStandardsChecker()
violations = checker.validate_file(file_path, content)
# Add violations to log entry
input_data["violations_found"] = [asdict(v) for v in violations]
if violations:
message_lines = ["⚠️ API Standards Review:"]
for v in violations:
message_lines.append(f" - {v.rule}: {v.message}")
message_lines.append("")
message_lines.append(
"Consider addressing these issues to maintain API consistency."
)
result = {"approve": True, "message": "\n".join(message_lines)}
else:
result = {"approve": True, "message": "✅ API standards check passed"}
# Add result to log entry
input_data["validation_result"] = result
# Append new data to log
log_data.append(input_data)
# Write back to file with formatting
with open(log_path, "w") as f:
json.dump(log_data, f, indent=2)
return result
except Exception as e:
# Log the error as well
try:
log_dir = Path.cwd() / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_path = log_dir / "api_standards_checker.json"
if log_path.exists():
with open(log_path) as f:
try:
log_data = json.load(f)
except (json.JSONDecodeError, ValueError):
log_data = []
else:
log_data = []
timestamp = datetime.now().strftime("%b %d, %I:%M%p").lower()
error_entry = {
"timestamp": timestamp,
"error": str(e),
"validation_result": {
"approve": True,
"message": f"API checker error: {str(e)}",
},
}
log_data.append(error_entry)
with open(log_path, "w") as f:
json.dump(log_data, f, indent=2)
except Exception:
# If logging fails, continue with original error handling
pass
return {"approve": True, "message": f"API checker error: {str(e)}"}
def format_violations(violations: list[Violation]) -> str:
"""Format violations for display"""
if not violations:
return "✅ No API standards violations found!"
lines = ["📋 API Standards Check Results:", ""]
# Group by severity
errors = [v for v in violations if v.severity == "error"]
warnings = [v for v in violations if v.severity == "warning"]
info = [v for v in violations if v.severity == "info"]
for severity, items, emoji in [
("ERRORS", errors, ""),
("WARNINGS", warnings, "⚠️"),
("INFO", info, "💡"),
]:
if items:
lines.append(f"{emoji} {severity}:")
for item in items:
location = f" ({item.file_path})" if item.file_path else ""
lines.append(f" - {item.rule}: {item.message}{location}")
lines.append("")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="API Standards Checker")
parser.add_argument("file_path", nargs="?", help="File to check")
parser.add_argument("--check-dir", help="Directory to check recursively")
parser.add_argument(
"--hook-mode", action="store_true", help="Claude Code hook compatibility mode"
)
parser.add_argument("--json", action="store_true", help="Output as JSON")
args = parser.parse_args()
if args.hook_mode:
result = hook_mode()
print(json.dumps(result))
return
violations = []
if args.check_dir:
violations = check_directory(args.check_dir)
elif args.file_path:
violations = check_file(args.file_path)
else:
parser.print_help()
return
if args.json:
print(json.dumps([asdict(v) for v in violations], indent=2))
else:
print(format_violations(violations))
# Exit with error code if violations found
if any(v.severity == "error" for v in violations):
sys.exit(1)
if __name__ == "__main__":
main()