Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:47:40 +08:00
commit 14c678ceac
22 changed files with 7501 additions and 0 deletions

378
scripts/load_balancer.py Normal file
View File

@@ -0,0 +1,378 @@
#!/usr/bin/env python3
"""
Load balancer for Tailscale SSH Sync Agent.
Intelligent task distribution based on machine resources.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import logging
# Add utils to path
sys.path.insert(0, str(Path(__file__).parent))
from utils.helpers import parse_cpu_load, parse_memory_usage, parse_disk_usage, calculate_load_score, classify_load_status
from sshsync_wrapper import execute_on_host
logger = logging.getLogger(__name__)
@dataclass
class MachineMetrics:
"""Resource metrics for a machine."""
host: str
cpu_pct: float
mem_pct: float
disk_pct: float
load_score: float
status: str
def get_machine_load(host: str, timeout: int = 10) -> Optional[MachineMetrics]:
"""
Get CPU, memory, disk metrics for a machine.
Args:
host: Host to check
timeout: Command timeout
Returns:
MachineMetrics object or None on failure
Example:
>>> metrics = get_machine_load("web-01")
>>> metrics.cpu_pct
45.2
>>> metrics.load_score
0.49
"""
try:
# Get CPU load
cpu_result = execute_on_host(host, "uptime", timeout=timeout)
cpu_data = {}
if cpu_result.get('success'):
cpu_data = parse_cpu_load(cpu_result['stdout'])
# Get memory usage
mem_result = execute_on_host(host, "free -m 2>/dev/null || vm_stat", timeout=timeout)
mem_data = {}
if mem_result.get('success'):
mem_data = parse_memory_usage(mem_result['stdout'])
# Get disk usage
disk_result = execute_on_host(host, "df -h / | tail -1", timeout=timeout)
disk_data = {}
if disk_result.get('success'):
disk_data = parse_disk_usage(disk_result['stdout'])
# Calculate metrics
# CPU: Use 1-min load average, normalize by assuming 4 cores (adjust as needed)
cpu_pct = (cpu_data.get('load_1min', 0) / 4.0) * 100 if cpu_data else 50.0
# Memory: Direct percentage
mem_pct = mem_data.get('use_pct', 50.0)
# Disk: Direct percentage
disk_pct = disk_data.get('use_pct', 50.0)
# Calculate load score
score = calculate_load_score(cpu_pct, mem_pct, disk_pct)
status = classify_load_status(score)
return MachineMetrics(
host=host,
cpu_pct=cpu_pct,
mem_pct=mem_pct,
disk_pct=disk_pct,
load_score=score,
status=status
)
except Exception as e:
logger.error(f"Error getting load for {host}: {e}")
return None
def select_optimal_host(candidates: List[str],
prefer_group: Optional[str] = None,
timeout: int = 10) -> Tuple[Optional[str], Optional[MachineMetrics]]:
"""
Pick best host from candidates based on load.
Args:
candidates: List of candidate hosts
prefer_group: Prefer hosts from this group if available
timeout: Timeout for metric gathering
Returns:
Tuple of (selected_host, metrics)
Example:
>>> host, metrics = select_optimal_host(["web-01", "web-02", "web-03"])
>>> host
"web-03"
>>> metrics.load_score
0.28
"""
if not candidates:
return None, None
# Get metrics for all candidates
metrics_list: List[MachineMetrics] = []
for host in candidates:
metrics = get_machine_load(host, timeout=timeout)
if metrics:
metrics_list.append(metrics)
if not metrics_list:
logger.warning("No valid metrics collected from candidates")
return None, None
# Sort by load score (lower is better)
metrics_list.sort(key=lambda m: m.load_score)
# If prefer_group specified, prioritize those hosts if load is similar
if prefer_group:
from utils.helpers import parse_sshsync_config, get_groups_for_host
groups_config = parse_sshsync_config()
# Find hosts in preferred group
preferred_metrics = [
m for m in metrics_list
if prefer_group in get_groups_for_host(m.host, groups_config)
]
# Use preferred if load score within 20% of absolute best
if preferred_metrics:
best_score = metrics_list[0].load_score
for m in preferred_metrics:
if m.load_score <= best_score * 1.2:
return m.host, m
# Return absolute best
best = metrics_list[0]
return best.host, best
def get_group_capacity(group: str, timeout: int = 10) -> Dict:
"""
Get aggregate capacity of a group.
Args:
group: Group name
timeout: Timeout for metric gathering
Returns:
Dict with aggregate metrics:
{
'hosts': List[MachineMetrics],
'total_hosts': int,
'avg_cpu': float,
'avg_mem': float,
'avg_disk': float,
'avg_load_score': float,
'total_capacity': str # descriptive
}
Example:
>>> capacity = get_group_capacity("production")
>>> capacity['avg_load_score']
0.45
"""
from utils.helpers import parse_sshsync_config
groups_config = parse_sshsync_config()
group_hosts = groups_config.get(group, [])
if not group_hosts:
return {
'error': f'Group {group} not found or has no members',
'hosts': []
}
# Get metrics for all hosts in group
metrics_list: List[MachineMetrics] = []
for host in group_hosts:
metrics = get_machine_load(host, timeout=timeout)
if metrics:
metrics_list.append(metrics)
if not metrics_list:
return {
'error': f'Could not get metrics for any hosts in {group}',
'hosts': []
}
# Calculate aggregates
avg_cpu = sum(m.cpu_pct for m in metrics_list) / len(metrics_list)
avg_mem = sum(m.mem_pct for m in metrics_list) / len(metrics_list)
avg_disk = sum(m.disk_pct for m in metrics_list) / len(metrics_list)
avg_score = sum(m.load_score for m in metrics_list) / len(metrics_list)
# Determine overall capacity description
if avg_score < 0.4:
capacity_desc = "High capacity available"
elif avg_score < 0.7:
capacity_desc = "Moderate capacity"
else:
capacity_desc = "Limited capacity"
return {
'group': group,
'hosts': metrics_list,
'total_hosts': len(metrics_list),
'available_hosts': len(group_hosts),
'avg_cpu': avg_cpu,
'avg_mem': avg_mem,
'avg_disk': avg_disk,
'avg_load_score': avg_score,
'total_capacity': capacity_desc
}
def distribute_tasks(tasks: List[Dict], hosts: List[str],
timeout: int = 10) -> Dict[str, List[Dict]]:
"""
Distribute multiple tasks optimally across hosts.
Args:
tasks: List of task dicts (each with 'command', 'priority', etc)
hosts: Available hosts
timeout: Timeout for metric gathering
Returns:
Dict mapping hosts to assigned tasks
Algorithm:
- Get current load for all hosts
- Assign tasks to least loaded hosts
- Balance by estimated task weight
Example:
>>> tasks = [
... {'command': 'npm run build', 'weight': 3},
... {'command': 'npm test', 'weight': 2}
... ]
>>> distribution = distribute_tasks(tasks, ["web-01", "web-02"])
>>> distribution["web-01"]
[{'command': 'npm run build', 'weight': 3}]
"""
if not tasks or not hosts:
return {}
# Get current load for all hosts
host_metrics = {}
for host in hosts:
metrics = get_machine_load(host, timeout=timeout)
if metrics:
host_metrics[host] = metrics
if not host_metrics:
logger.error("No valid host metrics available")
return {}
# Initialize assignment
assignment: Dict[str, List[Dict]] = {host: [] for host in host_metrics.keys()}
host_loads = {host: m.load_score for host, m in host_metrics.items()}
# Sort tasks by weight (descending) to assign heavy tasks first
sorted_tasks = sorted(
tasks,
key=lambda t: t.get('weight', 1),
reverse=True
)
# Assign each task to least loaded host
for task in sorted_tasks:
# Find host with minimum current load
min_host = min(host_loads.keys(), key=lambda h: host_loads[h])
# Assign task
assignment[min_host].append(task)
# Update simulated load (add task weight normalized)
task_weight = task.get('weight', 1)
host_loads[min_host] += (task_weight * 0.1) # 0.1 = scaling factor
return assignment
def format_load_report(metrics: MachineMetrics, compare_to_avg: Optional[Dict] = None) -> str:
"""
Format load metrics as human-readable report.
Args:
metrics: Machine metrics
compare_to_avg: Optional dict with avg_cpu, avg_mem, avg_disk for comparison
Returns:
Formatted report string
Example:
>>> metrics = MachineMetrics('web-01', 45, 60, 40, 0.49, 'moderate')
>>> print(format_load_report(metrics))
web-01: Load Score: 0.49 (moderate)
CPU: 45.0% | Memory: 60.0% | Disk: 40.0%
"""
lines = [
f"{metrics.host}: Load Score: {metrics.load_score:.2f} ({metrics.status})",
f" CPU: {metrics.cpu_pct:.1f}% | Memory: {metrics.mem_pct:.1f}% | Disk: {metrics.disk_pct:.1f}%"
]
if compare_to_avg:
cpu_vs = metrics.cpu_pct - compare_to_avg.get('avg_cpu', 0)
mem_vs = metrics.mem_pct - compare_to_avg.get('avg_mem', 0)
disk_vs = metrics.disk_pct - compare_to_avg.get('avg_disk', 0)
comparisons = []
if abs(cpu_vs) > 10:
comparisons.append(f"CPU {'+' if cpu_vs > 0 else ''}{cpu_vs:.0f}% vs avg")
if abs(mem_vs) > 10:
comparisons.append(f"Mem {'+' if mem_vs > 0 else ''}{mem_vs:.0f}% vs avg")
if abs(disk_vs) > 10:
comparisons.append(f"Disk {'+' if disk_vs > 0 else ''}{disk_vs:.0f}% vs avg")
if comparisons:
lines.append(f" vs Average: {' | '.join(comparisons)}")
return "\n".join(lines)
def main():
"""Test load balancer functions."""
print("Testing load balancer...\n")
print("1. Testing select_optimal_host:")
print(" (Requires configured hosts - using dry-run simulation)")
# Simulate metrics
test_metrics = [
MachineMetrics('web-01', 45, 60, 40, 0.49, 'moderate'),
MachineMetrics('web-02', 85, 70, 65, 0.75, 'high'),
MachineMetrics('web-03', 20, 35, 30, 0.28, 'low'),
]
# Sort by score
test_metrics.sort(key=lambda m: m.load_score)
best = test_metrics[0]
print(f" ✓ Best host: {best.host} (score: {best.load_score:.2f})")
print(f" Reason: {best.status} load")
print("\n2. Format load report:")
report = format_load_report(test_metrics[0], {
'avg_cpu': 50,
'avg_mem': 55,
'avg_disk': 45
})
print(report)
print("\n✅ Load balancer tested")
if __name__ == "__main__":
main()

409
scripts/sshsync_wrapper.py Normal file
View File

@@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
SSH Sync wrapper for Tailscale SSH Sync Agent.
Python interface to sshsync CLI operations.
"""
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import json
import logging
# Add utils to path
sys.path.insert(0, str(Path(__file__).parent))
from utils.helpers import parse_ssh_config, parse_sshsync_config, format_bytes, format_duration
from utils.validators import validate_host, validate_group, validate_path_exists, validate_timeout, validate_command
logger = logging.getLogger(__name__)
def get_host_status(group: Optional[str] = None) -> Dict:
"""
Get online/offline status of hosts.
Args:
group: Optional group to filter (None = all hosts)
Returns:
Dict with status info
Example:
>>> status = get_host_status()
>>> status['online_count']
8
"""
try:
# Run sshsync ls --with-status
cmd = ["sshsync", "ls", "--with-status"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
return {'error': result.stderr, 'hosts': []}
# Parse output
hosts = []
for line in result.stdout.strip().split('\n'):
if not line or line.startswith('Host') or line.startswith('---'):
continue
parts = line.split()
if len(parts) >= 2:
host_name = parts[0]
status = parts[1] if len(parts) > 1 else 'unknown'
hosts.append({
'host': host_name,
'online': status.lower() in ['online', 'reachable', ''],
'status': status
})
# Filter by group if specified
if group:
groups_config = parse_sshsync_config()
group_hosts = groups_config.get(group, [])
hosts = [h for h in hosts if h['host'] in group_hosts]
online_count = sum(1 for h in hosts if h['online'])
return {
'hosts': hosts,
'total_count': len(hosts),
'online_count': online_count,
'offline_count': len(hosts) - online_count,
'availability_pct': (online_count / len(hosts) * 100) if hosts else 0
}
except Exception as e:
logger.error(f"Error getting host status: {e}")
return {'error': str(e), 'hosts': []}
def execute_on_all(command: str, timeout: int = 10, dry_run: bool = False) -> Dict:
"""
Execute command on all hosts.
Args:
command: Command to execute
timeout: Timeout in seconds
dry_run: If True, don't actually execute
Returns:
Dict with results per host
Example:
>>> result = execute_on_all("uptime", timeout=15)
>>> len(result['results'])
10
"""
validate_command(command)
validate_timeout(timeout)
if dry_run:
return {
'dry_run': True,
'command': command,
'message': 'Would execute on all hosts'
}
try:
cmd = ["sshsync", "all", f"--timeout={timeout}", command]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 30)
# Parse results (format varies, simplified here)
return {
'success': result.returncode == 0,
'stdout': result.stdout,
'stderr': result.stderr,
'command': command
}
except subprocess.TimeoutExpired:
return {'error': f'Command timed out after {timeout}s'}
except Exception as e:
return {'error': str(e)}
def execute_on_group(group: str, command: str, timeout: int = 10, dry_run: bool = False) -> Dict:
"""
Execute command on specific group.
Args:
group: Group name
command: Command to execute
timeout: Timeout in seconds
dry_run: Preview without executing
Returns:
Dict with execution results
Example:
>>> result = execute_on_group("web-servers", "df -h /var/www")
>>> result['success']
True
"""
groups_config = parse_sshsync_config()
validate_group(group, list(groups_config.keys()))
validate_command(command)
validate_timeout(timeout)
if dry_run:
group_hosts = groups_config.get(group, [])
return {
'dry_run': True,
'group': group,
'hosts': group_hosts,
'command': command,
'message': f'Would execute on {len(group_hosts)} hosts in group {group}'
}
try:
cmd = ["sshsync", "group", f"--timeout={timeout}", group, command]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 30)
return {
'success': result.returncode == 0,
'group': group,
'stdout': result.stdout,
'stderr': result.stderr,
'command': command
}
except subprocess.TimeoutExpired:
return {'error': f'Command timed out after {timeout}s'}
except Exception as e:
return {'error': str(e)}
def execute_on_host(host: str, command: str, timeout: int = 10) -> Dict:
"""
Execute command on single host.
Args:
host: Host name
command: Command to execute
timeout: Timeout in seconds
Returns:
Dict with result
Example:
>>> result = execute_on_host("web-01", "hostname")
>>> result['stdout']
"web-01"
"""
ssh_hosts = parse_ssh_config()
validate_host(host, list(ssh_hosts.keys()))
validate_command(command)
validate_timeout(timeout)
try:
cmd = ["ssh", "-o", f"ConnectTimeout={timeout}", host, command]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 5)
return {
'success': result.returncode == 0,
'host': host,
'stdout': result.stdout,
'stderr': result.stderr,
'command': command
}
except subprocess.TimeoutExpired:
return {'error': f'Command timed out after {timeout}s'}
except Exception as e:
return {'error': str(e)}
def push_to_hosts(local_path: str, remote_path: str,
hosts: Optional[List[str]] = None,
group: Optional[str] = None,
recurse: bool = False,
dry_run: bool = False) -> Dict:
"""
Push files to hosts.
Args:
local_path: Local file/directory path
remote_path: Remote destination path
hosts: Specific hosts (None = all if group also None)
group: Group name
recurse: Recursive copy
dry_run: Preview without executing
Returns:
Dict with push results
Example:
>>> result = push_to_hosts("./dist", "/var/www/app", group="production", recurse=True)
>>> result['success']
True
"""
validate_path_exists(local_path)
if dry_run:
return {
'dry_run': True,
'local_path': local_path,
'remote_path': remote_path,
'hosts': hosts,
'group': group,
'recurse': recurse,
'message': 'Would push files'
}
try:
cmd = ["sshsync", "push"]
if hosts:
for host in hosts:
cmd.extend(["--host", host])
elif group:
cmd.extend(["--group", group])
else:
cmd.append("--all")
if recurse:
cmd.append("--recurse")
cmd.extend([local_path, remote_path])
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
return {
'success': result.returncode == 0,
'stdout': result.stdout,
'stderr': result.stderr,
'local_path': local_path,
'remote_path': remote_path
}
except subprocess.TimeoutExpired:
return {'error': 'Push operation timed out'}
except Exception as e:
return {'error': str(e)}
def pull_from_host(host: str, remote_path: str, local_path: str,
recurse: bool = False, dry_run: bool = False) -> Dict:
"""
Pull files from host.
Args:
host: Host to pull from
remote_path: Remote file/directory path
local_path: Local destination path
recurse: Recursive copy
dry_run: Preview without executing
Returns:
Dict with pull results
Example:
>>> result = pull_from_host("web-01", "/var/log/nginx", "./logs", recurse=True)
>>> result['success']
True
"""
ssh_hosts = parse_ssh_config()
validate_host(host, list(ssh_hosts.keys()))
if dry_run:
return {
'dry_run': True,
'host': host,
'remote_path': remote_path,
'local_path': local_path,
'recurse': recurse,
'message': f'Would pull from {host}'
}
try:
cmd = ["sshsync", "pull", "--host", host]
if recurse:
cmd.append("--recurse")
cmd.extend([remote_path, local_path])
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
return {
'success': result.returncode == 0,
'host': host,
'stdout': result.stdout,
'stderr': result.stderr,
'remote_path': remote_path,
'local_path': local_path
}
except subprocess.TimeoutExpired:
return {'error': 'Pull operation timed out'}
except Exception as e:
return {'error': str(e)}
def list_hosts(with_status: bool = True) -> Dict:
"""
List all configured hosts.
Args:
with_status: Include online/offline status
Returns:
Dict with hosts info
Example:
>>> result = list_hosts(with_status=True)
>>> len(result['hosts'])
10
"""
if with_status:
return get_host_status()
else:
ssh_hosts = parse_ssh_config()
return {
'hosts': [{'host': name} for name in ssh_hosts.keys()],
'count': len(ssh_hosts)
}
def get_groups() -> Dict[str, List[str]]:
"""
Get all defined groups and their members.
Returns:
Dict mapping group names to host lists
Example:
>>> groups = get_groups()
>>> groups['production']
['prod-web-01', 'prod-db-01']
"""
return parse_sshsync_config()
def main():
"""Test sshsync wrapper functions."""
print("Testing sshsync wrapper...\n")
print("1. List hosts:")
result = list_hosts(with_status=False)
print(f" Found {result.get('count', 0)} hosts")
print("\n2. Get groups:")
groups = get_groups()
print(f" Found {len(groups)} groups")
for group, hosts in groups.items():
print(f" - {group}: {len(hosts)} hosts")
print("\n3. Test dry-run:")
result = execute_on_all("uptime", dry_run=True)
print(f" Dry-run: {result.get('message', 'OK')}")
print("\n✅ sshsync wrapper tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,426 @@
#!/usr/bin/env python3
"""
Tailscale manager for Tailscale SSH Sync Agent.
Tailscale-specific operations and status management.
"""
import subprocess
import re
import json
from typing import Dict, List, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class TailscalePeer:
"""Represents a Tailscale peer."""
hostname: str
ip: str
online: bool
last_seen: Optional[str] = None
os: Optional[str] = None
relay: Optional[str] = None
def get_tailscale_status() -> Dict:
"""
Get Tailscale network status (all peers).
Returns:
Dict with network status:
{
'connected': bool,
'peers': List[TailscalePeer],
'online_count': int,
'total_count': int,
'self_ip': str
}
Example:
>>> status = get_tailscale_status()
>>> status['online_count']
8
>>> status['peers'][0].hostname
'homelab-1'
"""
try:
# Get status in JSON format
result = subprocess.run(
["tailscale", "status", "--json"],
capture_output=True,
text=True,
timeout=10
)
if result.returncode != 0:
# Try text format if JSON fails
result = subprocess.run(
["tailscale", "status"],
capture_output=True,
text=True,
timeout=10
)
if result.returncode != 0:
return {
'connected': False,
'error': 'Tailscale not running or accessible',
'peers': []
}
# Parse text format
return _parse_text_status(result.stdout)
# Parse JSON format
data = json.loads(result.stdout)
return _parse_json_status(data)
except FileNotFoundError:
return {
'connected': False,
'error': 'Tailscale not installed',
'peers': []
}
except subprocess.TimeoutExpired:
return {
'connected': False,
'error': 'Timeout getting Tailscale status',
'peers': []
}
except Exception as e:
logger.error(f"Error getting Tailscale status: {e}")
return {
'connected': False,
'error': str(e),
'peers': []
}
def _parse_json_status(data: Dict) -> Dict:
"""Parse Tailscale JSON status."""
peers = []
self_data = data.get('Self', {})
self_ip = self_data.get('TailscaleIPs', [''])[0]
for peer_id, peer_data in data.get('Peer', {}).items():
hostname = peer_data.get('HostName', 'unknown')
ips = peer_data.get('TailscaleIPs', [])
ip = ips[0] if ips else 'unknown'
online = peer_data.get('Online', False)
os = peer_data.get('OS', 'unknown')
peers.append(TailscalePeer(
hostname=hostname,
ip=ip,
online=online,
os=os
))
online_count = sum(1 for p in peers if p.online)
return {
'connected': True,
'peers': peers,
'online_count': online_count,
'total_count': len(peers),
'self_ip': self_ip
}
def _parse_text_status(output: str) -> Dict:
"""Parse Tailscale text status output."""
peers = []
self_ip = None
for line in output.strip().split('\n'):
line = line.strip()
if not line:
continue
# Parse format: hostname ip status ...
parts = line.split()
if len(parts) >= 2:
hostname = parts[0]
ip = parts[1] if len(parts) > 1 else 'unknown'
# Check for self (usually marked with *)
if hostname.endswith('-'):
self_ip = ip
continue
# Determine online status from additional fields
online = 'offline' not in line.lower()
peers.append(TailscalePeer(
hostname=hostname,
ip=ip,
online=online
))
online_count = sum(1 for p in peers if p.online)
return {
'connected': True,
'peers': peers,
'online_count': online_count,
'total_count': len(peers),
'self_ip': self_ip or 'unknown'
}
def check_connectivity(host: str, timeout: int = 5) -> bool:
"""
Ping host via Tailscale.
Args:
host: Hostname to ping
timeout: Timeout in seconds
Returns:
True if host responds to ping
Example:
>>> check_connectivity("homelab-1")
True
"""
try:
result = subprocess.run(
["tailscale", "ping", "--timeout", f"{timeout}s", "--c", "1", host],
capture_output=True,
text=True,
timeout=timeout + 2
)
# Check if ping succeeded
return result.returncode == 0 or 'pong' in result.stdout.lower()
except (FileNotFoundError, subprocess.TimeoutExpired):
return False
except Exception as e:
logger.error(f"Error pinging {host}: {e}")
return False
def get_peer_info(hostname: str) -> Optional[TailscalePeer]:
"""
Get detailed info about a specific peer.
Args:
hostname: Peer hostname
Returns:
TailscalePeer object or None if not found
Example:
>>> peer = get_peer_info("homelab-1")
>>> peer.ip
'100.64.1.10'
"""
status = get_tailscale_status()
if not status.get('connected'):
return None
for peer in status.get('peers', []):
if peer.hostname == hostname or hostname in peer.hostname:
return peer
return None
def list_online_machines() -> List[str]:
"""
List all online Tailscale machines.
Returns:
List of online machine hostnames
Example:
>>> machines = list_online_machines()
>>> len(machines)
8
"""
status = get_tailscale_status()
if not status.get('connected'):
return []
return [
peer.hostname
for peer in status.get('peers', [])
if peer.online
]
def get_machine_ip(hostname: str) -> Optional[str]:
"""
Get Tailscale IP for a machine.
Args:
hostname: Machine hostname
Returns:
IP address or None if not found
Example:
>>> ip = get_machine_ip("homelab-1")
>>> ip
'100.64.1.10'
"""
peer = get_peer_info(hostname)
return peer.ip if peer else None
def validate_tailscale_ssh(host: str, timeout: int = 10) -> Dict:
"""
Check if Tailscale SSH is working for a host.
Args:
host: Host to check
timeout: Connection timeout
Returns:
Dict with validation results:
{
'working': bool,
'message': str,
'details': Dict
}
Example:
>>> result = validate_tailscale_ssh("homelab-1")
>>> result['working']
True
"""
# First check if host is in Tailscale network
peer = get_peer_info(host)
if not peer:
return {
'working': False,
'message': f'Host {host} not found in Tailscale network',
'details': {'peer_found': False}
}
if not peer.online:
return {
'working': False,
'message': f'Host {host} is offline in Tailscale',
'details': {'peer_found': True, 'online': False}
}
# Check connectivity
if not check_connectivity(host, timeout=timeout):
return {
'working': False,
'message': f'Cannot ping {host} via Tailscale',
'details': {'peer_found': True, 'online': True, 'ping': False}
}
# Try SSH connection
try:
result = subprocess.run(
["tailscale", "ssh", host, "echo", "test"],
capture_output=True,
text=True,
timeout=timeout
)
if result.returncode == 0:
return {
'working': True,
'message': f'Tailscale SSH to {host} is working',
'details': {
'peer_found': True,
'online': True,
'ping': True,
'ssh': True,
'ip': peer.ip
}
}
else:
return {
'working': False,
'message': f'Tailscale SSH failed: {result.stderr}',
'details': {
'peer_found': True,
'online': True,
'ping': True,
'ssh': False,
'error': result.stderr
}
}
except subprocess.TimeoutExpired:
return {
'working': False,
'message': f'Tailscale SSH timed out after {timeout}s',
'details': {'timeout': True}
}
except Exception as e:
return {
'working': False,
'message': f'Error testing Tailscale SSH: {e}',
'details': {'error': str(e)}
}
def get_network_summary() -> str:
"""
Get human-readable network summary.
Returns:
Formatted summary string
Example:
>>> print(get_network_summary())
Tailscale Network: Connected
Online: 8/10 machines (80%)
Self IP: 100.64.1.5
"""
status = get_tailscale_status()
if not status.get('connected'):
return "Tailscale Network: Not connected\nError: {}".format(
status.get('error', 'Unknown error')
)
lines = [
"Tailscale Network: Connected",
f"Online: {status['online_count']}/{status['total_count']} machines ({status['online_count']/status['total_count']*100:.0f}%)",
f"Self IP: {status.get('self_ip', 'unknown')}"
]
return "\n".join(lines)
def main():
"""Test Tailscale manager functions."""
print("Testing Tailscale manager...\n")
print("1. Get Tailscale status:")
status = get_tailscale_status()
if status.get('connected'):
print(f" ✓ Connected")
print(f" Peers: {status['total_count']} total, {status['online_count']} online")
else:
print(f" ✗ Not connected: {status.get('error', 'Unknown error')}")
print("\n2. List online machines:")
machines = list_online_machines()
print(f" Found {len(machines)} online machines")
for machine in machines[:5]: # Show first 5
print(f" - {machine}")
print("\n3. Network summary:")
print(get_network_summary())
print("\n✅ Tailscale manager tested")
if __name__ == "__main__":
main()

628
scripts/utils/helpers.py Normal file
View File

@@ -0,0 +1,628 @@
#!/usr/bin/env python3
"""
Helper utilities for Tailscale SSH Sync Agent.
Provides common formatting, parsing, and utility functions.
"""
import os
import re
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import yaml
import logging
logger = logging.getLogger(__name__)
def format_bytes(bytes_value: int) -> str:
"""
Format bytes as human-readable string.
Args:
bytes_value: Number of bytes
Returns:
Formatted string (e.g., "12.3 MB", "1.5 GB")
Example:
>>> format_bytes(12582912)
"12.0 MB"
>>> format_bytes(1610612736)
"1.5 GB"
"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_value < 1024.0:
return f"{bytes_value:.1f} {unit}"
bytes_value /= 1024.0
return f"{bytes_value:.1f} PB"
def format_duration(seconds: float) -> str:
"""
Format duration as human-readable string.
Args:
seconds: Duration in seconds
Returns:
Formatted string (e.g., "2m 15s", "1h 30m")
Example:
>>> format_duration(135)
"2m 15s"
>>> format_duration(5430)
"1h 30m 30s"
"""
if seconds < 60:
return f"{int(seconds)}s"
minutes = int(seconds // 60)
secs = int(seconds % 60)
if minutes < 60:
return f"{minutes}m {secs}s" if secs > 0 else f"{minutes}m"
hours = minutes // 60
minutes = minutes % 60
parts = [f"{hours}h"]
if minutes > 0:
parts.append(f"{minutes}m")
if secs > 0 and hours == 0: # Only show seconds if < 1 hour
parts.append(f"{secs}s")
return " ".join(parts)
def format_percentage(value: float, decimals: int = 1) -> str:
"""
Format percentage with specified decimals.
Args:
value: Percentage value (0-100)
decimals: Number of decimal places
Returns:
Formatted string (e.g., "45.5%")
Example:
>>> format_percentage(45.567)
"45.6%"
"""
return f"{value:.{decimals}f}%"
def parse_ssh_config(config_path: Optional[Path] = None) -> Dict[str, Dict[str, str]]:
"""
Parse SSH config file for host definitions.
Args:
config_path: Path to SSH config (default: ~/.ssh/config)
Returns:
Dict mapping host aliases to their configuration:
{
'host-alias': {
'hostname': '100.64.1.10',
'user': 'admin',
'port': '22',
'identityfile': '~/.ssh/id_ed25519'
}
}
Example:
>>> hosts = parse_ssh_config()
>>> hosts['homelab-1']['hostname']
'100.64.1.10'
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
logger.warning(f"SSH config not found: {config_path}")
return {}
hosts = {}
current_host = None
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Host directive
if line.lower().startswith('host '):
host_alias = line.split(maxsplit=1)[1]
# Skip wildcards
if '*' not in host_alias and '?' not in host_alias:
current_host = host_alias
hosts[current_host] = {}
# Configuration directives
elif current_host:
parts = line.split(maxsplit=1)
if len(parts) == 2:
key, value = parts
hosts[current_host][key.lower()] = value
return hosts
except Exception as e:
logger.error(f"Error parsing SSH config: {e}")
return {}
def parse_sshsync_config(config_path: Optional[Path] = None) -> Dict[str, List[str]]:
"""
Parse sshsync config file for group definitions.
Args:
config_path: Path to sshsync config (default: ~/.config/sshsync/config.yaml)
Returns:
Dict mapping group names to list of hosts:
{
'production': ['prod-web-01', 'prod-db-01'],
'development': ['dev-laptop', 'dev-desktop']
}
Example:
>>> groups = parse_sshsync_config()
>>> groups['production']
['prod-web-01', 'prod-db-01']
"""
if config_path is None:
config_path = Path.home() / '.config' / 'sshsync' / 'config.yaml'
if not config_path.exists():
logger.warning(f"sshsync config not found: {config_path}")
return {}
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config.get('groups', {})
except Exception as e:
logger.error(f"Error parsing sshsync config: {e}")
return {}
def get_timestamp(iso: bool = True) -> str:
"""
Get current timestamp.
Args:
iso: If True, return ISO format; otherwise human-readable
Returns:
Timestamp string
Example:
>>> get_timestamp(iso=True)
"2025-10-19T19:43:41Z"
>>> get_timestamp(iso=False)
"2025-10-19 19:43:41"
"""
now = datetime.now()
if iso:
return now.strftime("%Y-%m-%dT%H:%M:%SZ")
else:
return now.strftime("%Y-%m-%d %H:%M:%S")
def safe_execute(func, *args, default=None, **kwargs) -> Any:
"""
Execute function with error handling.
Args:
func: Function to execute
*args: Positional arguments
default: Value to return on error
**kwargs: Keyword arguments
Returns:
Function result or default on error
Example:
>>> safe_execute(int, "not_a_number", default=0)
0
>>> safe_execute(int, "42")
42
"""
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(f"Error executing {func.__name__}: {e}")
return default
def validate_path(path: str, must_exist: bool = True) -> bool:
"""
Check if path is valid and accessible.
Args:
path: Path to validate
must_exist: If True, path must exist
Returns:
True if valid, False otherwise
Example:
>>> validate_path("/tmp")
True
>>> validate_path("/nonexistent", must_exist=True)
False
"""
p = Path(path).expanduser()
if must_exist:
return p.exists()
else:
# Check if parent directory exists (for paths that will be created)
return p.parent.exists()
def parse_disk_usage(df_output: str) -> Dict[str, Any]:
"""
Parse 'df' command output.
Args:
df_output: Output from 'df -h' command
Returns:
Dict with disk usage info:
{
'filesystem': '/dev/sda1',
'size': '100G',
'used': '45G',
'available': '50G',
'use_pct': 45,
'mount': '/'
}
Example:
>>> output = "Filesystem Size Used Avail Use% Mounted on\\n/dev/sda1 100G 45G 50G 45% /"
>>> parse_disk_usage(output)
{'filesystem': '/dev/sda1', 'size': '100G', ...}
"""
lines = df_output.strip().split('\n')
if len(lines) < 2:
return {}
# Parse last line (actual data, not header)
data_line = lines[-1]
parts = data_line.split()
if len(parts) < 6:
return {}
try:
return {
'filesystem': parts[0],
'size': parts[1],
'used': parts[2],
'available': parts[3],
'use_pct': int(parts[4].rstrip('%')),
'mount': parts[5]
}
except (ValueError, IndexError) as e:
logger.error(f"Error parsing disk usage: {e}")
return {}
def parse_memory_usage(free_output: str) -> Dict[str, Any]:
"""
Parse 'free' command output (Linux).
Args:
free_output: Output from 'free -m' command
Returns:
Dict with memory info:
{
'total': 16384, # MB
'used': 8192,
'free': 8192,
'use_pct': 50.0
}
Example:
>>> output = "Mem: 16384 8192 8192 0 0 0"
>>> parse_memory_usage(output)
{'total': 16384, 'used': 8192, ...}
"""
lines = free_output.strip().split('\n')
for line in lines:
if line.startswith('Mem:'):
parts = line.split()
if len(parts) >= 3:
try:
total = int(parts[1])
used = int(parts[2])
free = int(parts[3]) if len(parts) > 3 else (total - used)
return {
'total': total,
'used': used,
'free': free,
'use_pct': (used / total * 100) if total > 0 else 0
}
except (ValueError, IndexError) as e:
logger.error(f"Error parsing memory usage: {e}")
return {}
def parse_cpu_load(uptime_output: str) -> Dict[str, float]:
"""
Parse 'uptime' command output for load averages.
Args:
uptime_output: Output from 'uptime' command
Returns:
Dict with load averages:
{
'load_1min': 0.45,
'load_5min': 0.38,
'load_15min': 0.32
}
Example:
>>> output = "19:43:41 up 5 days, 2:15, 3 users, load average: 0.45, 0.38, 0.32"
>>> parse_cpu_load(output)
{'load_1min': 0.45, 'load_5min': 0.38, 'load_15min': 0.32}
"""
# Find "load average:" part
match = re.search(r'load average:\s+([\d.]+),\s+([\d.]+),\s+([\d.]+)', uptime_output)
if match:
try:
return {
'load_1min': float(match.group(1)),
'load_5min': float(match.group(2)),
'load_15min': float(match.group(3))
}
except ValueError as e:
logger.error(f"Error parsing CPU load: {e}")
return {}
def format_host_status(host: str, online: bool, groups: List[str],
latency: Optional[int] = None,
tailscale_connected: bool = False) -> str:
"""
Format host status as display string.
Args:
host: Host name
online: Whether host is online
groups: List of groups host belongs to
latency: Latency in ms (optional)
tailscale_connected: Tailscale connection status
Returns:
Formatted status string
Example:
>>> format_host_status("web-01", True, ["production", "web"], 25, True)
"🟢 web-01 (production, web) - Online - Tailscale: Connected | Latency: 25ms"
"""
icon = "🟢" if online else "🔴"
status = "Online" if online else "Offline"
group_str = ", ".join(groups) if groups else "no group"
parts = [f"{icon} {host} ({group_str}) - {status}"]
if tailscale_connected:
parts.append("Tailscale: Connected")
if latency is not None and online:
parts.append(f"Latency: {latency}ms")
return " - ".join(parts)
def calculate_load_score(cpu_pct: float, mem_pct: float, disk_pct: float) -> float:
"""
Calculate composite load score for a machine.
Args:
cpu_pct: CPU usage percentage (0-100)
mem_pct: Memory usage percentage (0-100)
disk_pct: Disk usage percentage (0-100)
Returns:
Load score (0-1, lower is better)
Formula:
score = (cpu * 0.4) + (mem * 0.3) + (disk * 0.3)
Example:
>>> calculate_load_score(45, 60, 40)
0.48 # (0.45*0.4 + 0.60*0.3 + 0.40*0.3)
"""
return (cpu_pct * 0.4 + mem_pct * 0.3 + disk_pct * 0.3) / 100
def classify_load_status(score: float) -> str:
"""
Classify load score into status category.
Args:
score: Load score (0-1)
Returns:
Status string: "low", "moderate", or "high"
Example:
>>> classify_load_status(0.28)
"low"
>>> classify_load_status(0.55)
"moderate"
>>> classify_load_status(0.82)
"high"
"""
if score < 0.4:
return "low"
elif score < 0.7:
return "moderate"
else:
return "high"
def classify_latency(latency_ms: int) -> Tuple[str, str]:
"""
Classify network latency.
Args:
latency_ms: Latency in milliseconds
Returns:
Tuple of (status, description)
Example:
>>> classify_latency(25)
("excellent", "Ideal for interactive tasks")
>>> classify_latency(150)
("fair", "May impact interactive workflows")
"""
if latency_ms < 50:
return ("excellent", "Ideal for interactive tasks")
elif latency_ms < 100:
return ("good", "Suitable for most operations")
elif latency_ms < 200:
return ("fair", "May impact interactive workflows")
else:
return ("poor", "Investigate network issues")
def get_hosts_from_groups(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Get list of hosts in a group.
Args:
group: Group name
groups_config: Groups configuration dict
Returns:
List of host names in group
Example:
>>> groups = {'production': ['web-01', 'db-01']}
>>> get_hosts_from_groups('production', groups)
['web-01', 'db-01']
"""
return groups_config.get(group, [])
def get_groups_for_host(host: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Get list of groups a host belongs to.
Args:
host: Host name
groups_config: Groups configuration dict
Returns:
List of group names
Example:
>>> groups = {'production': ['web-01'], 'web': ['web-01', 'web-02']}
>>> get_groups_for_host('web-01', groups)
['production', 'web']
"""
return [group for group, hosts in groups_config.items() if host in hosts]
def run_command(command: str, timeout: int = 10) -> Tuple[bool, str, str]:
"""
Run shell command with timeout.
Args:
command: Command to execute
timeout: Timeout in seconds
Returns:
Tuple of (success, stdout, stderr)
Example:
>>> success, stdout, stderr = run_command("echo hello")
>>> success
True
>>> stdout.strip()
"hello"
"""
try:
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout
)
return (
result.returncode == 0,
result.stdout,
result.stderr
)
except subprocess.TimeoutExpired:
return (False, "", f"Command timed out after {timeout}s")
except Exception as e:
return (False, "", str(e))
def main():
"""Test helper functions."""
print("Testing helper functions...\n")
# Test formatting
print("1. Format bytes:")
print(f" 12582912 bytes = {format_bytes(12582912)}")
print(f" 1610612736 bytes = {format_bytes(1610612736)}")
print("\n2. Format duration:")
print(f" 135 seconds = {format_duration(135)}")
print(f" 5430 seconds = {format_duration(5430)}")
print("\n3. Format percentage:")
print(f" 45.567 = {format_percentage(45.567)}")
print("\n4. Calculate load score:")
score = calculate_load_score(45, 60, 40)
print(f" CPU 45%, Mem 60%, Disk 40% = {score:.2f}")
print(f" Status: {classify_load_status(score)}")
print("\n5. Classify latency:")
latencies = [25, 75, 150, 250]
for lat in latencies:
status, desc = classify_latency(lat)
print(f" {lat}ms: {status} - {desc}")
print("\n6. Parse SSH config:")
ssh_hosts = parse_ssh_config()
print(f" Found {len(ssh_hosts)} hosts")
print("\n7. Parse sshsync config:")
groups = parse_sshsync_config()
print(f" Found {len(groups)} groups")
for group, hosts in groups.items():
print(f" - {group}: {len(hosts)} hosts")
print("\n✅ All helpers tested successfully")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,43 @@
"""
Validators package for Tailscale SSH Sync Agent.
"""
from .parameter_validator import (
ValidationError,
validate_host,
validate_group,
validate_path_exists,
validate_timeout,
validate_command
)
from .host_validator import (
validate_ssh_config,
validate_host_reachable,
validate_group_members,
get_invalid_hosts
)
from .connection_validator import (
validate_ssh_connection,
validate_tailscale_connection,
validate_ssh_key,
get_connection_diagnostics
)
__all__ = [
'ValidationError',
'validate_host',
'validate_group',
'validate_path_exists',
'validate_timeout',
'validate_command',
'validate_ssh_config',
'validate_host_reachable',
'validate_group_members',
'get_invalid_hosts',
'validate_ssh_connection',
'validate_tailscale_connection',
'validate_ssh_key',
'get_connection_diagnostics',
]

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python3
"""
Connection validators for Tailscale SSH Sync Agent.
Validates SSH and Tailscale connections.
"""
import subprocess
from typing import Dict, Optional
import logging
from .parameter_validator import ValidationError
logger = logging.getLogger(__name__)
def validate_ssh_connection(host: str, timeout: int = 10) -> bool:
"""
Test SSH connection works.
Args:
host: Host to connect to
timeout: Connection timeout in seconds
Returns:
True if SSH connection successful
Raises:
ValidationError: If connection fails
Example:
>>> validate_ssh_connection("web-01")
True
"""
try:
# Try to execute a simple command via SSH
result = subprocess.run(
["ssh", "-o", "ConnectTimeout={}".format(timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
host, "echo", "test"],
capture_output=True,
text=True,
timeout=timeout + 5
)
if result.returncode == 0:
return True
else:
# Parse error message
error_msg = result.stderr.strip()
if "Permission denied" in error_msg:
raise ValidationError(
f"SSH authentication failed for '{host}'\n"
"Check:\n"
"1. SSH key is added: ssh-add -l\n"
"2. Public key is on remote: cat ~/.ssh/authorized_keys\n"
"3. User/key in SSH config is correct"
)
elif "Connection refused" in error_msg:
raise ValidationError(
f"SSH connection refused for '{host}'\n"
"Check:\n"
"1. SSH server is running on remote\n"
"2. Port 22 is not blocked by firewall"
)
elif "Connection timed out" in error_msg or "timeout" in error_msg.lower():
raise ValidationError(
f"SSH connection timed out for '{host}'\n"
"Check:\n"
"1. Host is reachable (ping test)\n"
"2. Tailscale is connected\n"
"3. Network connectivity"
)
else:
raise ValidationError(
f"SSH connection failed for '{host}': {error_msg}"
)
except subprocess.TimeoutExpired:
raise ValidationError(
f"SSH connection timed out for '{host}' (>{timeout}s)"
)
except Exception as e:
raise ValidationError(f"Error testing SSH connection to '{host}': {e}")
def validate_tailscale_connection(host: str) -> bool:
"""
Test Tailscale connectivity to host.
Args:
host: Host to check
Returns:
True if Tailscale connection active
Raises:
ValidationError: If Tailscale not connected
Example:
>>> validate_tailscale_connection("web-01")
True
"""
try:
# Check if tailscale is running
result = subprocess.run(
["tailscale", "status"],
capture_output=True,
text=True,
timeout=5
)
if result.returncode != 0:
raise ValidationError(
"Tailscale is not running\n"
"Start Tailscale: sudo tailscale up"
)
# Check if specific host is in the network
if host in result.stdout or host.replace('-', '.') in result.stdout:
return True
else:
raise ValidationError(
f"Host '{host}' not found in Tailscale network\n"
"Ensure host is:\n"
"1. Connected to Tailscale\n"
"2. In the same tailnet\n"
"3. Not expired/offline"
)
except FileNotFoundError:
raise ValidationError(
"Tailscale not installed\n"
"Install: https://tailscale.com/download"
)
except subprocess.TimeoutExpired:
raise ValidationError("Timeout checking Tailscale status")
except Exception as e:
raise ValidationError(f"Error checking Tailscale connection: {e}")
def validate_ssh_key(host: str) -> bool:
"""
Check SSH key authentication is working.
Args:
host: Host to check
Returns:
True if SSH key auth works
Raises:
ValidationError: If key auth fails
Example:
>>> validate_ssh_key("web-01")
True
"""
try:
# Test connection with explicit key-only auth
result = subprocess.run(
["ssh", "-o", "BatchMode=yes",
"-o", "PasswordAuthentication=no",
"-o", "ConnectTimeout=5",
host, "echo", "test"],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
return True
else:
error_msg = result.stderr.strip()
if "Permission denied" in error_msg:
raise ValidationError(
f"SSH key authentication failed for '{host}'\n"
"Fix:\n"
"1. Add your SSH key: ssh-add ~/.ssh/id_ed25519\n"
"2. Copy public key to remote: ssh-copy-id {}\n"
"3. Verify: ssh -v {} 2>&1 | grep -i 'offering public key'".format(host, host)
)
else:
raise ValidationError(
f"SSH key validation failed for '{host}': {error_msg}"
)
except subprocess.TimeoutExpired:
raise ValidationError(f"Timeout validating SSH key for '{host}'")
except Exception as e:
raise ValidationError(f"Error validating SSH key for '{host}': {e}")
def get_connection_diagnostics(host: str) -> Dict[str, any]:
"""
Comprehensive connection testing.
Args:
host: Host to diagnose
Returns:
Dict with diagnostic results:
{
'ping': {'success': bool, 'message': str},
'ssh': {'success': bool, 'message': str},
'tailscale': {'success': bool, 'message': str},
'ssh_key': {'success': bool, 'message': str}
}
Example:
>>> diag = get_connection_diagnostics("web-01")
>>> diag['ssh']['success']
True
"""
diagnostics = {}
# Test 1: Ping
try:
result = subprocess.run(
["ping", "-c", "1", "-W", "2", host],
capture_output=True,
timeout=3
)
diagnostics['ping'] = {
'success': result.returncode == 0,
'message': 'Host is reachable' if result.returncode == 0 else 'Host not reachable'
}
except Exception as e:
diagnostics['ping'] = {'success': False, 'message': str(e)}
# Test 2: SSH connection
try:
validate_ssh_connection(host, timeout=5)
diagnostics['ssh'] = {'success': True, 'message': 'SSH connection works'}
except ValidationError as e:
diagnostics['ssh'] = {'success': False, 'message': str(e).split('\n')[0]}
# Test 3: Tailscale
try:
validate_tailscale_connection(host)
diagnostics['tailscale'] = {'success': True, 'message': 'Tailscale connected'}
except ValidationError as e:
diagnostics['tailscale'] = {'success': False, 'message': str(e).split('\n')[0]}
# Test 4: SSH key
try:
validate_ssh_key(host)
diagnostics['ssh_key'] = {'success': True, 'message': 'SSH key authentication works'}
except ValidationError as e:
diagnostics['ssh_key'] = {'success': False, 'message': str(e).split('\n')[0]}
return diagnostics
def main():
"""Test connection validators."""
print("Testing connection validators...\n")
print("1. Testing connection diagnostics:")
try:
diag = get_connection_diagnostics("localhost")
print(" Results:")
for test, result in diag.items():
status = "" if result['success'] else ""
print(f" {status} {test}: {result['message']}")
except Exception as e:
print(f" Error: {e}")
print("\n✅ Connection validators tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
Host validators for Tailscale SSH Sync Agent.
Validates host configuration and availability.
"""
import subprocess
from typing import List, Dict, Optional
from pathlib import Path
import logging
from .parameter_validator import ValidationError
logger = logging.getLogger(__name__)
def validate_ssh_config(host: str, config_path: Optional[Path] = None) -> bool:
"""
Check if host has SSH config entry.
Args:
host: Host name to check
config_path: Path to SSH config (default: ~/.ssh/config)
Returns:
True if host is in SSH config
Raises:
ValidationError: If host not found in config
Example:
>>> validate_ssh_config("web-01")
True
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
raise ValidationError(
f"SSH config file not found: {config_path}\n"
"Create ~/.ssh/config with your host definitions"
)
# Parse SSH config for this host
host_found = False
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
if line.lower().startswith('host ') and host in line:
host_found = True
break
if not host_found:
raise ValidationError(
f"Host '{host}' not found in SSH config: {config_path}\n"
"Add host to SSH config:\n"
f"Host {host}\n"
f" HostName <IP_ADDRESS>\n"
f" User <USERNAME>"
)
return True
except IOError as e:
raise ValidationError(f"Error reading SSH config: {e}")
def validate_host_reachable(host: str, timeout: int = 5) -> bool:
"""
Check if host is reachable via ping.
Args:
host: Host name to check
timeout: Timeout in seconds
Returns:
True if host is reachable
Raises:
ValidationError: If host is not reachable
Example:
>>> validate_host_reachable("web-01", timeout=5)
True
"""
try:
# Try to resolve via SSH config first
result = subprocess.run(
["ssh", "-G", host],
capture_output=True,
text=True,
timeout=2
)
if result.returncode == 0:
# Extract hostname from SSH config
for line in result.stdout.split('\n'):
if line.startswith('hostname '):
actual_host = line.split()[1]
break
else:
actual_host = host
else:
actual_host = host
# Ping the host
ping_result = subprocess.run(
["ping", "-c", "1", "-W", str(timeout), actual_host],
capture_output=True,
text=True,
timeout=timeout + 1
)
if ping_result.returncode == 0:
return True
else:
raise ValidationError(
f"Host '{host}' ({actual_host}) is not reachable\n"
"Check:\n"
"1. Host is powered on\n"
"2. Tailscale is connected\n"
"3. Network connectivity"
)
except subprocess.TimeoutExpired:
raise ValidationError(f"Timeout checking host '{host}' (>{timeout}s)")
except Exception as e:
raise ValidationError(f"Error checking host '{host}': {e}")
def validate_group_members(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
"""
Ensure group has valid members.
Args:
group: Group name
groups_config: Groups configuration dict
Returns:
List of valid hosts in group
Raises:
ValidationError: If group is empty or has no valid members
Example:
>>> groups = {'production': ['web-01', 'db-01']}
>>> validate_group_members('production', groups)
['web-01', 'db-01']
"""
if group not in groups_config:
raise ValidationError(
f"Group '{group}' not found in configuration\n"
f"Available groups: {', '.join(groups_config.keys())}"
)
members = groups_config[group]
if not members:
raise ValidationError(
f"Group '{group}' has no members\n"
f"Add hosts to group with: sshsync gadd {group}"
)
if not isinstance(members, list):
raise ValidationError(
f"Invalid group configuration for '{group}': members must be a list"
)
return members
def get_invalid_hosts(hosts: List[str], config_path: Optional[Path] = None) -> List[str]:
"""
Find hosts without valid SSH config.
Args:
hosts: List of host names
config_path: Path to SSH config
Returns:
List of hosts without valid config
Example:
>>> get_invalid_hosts(["web-01", "nonexistent"])
["nonexistent"]
"""
if config_path is None:
config_path = Path.home() / '.ssh' / 'config'
if not config_path.exists():
return hosts # All invalid if no config
# Parse SSH config
valid_hosts = set()
try:
with open(config_path, 'r') as f:
for line in f:
line = line.strip()
if line.lower().startswith('host '):
host_alias = line.split(maxsplit=1)[1]
if '*' not in host_alias and '?' not in host_alias:
valid_hosts.add(host_alias)
except IOError:
return hosts
# Find invalid hosts
return [h for h in hosts if h not in valid_hosts]
def main():
"""Test host validators."""
print("Testing host validators...\n")
print("1. Testing validate_ssh_config():")
try:
validate_ssh_config("localhost")
print(" ✓ localhost has SSH config")
except ValidationError as e:
print(f" Note: {e.args[0].split(chr(10))[0]}")
print("\n2. Testing get_invalid_hosts():")
test_hosts = ["localhost", "nonexistent-host-12345"]
invalid = get_invalid_hosts(test_hosts)
print(f" Invalid hosts: {invalid}")
print("\n✅ Host validators tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,363 @@
#!/usr/bin/env python3
"""
Parameter validators for Tailscale SSH Sync Agent.
Validates user inputs before making operations.
"""
from typing import List, Optional
from pathlib import Path
import re
import logging
logger = logging.getLogger(__name__)
class ValidationError(Exception):
"""Raised when validation fails."""
pass
def validate_host(host: str, valid_hosts: Optional[List[str]] = None) -> str:
"""
Validate host parameter.
Args:
host: Host name or alias
valid_hosts: List of valid hosts (None to skip check)
Returns:
str: Validated and normalized host name
Raises:
ValidationError: If host is invalid
Example:
>>> validate_host("web-01")
"web-01"
>>> validate_host("web-01", ["web-01", "web-02"])
"web-01"
"""
if not host:
raise ValidationError("Host cannot be empty")
if not isinstance(host, str):
raise ValidationError(f"Host must be string, got {type(host)}")
# Normalize (strip whitespace, lowercase for comparison)
host = host.strip()
# Basic validation: alphanumeric, dash, underscore, dot
if not re.match(r'^[a-zA-Z0-9._-]+$', host):
raise ValidationError(
f"Invalid host name format: {host}\n"
"Host names must contain only letters, numbers, dots, dashes, and underscores"
)
# Check if valid (if list provided)
if valid_hosts:
# Try exact match first
if host in valid_hosts:
return host
# Try case-insensitive match
for valid_host in valid_hosts:
if host.lower() == valid_host.lower():
return valid_host
# Not found - provide suggestions
suggestions = [h for h in valid_hosts if host[:3].lower() in h.lower()]
raise ValidationError(
f"Invalid host: {host}\n"
f"Valid options: {', '.join(valid_hosts[:10])}\n"
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
)
return host
def validate_group(group: str, valid_groups: Optional[List[str]] = None) -> str:
"""
Validate group parameter.
Args:
group: Group name
valid_groups: List of valid groups (None to skip check)
Returns:
str: Validated group name
Raises:
ValidationError: If group is invalid
Example:
>>> validate_group("production")
"production"
>>> validate_group("prod", ["production", "development"])
ValidationError: Invalid group: prod
"""
if not group:
raise ValidationError("Group cannot be empty")
if not isinstance(group, str):
raise ValidationError(f"Group must be string, got {type(group)}")
# Normalize
group = group.strip().lower()
# Basic validation
if not re.match(r'^[a-z0-9_-]+$', group):
raise ValidationError(
f"Invalid group name format: {group}\n"
"Group names must contain only lowercase letters, numbers, dashes, and underscores"
)
# Check if valid (if list provided)
if valid_groups:
if group not in valid_groups:
suggestions = [g for g in valid_groups if group[:3] in g]
raise ValidationError(
f"Invalid group: {group}\n"
f"Valid groups: {', '.join(valid_groups)}\n"
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
)
return group
def validate_path_exists(path: str, must_be_file: bool = False,
must_be_dir: bool = False) -> Path:
"""
Validate path exists and is accessible.
Args:
path: Path to validate
must_be_file: If True, path must be a file
must_be_dir: If True, path must be a directory
Returns:
Path: Validated Path object
Raises:
ValidationError: If path is invalid
Example:
>>> validate_path_exists("/tmp", must_be_dir=True)
Path('/tmp')
>>> validate_path_exists("/nonexistent")
ValidationError: Path does not exist: /nonexistent
"""
if not path:
raise ValidationError("Path cannot be empty")
p = Path(path).expanduser().resolve()
if not p.exists():
raise ValidationError(
f"Path does not exist: {path}\n"
f"Resolved to: {p}"
)
if must_be_file and not p.is_file():
raise ValidationError(f"Path must be a file: {path}")
if must_be_dir and not p.is_dir():
raise ValidationError(f"Path must be a directory: {path}")
return p
def validate_timeout(timeout: int, min_timeout: int = 1,
max_timeout: int = 600) -> int:
"""
Validate timeout parameter.
Args:
timeout: Timeout in seconds
min_timeout: Minimum allowed timeout
max_timeout: Maximum allowed timeout
Returns:
int: Validated timeout
Raises:
ValidationError: If timeout is invalid
Example:
>>> validate_timeout(10)
10
>>> validate_timeout(0)
ValidationError: Timeout must be between 1 and 600 seconds
"""
if not isinstance(timeout, int):
raise ValidationError(f"Timeout must be integer, got {type(timeout)}")
if timeout < min_timeout:
raise ValidationError(
f"Timeout too low: {timeout}s (minimum: {min_timeout}s)"
)
if timeout > max_timeout:
raise ValidationError(
f"Timeout too high: {timeout}s (maximum: {max_timeout}s)"
)
return timeout
def validate_command(command: str, allow_dangerous: bool = False) -> str:
"""
Basic command safety validation.
Args:
command: Command to validate
allow_dangerous: If False, block potentially dangerous commands
Returns:
str: Validated command
Raises:
ValidationError: If command is invalid or dangerous
Example:
>>> validate_command("ls -la")
"ls -la"
>>> validate_command("rm -rf /", allow_dangerous=False)
ValidationError: Potentially dangerous command blocked: rm -rf
"""
if not command:
raise ValidationError("Command cannot be empty")
if not isinstance(command, str):
raise ValidationError(f"Command must be string, got {type(command)}")
command = command.strip()
if not allow_dangerous:
# Check for dangerous patterns
dangerous_patterns = [
(r'\brm\s+-rf\s+/', "rm -rf on root directory"),
(r'\bmkfs\.', "filesystem formatting"),
(r'\bdd\s+.*of=/dev/', "disk writing with dd"),
(r':(){:|:&};:', "fork bomb"),
(r'>\s*/dev/sd[a-z]', "direct disk writing"),
]
for pattern, description in dangerous_patterns:
if re.search(pattern, command, re.IGNORECASE):
raise ValidationError(
f"Potentially dangerous command blocked: {description}\n"
f"Command: {command}\n"
"Use allow_dangerous=True if you really want to execute this"
)
return command
def validate_hosts_list(hosts: List[str], valid_hosts: Optional[List[str]] = None) -> List[str]:
"""
Validate a list of hosts.
Args:
hosts: List of host names
valid_hosts: List of valid hosts (None to skip check)
Returns:
List[str]: Validated host names
Raises:
ValidationError: If any host is invalid
Example:
>>> validate_hosts_list(["web-01", "web-02"])
["web-01", "web-02"]
"""
if not hosts:
raise ValidationError("Hosts list cannot be empty")
if not isinstance(hosts, list):
raise ValidationError(f"Hosts must be list, got {type(hosts)}")
validated = []
errors = []
for host in hosts:
try:
validated.append(validate_host(host, valid_hosts))
except ValidationError as e:
errors.append(str(e))
if errors:
raise ValidationError(
f"Invalid hosts in list:\n" + "\n".join(errors)
)
return validated
def main():
"""Test validators."""
print("Testing parameter validators...\n")
# Test host validation
print("1. Testing validate_host():")
try:
host = validate_host("web-01", ["web-01", "web-02", "db-01"])
print(f" ✓ Valid host: {host}")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
host = validate_host("invalid-host", ["web-01", "web-02"])
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
# Test group validation
print("\n2. Testing validate_group():")
try:
group = validate_group("production", ["production", "development"])
print(f" ✓ Valid group: {group}")
except ValidationError as e:
print(f" ✗ Error: {e}")
# Test path validation
print("\n3. Testing validate_path_exists():")
try:
path = validate_path_exists("/tmp", must_be_dir=True)
print(f" ✓ Valid path: {path}")
except ValidationError as e:
print(f" ✗ Error: {e}")
# Test timeout validation
print("\n4. Testing validate_timeout():")
try:
timeout = validate_timeout(10)
print(f" ✓ Valid timeout: {timeout}s")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
timeout = validate_timeout(0)
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
# Test command validation
print("\n5. Testing validate_command():")
try:
cmd = validate_command("ls -la")
print(f" ✓ Safe command: {cmd}")
except ValidationError as e:
print(f" ✗ Error: {e}")
try:
cmd = validate_command("rm -rf /", allow_dangerous=False)
print(f" ✗ Should have failed!")
except ValidationError as e:
print(f" ✓ Correctly blocked: {e.args[0].split(chr(10))[0]}")
print("\n✅ All parameter validators tested")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,445 @@
#!/usr/bin/env python3
"""
Workflow executor for Tailscale SSH Sync Agent.
Common multi-machine workflow automation.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional
import time
import logging
# Add utils to path
sys.path.insert(0, str(Path(__file__).parent))
from utils.helpers import format_duration, get_timestamp
from sshsync_wrapper import execute_on_group, execute_on_host, push_to_hosts
from load_balancer import get_group_capacity
logger = logging.getLogger(__name__)
def deploy_workflow(code_path: str,
staging_group: str,
prod_group: str,
run_tests: bool = True) -> Dict:
"""
Full deployment pipeline: staging → test → production.
Args:
code_path: Path to code to deploy
staging_group: Staging server group
prod_group: Production server group
run_tests: Whether to run tests on staging
Returns:
Dict with deployment results
Example:
>>> result = deploy_workflow("./dist", "staging", "production")
>>> result['success']
True
>>> result['duration']
"12m 45s"
"""
start_time = time.time()
results = {
'stages': {},
'success': False,
'start_time': get_timestamp()
}
try:
# Stage 1: Deploy to staging
logger.info("Stage 1: Deploying to staging...")
stage1 = push_to_hosts(
local_path=code_path,
remote_path="/var/www/app",
group=staging_group,
recurse=True
)
results['stages']['staging_deploy'] = stage1
if not stage1.get('success'):
results['error'] = 'Staging deployment failed'
return results
# Build on staging
logger.info("Building on staging...")
build_result = execute_on_group(
staging_group,
"cd /var/www/app && npm run build",
timeout=300
)
results['stages']['staging_build'] = build_result
if not build_result.get('success'):
results['error'] = 'Staging build failed'
return results
# Stage 2: Run tests (if enabled)
if run_tests:
logger.info("Stage 2: Running tests...")
test_result = execute_on_group(
staging_group,
"cd /var/www/app && npm test",
timeout=600
)
results['stages']['tests'] = test_result
if not test_result.get('success'):
results['error'] = 'Tests failed on staging'
return results
# Stage 3: Validation
logger.info("Stage 3: Validating staging...")
health_result = execute_on_group(
staging_group,
"curl -f http://localhost:3000/health || echo 'Health check failed'",
timeout=10
)
results['stages']['staging_validation'] = health_result
# Stage 4: Deploy to production
logger.info("Stage 4: Deploying to production...")
prod_deploy = push_to_hosts(
local_path=code_path,
remote_path="/var/www/app",
group=prod_group,
recurse=True
)
results['stages']['production_deploy'] = prod_deploy
if not prod_deploy.get('success'):
results['error'] = 'Production deployment failed'
return results
# Build and restart on production
logger.info("Building and restarting production...")
prod_build = execute_on_group(
prod_group,
"cd /var/www/app && npm run build && pm2 restart app",
timeout=300
)
results['stages']['production_build'] = prod_build
# Stage 5: Production verification
logger.info("Stage 5: Verifying production...")
prod_health = execute_on_group(
prod_group,
"curl -f http://localhost:3000/health",
timeout=15
)
results['stages']['production_verification'] = prod_health
# Success!
results['success'] = True
results['duration'] = format_duration(time.time() - start_time)
return results
except Exception as e:
logger.error(f"Deployment workflow error: {e}")
results['error'] = str(e)
results['duration'] = format_duration(time.time() - start_time)
return results
def backup_workflow(hosts: List[str],
backup_paths: List[str],
destination: str) -> Dict:
"""
Backup files from multiple hosts.
Args:
hosts: List of hosts to backup from
backup_paths: Paths to backup on each host
destination: Local destination directory
Returns:
Dict with backup results
Example:
>>> result = backup_workflow(
... ["db-01", "db-02"],
... ["/var/lib/mysql"],
... "./backups"
... )
>>> result['backed_up_hosts']
2
"""
from sshsync_wrapper import pull_from_host
start_time = time.time()
results = {
'hosts': {},
'success': True,
'backed_up_hosts': 0
}
for host in hosts:
host_results = []
for backup_path in backup_paths:
# Create timestamped backup directory
timestamp = time.strftime("%Y%m%d_%H%M%S")
host_dest = f"{destination}/{host}_{timestamp}"
result = pull_from_host(
host=host,
remote_path=backup_path,
local_path=host_dest,
recurse=True
)
host_results.append(result)
if not result.get('success'):
results['success'] = False
results['hosts'][host] = host_results
if all(r.get('success') for r in host_results):
results['backed_up_hosts'] += 1
results['duration'] = format_duration(time.time() - start_time)
return results
def sync_workflow(source_host: str,
target_group: str,
paths: List[str]) -> Dict:
"""
Sync files from one host to many.
Args:
source_host: Host to pull from
target_group: Group to push to
paths: Paths to sync
Returns:
Dict with sync results
Example:
>>> result = sync_workflow(
... "master-db",
... "replica-dbs",
... ["/var/lib/mysql/config"]
... )
>>> result['success']
True
"""
from sshsync_wrapper import pull_from_host, push_to_hosts
import tempfile
import shutil
start_time = time.time()
results = {'paths': {}, 'success': True}
# Create temp directory
with tempfile.TemporaryDirectory() as temp_dir:
for path in paths:
# Pull from source
pull_result = pull_from_host(
host=source_host,
remote_path=path,
local_path=f"{temp_dir}/{Path(path).name}",
recurse=True
)
if not pull_result.get('success'):
results['paths'][path] = {
'success': False,
'error': 'Pull from source failed'
}
results['success'] = False
continue
# Push to targets
push_result = push_to_hosts(
local_path=f"{temp_dir}/{Path(path).name}",
remote_path=path,
group=target_group,
recurse=True
)
results['paths'][path] = {
'pull': pull_result,
'push': push_result,
'success': push_result.get('success', False)
}
if not push_result.get('success'):
results['success'] = False
results['duration'] = format_duration(time.time() - start_time)
return results
def rolling_restart(group: str,
service_name: str,
wait_between: int = 30) -> Dict:
"""
Zero-downtime rolling restart of a service across group.
Args:
group: Group to restart
service_name: Service name (e.g., "nginx", "app")
wait_between: Seconds to wait between restarts
Returns:
Dict with restart results
Example:
>>> result = rolling_restart("web-servers", "nginx")
>>> result['restarted_count']
3
"""
from utils.helpers import parse_sshsync_config
start_time = time.time()
groups_config = parse_sshsync_config()
hosts = groups_config.get(group, [])
if not hosts:
return {
'success': False,
'error': f'Group {group} not found or empty'
}
results = {
'hosts': {},
'restarted_count': 0,
'failed_count': 0,
'success': True
}
for host in hosts:
logger.info(f"Restarting {service_name} on {host}...")
# Restart service
restart_result = execute_on_host(
host,
f"sudo systemctl restart {service_name} || sudo service {service_name} restart",
timeout=30
)
# Health check
time.sleep(5) # Wait for service to start
health_result = execute_on_host(
host,
f"sudo systemctl is-active {service_name} || sudo service {service_name} status",
timeout=10
)
success = restart_result.get('success') and health_result.get('success')
results['hosts'][host] = {
'restart': restart_result,
'health': health_result,
'success': success
}
if success:
results['restarted_count'] += 1
logger.info(f"{host} restarted successfully")
else:
results['failed_count'] += 1
results['success'] = False
logger.error(f"{host} restart failed")
# Wait before next restart (except last)
if host != hosts[-1]:
time.sleep(wait_between)
results['duration'] = format_duration(time.time() - start_time)
return results
def health_check_workflow(group: str,
endpoint: str = "/health",
timeout: int = 10) -> Dict:
"""
Check health endpoint across group.
Args:
group: Group to check
endpoint: Health endpoint path
timeout: Request timeout
Returns:
Dict with health check results
Example:
>>> result = health_check_workflow("production", "/health")
>>> result['healthy_count']
3
"""
from utils.helpers import parse_sshsync_config
groups_config = parse_sshsync_config()
hosts = groups_config.get(group, [])
if not hosts:
return {
'success': False,
'error': f'Group {group} not found or empty'
}
results = {
'hosts': {},
'healthy_count': 0,
'unhealthy_count': 0
}
for host in hosts:
health_result = execute_on_host(
host,
f"curl -f -s -o /dev/null -w '%{{http_code}}' http://localhost:3000{endpoint}",
timeout=timeout
)
is_healthy = (
health_result.get('success') and
'200' in health_result.get('stdout', '')
)
results['hosts'][host] = {
'healthy': is_healthy,
'response': health_result.get('stdout', '').strip()
}
if is_healthy:
results['healthy_count'] += 1
else:
results['unhealthy_count'] += 1
results['success'] = results['unhealthy_count'] == 0
return results
def main():
"""Test workflow executor functions."""
print("Testing workflow executor...\n")
print("Note: Workflow executor requires configured hosts and groups.")
print("Tests would execute real operations, so showing dry-run simulations.\n")
print("✅ Workflow executor ready")
if __name__ == "__main__":
main()