Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:47:40 +08:00
commit 14c678ceac
22 changed files with 7501 additions and 0 deletions

628
scripts/utils/helpers.py Normal file
View File

@@ -0,0 +1,628 @@
#!/usr/bin/env python3
"""
Helper utilities for Tailscale SSH Sync Agent.
Provides common formatting, parsing, and utility functions.
"""
import os
import re
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import yaml
import logging
logger = logging.getLogger(__name__)
def format_bytes(bytes_value: int) -> str:
"""
Format bytes as human-readable string.
Args:
bytes_value: Number of bytes
Returns:
Formatted string (e.g., "12.3 MB", "1.5 GB")
Example:
>>> format_bytes(12582912)
"12.0 MB"
>>> format_bytes(1610612736)
"1.5 GB"
"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_value < 1024.0:
return f"{bytes_value:.1f} {unit}"
bytes_value /= 1024.0
return f"{bytes_value:.1f} PB"
def format_duration(seconds: float) -> str:
"""
Format duration as human-readable string.
Args:
seconds: Duration in seconds
Returns:
Formatted string (e.g., "2m 15s", "1h 30m")
Example:
>>> format_duration(135)
"2m 15s"
>>> format_duration(5430)
"1h 30m 30s"
"""
if seconds < 60:
return f"{int(seconds)}s"
minutes = int(seconds // 60)
secs = int(seconds % 60)
if minutes < 60:
return f"{minutes}m {secs}s" if secs > 0 else f"{minutes}m"
hours = minutes // 60
minutes = minutes % 60
parts = [f"{hours}h"]
if minutes > 0:
parts.append(f"{minutes}m")
if secs > 0 and hours == 0: # Only show seconds if < 1 hour
parts.append(f"{secs}s")
return " ".join(parts)
def format_percentage(value: float, decimals: int = 1) -> str:
"""
Format percentage with specified decimals.
Args:
value: Percentage value (0-100)
decimals: Number of decimal places
Returns:
Formatted string (e.g., "45.5%")
Example:
>>> format_percentage(45.567)
"45.6%"
"""
return f"{value:.{decimals}f}%"
def parse_ssh_config(config_path: Optional[Path] = None) -> Dict[str, Dict[str, str]]:
"""
Parse SSH config file for host definitions.
Args:
config_path: Path to SSH config (default: ~/.ssh/config)
Returns:
Dict mapping host aliases to their configuration:
{
'host-alias': {
'hostname': '100.64.1.10',
'user': 'admin',
'port': '22',
'identityfile': '~/.ssh/id_ed25519'
}
}
Example:
>>> hosts = parse_ssh_config()
>>> hosts['homelab-1']['hostname']
'100.64.1.10'
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
logger.warning(f"SSH config not found: {config_path}")
return {}
hosts = {}
current_host = None
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Host directive
if line.lower().startswith('host '):
host_alias = line.split(maxsplit=1)[1]
# Skip wildcards
if '*' not in host_alias and '?' not in host_alias:
current_host = host_alias
hosts[current_host] = {}
# Configuration directives
elif current_host:
parts = line.split(maxsplit=1)
if len(parts) == 2:
key, value = parts
hosts[current_host][key.lower()] = value
return hosts
except Exception as e:
logger.error(f"Error parsing SSH config: {e}")
return {}
def parse_sshsync_config(config_path: Optional[Path] = None) -> Dict[str, List[str]]:
"""
Parse sshsync config file for group definitions.
Args:
config_path: Path to sshsync config (default: ~/.config/sshsync/config.yaml)
Returns:
Dict mapping group names to list of hosts:
{
'production': ['prod-web-01', 'prod-db-01'],
'development': ['dev-laptop', 'dev-desktop']
}
Example:
>>> groups = parse_sshsync_config()
>>> groups['production']
['prod-web-01', 'prod-db-01']
"""
if config_path is None:
config_path = Path.home() / '.config' / 'sshsync' / 'config.yaml'
if not config_path.exists():
logger.warning(f"sshsync config not found: {config_path}")
return {}
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config.get('groups', {})
except Exception as e:
logger.error(f"Error parsing sshsync config: {e}")
return {}
def get_timestamp(iso: bool = True) -> str:
"""
Get current timestamp.
Args:
iso: If True, return ISO format; otherwise human-readable
Returns:
Timestamp string
Example:
>>> get_timestamp(iso=True)
"2025-10-19T19:43:41Z"
>>> get_timestamp(iso=False)
"2025-10-19 19:43:41"
"""
now = datetime.now()
if iso:
return now.strftime("%Y-%m-%dT%H:%M:%SZ")
else:
return now.strftime("%Y-%m-%d %H:%M:%S")
def safe_execute(func, *args, default=None, **kwargs) -> Any:
"""
Execute function with error handling.
Args:
func: Function to execute
*args: Positional arguments
default: Value to return on error
**kwargs: Keyword arguments
Returns:
Function result or default on error
Example:
>>> safe_execute(int, "not_a_number", default=0)
0
>>> safe_execute(int, "42")
42
"""
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(f"Error executing {func.__name__}: {e}")
return default
def validate_path(path: str, must_exist: bool = True) -> bool:
"""
Check if path is valid and accessible.
Args:
path: Path to validate
must_exist: If True, path must exist
Returns:
True if valid, False otherwise
Example:
>>> validate_path("/tmp")
True
>>> validate_path("/nonexistent", must_exist=True)
False
"""
p = Path(path).expanduser()
if must_exist:
return p.exists()
else:
# Check if parent directory exists (for paths that will be created)
return p.parent.exists()
def parse_disk_usage(df_output: str) -> Dict[str, Any]:
"""
Parse 'df' command output.
Args:
df_output: Output from 'df -h' command
Returns:
Dict with disk usage info:
{
'filesystem': '/dev/sda1',
'size': '100G',
'used': '45G',
'available': '50G',
'use_pct': 45,
'mount': '/'
}
Example:
>>> output = "Filesystem Size Used Avail Use% Mounted on\\n/dev/sda1 100G 45G 50G 45% /"
>>> parse_disk_usage(output)
{'filesystem': '/dev/sda1', 'size': '100G', ...}
"""
lines = df_output.strip().split('\n')
if len(lines) < 2:
return {}
# Parse last line (actual data, not header)
data_line = lines[-1]
parts = data_line.split()
if len(parts) < 6:
return {}
try:
return {
'filesystem': parts[0],
'size': parts[1],
'used': parts[2],
'available': parts[3],
'use_pct': int(parts[4].rstrip('%')),
'mount': parts[5]
}
except (ValueError, IndexError) as e:
logger.error(f"Error parsing disk usage: {e}")
return {}
def parse_memory_usage(free_output: str) -> Dict[str, Any]:
"""
Parse 'free' command output (Linux).
Args:
free_output: Output from 'free -m' command
Returns:
Dict with memory info:
{
'total': 16384, # MB
'used': 8192,
'free': 8192,
'use_pct': 50.0
}
Example:
>>> output = "Mem: 16384 8192 8192 0 0 0"
>>> parse_memory_usage(output)
{'total': 16384, 'used': 8192, ...}
"""
lines = free_output.strip().split('\n')
for line in lines:
if line.startswith('Mem:'):
parts = line.split()
if len(parts) >= 3:
try:
total = int(parts[1])
used = int(parts[2])
free = int(parts[3]) if len(parts) > 3 else (total - used)
return {
'total': total,
'used': used,
'free': free,
'use_pct': (used / total * 100) if total > 0 else 0
}
except (ValueError, IndexError) as e:
logger.error(f"Error parsing memory usage: {e}")
return {}
def parse_cpu_load(uptime_output: str) -> Dict[str, float]:
"""
Parse 'uptime' command output for load averages.
Args:
uptime_output: Output from 'uptime' command
Returns:
Dict with load averages:
{
'load_1min': 0.45,
'load_5min': 0.38,
'load_15min': 0.32
}
Example:
>>> output = "19:43:41 up 5 days, 2:15, 3 users, load average: 0.45, 0.38, 0.32"
>>> parse_cpu_load(output)
{'load_1min': 0.45, 'load_5min': 0.38, 'load_15min': 0.32}
"""
# Find "load average:" part
match = re.search(r'load average:\s+([\d.]+),\s+([\d.]+),\s+([\d.]+)', uptime_output)
if match:
try:
return {
'load_1min': float(match.group(1)),
'load_5min': float(match.group(2)),
'load_15min': float(match.group(3))
}
except ValueError as e:
logger.error(f"Error parsing CPU load: {e}")
return {}
def format_host_status(host: str, online: bool, groups: List[str],
latency: Optional[int] = None,
tailscale_connected: bool = False) -> str:
"""
Format host status as display string.
Args:
host: Host name
online: Whether host is online
groups: List of groups host belongs to
latency: Latency in ms (optional)
tailscale_connected: Tailscale connection status
Returns:
Formatted status string
Example:
>>> format_host_status("web-01", True, ["production", "web"], 25, True)
"🟢 web-01 (production, web) - Online - Tailscale: Connected | Latency: 25ms"
"""
icon = "🟢" if online else "🔴"
status = "Online" if online else "Offline"
group_str = ", ".join(groups) if groups else "no group"
parts = [f"{icon} {host} ({group_str}) - {status}"]
if tailscale_connected:
parts.append("Tailscale: Connected")
if latency is not None and online:
parts.append(f"Latency: {latency}ms")
return " - ".join(parts)
def calculate_load_score(cpu_pct: float, mem_pct: float, disk_pct: float) -> float:
"""
Calculate composite load score for a machine.
Args:
cpu_pct: CPU usage percentage (0-100)
mem_pct: Memory usage percentage (0-100)
disk_pct: Disk usage percentage (0-100)
Returns:
Load score (0-1, lower is better)
Formula:
score = (cpu * 0.4) + (mem * 0.3) + (disk * 0.3)
Example:
>>> calculate_load_score(45, 60, 40)
0.48 # (0.45*0.4 + 0.60*0.3 + 0.40*0.3)
"""
return (cpu_pct * 0.4 + mem_pct * 0.3 + disk_pct * 0.3) / 100
def classify_load_status(score: float) -> str:
"""
Classify load score into status category.
Args:
score: Load score (0-1)
Returns:
Status string: "low", "moderate", or "high"
Example:
>>> classify_load_status(0.28)
"low"
>>> classify_load_status(0.55)
"moderate"
>>> classify_load_status(0.82)
"high"
"""
if score < 0.4:
return "low"
elif score < 0.7:
return "moderate"
else:
return "high"
def classify_latency(latency_ms: int) -> Tuple[str, str]:
"""
Classify network latency.
Args:
latency_ms: Latency in milliseconds
Returns:
Tuple of (status, description)
Example:
>>> classify_latency(25)
("excellent", "Ideal for interactive tasks")
>>> classify_latency(150)
("fair", "May impact interactive workflows")
"""
if latency_ms < 50:
return ("excellent", "Ideal for interactive tasks")
elif latency_ms < 100:
return ("good", "Suitable for most operations")
elif latency_ms < 200:
return ("fair", "May impact interactive workflows")
else:
return ("poor", "Investigate network issues")
def get_hosts_from_groups(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Get list of hosts in a group.
Args:
group: Group name
groups_config: Groups configuration dict
Returns:
List of host names in group
Example:
>>> groups = {'production': ['web-01', 'db-01']}
>>> get_hosts_from_groups('production', groups)
['web-01', 'db-01']
"""
return groups_config.get(group, [])
def get_groups_for_host(host: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Get list of groups a host belongs to.
Args:
host: Host name
groups_config: Groups configuration dict
Returns:
List of group names
Example:
>>> groups = {'production': ['web-01'], 'web': ['web-01', 'web-02']}
>>> get_groups_for_host('web-01', groups)
['production', 'web']
"""
return [group for group, hosts in groups_config.items() if host in hosts]
def run_command(command: str, timeout: int = 10) -> Tuple[bool, str, str]:
"""
Run shell command with timeout.
Args:
command: Command to execute
timeout: Timeout in seconds
Returns:
Tuple of (success, stdout, stderr)
Example:
>>> success, stdout, stderr = run_command("echo hello")
>>> success
True
>>> stdout.strip()
"hello"
"""
try:
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout
)
return (
result.returncode == 0,
result.stdout,
result.stderr
)
except subprocess.TimeoutExpired:
return (False, "", f"Command timed out after {timeout}s")
except Exception as e:
return (False, "", str(e))
def main():
"""Test helper functions."""
print("Testing helper functions...\n")
# Test formatting
print("1. Format bytes:")
print(f" 12582912 bytes = {format_bytes(12582912)}")
print(f" 1610612736 bytes = {format_bytes(1610612736)}")
print("\n2. Format duration:")
print(f" 135 seconds = {format_duration(135)}")
print(f" 5430 seconds = {format_duration(5430)}")
print("\n3. Format percentage:")
print(f" 45.567 = {format_percentage(45.567)}")
print("\n4. Calculate load score:")
score = calculate_load_score(45, 60, 40)
print(f" CPU 45%, Mem 60%, Disk 40% = {score:.2f}")
print(f" Status: {classify_load_status(score)}")
print("\n5. Classify latency:")
latencies = [25, 75, 150, 250]
for lat in latencies:
status, desc = classify_latency(lat)
print(f" {lat}ms: {status} - {desc}")
print("\n6. Parse SSH config:")
ssh_hosts = parse_ssh_config()
print(f" Found {len(ssh_hosts)} hosts")
print("\n7. Parse sshsync config:")
groups = parse_sshsync_config()
print(f" Found {len(groups)} groups")
for group, hosts in groups.items():
print(f" - {group}: {len(hosts)} hosts")
print("\n✅ All helpers tested successfully")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,43 @@
"""
Validators package for Tailscale SSH Sync Agent.
"""
from .parameter_validator import (
ValidationError,
validate_host,
validate_group,
validate_path_exists,
validate_timeout,
validate_command
)
from .host_validator import (
validate_ssh_config,
validate_host_reachable,
validate_group_members,
get_invalid_hosts
)
from .connection_validator import (
validate_ssh_connection,
validate_tailscale_connection,
validate_ssh_key,
get_connection_diagnostics
)
__all__ = [
'ValidationError',
'validate_host',
'validate_group',
'validate_path_exists',
'validate_timeout',
'validate_command',
'validate_ssh_config',
'validate_host_reachable',
'validate_group_members',
'get_invalid_hosts',
'validate_ssh_connection',
'validate_tailscale_connection',
'validate_ssh_key',
'get_connection_diagnostics',
]

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python3
"""
Connection validators for Tailscale SSH Sync Agent.
Validates SSH and Tailscale connections.
"""
import subprocess
from typing import Dict, Optional
import logging
from .parameter_validator import ValidationError
logger = logging.getLogger(__name__)
def validate_ssh_connection(host: str, timeout: int = 10) -> bool:
"""
Test SSH connection works.
Args:
host: Host to connect to
timeout: Connection timeout in seconds
Returns:
True if SSH connection successful
Raises:
ValidationError: If connection fails
Example:
>>> validate_ssh_connection("web-01")
True
"""
try:
# Try to execute a simple command via SSH
result = subprocess.run(
["ssh", "-o", "ConnectTimeout={}".format(timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
host, "echo", "test"],
capture_output=True,
text=True,
timeout=timeout + 5
)
if result.returncode == 0:
return True
else:
# Parse error message
error_msg = result.stderr.strip()
if "Permission denied" in error_msg:
raise ValidationError(
f"SSH authentication failed for '{host}'\n"
"Check:\n"
"1. SSH key is added: ssh-add -l\n"
"2. Public key is on remote: cat ~/.ssh/authorized_keys\n"
"3. User/key in SSH config is correct"
)
elif "Connection refused" in error_msg:
raise ValidationError(
f"SSH connection refused for '{host}'\n"
"Check:\n"
"1. SSH server is running on remote\n"
"2. Port 22 is not blocked by firewall"
)
elif "Connection timed out" in error_msg or "timeout" in error_msg.lower():
raise ValidationError(
f"SSH connection timed out for '{host}'\n"
"Check:\n"
"1. Host is reachable (ping test)\n"
"2. Tailscale is connected\n"
"3. Network connectivity"
)
else:
raise ValidationError(
f"SSH connection failed for '{host}': {error_msg}"
)
except subprocess.TimeoutExpired:
raise ValidationError(
f"SSH connection timed out for '{host}' (>{timeout}s)"
)
except Exception as e:
raise ValidationError(f"Error testing SSH connection to '{host}': {e}")
def validate_tailscale_connection(host: str) -> bool:
"""
Test Tailscale connectivity to host.
Args:
host: Host to check
Returns:
True if Tailscale connection active
Raises:
ValidationError: If Tailscale not connected
Example:
>>> validate_tailscale_connection("web-01")
True
"""
try:
# Check if tailscale is running
result = subprocess.run(
["tailscale", "status"],
capture_output=True,
text=True,
timeout=5
)
if result.returncode != 0:
raise ValidationError(
"Tailscale is not running\n"
"Start Tailscale: sudo tailscale up"
)
# Check if specific host is in the network
if host in result.stdout or host.replace('-', '.') in result.stdout:
return True
else:
raise ValidationError(
f"Host '{host}' not found in Tailscale network\n"
"Ensure host is:\n"
"1. Connected to Tailscale\n"
"2. In the same tailnet\n"
"3. Not expired/offline"
)
except FileNotFoundError:
raise ValidationError(
"Tailscale not installed\n"
"Install: https://tailscale.com/download"
)
except subprocess.TimeoutExpired:
raise ValidationError("Timeout checking Tailscale status")
except Exception as e:
raise ValidationError(f"Error checking Tailscale connection: {e}")
def validate_ssh_key(host: str) -> bool:
"""
Check SSH key authentication is working.
Args:
host: Host to check
Returns:
True if SSH key auth works
Raises:
ValidationError: If key auth fails
Example:
>>> validate_ssh_key("web-01")
True
"""
try:
# Test connection with explicit key-only auth
result = subprocess.run(
["ssh", "-o", "BatchMode=yes",
"-o", "PasswordAuthentication=no",
"-o", "ConnectTimeout=5",
host, "echo", "test"],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
return True
else:
error_msg = result.stderr.strip()
if "Permission denied" in error_msg:
raise ValidationError(
f"SSH key authentication failed for '{host}'\n"
"Fix:\n"
"1. Add your SSH key: ssh-add ~/.ssh/id_ed25519\n"
"2. Copy public key to remote: ssh-copy-id {}\n"
"3. Verify: ssh -v {} 2>&1 | grep -i 'offering public key'".format(host, host)
)
else:
raise ValidationError(
f"SSH key validation failed for '{host}': {error_msg}"
)
except subprocess.TimeoutExpired:
raise ValidationError(f"Timeout validating SSH key for '{host}'")
except Exception as e:
raise ValidationError(f"Error validating SSH key for '{host}': {e}")
def get_connection_diagnostics(host: str) -> Dict[str, any]:
"""
Comprehensive connection testing.
Args:
host: Host to diagnose
Returns:
Dict with diagnostic results:
{
'ping': {'success': bool, 'message': str},
'ssh': {'success': bool, 'message': str},
'tailscale': {'success': bool, 'message': str},
'ssh_key': {'success': bool, 'message': str}
}
Example:
>>> diag = get_connection_diagnostics("web-01")
>>> diag['ssh']['success']
True
"""
diagnostics = {}
# Test 1: Ping
try:
result = subprocess.run(
["ping", "-c", "1", "-W", "2", host],
capture_output=True,
timeout=3
)
diagnostics['ping'] = {
'success': result.returncode == 0,
'message': 'Host is reachable' if result.returncode == 0 else 'Host not reachable'
}
except Exception as e:
diagnostics['ping'] = {'success': False, 'message': str(e)}
# Test 2: SSH connection
try:
validate_ssh_connection(host, timeout=5)
diagnostics['ssh'] = {'success': True, 'message': 'SSH connection works'}
except ValidationError as e:
diagnostics['ssh'] = {'success': False, 'message': str(e).split('\n')[0]}
# Test 3: Tailscale
try:
validate_tailscale_connection(host)
diagnostics['tailscale'] = {'success': True, 'message': 'Tailscale connected'}
except ValidationError as e:
diagnostics['tailscale'] = {'success': False, 'message': str(e).split('\n')[0]}
# Test 4: SSH key
try:
validate_ssh_key(host)
diagnostics['ssh_key'] = {'success': True, 'message': 'SSH key authentication works'}
except ValidationError as e:
diagnostics['ssh_key'] = {'success': False, 'message': str(e).split('\n')[0]}
return diagnostics
def main():
"""Test connection validators."""
print("Testing connection validators...\n")
print("1. Testing connection diagnostics:")
try:
diag = get_connection_diagnostics("localhost")
print(" Results:")
for test, result in diag.items():
status = "" if result['success'] else ""
print(f" {status} {test}: {result['message']}")
except Exception as e:
print(f" Error: {e}")
print("\n✅ Connection validators tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
Host validators for Tailscale SSH Sync Agent.
Validates host configuration and availability.
"""
import subprocess
from typing import List, Dict, Optional
from pathlib import Path
import logging
from .parameter_validator import ValidationError
logger = logging.getLogger(__name__)
def validate_ssh_config(host: str, config_path: Optional[Path] = None) -> bool:
"""
Check if host has SSH config entry.
Args:
host: Host name to check
config_path: Path to SSH config (default: ~/.ssh/config)
Returns:
True if host is in SSH config
Raises:
ValidationError: If host not found in config
Example:
>>> validate_ssh_config("web-01")
True
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
raise ValidationError(
f"SSH config file not found: {config_path}\n"
"Create ~/.ssh/config with your host definitions"
)
# Parse SSH config for this host
host_found = False
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
if line.lower().startswith('host ') and host in line:
host_found = True
break
if not host_found:
raise ValidationError(
f"Host '{host}' not found in SSH config: {config_path}\n"
"Add host to SSH config:\n"
f"Host {host}\n"
f" HostName <IP_ADDRESS>\n"
f" User <USERNAME>"
)
return True
except IOError as e:
raise ValidationError(f"Error reading SSH config: {e}")
def validate_host_reachable(host: str, timeout: int = 5) -> bool:
"""
Check if host is reachable via ping.
Args:
host: Host name to check
timeout: Timeout in seconds
Returns:
True if host is reachable
Raises:
ValidationError: If host is not reachable
Example:
>>> validate_host_reachable("web-01", timeout=5)
True
"""
try:
# Try to resolve via SSH config first
result = subprocess.run(
["ssh", "-G", host],
capture_output=True,
text=True,
timeout=2
)
if result.returncode == 0:
# Extract hostname from SSH config
for line in result.stdout.split('\n'):
if line.startswith('hostname '):
actual_host = line.split()[1]
break
else:
actual_host = host
else:
actual_host = host
# Ping the host
ping_result = subprocess.run(
["ping", "-c", "1", "-W", str(timeout), actual_host],
capture_output=True,
text=True,
timeout=timeout + 1
)
if ping_result.returncode == 0:
return True
else:
raise ValidationError(
f"Host '{host}' ({actual_host}) is not reachable\n"
"Check:\n"
"1. Host is powered on\n"
"2. Tailscale is connected\n"
"3. Network connectivity"
)
except subprocess.TimeoutExpired:
raise ValidationError(f"Timeout checking host '{host}' (>{timeout}s)")
except Exception as e:
raise ValidationError(f"Error checking host '{host}': {e}")
def validate_group_members(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Ensure group has valid members.
Args:
group: Group name
groups_config: Groups configuration dict
Returns:
List of valid hosts in group
Raises:
ValidationError: If group is empty or has no valid members
Example:
>>> groups = {'production': ['web-01', 'db-01']}
>>> validate_group_members('production', groups)
['web-01', 'db-01']
"""
if group not in groups_config:
raise ValidationError(
f"Group '{group}' not found in configuration\n"
f"Available groups: {', '.join(groups_config.keys())}"
)
members = groups_config[group]
if not members:
raise ValidationError(
f"Group '{group}' has no members\n"
f"Add hosts to group with: sshsync gadd {group}"
)
if not isinstance(members, list):
raise ValidationError(
f"Invalid group configuration for '{group}': members must be a list"
)
return members
def get_invalid_hosts(hosts: List[str], config_path: Optional[Path] = None) -> List[str]:
"""
Find hosts without valid SSH config.
Args:
hosts: List of host names
config_path: Path to SSH config
Returns:
List of hosts without valid config
Example:
>>> get_invalid_hosts(["web-01", "nonexistent"])
["nonexistent"]
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
return hosts # All invalid if no config
# Parse SSH config
valid_hosts = set()
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
if line.lower().startswith('host '):
host_alias = line.split(maxsplit=1)[1]
if '*' not in host_alias and '?' not in host_alias:
valid_hosts.add(host_alias)
except IOError:
return hosts
# Find invalid hosts
return [h for h in hosts if h not in valid_hosts]
def main():
"""Test host validators."""
print("Testing host validators...\n")
print("1. Testing validate_ssh_config():")
try:
validate_ssh_config("localhost")
print(" ✓ localhost has SSH config")
except ValidationError as e:
print(f" Note: {e.args[0].split(chr(10))[0]}")
print("\n2. Testing get_invalid_hosts():")
test_hosts = ["localhost", "nonexistent-host-12345"]
invalid = get_invalid_hosts(test_hosts)
print(f" Invalid hosts: {invalid}")
print("\n✅ Host validators tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,363 @@
#!/usr/bin/env python3
"""
Parameter validators for Tailscale SSH Sync Agent.
Validates user inputs before making operations.
"""
from typing import List, Optional
from pathlib import Path
import re
import logging
logger = logging.getLogger(__name__)
class ValidationError(Exception):
"""Raised when validation fails."""
pass
def validate_host(host: str, valid_hosts: Optional[List[str]] = None) -> str:
"""
Validate host parameter.
Args:
host: Host name or alias
valid_hosts: List of valid hosts (None to skip check)
Returns:
str: Validated and normalized host name
Raises:
ValidationError: If host is invalid
Example:
>>> validate_host("web-01")
"web-01"
>>> validate_host("web-01", ["web-01", "web-02"])
"web-01"
"""
if not host:
raise ValidationError("Host cannot be empty")
if not isinstance(host, str):
raise ValidationError(f"Host must be string, got {type(host)}")
# Normalize (strip whitespace, lowercase for comparison)
host = host.strip()
# Basic validation: alphanumeric, dash, underscore, dot
if not re.match(r'^[a-zA-Z0-9._-]+$', host):
raise ValidationError(
f"Invalid host name format: {host}\n"
"Host names must contain only letters, numbers, dots, dashes, and underscores"
)
# Check if valid (if list provided)
if valid_hosts:
# Try exact match first
if host in valid_hosts:
return host
# Try case-insensitive match
for valid_host in valid_hosts:
if host.lower() == valid_host.lower():
return valid_host
# Not found - provide suggestions
suggestions = [h for h in valid_hosts if host[:3].lower() in h.lower()]
raise ValidationError(
f"Invalid host: {host}\n"
f"Valid options: {', '.join(valid_hosts[:10])}\n"
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
)
return host
def validate_group(group: str, valid_groups: Optional[List[str]] = None) -> str:
"""
Validate group parameter.
Args:
group: Group name
valid_groups: List of valid groups (None to skip check)
Returns:
str: Validated group name
Raises:
ValidationError: If group is invalid
Example:
>>> validate_group("production")
"production"
>>> validate_group("prod", ["production", "development"])
ValidationError: Invalid group: prod
"""
if not group:
raise ValidationError("Group cannot be empty")
if not isinstance(group, str):
raise ValidationError(f"Group must be string, got {type(group)}")
# Normalize
group = group.strip().lower()
# Basic validation
if not re.match(r'^[a-z0-9_-]+$', group):
raise ValidationError(
f"Invalid group name format: {group}\n"
"Group names must contain only lowercase letters, numbers, dashes, and underscores"
)
# Check if valid (if list provided)
if valid_groups:
if group not in valid_groups:
suggestions = [g for g in valid_groups if group[:3] in g]
raise ValidationError(
f"Invalid group: {group}\n"
f"Valid groups: {', '.join(valid_groups)}\n"
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
)
return group
def validate_path_exists(path: str, must_be_file: bool = False,
must_be_dir: bool = False) -> Path:
"""
Validate path exists and is accessible.
Args:
path: Path to validate
must_be_file: If True, path must be a file
must_be_dir: If True, path must be a directory
Returns:
Path: Validated Path object
Raises:
ValidationError: If path is invalid
Example:
>>> validate_path_exists("/tmp", must_be_dir=True)
Path('/tmp')
>>> validate_path_exists("/nonexistent")
ValidationError: Path does not exist: /nonexistent
"""
if not path:
raise ValidationError("Path cannot be empty")
p = Path(path).expanduser().resolve()
if not p.exists():
raise ValidationError(
f"Path does not exist: {path}\n"
f"Resolved to: {p}"
)
if must_be_file and not p.is_file():
raise ValidationError(f"Path must be a file: {path}")
if must_be_dir and not p.is_dir():
raise ValidationError(f"Path must be a directory: {path}")
return p
def validate_timeout(timeout: int, min_timeout: int = 1,
max_timeout: int = 600) -> int:
"""
Validate timeout parameter.
Args:
timeout: Timeout in seconds
min_timeout: Minimum allowed timeout
max_timeout: Maximum allowed timeout
Returns:
int: Validated timeout
Raises:
ValidationError: If timeout is invalid
Example:
>>> validate_timeout(10)
10
>>> validate_timeout(0)
ValidationError: Timeout must be between 1 and 600 seconds
"""
if not isinstance(timeout, int):
raise ValidationError(f"Timeout must be integer, got {type(timeout)}")
if timeout < min_timeout:
raise ValidationError(
f"Timeout too low: {timeout}s (minimum: {min_timeout}s)"
)
if timeout > max_timeout:
raise ValidationError(
f"Timeout too high: {timeout}s (maximum: {max_timeout}s)"
)
return timeout
def validate_command(command: str, allow_dangerous: bool = False) -> str:
"""
Basic command safety validation.
Args:
command: Command to validate
allow_dangerous: If False, block potentially dangerous commands
Returns:
str: Validated command
Raises:
ValidationError: If command is invalid or dangerous
Example:
>>> validate_command("ls -la")
"ls -la"
>>> validate_command("rm -rf /", allow_dangerous=False)
ValidationError: Potentially dangerous command blocked: rm -rf
"""
if not command:
raise ValidationError("Command cannot be empty")
if not isinstance(command, str):
raise ValidationError(f"Command must be string, got {type(command)}")
command = command.strip()
if not allow_dangerous:
# Check for dangerous patterns
dangerous_patterns = [
(r'\brm\s+-rf\s+/', "rm -rf on root directory"),
(r'\bmkfs\.', "filesystem formatting"),
(r'\bdd\s+.*of=/dev/', "disk writing with dd"),
(r':(){:|:&};:', "fork bomb"),
(r'>\s*/dev/sd[a-z]', "direct disk writing"),
]
for pattern, description in dangerous_patterns:
if re.search(pattern, command, re.IGNORECASE):
raise ValidationError(
f"Potentially dangerous command blocked: {description}\n"
f"Command: {command}\n"
"Use allow_dangerous=True if you really want to execute this"
)
return command
def validate_hosts_list(hosts: List[str], valid_hosts: Optional[List[str]] = None) -> List[str]:
"""
Validate a list of hosts.
Args:
hosts: List of host names
valid_hosts: List of valid hosts (None to skip check)
Returns:
List[str]: Validated host names
Raises:
ValidationError: If any host is invalid
Example:
>>> validate_hosts_list(["web-01", "web-02"])
["web-01", "web-02"]
"""
if not hosts:
raise ValidationError("Hosts list cannot be empty")
if not isinstance(hosts, list):
raise ValidationError(f"Hosts must be list, got {type(hosts)}")
validated = []
errors = []
for host in hosts:
try:
validated.append(validate_host(host, valid_hosts))
except ValidationError as e:
errors.append(str(e))
if errors:
raise ValidationError(
f"Invalid hosts in list:\n" + "\n".join(errors)
)
return validated
def main():
"""Test validators."""
print("Testing parameter validators...\n")
# Test host validation
print("1. Testing validate_host():")
try:
host = validate_host("web-01", ["web-01", "web-02", "db-01"])
print(f" ✓ Valid host: {host}")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
host = validate_host("invalid-host", ["web-01", "web-02"])
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
# Test group validation
print("\n2. Testing validate_group():")
try:
group = validate_group("production", ["production", "development"])
print(f" ✓ Valid group: {group}")
except ValidationError as e:
print(f" ✗ Error: {e}")
# Test path validation
print("\n3. Testing validate_path_exists():")
try:
path = validate_path_exists("/tmp", must_be_dir=True)
print(f" ✓ Valid path: {path}")
except ValidationError as e:
print(f" ✗ Error: {e}")
# Test timeout validation
print("\n4. Testing validate_timeout():")
try:
timeout = validate_timeout(10)
print(f" ✓ Valid timeout: {timeout}s")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
timeout = validate_timeout(0)
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
# Test command validation
print("\n5. Testing validate_command():")
try:
cmd = validate_command("ls -la")
print(f" ✓ Safe command: {cmd}")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
cmd = validate_command("rm -rf /", allow_dangerous=False)
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly blocked: {e.args[0].split(chr(10))[0]}")
print("\n✅ All parameter validators tested")
if __name__ == "__main__":
main()