Initial commit

This commit is contained in:
Zhongwei Li
2025-11-30 08:48:52 +08:00
commit 6ec3196ecc
434 changed files with 125248 additions and 0 deletions

View File

@@ -0,0 +1,421 @@
"""
Tests for better_auth_init.py
Covers main functionality with mocked I/O and file operations.
Target: >80% coverage
"""
import sys
import pytest
from pathlib import Path
from unittest.mock import Mock, patch, mock_open, MagicMock
from io import StringIO
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from better_auth_init import BetterAuthInit, EnvConfig, main
@pytest.fixture
def mock_project_root(tmp_path):
"""Create mock project root with package.json."""
(tmp_path / "package.json").write_text("{}")
return tmp_path
@pytest.fixture
def auth_init(mock_project_root):
"""Create BetterAuthInit instance with mock project root."""
return BetterAuthInit(project_root=mock_project_root)
class TestBetterAuthInit:
"""Test BetterAuthInit class."""
def test_init_with_project_root(self, mock_project_root):
"""Test initialization with explicit project root."""
init = BetterAuthInit(project_root=mock_project_root)
assert init.project_root == mock_project_root
assert init.env_config is None
def test_find_project_root_success(self, mock_project_root, monkeypatch):
"""Test finding project root successfully."""
monkeypatch.chdir(mock_project_root)
init = BetterAuthInit()
assert init.project_root == mock_project_root
def test_find_project_root_failure(self, tmp_path, monkeypatch):
"""Test failure to find project root."""
# Create path without package.json
no_package_dir = tmp_path / "no-package"
no_package_dir.mkdir()
monkeypatch.chdir(no_package_dir)
# Mock parent to stop infinite loop
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
with pytest.raises(RuntimeError, match="Could not find project root"):
BetterAuthInit()
def test_generate_secret(self):
"""Test secret generation."""
secret = BetterAuthInit.generate_secret()
assert len(secret) == 64 # 32 bytes = 64 hex chars
assert all(c in "0123456789abcdef" for c in secret)
# Test custom length
secret = BetterAuthInit.generate_secret(length=16)
assert len(secret) == 32 # 16 bytes = 32 hex chars
def test_parse_env_file(self, tmp_path):
"""Test parsing .env file."""
env_content = """
# Comment
KEY1=value1
KEY2="value2"
KEY3='value3'
INVALID LINE
KEY4=value=with=equals
"""
env_file = tmp_path / ".env"
env_file.write_text(env_content)
result = BetterAuthInit._parse_env_file(env_file)
assert result["KEY1"] == "value1"
assert result["KEY2"] == "value2"
assert result["KEY3"] == "value3"
assert result["KEY4"] == "value=with=equals"
assert "INVALID" not in result
def test_parse_env_file_missing(self, tmp_path):
"""Test parsing missing .env file."""
result = BetterAuthInit._parse_env_file(tmp_path / "nonexistent.env")
assert result == {}
def test_load_env_files(self, auth_init, mock_project_root):
"""Test loading environment variables from multiple files."""
# Create .env files
claude_env = mock_project_root / ".claude" / ".env"
claude_env.parent.mkdir(parents=True, exist_ok=True)
claude_env.write_text("BASE_VAR=base\nOVERRIDE=claude")
skills_env = mock_project_root / ".claude" / "skills" / ".env"
skills_env.parent.mkdir(parents=True, exist_ok=True)
skills_env.write_text("OVERRIDE=skills\nSKILLS_VAR=skills")
# Mock process env (highest priority)
with patch.dict("os.environ", {"OVERRIDE": "process", "PROCESS_VAR": "process"}):
result = auth_init._load_env_files()
assert result["BASE_VAR"] == "base"
assert result["SKILLS_VAR"] == "skills"
assert result["OVERRIDE"] == "process" # Process env wins
assert result["PROCESS_VAR"] == "process"
def test_prompt_direct_db_sqlite(self, auth_init):
"""Test prompting for SQLite database."""
with patch("builtins.input", side_effect=["3", "./test.db"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "sqlite"
assert "better-sqlite3" in config["import"]
assert "./test.db" in config["config"]
def test_prompt_direct_db_postgresql(self, auth_init):
"""Test prompting for PostgreSQL database."""
with patch("builtins.input", side_effect=["1", "postgresql://localhost/test"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "postgresql"
assert "pg" in config["import"]
assert config["env_var"] == ("DATABASE_URL", "postgresql://localhost/test")
def test_prompt_direct_db_mysql(self, auth_init):
"""Test prompting for MySQL database."""
with patch("builtins.input", side_effect=["2", "mysql://localhost/test"]):
config = auth_init._prompt_direct_db()
assert config["type"] == "mysql"
assert "mysql2" in config["import"]
assert config["env_var"][0] == "DATABASE_URL"
def test_prompt_drizzle(self, auth_init):
"""Test prompting for Drizzle ORM."""
with patch("builtins.input", return_value="1"):
config = auth_init._prompt_drizzle()
assert config["type"] == "drizzle"
assert config["provider"] == "pg"
assert "drizzleAdapter" in config["import"]
assert "drizzleAdapter" in config["config"]
def test_prompt_prisma(self, auth_init):
"""Test prompting for Prisma."""
with patch("builtins.input", return_value="2"):
config = auth_init._prompt_prisma()
assert config["type"] == "prisma"
assert config["provider"] == "mysql"
assert "prismaAdapter" in config["import"]
assert "PrismaClient" in config["import"]
def test_prompt_kysely(self, auth_init):
"""Test prompting for Kysely."""
config = auth_init._prompt_kysely()
assert config["type"] == "kysely"
assert "kyselyAdapter" in config["import"]
def test_prompt_mongodb(self, auth_init):
"""Test prompting for MongoDB."""
with patch("builtins.input", side_effect=["mongodb://localhost/test", "mydb"]):
config = auth_init._prompt_mongodb()
assert config["type"] == "mongodb"
assert "mongodbAdapter" in config["import"]
assert "mydb" in config["config"]
assert config["env_var"] == ("MONGODB_URI", "mongodb://localhost/test")
def test_prompt_database(self, auth_init):
"""Test database prompting with different choices."""
# Test valid choice
with patch("builtins.input", side_effect=["3", "1"]):
config = auth_init.prompt_database()
assert config["type"] == "prisma"
# Test invalid choice (defaults to direct DB)
with patch("builtins.input", side_effect=["99", "1", "postgresql://localhost/test"]):
with patch("builtins.print"):
config = auth_init.prompt_database()
assert config["type"] == "postgresql"
def test_prompt_auth_methods(self, auth_init):
"""Test prompting for authentication methods."""
with patch("builtins.input", return_value="1 2 3 5 8"):
with patch("builtins.print"):
methods = auth_init.prompt_auth_methods()
assert methods == ["1", "2", "3", "5", "8"]
def test_prompt_auth_methods_invalid(self, auth_init):
"""Test filtering invalid auth method choices."""
with patch("builtins.input", return_value="1 99 abc 3"):
with patch("builtins.print"):
methods = auth_init.prompt_auth_methods()
assert methods == ["1", "3"]
def test_generate_auth_config_basic(self, auth_init):
"""Test generating basic auth config."""
db_config = {
"import": "import Database from 'better-sqlite3';",
"config": "database: new Database('./dev.db')"
}
auth_methods = ["1"] # Email/password only
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "import { betterAuth }" in config
assert "emailAndPassword" in config
assert "enabled: true" in config
assert "better-sqlite3" in config
def test_generate_auth_config_with_oauth(self, auth_init):
"""Test generating config with OAuth providers."""
db_config = {
"import": "import { Pool } from 'pg';",
"config": "database: new Pool()"
}
auth_methods = ["1", "2", "3", "4"] # Email + GitHub + Google + Discord
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "socialProviders" in config
assert "github:" in config
assert "google:" in config
assert "discord:" in config
assert "GITHUB_CLIENT_ID" in config
assert "GOOGLE_CLIENT_ID" in config
assert "DISCORD_CLIENT_ID" in config
def test_generate_auth_config_with_plugins(self, auth_init):
"""Test generating config with plugins."""
db_config = {"import": "", "config": "database: db"}
auth_methods = ["5", "6", "7", "8"] # 2FA, Passkey, Magic Link, Username
config = auth_init.generate_auth_config(db_config, auth_methods)
assert "plugins:" in config
assert "twoFactor" in config
assert "passkey" in config
assert "magicLink" in config
assert "username" in config
assert "from 'better-auth/plugins'" in config
def test_generate_env_file_basic(self, auth_init):
"""Test generating basic .env file."""
db_config = {"type": "sqlite"}
auth_methods = ["1"]
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "BETTER_AUTH_SECRET=" in env_content
assert "BETTER_AUTH_URL=http://localhost:3000" in env_content
assert len(env_content.split("\n")) >= 2
def test_generate_env_file_with_database_url(self, auth_init):
"""Test generating .env with database URL."""
db_config = {
"env_var": ("DATABASE_URL", "postgresql://localhost/test")
}
auth_methods = []
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "DATABASE_URL=postgresql://localhost/test" in env_content
def test_generate_env_file_with_oauth(self, auth_init):
"""Test generating .env with OAuth credentials."""
db_config = {}
auth_methods = ["2", "3", "4"] # GitHub, Google, Discord
env_content = auth_init.generate_env_file(db_config, auth_methods)
assert "GITHUB_CLIENT_ID=" in env_content
assert "GITHUB_CLIENT_SECRET=" in env_content
assert "GOOGLE_CLIENT_ID=" in env_content
assert "GOOGLE_CLIENT_SECRET=" in env_content
assert "DISCORD_CLIENT_ID=" in env_content
assert "DISCORD_CLIENT_SECRET=" in env_content
def test_save_files(self, auth_init, mock_project_root):
"""Test saving configuration files."""
auth_config = "// auth config"
env_content = "SECRET=test"
with patch("builtins.input", side_effect=["1"]):
auth_init._save_files(auth_config, env_content)
# Check auth.ts was saved
auth_path = mock_project_root / "lib" / "auth.ts"
assert auth_path.exists()
assert auth_path.read_text() == auth_config
# Check .env was saved
env_path = mock_project_root / ".env"
assert env_path.exists()
assert env_path.read_text() == env_content
def test_save_files_custom_path(self, auth_init, mock_project_root):
"""Test saving with custom path."""
auth_config = "// config"
env_content = "SECRET=test"
custom_path = str(mock_project_root / "custom" / "auth.ts")
with patch("builtins.input", side_effect=["5", custom_path]):
auth_init._save_files(auth_config, env_content)
assert Path(custom_path).exists()
def test_save_files_backup_existing_env(self, auth_init, mock_project_root):
"""Test backing up existing .env file."""
# Create existing .env
env_path = mock_project_root / ".env"
env_path.write_text("OLD_SECRET=old")
auth_config = "// config"
env_content = "NEW_SECRET=new"
with patch("builtins.input", return_value="1"):
auth_init._save_files(auth_config, env_content)
# Check backup was created
backup_path = mock_project_root / ".env.backup"
assert backup_path.exists()
assert backup_path.read_text() == "OLD_SECRET=old"
# Check new .env
assert env_path.read_text() == "NEW_SECRET=new"
def test_run_full_flow(self, auth_init, mock_project_root):
"""Test complete run flow."""
inputs = [
"1", # Direct DB
"1", # PostgreSQL
"postgresql://localhost/test",
"1 2", # Email + GitHub
"n" # Don't save
]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
auth_init.run()
# Should complete without errors
# Files not saved because user chose 'n'
assert not (mock_project_root / "auth.ts").exists()
def test_run_save_files(self, auth_init, mock_project_root):
"""Test run flow with file saving."""
inputs = [
"1", # Direct DB
"3", # SQLite
"", # Default path
"1", # Email only
"y", # Save
"1" # Save location
]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
auth_init.run()
# Check files were created
assert (mock_project_root / "lib" / "auth.ts").exists()
assert (mock_project_root / ".env").exists()
class TestMainFunction:
"""Test main entry point."""
def test_main_success(self, tmp_path, monkeypatch):
"""Test successful main execution."""
(tmp_path / "package.json").write_text("{}")
monkeypatch.chdir(tmp_path)
inputs = ["1", "3", "", "1", "n"]
with patch("builtins.input", side_effect=inputs):
with patch("builtins.print"):
exit_code = main()
assert exit_code == 0
def test_main_keyboard_interrupt(self, tmp_path, monkeypatch):
"""Test main with keyboard interrupt."""
(tmp_path / "package.json").write_text("{}")
monkeypatch.chdir(tmp_path)
with patch("builtins.input", side_effect=KeyboardInterrupt()):
with patch("builtins.print"):
exit_code = main()
assert exit_code == 1
def test_main_error(self, tmp_path, monkeypatch):
"""Test main with error."""
# No package.json - should fail
no_package = tmp_path / "no-package"
no_package.mkdir()
monkeypatch.chdir(no_package)
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
with patch("sys.stderr", new_callable=StringIO):
exit_code = main()
assert exit_code == 1
if __name__ == "__main__":
pytest.main([__file__, "-v", "--cov=better_auth_init", "--cov-report=term-missing"])