Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:48:52 +08:00
commit 6ec3196ecc
434 changed files with 125248 additions and 0 deletions

View File

@@ -0,0 +1,521 @@
#!/usr/bin/env python3
"""
Better Auth Initialization Script
Interactive script to initialize Better Auth configuration.
Supports multiple databases, ORMs, and authentication methods.
.env loading order: process.env > skill/.env > skills/.env > .claude/.env
"""
import os
import sys
import json
import secrets
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
@dataclass
class EnvConfig:
"""Environment configuration holder."""
secret: str
url: str
database_url: Optional[str] = None
github_client_id: Optional[str] = None
github_client_secret: Optional[str] = None
google_client_id: Optional[str] = None
google_client_secret: Optional[str] = None
class BetterAuthInit:
"""Better Auth configuration initializer."""
def __init__(self, project_root: Optional[Path] = None):
"""
Initialize the Better Auth configuration tool.
Args:
project_root: Project root directory. Auto-detected if not provided.
"""
self.project_root = project_root or self._find_project_root()
self.env_config: Optional[EnvConfig] = None
@staticmethod
def _find_project_root() -> Path:
"""
Find project root by looking for package.json.
Returns:
Path to project root.
Raises:
RuntimeError: If project root cannot be found.
"""
current = Path.cwd()
while current != current.parent:
if (current / "package.json").exists():
return current
current = current.parent
raise RuntimeError("Could not find project root (no package.json found)")
def _load_env_files(self) -> Dict[str, str]:
"""
Load environment variables from .env files in order.
Loading order: process.env > skill/.env > skills/.env > .claude/.env
Returns:
Dictionary of environment variables.
"""
env_vars = {}
# Define search paths in reverse priority order
skill_dir = Path(__file__).parent.parent
env_paths = [
self.project_root / ".claude" / ".env",
self.project_root / ".claude" / "skills" / ".env",
skill_dir / ".env",
]
# Load from files (lowest priority first)
for env_path in env_paths:
if env_path.exists():
env_vars.update(self._parse_env_file(env_path))
# Override with process environment (highest priority)
env_vars.update(os.environ)
return env_vars
@staticmethod
def _parse_env_file(path: Path) -> Dict[str, str]:
"""
Parse .env file into dictionary.
Args:
path: Path to .env file.
Returns:
Dictionary of key-value pairs.
"""
env_vars = {}
try:
with open(path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
# Remove quotes if present
value = value.strip().strip('"').strip("'")
env_vars[key.strip()] = value
except Exception as e:
print(f"Warning: Could not parse {path}: {e}")
return env_vars
@staticmethod
def generate_secret(length: int = 32) -> str:
"""
Generate cryptographically secure random secret.
Args:
length: Length of secret in bytes.
Returns:
Hex-encoded secret string.
"""
return secrets.token_hex(length)
def prompt_database(self) -> Dict[str, Any]:
"""
Prompt user for database configuration.
Returns:
Database configuration dictionary.
"""
print("\nDatabase Configuration")
print("=" * 50)
print("1. Direct Connection (PostgreSQL/MySQL/SQLite)")
print("2. Drizzle ORM")
print("3. Prisma")
print("4. Kysely")
print("5. MongoDB")
choice = input("\nSelect database option (1-5): ").strip()
db_configs = {
"1": self._prompt_direct_db,
"2": self._prompt_drizzle,
"3": self._prompt_prisma,
"4": self._prompt_kysely,
"5": self._prompt_mongodb,
}
handler = db_configs.get(choice)
if not handler:
print("Invalid choice. Defaulting to direct PostgreSQL.")
return self._prompt_direct_db()
return handler()
def _prompt_direct_db(self) -> Dict[str, Any]:
"""Prompt for direct database connection."""
print("\nDatabase Type:")
print("1. PostgreSQL")
print("2. MySQL")
print("3. SQLite")
db_type = input("Select (1-3): ").strip()
if db_type == "3":
db_path = input("SQLite file path [./dev.db]: ").strip() or "./dev.db"
return {
"type": "sqlite",
"import": "import Database from 'better-sqlite3';",
"config": f'database: new Database("{db_path}")'
}
elif db_type == "2":
db_url = input("MySQL connection string: ").strip()
return {
"type": "mysql",
"import": "import { createPool } from 'mysql2/promise';",
"config": f"database: createPool({{ connectionString: process.env.DATABASE_URL }})",
"env_var": ("DATABASE_URL", db_url)
}
else:
db_url = input("PostgreSQL connection string: ").strip()
return {
"type": "postgresql",
"import": "import { Pool } from 'pg';",
"config": "database: new Pool({ connectionString: process.env.DATABASE_URL })",
"env_var": ("DATABASE_URL", db_url)
}
def _prompt_drizzle(self) -> Dict[str, Any]:
"""Prompt for Drizzle ORM configuration."""
print("\nDrizzle Provider:")
print("1. PostgreSQL")
print("2. MySQL")
print("3. SQLite")
provider = input("Select (1-3): ").strip()
provider_map = {"1": "pg", "2": "mysql", "3": "sqlite"}
provider_name = provider_map.get(provider, "pg")
return {
"type": "drizzle",
"provider": provider_name,
"import": "import { drizzleAdapter } from 'better-auth/adapters/drizzle';\nimport { db } from '@/db';",
"config": f"database: drizzleAdapter(db, {{ provider: '{provider_name}' }})"
}
def _prompt_prisma(self) -> Dict[str, Any]:
"""Prompt for Prisma configuration."""
print("\nPrisma Provider:")
print("1. PostgreSQL")
print("2. MySQL")
print("3. SQLite")
provider = input("Select (1-3): ").strip()
provider_map = {"1": "postgresql", "2": "mysql", "3": "sqlite"}
provider_name = provider_map.get(provider, "postgresql")
return {
"type": "prisma",
"provider": provider_name,
"import": "import { prismaAdapter } from 'better-auth/adapters/prisma';\nimport { PrismaClient } from '@prisma/client';\n\nconst prisma = new PrismaClient();",
"config": f"database: prismaAdapter(prisma, {{ provider: '{provider_name}' }})"
}
def _prompt_kysely(self) -> Dict[str, Any]:
"""Prompt for Kysely configuration."""
return {
"type": "kysely",
"import": "import { kyselyAdapter } from 'better-auth/adapters/kysely';\nimport { db } from '@/db';",
"config": "database: kyselyAdapter(db, { provider: 'pg' })"
}
def _prompt_mongodb(self) -> Dict[str, Any]:
"""Prompt for MongoDB configuration."""
mongo_uri = input("MongoDB connection string: ").strip()
db_name = input("Database name: ").strip()
return {
"type": "mongodb",
"import": "import { mongodbAdapter } from 'better-auth/adapters/mongodb';\nimport { client } from '@/db';",
"config": f"database: mongodbAdapter(client, {{ databaseName: '{db_name}' }})",
"env_var": ("MONGODB_URI", mongo_uri)
}
def prompt_auth_methods(self) -> List[str]:
"""
Prompt user for authentication methods.
Returns:
List of selected auth method codes.
"""
print("\nAuthentication Methods")
print("=" * 50)
print("Select authentication methods (space-separated, e.g., '1 2 3'):")
print("1. Email/Password")
print("2. GitHub OAuth")
print("3. Google OAuth")
print("4. Discord OAuth")
print("5. Two-Factor Authentication (2FA)")
print("6. Passkeys (WebAuthn)")
print("7. Magic Link")
print("8. Username")
choices = input("\nYour selection: ").strip().split()
return [c for c in choices if c in "12345678"]
def generate_auth_config(
self,
db_config: Dict[str, Any],
auth_methods: List[str],
) -> str:
"""
Generate auth.ts configuration file content.
Args:
db_config: Database configuration.
auth_methods: Selected authentication methods.
Returns:
Generated TypeScript configuration code.
"""
imports = ["import { betterAuth } from 'better-auth';"]
plugins = []
plugin_imports = []
config_parts = []
# Database import
if db_config.get("import"):
imports.append(db_config["import"])
# Email/Password
if "1" in auth_methods:
config_parts.append(""" emailAndPassword: {
enabled: true,
autoSignIn: true
}""")
# OAuth providers
social_providers = []
if "2" in auth_methods:
social_providers.append(""" github: {
clientId: process.env.GITHUB_CLIENT_ID!,
clientSecret: process.env.GITHUB_CLIENT_SECRET!,
}""")
if "3" in auth_methods:
social_providers.append(""" google: {
clientId: process.env.GOOGLE_CLIENT_ID!,
clientSecret: process.env.GOOGLE_CLIENT_SECRET!,
}""")
if "4" in auth_methods:
social_providers.append(""" discord: {
clientId: process.env.DISCORD_CLIENT_ID!,
clientSecret: process.env.DISCORD_CLIENT_SECRET!,
}""")
if social_providers:
config_parts.append(f" socialProviders: {{\n{',\\n'.join(social_providers)}\n }}")
# Plugins
if "5" in auth_methods:
plugin_imports.append("import { twoFactor } from 'better-auth/plugins';")
plugins.append("twoFactor()")
if "6" in auth_methods:
plugin_imports.append("import { passkey } from 'better-auth/plugins';")
plugins.append("passkey()")
if "7" in auth_methods:
plugin_imports.append("import { magicLink } from 'better-auth/plugins';")
plugins.append("""magicLink({
sendMagicLink: async ({ email, url }) => {
// TODO: Implement email sending
console.log(`Magic link for ${email}: ${url}`);
}
})""")
if "8" in auth_methods:
plugin_imports.append("import { username } from 'better-auth/plugins';")
plugins.append("username()")
# Combine all imports
all_imports = imports + plugin_imports
# Build config
config_body = ",\n".join(config_parts)
if plugins:
plugins_str = ",\n ".join(plugins)
config_body += f",\n plugins: [\n {plugins_str}\n ]"
# Final output
return f"""{chr(10).join(all_imports)}
export const auth = betterAuth({{
{db_config["config"]},
{config_body}
}});
"""
def generate_env_file(
self,
db_config: Dict[str, Any],
auth_methods: List[str]
) -> str:
"""
Generate .env file content.
Args:
db_config: Database configuration.
auth_methods: Selected authentication methods.
Returns:
Generated .env file content.
"""
env_vars = [
f"BETTER_AUTH_SECRET={self.generate_secret()}",
"BETTER_AUTH_URL=http://localhost:3000",
]
# Database URL
if db_config.get("env_var"):
key, value = db_config["env_var"]
env_vars.append(f"{key}={value}")
# OAuth credentials
if "2" in auth_methods:
env_vars.extend([
"GITHUB_CLIENT_ID=your_github_client_id",
"GITHUB_CLIENT_SECRET=your_github_client_secret",
])
if "3" in auth_methods:
env_vars.extend([
"GOOGLE_CLIENT_ID=your_google_client_id",
"GOOGLE_CLIENT_SECRET=your_google_client_secret",
])
if "4" in auth_methods:
env_vars.extend([
"DISCORD_CLIENT_ID=your_discord_client_id",
"DISCORD_CLIENT_SECRET=your_discord_client_secret",
])
return "\n".join(env_vars) + "\n"
def run(self) -> None:
"""Run interactive initialization."""
print("=" * 50)
print("Better Auth Configuration Generator")
print("=" * 50)
# Load existing env
env_vars = self._load_env_files()
# Prompt for configuration
db_config = self.prompt_database()
auth_methods = self.prompt_auth_methods()
# Generate files
auth_config = self.generate_auth_config(db_config, auth_methods)
env_content = self.generate_env_file(db_config, auth_methods)
# Display output
print("\n" + "=" * 50)
print("Generated Configuration")
print("=" * 50)
print("\n--- auth.ts ---")
print(auth_config)
print("\n--- .env ---")
print(env_content)
# Offer to save
save = input("\nSave configuration files? (y/N): ").strip().lower()
if save == "y":
self._save_files(auth_config, env_content)
else:
print("Configuration not saved.")
def _save_files(self, auth_config: str, env_content: str) -> None:
"""
Save generated configuration files.
Args:
auth_config: auth.ts content.
env_content: .env content.
"""
# Save auth.ts
auth_locations = [
self.project_root / "lib" / "auth.ts",
self.project_root / "src" / "lib" / "auth.ts",
self.project_root / "utils" / "auth.ts",
self.project_root / "auth.ts",
]
print("\nWhere to save auth.ts?")
for i, loc in enumerate(auth_locations, 1):
print(f"{i}. {loc}")
print("5. Custom path")
choice = input("Select (1-5): ").strip()
if choice == "5":
custom_path = input("Enter path: ").strip()
auth_path = Path(custom_path)
else:
idx = int(choice) - 1 if choice.isdigit() else 0
auth_path = auth_locations[idx]
auth_path.parent.mkdir(parents=True, exist_ok=True)
auth_path.write_text(auth_config)
print(f"Saved: {auth_path}")
# Save .env
env_path = self.project_root / ".env"
if env_path.exists():
backup = self.project_root / ".env.backup"
env_path.rename(backup)
print(f"Backed up existing .env to {backup}")
env_path.write_text(env_content)
print(f"Saved: {env_path}")
print("\nNext steps:")
print("1. Run: npx @better-auth/cli generate")
print("2. Apply database migrations")
print("3. Mount API handler in your framework")
print("4. Create client instance")
def main() -> int:
"""
Main entry point.
Returns:
Exit code (0 for success, 1 for error).
"""
try:
initializer = BetterAuthInit()
initializer.run()
return 0
except KeyboardInterrupt:
print("\n\nOperation cancelled.")
return 1
except Exception as e:
print(f"\nError: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,15 @@
# Better Auth Skill Dependencies
# Python 3.10+ required
# No Python package dependencies - uses only standard library
# Testing dependencies (dev)
pytest>=8.0.0
pytest-cov>=4.1.0
pytest-mock>=3.12.0
# Note: This script generates Better Auth configuration
# The actual Better Auth library is installed via npm/pnpm/yarn:
# npm install better-auth
# pnpm add better-auth
# yarn add better-auth

View File

@@ -0,0 +1,421 @@
"""
Tests for better_auth_init.py
Covers main functionality with mocked I/O and file operations.
Target: >80% coverage
"""
import sys
import pytest
from pathlib import Path
from unittest.mock import Mock, patch, mock_open, MagicMock
from io import StringIO
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from better_auth_init import BetterAuthInit, EnvConfig, main
@pytest.fixture
def mock_project_root(tmp_path):
"""Create mock project root with package.json."""
(tmp_path / "package.json").write_text("{}")
return tmp_path
@pytest.fixture
def auth_init(mock_project_root):
"""Create BetterAuthInit instance with mock project root."""
return BetterAuthInit(project_root=mock_project_root)
class TestBetterAuthInit:
"""Test BetterAuthInit class."""
def test_init_with_project_root(self, mock_project_root):
"""Test initialization with explicit project root."""
init = BetterAuthInit(project_root=mock_project_root)
assert init.project_root == mock_project_root
assert init.env_config is None
def test_find_project_root_success(self, mock_project_root, monkeypatch):
"""Test finding project root successfully."""
monkeypatch.chdir(mock_project_root)
init = BetterAuthInit()
assert init.project_root == mock_project_root
def test_find_project_root_failure(self, tmp_path, monkeypatch):
"""Test failure to find project root."""
# Create path without package.json
no_package_dir = tmp_path / "no-package"
no_package_dir.mkdir()
monkeypatch.chdir(no_package_dir)
# Mock parent to stop infinite loop
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
with pytest.raises(RuntimeError, match="Could not find project root"):
BetterAuthInit()
def test_generate_secret(self):
"""Test secret generation."""
secret = BetterAuthInit.generate_secret()
assert len(secret) == 64 # 32 bytes = 64 hex chars
assert all(c in "0123456789abcdef" for c in secret)
# Test custom length
secret = BetterAuthInit.generate_secret(length=16)
assert len(secret) == 32 # 16 bytes = 32 hex chars
def test_parse_env_file(self, tmp_path):
"""Test parsing .env file."""
env_content = """
# Comment
KEY1=value1
KEY2="value2"
KEY3='value3'
INVALID LINE
KEY4=value=with=equals
"""
env_file = tmp_path / ".env"
env_file.write_text(env_content)
result = BetterAuthInit._parse_env_file(env_file)
assert result["KEY1"] == "value1"
assert result["KEY2"] == "value2"
assert result["KEY3"] == "value3"
assert result["KEY4"] == "value=with=equals"
assert "INVALID" not in result
def test_parse_env_file_missing(self, tmp_path):
"""Test parsing missing .env file."""
result = BetterAuthInit._parse_env_file(tmp_path / "nonexistent.env")
assert result == {}
def test_load_env_files(self, auth_init, mock_project_root):
"""Test loading environment variables from multiple files."""
# Create .env files
claude_env = mock_project_root / ".claude" / ".env"
claude_env.parent.mkdir(parents=True, exist_ok=True)
claude_env.write_text("BASE_VAR=base\nOVERRIDE=claude")
skills_env = mock_project_root / ".claude" / "skills" / ".env"
skills_env.parent.mkdir(parents=True, exist_ok=True)
skills_env.write_text("OVERRIDE=skills\nSKILLS_VAR=skills")
# Mock process env (highest priority)
with patch.dict("os.environ", {"OVERRIDE": "process", "PROCESS_VAR": "process"}):
result = auth_init._load_env_files()
assert result["BASE_VAR"] == "base"
assert result["SKILLS_VAR"] == "skills"
assert result["OVERRIDE"] == "process" # Process env wins
assert result["PROCESS_VAR"] == "process"
def test_prompt_direct_db_sqlite(self, auth_init):
"""Test prompting for SQLite database."""
with patch("builtins.input", side_effect=["3", "./test.db"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "sqlite"
assert "better-sqlite3" in config["import"]
assert "./test.db" in config["config"]
def test_prompt_direct_db_postgresql(self, auth_init):
"""Test prompting for PostgreSQL database."""
with patch("builtins.input", side_effect=["1", "postgresql://localhost/test"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "postgresql"
assert "pg" in config["import"]
assert config["env_var"] == ("DATABASE_URL", "postgresql://localhost/test")
def test_prompt_direct_db_mysql(self, auth_init):
"""Test prompting for MySQL database."""
with patch("builtins.input", side_effect=["2", "mysql://localhost/test"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "mysql"
assert "mysql2" in config["import"]
assert config["env_var"][0] == "DATABASE_URL"
def test_prompt_drizzle(self, auth_init):
"""Test prompting for Drizzle ORM."""
with patch("builtins.input", return_value="1"):
config = auth_init._prompt_drizzle()
assert config["type"] == "drizzle"
assert config["provider"] == "pg"
assert "drizzleAdapter" in config["import"]
assert "drizzleAdapter" in config["config"]
def test_prompt_prisma(self, auth_init):
"""Test prompting for Prisma."""
with patch("builtins.input", return_value="2"):
config = auth_init._prompt_prisma()
assert config["type"] == "prisma"
assert config["provider"] == "mysql"
assert "prismaAdapter" in config["import"]
assert "PrismaClient" in config["import"]
def test_prompt_kysely(self, auth_init):
"""Test prompting for Kysely."""
config = auth_init._prompt_kysely()
assert config["type"] == "kysely"
assert "kyselyAdapter" in config["import"]
def test_prompt_mongodb(self, auth_init):
"""Test prompting for MongoDB."""
with patch("builtins.input", side_effect=["mongodb://localhost/test", "mydb"]):
config = auth_init._prompt_mongodb()
assert config["type"] == "mongodb"
assert "mongodbAdapter" in config["import"]
assert "mydb" in config["config"]
assert config["env_var"] == ("MONGODB_URI", "mongodb://localhost/test")
def test_prompt_database(self, auth_init):
"""Test database prompting with different choices."""
# Test valid choice
with patch("builtins.input", side_effect=["3", "1"]):
config = auth_init.prompt_database()
assert config["type"] == "prisma"
# Test invalid choice (defaults to direct DB)
with patch("builtins.input", side_effect=["99", "1", "postgresql://localhost/test"]):
with patch("builtins.print"):
config = auth_init.prompt_database()
assert config["type"] == "postgresql"
def test_prompt_auth_methods(self, auth_init):
"""Test prompting for authentication methods."""
with patch("builtins.input", return_value="1 2 3 5 8"):
with patch("builtins.print"):
methods = auth_init.prompt_auth_methods()
assert methods == ["1", "2", "3", "5", "8"]
def test_prompt_auth_methods_invalid(self, auth_init):
"""Test filtering invalid auth method choices."""
with patch("builtins.input", return_value="1 99 abc 3"):
with patch("builtins.print"):
methods = auth_init.prompt_auth_methods()
assert methods == ["1", "3"]
def test_generate_auth_config_basic(self, auth_init):
"""Test generating basic auth config."""
db_config = {
"import": "import Database from 'better-sqlite3';",
"config": "database: new Database('./dev.db')"
}
auth_methods = ["1"] # Email/password only
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "import { betterAuth }" in config
assert "emailAndPassword" in config
assert "enabled: true" in config
assert "better-sqlite3" in config
def test_generate_auth_config_with_oauth(self, auth_init):
"""Test generating config with OAuth providers."""
db_config = {
"import": "import { Pool } from 'pg';",
"config": "database: new Pool()"
}
auth_methods = ["1", "2", "3", "4"] # Email + GitHub + Google + Discord
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "socialProviders" in config
assert "github:" in config
assert "google:" in config
assert "discord:" in config
assert "GITHUB_CLIENT_ID" in config
assert "GOOGLE_CLIENT_ID" in config
assert "DISCORD_CLIENT_ID" in config
def test_generate_auth_config_with_plugins(self, auth_init):
"""Test generating config with plugins."""
db_config = {"import": "", "config": "database: db"}
auth_methods = ["5", "6", "7", "8"] # 2FA, Passkey, Magic Link, Username
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "plugins:" in config
assert "twoFactor" in config
assert "passkey" in config
assert "magicLink" in config
assert "username" in config
assert "from 'better-auth/plugins'" in config
def test_generate_env_file_basic(self, auth_init):
"""Test generating basic .env file."""
db_config = {"type": "sqlite"}
auth_methods = ["1"]
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "BETTER_AUTH_SECRET=" in env_content
assert "BETTER_AUTH_URL=http://localhost:3000" in env_content
assert len(env_content.split("\n")) >= 2
def test_generate_env_file_with_database_url(self, auth_init):
"""Test generating .env with database URL."""
db_config = {
"env_var": ("DATABASE_URL", "postgresql://localhost/test")
}
auth_methods = []
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "DATABASE_URL=postgresql://localhost/test" in env_content
def test_generate_env_file_with_oauth(self, auth_init):
"""Test generating .env with OAuth credentials."""
db_config = {}
auth_methods = ["2", "3", "4"] # GitHub, Google, Discord
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "GITHUB_CLIENT_ID=" in env_content
assert "GITHUB_CLIENT_SECRET=" in env_content
assert "GOOGLE_CLIENT_ID=" in env_content
assert "GOOGLE_CLIENT_SECRET=" in env_content
assert "DISCORD_CLIENT_ID=" in env_content
assert "DISCORD_CLIENT_SECRET=" in env_content
def test_save_files(self, auth_init, mock_project_root):
"""Test saving configuration files."""
auth_config = "// auth config"
env_content = "SECRET=test"
with patch("builtins.input", side_effect=["1"]):
auth_init._save_files(auth_config, env_content)
# Check auth.ts was saved
auth_path = mock_project_root / "lib" / "auth.ts"
assert auth_path.exists()
assert auth_path.read_text() == auth_config
# Check .env was saved
env_path = mock_project_root / ".env"
assert env_path.exists()
assert env_path.read_text() == env_content
def test_save_files_custom_path(self, auth_init, mock_project_root):
"""Test saving with custom path."""
auth_config = "// config"
env_content = "SECRET=test"
custom_path = str(mock_project_root / "custom" / "auth.ts")
with patch("builtins.input", side_effect=["5", custom_path]):
auth_init._save_files(auth_config, env_content)
assert Path(custom_path).exists()
def test_save_files_backup_existing_env(self, auth_init, mock_project_root):
"""Test backing up existing .env file."""
# Create existing .env
env_path = mock_project_root / ".env"
env_path.write_text("OLD_SECRET=old")
auth_config = "// config"
env_content = "NEW_SECRET=new"
with patch("builtins.input", return_value="1"):
auth_init._save_files(auth_config, env_content)
# Check backup was created
backup_path = mock_project_root / ".env.backup"
assert backup_path.exists()
assert backup_path.read_text() == "OLD_SECRET=old"
# Check new .env
assert env_path.read_text() == "NEW_SECRET=new"
def test_run_full_flow(self, auth_init, mock_project_root):
"""Test complete run flow."""
inputs = [
"1", # Direct DB
"1", # PostgreSQL
"postgresql://localhost/test",
"1 2", # Email + GitHub
"n" # Don't save
]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
auth_init.run()
# Should complete without errors
# Files not saved because user chose 'n'
assert not (mock_project_root / "auth.ts").exists()
def test_run_save_files(self, auth_init, mock_project_root):
"""Test run flow with file saving."""
inputs = [
"1", # Direct DB
"3", # SQLite
"", # Default path
"1", # Email only
"y", # Save
"1" # Save location
]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
auth_init.run()
# Check files were created
assert (mock_project_root / "lib" / "auth.ts").exists()
assert (mock_project_root / ".env").exists()
class TestMainFunction:
"""Test main entry point."""
def test_main_success(self, tmp_path, monkeypatch):
"""Test successful main execution."""
(tmp_path / "package.json").write_text("{}")
monkeypatch.chdir(tmp_path)
inputs = ["1", "3", "", "1", "n"]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
exit_code = main()
assert exit_code == 0
def test_main_keyboard_interrupt(self, tmp_path, monkeypatch):
"""Test main with keyboard interrupt."""
(tmp_path / "package.json").write_text("{}")
monkeypatch.chdir(tmp_path)
with patch("builtins.input", side_effect=KeyboardInterrupt()):
with patch("builtins.print"):
exit_code = main()
assert exit_code == 1
def test_main_error(self, tmp_path, monkeypatch):
"""Test main with error."""
# No package.json - should fail
no_package = tmp_path / "no-package"
no_package.mkdir()
monkeypatch.chdir(no_package)
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
with patch("sys.stderr", new_callable=StringIO):
exit_code = main()
assert exit_code == 1
if __name__ == "__main__":
pytest.main([__file__, "-v", "--cov=better_auth_init", "--cov-report=term-missing"])