Initial commit
This commit is contained in:
378
scripts/load_balancer.py
Normal file
378
scripts/load_balancer.py
Normal file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Load balancer for Tailscale SSH Sync Agent.
|
||||
Intelligent task distribution based on machine resources.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
# Add utils to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from utils.helpers import parse_cpu_load, parse_memory_usage, parse_disk_usage, calculate_load_score, classify_load_status
|
||||
from sshsync_wrapper import execute_on_host
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MachineMetrics:
|
||||
"""Resource metrics for a machine."""
|
||||
host: str
|
||||
cpu_pct: float
|
||||
mem_pct: float
|
||||
disk_pct: float
|
||||
load_score: float
|
||||
status: str
|
||||
|
||||
|
||||
def get_machine_load(host: str, timeout: int = 10) -> Optional[MachineMetrics]:
|
||||
"""
|
||||
Get CPU, memory, disk metrics for a machine.
|
||||
|
||||
Args:
|
||||
host: Host to check
|
||||
timeout: Command timeout
|
||||
|
||||
Returns:
|
||||
MachineMetrics object or None on failure
|
||||
|
||||
Example:
|
||||
>>> metrics = get_machine_load("web-01")
|
||||
>>> metrics.cpu_pct
|
||||
45.2
|
||||
>>> metrics.load_score
|
||||
0.49
|
||||
"""
|
||||
try:
|
||||
# Get CPU load
|
||||
cpu_result = execute_on_host(host, "uptime", timeout=timeout)
|
||||
cpu_data = {}
|
||||
if cpu_result.get('success'):
|
||||
cpu_data = parse_cpu_load(cpu_result['stdout'])
|
||||
|
||||
# Get memory usage
|
||||
mem_result = execute_on_host(host, "free -m 2>/dev/null || vm_stat", timeout=timeout)
|
||||
mem_data = {}
|
||||
if mem_result.get('success'):
|
||||
mem_data = parse_memory_usage(mem_result['stdout'])
|
||||
|
||||
# Get disk usage
|
||||
disk_result = execute_on_host(host, "df -h / | tail -1", timeout=timeout)
|
||||
disk_data = {}
|
||||
if disk_result.get('success'):
|
||||
disk_data = parse_disk_usage(disk_result['stdout'])
|
||||
|
||||
# Calculate metrics
|
||||
# CPU: Use 1-min load average, normalize by assuming 4 cores (adjust as needed)
|
||||
cpu_pct = (cpu_data.get('load_1min', 0) / 4.0) * 100 if cpu_data else 50.0
|
||||
|
||||
# Memory: Direct percentage
|
||||
mem_pct = mem_data.get('use_pct', 50.0)
|
||||
|
||||
# Disk: Direct percentage
|
||||
disk_pct = disk_data.get('use_pct', 50.0)
|
||||
|
||||
# Calculate load score
|
||||
score = calculate_load_score(cpu_pct, mem_pct, disk_pct)
|
||||
status = classify_load_status(score)
|
||||
|
||||
return MachineMetrics(
|
||||
host=host,
|
||||
cpu_pct=cpu_pct,
|
||||
mem_pct=mem_pct,
|
||||
disk_pct=disk_pct,
|
||||
load_score=score,
|
||||
status=status
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting load for {host}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def select_optimal_host(candidates: List[str],
|
||||
prefer_group: Optional[str] = None,
|
||||
timeout: int = 10) -> Tuple[Optional[str], Optional[MachineMetrics]]:
|
||||
"""
|
||||
Pick best host from candidates based on load.
|
||||
|
||||
Args:
|
||||
candidates: List of candidate hosts
|
||||
prefer_group: Prefer hosts from this group if available
|
||||
timeout: Timeout for metric gathering
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_host, metrics)
|
||||
|
||||
Example:
|
||||
>>> host, metrics = select_optimal_host(["web-01", "web-02", "web-03"])
|
||||
>>> host
|
||||
"web-03"
|
||||
>>> metrics.load_score
|
||||
0.28
|
||||
"""
|
||||
if not candidates:
|
||||
return None, None
|
||||
|
||||
# Get metrics for all candidates
|
||||
metrics_list: List[MachineMetrics] = []
|
||||
|
||||
for host in candidates:
|
||||
metrics = get_machine_load(host, timeout=timeout)
|
||||
if metrics:
|
||||
metrics_list.append(metrics)
|
||||
|
||||
if not metrics_list:
|
||||
logger.warning("No valid metrics collected from candidates")
|
||||
return None, None
|
||||
|
||||
# Sort by load score (lower is better)
|
||||
metrics_list.sort(key=lambda m: m.load_score)
|
||||
|
||||
# If prefer_group specified, prioritize those hosts if load is similar
|
||||
if prefer_group:
|
||||
from utils.helpers import parse_sshsync_config, get_groups_for_host
|
||||
groups_config = parse_sshsync_config()
|
||||
|
||||
# Find hosts in preferred group
|
||||
preferred_metrics = [
|
||||
m for m in metrics_list
|
||||
if prefer_group in get_groups_for_host(m.host, groups_config)
|
||||
]
|
||||
|
||||
# Use preferred if load score within 20% of absolute best
|
||||
if preferred_metrics:
|
||||
best_score = metrics_list[0].load_score
|
||||
for m in preferred_metrics:
|
||||
if m.load_score <= best_score * 1.2:
|
||||
return m.host, m
|
||||
|
||||
# Return absolute best
|
||||
best = metrics_list[0]
|
||||
return best.host, best
|
||||
|
||||
|
||||
def get_group_capacity(group: str, timeout: int = 10) -> Dict:
|
||||
"""
|
||||
Get aggregate capacity of a group.
|
||||
|
||||
Args:
|
||||
group: Group name
|
||||
timeout: Timeout for metric gathering
|
||||
|
||||
Returns:
|
||||
Dict with aggregate metrics:
|
||||
{
|
||||
'hosts': List[MachineMetrics],
|
||||
'total_hosts': int,
|
||||
'avg_cpu': float,
|
||||
'avg_mem': float,
|
||||
'avg_disk': float,
|
||||
'avg_load_score': float,
|
||||
'total_capacity': str # descriptive
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> capacity = get_group_capacity("production")
|
||||
>>> capacity['avg_load_score']
|
||||
0.45
|
||||
"""
|
||||
from utils.helpers import parse_sshsync_config
|
||||
|
||||
groups_config = parse_sshsync_config()
|
||||
group_hosts = groups_config.get(group, [])
|
||||
|
||||
if not group_hosts:
|
||||
return {
|
||||
'error': f'Group {group} not found or has no members',
|
||||
'hosts': []
|
||||
}
|
||||
|
||||
# Get metrics for all hosts in group
|
||||
metrics_list: List[MachineMetrics] = []
|
||||
|
||||
for host in group_hosts:
|
||||
metrics = get_machine_load(host, timeout=timeout)
|
||||
if metrics:
|
||||
metrics_list.append(metrics)
|
||||
|
||||
if not metrics_list:
|
||||
return {
|
||||
'error': f'Could not get metrics for any hosts in {group}',
|
||||
'hosts': []
|
||||
}
|
||||
|
||||
# Calculate aggregates
|
||||
avg_cpu = sum(m.cpu_pct for m in metrics_list) / len(metrics_list)
|
||||
avg_mem = sum(m.mem_pct for m in metrics_list) / len(metrics_list)
|
||||
avg_disk = sum(m.disk_pct for m in metrics_list) / len(metrics_list)
|
||||
avg_score = sum(m.load_score for m in metrics_list) / len(metrics_list)
|
||||
|
||||
# Determine overall capacity description
|
||||
if avg_score < 0.4:
|
||||
capacity_desc = "High capacity available"
|
||||
elif avg_score < 0.7:
|
||||
capacity_desc = "Moderate capacity"
|
||||
else:
|
||||
capacity_desc = "Limited capacity"
|
||||
|
||||
return {
|
||||
'group': group,
|
||||
'hosts': metrics_list,
|
||||
'total_hosts': len(metrics_list),
|
||||
'available_hosts': len(group_hosts),
|
||||
'avg_cpu': avg_cpu,
|
||||
'avg_mem': avg_mem,
|
||||
'avg_disk': avg_disk,
|
||||
'avg_load_score': avg_score,
|
||||
'total_capacity': capacity_desc
|
||||
}
|
||||
|
||||
|
||||
def distribute_tasks(tasks: List[Dict], hosts: List[str],
|
||||
timeout: int = 10) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Distribute multiple tasks optimally across hosts.
|
||||
|
||||
Args:
|
||||
tasks: List of task dicts (each with 'command', 'priority', etc)
|
||||
hosts: Available hosts
|
||||
timeout: Timeout for metric gathering
|
||||
|
||||
Returns:
|
||||
Dict mapping hosts to assigned tasks
|
||||
|
||||
Algorithm:
|
||||
- Get current load for all hosts
|
||||
- Assign tasks to least loaded hosts
|
||||
- Balance by estimated task weight
|
||||
|
||||
Example:
|
||||
>>> tasks = [
|
||||
... {'command': 'npm run build', 'weight': 3},
|
||||
... {'command': 'npm test', 'weight': 2}
|
||||
... ]
|
||||
>>> distribution = distribute_tasks(tasks, ["web-01", "web-02"])
|
||||
>>> distribution["web-01"]
|
||||
[{'command': 'npm run build', 'weight': 3}]
|
||||
"""
|
||||
if not tasks or not hosts:
|
||||
return {}
|
||||
|
||||
# Get current load for all hosts
|
||||
host_metrics = {}
|
||||
for host in hosts:
|
||||
metrics = get_machine_load(host, timeout=timeout)
|
||||
if metrics:
|
||||
host_metrics[host] = metrics
|
||||
|
||||
if not host_metrics:
|
||||
logger.error("No valid host metrics available")
|
||||
return {}
|
||||
|
||||
# Initialize assignment
|
||||
assignment: Dict[str, List[Dict]] = {host: [] for host in host_metrics.keys()}
|
||||
host_loads = {host: m.load_score for host, m in host_metrics.items()}
|
||||
|
||||
# Sort tasks by weight (descending) to assign heavy tasks first
|
||||
sorted_tasks = sorted(
|
||||
tasks,
|
||||
key=lambda t: t.get('weight', 1),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Assign each task to least loaded host
|
||||
for task in sorted_tasks:
|
||||
# Find host with minimum current load
|
||||
min_host = min(host_loads.keys(), key=lambda h: host_loads[h])
|
||||
|
||||
# Assign task
|
||||
assignment[min_host].append(task)
|
||||
|
||||
# Update simulated load (add task weight normalized)
|
||||
task_weight = task.get('weight', 1)
|
||||
host_loads[min_host] += (task_weight * 0.1) # 0.1 = scaling factor
|
||||
|
||||
return assignment
|
||||
|
||||
|
||||
def format_load_report(metrics: MachineMetrics, compare_to_avg: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Format load metrics as human-readable report.
|
||||
|
||||
Args:
|
||||
metrics: Machine metrics
|
||||
compare_to_avg: Optional dict with avg_cpu, avg_mem, avg_disk for comparison
|
||||
|
||||
Returns:
|
||||
Formatted report string
|
||||
|
||||
Example:
|
||||
>>> metrics = MachineMetrics('web-01', 45, 60, 40, 0.49, 'moderate')
|
||||
>>> print(format_load_report(metrics))
|
||||
web-01: Load Score: 0.49 (moderate)
|
||||
CPU: 45.0% | Memory: 60.0% | Disk: 40.0%
|
||||
"""
|
||||
lines = [
|
||||
f"{metrics.host}: Load Score: {metrics.load_score:.2f} ({metrics.status})",
|
||||
f" CPU: {metrics.cpu_pct:.1f}% | Memory: {metrics.mem_pct:.1f}% | Disk: {metrics.disk_pct:.1f}%"
|
||||
]
|
||||
|
||||
if compare_to_avg:
|
||||
cpu_vs = metrics.cpu_pct - compare_to_avg.get('avg_cpu', 0)
|
||||
mem_vs = metrics.mem_pct - compare_to_avg.get('avg_mem', 0)
|
||||
disk_vs = metrics.disk_pct - compare_to_avg.get('avg_disk', 0)
|
||||
|
||||
comparisons = []
|
||||
if abs(cpu_vs) > 10:
|
||||
comparisons.append(f"CPU {'+' if cpu_vs > 0 else ''}{cpu_vs:.0f}% vs avg")
|
||||
if abs(mem_vs) > 10:
|
||||
comparisons.append(f"Mem {'+' if mem_vs > 0 else ''}{mem_vs:.0f}% vs avg")
|
||||
if abs(disk_vs) > 10:
|
||||
comparisons.append(f"Disk {'+' if disk_vs > 0 else ''}{disk_vs:.0f}% vs avg")
|
||||
|
||||
if comparisons:
|
||||
lines.append(f" vs Average: {' | '.join(comparisons)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
"""Test load balancer functions."""
|
||||
print("Testing load balancer...\n")
|
||||
|
||||
print("1. Testing select_optimal_host:")
|
||||
print(" (Requires configured hosts - using dry-run simulation)")
|
||||
|
||||
# Simulate metrics
|
||||
test_metrics = [
|
||||
MachineMetrics('web-01', 45, 60, 40, 0.49, 'moderate'),
|
||||
MachineMetrics('web-02', 85, 70, 65, 0.75, 'high'),
|
||||
MachineMetrics('web-03', 20, 35, 30, 0.28, 'low'),
|
||||
]
|
||||
|
||||
# Sort by score
|
||||
test_metrics.sort(key=lambda m: m.load_score)
|
||||
best = test_metrics[0]
|
||||
|
||||
print(f" ✓ Best host: {best.host} (score: {best.load_score:.2f})")
|
||||
print(f" Reason: {best.status} load")
|
||||
|
||||
print("\n2. Format load report:")
|
||||
report = format_load_report(test_metrics[0], {
|
||||
'avg_cpu': 50,
|
||||
'avg_mem': 55,
|
||||
'avg_disk': 45
|
||||
})
|
||||
print(report)
|
||||
|
||||
print("\n✅ Load balancer tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
409
scripts/sshsync_wrapper.py
Normal file
409
scripts/sshsync_wrapper.py
Normal file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SSH Sync wrapper for Tailscale SSH Sync Agent.
|
||||
Python interface to sshsync CLI operations.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Add utils to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from utils.helpers import parse_ssh_config, parse_sshsync_config, format_bytes, format_duration
|
||||
from utils.validators import validate_host, validate_group, validate_path_exists, validate_timeout, validate_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_host_status(group: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Get online/offline status of hosts.
|
||||
|
||||
Args:
|
||||
group: Optional group to filter (None = all hosts)
|
||||
|
||||
Returns:
|
||||
Dict with status info
|
||||
|
||||
Example:
|
||||
>>> status = get_host_status()
|
||||
>>> status['online_count']
|
||||
8
|
||||
"""
|
||||
try:
|
||||
# Run sshsync ls --with-status
|
||||
cmd = ["sshsync", "ls", "--with-status"]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
|
||||
if result.returncode != 0:
|
||||
return {'error': result.stderr, 'hosts': []}
|
||||
|
||||
# Parse output
|
||||
hosts = []
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if not line or line.startswith('Host') or line.startswith('---'):
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
host_name = parts[0]
|
||||
status = parts[1] if len(parts) > 1 else 'unknown'
|
||||
|
||||
hosts.append({
|
||||
'host': host_name,
|
||||
'online': status.lower() in ['online', 'reachable', '✓'],
|
||||
'status': status
|
||||
})
|
||||
|
||||
# Filter by group if specified
|
||||
if group:
|
||||
groups_config = parse_sshsync_config()
|
||||
group_hosts = groups_config.get(group, [])
|
||||
hosts = [h for h in hosts if h['host'] in group_hosts]
|
||||
|
||||
online_count = sum(1 for h in hosts if h['online'])
|
||||
|
||||
return {
|
||||
'hosts': hosts,
|
||||
'total_count': len(hosts),
|
||||
'online_count': online_count,
|
||||
'offline_count': len(hosts) - online_count,
|
||||
'availability_pct': (online_count / len(hosts) * 100) if hosts else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting host status: {e}")
|
||||
return {'error': str(e), 'hosts': []}
|
||||
|
||||
|
||||
def execute_on_all(command: str, timeout: int = 10, dry_run: bool = False) -> Dict:
|
||||
"""
|
||||
Execute command on all hosts.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
dry_run: If True, don't actually execute
|
||||
|
||||
Returns:
|
||||
Dict with results per host
|
||||
|
||||
Example:
|
||||
>>> result = execute_on_all("uptime", timeout=15)
|
||||
>>> len(result['results'])
|
||||
10
|
||||
"""
|
||||
validate_command(command)
|
||||
validate_timeout(timeout)
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
'dry_run': True,
|
||||
'command': command,
|
||||
'message': 'Would execute on all hosts'
|
||||
}
|
||||
|
||||
try:
|
||||
cmd = ["sshsync", "all", f"--timeout={timeout}", command]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 30)
|
||||
|
||||
# Parse results (format varies, simplified here)
|
||||
return {
|
||||
'success': result.returncode == 0,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr,
|
||||
'command': command
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'error': f'Command timed out after {timeout}s'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
def execute_on_group(group: str, command: str, timeout: int = 10, dry_run: bool = False) -> Dict:
|
||||
"""
|
||||
Execute command on specific group.
|
||||
|
||||
Args:
|
||||
group: Group name
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
dry_run: Preview without executing
|
||||
|
||||
Returns:
|
||||
Dict with execution results
|
||||
|
||||
Example:
|
||||
>>> result = execute_on_group("web-servers", "df -h /var/www")
|
||||
>>> result['success']
|
||||
True
|
||||
"""
|
||||
groups_config = parse_sshsync_config()
|
||||
validate_group(group, list(groups_config.keys()))
|
||||
validate_command(command)
|
||||
validate_timeout(timeout)
|
||||
|
||||
if dry_run:
|
||||
group_hosts = groups_config.get(group, [])
|
||||
return {
|
||||
'dry_run': True,
|
||||
'group': group,
|
||||
'hosts': group_hosts,
|
||||
'command': command,
|
||||
'message': f'Would execute on {len(group_hosts)} hosts in group {group}'
|
||||
}
|
||||
|
||||
try:
|
||||
cmd = ["sshsync", "group", f"--timeout={timeout}", group, command]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 30)
|
||||
|
||||
return {
|
||||
'success': result.returncode == 0,
|
||||
'group': group,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr,
|
||||
'command': command
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'error': f'Command timed out after {timeout}s'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
def execute_on_host(host: str, command: str, timeout: int = 10) -> Dict:
|
||||
"""
|
||||
Execute command on single host.
|
||||
|
||||
Args:
|
||||
host: Host name
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with result
|
||||
|
||||
Example:
|
||||
>>> result = execute_on_host("web-01", "hostname")
|
||||
>>> result['stdout']
|
||||
"web-01"
|
||||
"""
|
||||
ssh_hosts = parse_ssh_config()
|
||||
validate_host(host, list(ssh_hosts.keys()))
|
||||
validate_command(command)
|
||||
validate_timeout(timeout)
|
||||
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ConnectTimeout={timeout}", host, command]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 5)
|
||||
|
||||
return {
|
||||
'success': result.returncode == 0,
|
||||
'host': host,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr,
|
||||
'command': command
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'error': f'Command timed out after {timeout}s'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
def push_to_hosts(local_path: str, remote_path: str,
|
||||
hosts: Optional[List[str]] = None,
|
||||
group: Optional[str] = None,
|
||||
recurse: bool = False,
|
||||
dry_run: bool = False) -> Dict:
|
||||
"""
|
||||
Push files to hosts.
|
||||
|
||||
Args:
|
||||
local_path: Local file/directory path
|
||||
remote_path: Remote destination path
|
||||
hosts: Specific hosts (None = all if group also None)
|
||||
group: Group name
|
||||
recurse: Recursive copy
|
||||
dry_run: Preview without executing
|
||||
|
||||
Returns:
|
||||
Dict with push results
|
||||
|
||||
Example:
|
||||
>>> result = push_to_hosts("./dist", "/var/www/app", group="production", recurse=True)
|
||||
>>> result['success']
|
||||
True
|
||||
"""
|
||||
validate_path_exists(local_path)
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
'dry_run': True,
|
||||
'local_path': local_path,
|
||||
'remote_path': remote_path,
|
||||
'hosts': hosts,
|
||||
'group': group,
|
||||
'recurse': recurse,
|
||||
'message': 'Would push files'
|
||||
}
|
||||
|
||||
try:
|
||||
cmd = ["sshsync", "push"]
|
||||
|
||||
if hosts:
|
||||
for host in hosts:
|
||||
cmd.extend(["--host", host])
|
||||
elif group:
|
||||
cmd.extend(["--group", group])
|
||||
else:
|
||||
cmd.append("--all")
|
||||
|
||||
if recurse:
|
||||
cmd.append("--recurse")
|
||||
|
||||
cmd.extend([local_path, remote_path])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
return {
|
||||
'success': result.returncode == 0,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr,
|
||||
'local_path': local_path,
|
||||
'remote_path': remote_path
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'error': 'Push operation timed out'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
def pull_from_host(host: str, remote_path: str, local_path: str,
|
||||
recurse: bool = False, dry_run: bool = False) -> Dict:
|
||||
"""
|
||||
Pull files from host.
|
||||
|
||||
Args:
|
||||
host: Host to pull from
|
||||
remote_path: Remote file/directory path
|
||||
local_path: Local destination path
|
||||
recurse: Recursive copy
|
||||
dry_run: Preview without executing
|
||||
|
||||
Returns:
|
||||
Dict with pull results
|
||||
|
||||
Example:
|
||||
>>> result = pull_from_host("web-01", "/var/log/nginx", "./logs", recurse=True)
|
||||
>>> result['success']
|
||||
True
|
||||
"""
|
||||
ssh_hosts = parse_ssh_config()
|
||||
validate_host(host, list(ssh_hosts.keys()))
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
'dry_run': True,
|
||||
'host': host,
|
||||
'remote_path': remote_path,
|
||||
'local_path': local_path,
|
||||
'recurse': recurse,
|
||||
'message': f'Would pull from {host}'
|
||||
}
|
||||
|
||||
try:
|
||||
cmd = ["sshsync", "pull", "--host", host]
|
||||
|
||||
if recurse:
|
||||
cmd.append("--recurse")
|
||||
|
||||
cmd.extend([remote_path, local_path])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
return {
|
||||
'success': result.returncode == 0,
|
||||
'host': host,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr,
|
||||
'remote_path': remote_path,
|
||||
'local_path': local_path
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'error': 'Pull operation timed out'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
def list_hosts(with_status: bool = True) -> Dict:
|
||||
"""
|
||||
List all configured hosts.
|
||||
|
||||
Args:
|
||||
with_status: Include online/offline status
|
||||
|
||||
Returns:
|
||||
Dict with hosts info
|
||||
|
||||
Example:
|
||||
>>> result = list_hosts(with_status=True)
|
||||
>>> len(result['hosts'])
|
||||
10
|
||||
"""
|
||||
if with_status:
|
||||
return get_host_status()
|
||||
else:
|
||||
ssh_hosts = parse_ssh_config()
|
||||
return {
|
||||
'hosts': [{'host': name} for name in ssh_hosts.keys()],
|
||||
'count': len(ssh_hosts)
|
||||
}
|
||||
|
||||
|
||||
def get_groups() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get all defined groups and their members.
|
||||
|
||||
Returns:
|
||||
Dict mapping group names to host lists
|
||||
|
||||
Example:
|
||||
>>> groups = get_groups()
|
||||
>>> groups['production']
|
||||
['prod-web-01', 'prod-db-01']
|
||||
"""
|
||||
return parse_sshsync_config()
|
||||
|
||||
|
||||
def main():
|
||||
"""Test sshsync wrapper functions."""
|
||||
print("Testing sshsync wrapper...\n")
|
||||
|
||||
print("1. List hosts:")
|
||||
result = list_hosts(with_status=False)
|
||||
print(f" Found {result.get('count', 0)} hosts")
|
||||
|
||||
print("\n2. Get groups:")
|
||||
groups = get_groups()
|
||||
print(f" Found {len(groups)} groups")
|
||||
for group, hosts in groups.items():
|
||||
print(f" - {group}: {len(hosts)} hosts")
|
||||
|
||||
print("\n3. Test dry-run:")
|
||||
result = execute_on_all("uptime", dry_run=True)
|
||||
print(f" Dry-run: {result.get('message', 'OK')}")
|
||||
|
||||
print("\n✅ sshsync wrapper tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
426
scripts/tailscale_manager.py
Normal file
426
scripts/tailscale_manager.py
Normal file
@@ -0,0 +1,426 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tailscale manager for Tailscale SSH Sync Agent.
|
||||
Tailscale-specific operations and status management.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TailscalePeer:
|
||||
"""Represents a Tailscale peer."""
|
||||
hostname: str
|
||||
ip: str
|
||||
online: bool
|
||||
last_seen: Optional[str] = None
|
||||
os: Optional[str] = None
|
||||
relay: Optional[str] = None
|
||||
|
||||
|
||||
def get_tailscale_status() -> Dict:
|
||||
"""
|
||||
Get Tailscale network status (all peers).
|
||||
|
||||
Returns:
|
||||
Dict with network status:
|
||||
{
|
||||
'connected': bool,
|
||||
'peers': List[TailscalePeer],
|
||||
'online_count': int,
|
||||
'total_count': int,
|
||||
'self_ip': str
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> status = get_tailscale_status()
|
||||
>>> status['online_count']
|
||||
8
|
||||
>>> status['peers'][0].hostname
|
||||
'homelab-1'
|
||||
"""
|
||||
try:
|
||||
# Get status in JSON format
|
||||
result = subprocess.run(
|
||||
["tailscale", "status", "--json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Try text format if JSON fails
|
||||
result = subprocess.run(
|
||||
["tailscale", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return {
|
||||
'connected': False,
|
||||
'error': 'Tailscale not running or accessible',
|
||||
'peers': []
|
||||
}
|
||||
|
||||
# Parse text format
|
||||
return _parse_text_status(result.stdout)
|
||||
|
||||
# Parse JSON format
|
||||
data = json.loads(result.stdout)
|
||||
return _parse_json_status(data)
|
||||
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
'connected': False,
|
||||
'error': 'Tailscale not installed',
|
||||
'peers': []
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
'connected': False,
|
||||
'error': 'Timeout getting Tailscale status',
|
||||
'peers': []
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Tailscale status: {e}")
|
||||
return {
|
||||
'connected': False,
|
||||
'error': str(e),
|
||||
'peers': []
|
||||
}
|
||||
|
||||
|
||||
def _parse_json_status(data: Dict) -> Dict:
|
||||
"""Parse Tailscale JSON status."""
|
||||
peers = []
|
||||
|
||||
self_data = data.get('Self', {})
|
||||
self_ip = self_data.get('TailscaleIPs', [''])[0]
|
||||
|
||||
for peer_id, peer_data in data.get('Peer', {}).items():
|
||||
hostname = peer_data.get('HostName', 'unknown')
|
||||
ips = peer_data.get('TailscaleIPs', [])
|
||||
ip = ips[0] if ips else 'unknown'
|
||||
online = peer_data.get('Online', False)
|
||||
os = peer_data.get('OS', 'unknown')
|
||||
|
||||
peers.append(TailscalePeer(
|
||||
hostname=hostname,
|
||||
ip=ip,
|
||||
online=online,
|
||||
os=os
|
||||
))
|
||||
|
||||
online_count = sum(1 for p in peers if p.online)
|
||||
|
||||
return {
|
||||
'connected': True,
|
||||
'peers': peers,
|
||||
'online_count': online_count,
|
||||
'total_count': len(peers),
|
||||
'self_ip': self_ip
|
||||
}
|
||||
|
||||
|
||||
def _parse_text_status(output: str) -> Dict:
|
||||
"""Parse Tailscale text status output."""
|
||||
peers = []
|
||||
self_ip = None
|
||||
|
||||
for line in output.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Parse format: hostname ip status ...
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
hostname = parts[0]
|
||||
ip = parts[1] if len(parts) > 1 else 'unknown'
|
||||
|
||||
# Check for self (usually marked with *)
|
||||
if hostname.endswith('-'):
|
||||
self_ip = ip
|
||||
continue
|
||||
|
||||
# Determine online status from additional fields
|
||||
online = 'offline' not in line.lower()
|
||||
|
||||
peers.append(TailscalePeer(
|
||||
hostname=hostname,
|
||||
ip=ip,
|
||||
online=online
|
||||
))
|
||||
|
||||
online_count = sum(1 for p in peers if p.online)
|
||||
|
||||
return {
|
||||
'connected': True,
|
||||
'peers': peers,
|
||||
'online_count': online_count,
|
||||
'total_count': len(peers),
|
||||
'self_ip': self_ip or 'unknown'
|
||||
}
|
||||
|
||||
|
||||
def check_connectivity(host: str, timeout: int = 5) -> bool:
|
||||
"""
|
||||
Ping host via Tailscale.
|
||||
|
||||
Args:
|
||||
host: Hostname to ping
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if host responds to ping
|
||||
|
||||
Example:
|
||||
>>> check_connectivity("homelab-1")
|
||||
True
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["tailscale", "ping", "--timeout", f"{timeout}s", "--c", "1", host],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout + 2
|
||||
)
|
||||
|
||||
# Check if ping succeeded
|
||||
return result.returncode == 0 or 'pong' in result.stdout.lower()
|
||||
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error pinging {host}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_peer_info(hostname: str) -> Optional[TailscalePeer]:
|
||||
"""
|
||||
Get detailed info about a specific peer.
|
||||
|
||||
Args:
|
||||
hostname: Peer hostname
|
||||
|
||||
Returns:
|
||||
TailscalePeer object or None if not found
|
||||
|
||||
Example:
|
||||
>>> peer = get_peer_info("homelab-1")
|
||||
>>> peer.ip
|
||||
'100.64.1.10'
|
||||
"""
|
||||
status = get_tailscale_status()
|
||||
|
||||
if not status.get('connected'):
|
||||
return None
|
||||
|
||||
for peer in status.get('peers', []):
|
||||
if peer.hostname == hostname or hostname in peer.hostname:
|
||||
return peer
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_online_machines() -> List[str]:
|
||||
"""
|
||||
List all online Tailscale machines.
|
||||
|
||||
Returns:
|
||||
List of online machine hostnames
|
||||
|
||||
Example:
|
||||
>>> machines = list_online_machines()
|
||||
>>> len(machines)
|
||||
8
|
||||
"""
|
||||
status = get_tailscale_status()
|
||||
|
||||
if not status.get('connected'):
|
||||
return []
|
||||
|
||||
return [
|
||||
peer.hostname
|
||||
for peer in status.get('peers', [])
|
||||
if peer.online
|
||||
]
|
||||
|
||||
|
||||
def get_machine_ip(hostname: str) -> Optional[str]:
|
||||
"""
|
||||
Get Tailscale IP for a machine.
|
||||
|
||||
Args:
|
||||
hostname: Machine hostname
|
||||
|
||||
Returns:
|
||||
IP address or None if not found
|
||||
|
||||
Example:
|
||||
>>> ip = get_machine_ip("homelab-1")
|
||||
>>> ip
|
||||
'100.64.1.10'
|
||||
"""
|
||||
peer = get_peer_info(hostname)
|
||||
return peer.ip if peer else None
|
||||
|
||||
|
||||
def validate_tailscale_ssh(host: str, timeout: int = 10) -> Dict:
|
||||
"""
|
||||
Check if Tailscale SSH is working for a host.
|
||||
|
||||
Args:
|
||||
host: Host to check
|
||||
timeout: Connection timeout
|
||||
|
||||
Returns:
|
||||
Dict with validation results:
|
||||
{
|
||||
'working': bool,
|
||||
'message': str,
|
||||
'details': Dict
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> result = validate_tailscale_ssh("homelab-1")
|
||||
>>> result['working']
|
||||
True
|
||||
"""
|
||||
# First check if host is in Tailscale network
|
||||
peer = get_peer_info(host)
|
||||
|
||||
if not peer:
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Host {host} not found in Tailscale network',
|
||||
'details': {'peer_found': False}
|
||||
}
|
||||
|
||||
if not peer.online:
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Host {host} is offline in Tailscale',
|
||||
'details': {'peer_found': True, 'online': False}
|
||||
}
|
||||
|
||||
# Check connectivity
|
||||
if not check_connectivity(host, timeout=timeout):
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Cannot ping {host} via Tailscale',
|
||||
'details': {'peer_found': True, 'online': True, 'ping': False}
|
||||
}
|
||||
|
||||
# Try SSH connection
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["tailscale", "ssh", host, "echo", "test"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return {
|
||||
'working': True,
|
||||
'message': f'Tailscale SSH to {host} is working',
|
||||
'details': {
|
||||
'peer_found': True,
|
||||
'online': True,
|
||||
'ping': True,
|
||||
'ssh': True,
|
||||
'ip': peer.ip
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Tailscale SSH failed: {result.stderr}',
|
||||
'details': {
|
||||
'peer_found': True,
|
||||
'online': True,
|
||||
'ping': True,
|
||||
'ssh': False,
|
||||
'error': result.stderr
|
||||
}
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Tailscale SSH timed out after {timeout}s',
|
||||
'details': {'timeout': True}
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'working': False,
|
||||
'message': f'Error testing Tailscale SSH: {e}',
|
||||
'details': {'error': str(e)}
|
||||
}
|
||||
|
||||
|
||||
def get_network_summary() -> str:
|
||||
"""
|
||||
Get human-readable network summary.
|
||||
|
||||
Returns:
|
||||
Formatted summary string
|
||||
|
||||
Example:
|
||||
>>> print(get_network_summary())
|
||||
Tailscale Network: Connected
|
||||
Online: 8/10 machines (80%)
|
||||
Self IP: 100.64.1.5
|
||||
"""
|
||||
status = get_tailscale_status()
|
||||
|
||||
if not status.get('connected'):
|
||||
return "Tailscale Network: Not connected\nError: {}".format(
|
||||
status.get('error', 'Unknown error')
|
||||
)
|
||||
|
||||
lines = [
|
||||
"Tailscale Network: Connected",
|
||||
f"Online: {status['online_count']}/{status['total_count']} machines ({status['online_count']/status['total_count']*100:.0f}%)",
|
||||
f"Self IP: {status.get('self_ip', 'unknown')}"
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
"""Test Tailscale manager functions."""
|
||||
print("Testing Tailscale manager...\n")
|
||||
|
||||
print("1. Get Tailscale status:")
|
||||
status = get_tailscale_status()
|
||||
if status.get('connected'):
|
||||
print(f" ✓ Connected")
|
||||
print(f" Peers: {status['total_count']} total, {status['online_count']} online")
|
||||
else:
|
||||
print(f" ✗ Not connected: {status.get('error', 'Unknown error')}")
|
||||
|
||||
print("\n2. List online machines:")
|
||||
machines = list_online_machines()
|
||||
print(f" Found {len(machines)} online machines")
|
||||
for machine in machines[:5]: # Show first 5
|
||||
print(f" - {machine}")
|
||||
|
||||
print("\n3. Network summary:")
|
||||
print(get_network_summary())
|
||||
|
||||
print("\n✅ Tailscale manager tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
628
scripts/utils/helpers.py
Normal file
628
scripts/utils/helpers.py
Normal file
@@ -0,0 +1,628 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Helper utilities for Tailscale SSH Sync Agent.
|
||||
Provides common formatting, parsing, and utility functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_bytes(bytes_value: int) -> str:
|
||||
"""
|
||||
Format bytes as human-readable string.
|
||||
|
||||
Args:
|
||||
bytes_value: Number of bytes
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "12.3 MB", "1.5 GB")
|
||||
|
||||
Example:
|
||||
>>> format_bytes(12582912)
|
||||
"12.0 MB"
|
||||
>>> format_bytes(1610612736)
|
||||
"1.5 GB"
|
||||
"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if bytes_value < 1024.0:
|
||||
return f"{bytes_value:.1f} {unit}"
|
||||
bytes_value /= 1024.0
|
||||
return f"{bytes_value:.1f} PB"
|
||||
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""
|
||||
Format duration as human-readable string.
|
||||
|
||||
Args:
|
||||
seconds: Duration in seconds
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "2m 15s", "1h 30m")
|
||||
|
||||
Example:
|
||||
>>> format_duration(135)
|
||||
"2m 15s"
|
||||
>>> format_duration(5430)
|
||||
"1h 30m 30s"
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{int(seconds)}s"
|
||||
|
||||
minutes = int(seconds // 60)
|
||||
secs = int(seconds % 60)
|
||||
|
||||
if minutes < 60:
|
||||
return f"{minutes}m {secs}s" if secs > 0 else f"{minutes}m"
|
||||
|
||||
hours = minutes // 60
|
||||
minutes = minutes % 60
|
||||
|
||||
parts = [f"{hours}h"]
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}m")
|
||||
if secs > 0 and hours == 0: # Only show seconds if < 1 hour
|
||||
parts.append(f"{secs}s")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def format_percentage(value: float, decimals: int = 1) -> str:
|
||||
"""
|
||||
Format percentage with specified decimals.
|
||||
|
||||
Args:
|
||||
value: Percentage value (0-100)
|
||||
decimals: Number of decimal places
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "45.5%")
|
||||
|
||||
Example:
|
||||
>>> format_percentage(45.567)
|
||||
"45.6%"
|
||||
"""
|
||||
return f"{value:.{decimals}f}%"
|
||||
|
||||
|
||||
def parse_ssh_config(config_path: Optional[Path] = None) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Parse SSH config file for host definitions.
|
||||
|
||||
Args:
|
||||
config_path: Path to SSH config (default: ~/.ssh/config)
|
||||
|
||||
Returns:
|
||||
Dict mapping host aliases to their configuration:
|
||||
{
|
||||
'host-alias': {
|
||||
'hostname': '100.64.1.10',
|
||||
'user': 'admin',
|
||||
'port': '22',
|
||||
'identityfile': '~/.ssh/id_ed25519'
|
||||
}
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> hosts = parse_ssh_config()
|
||||
>>> hosts['homelab-1']['hostname']
|
||||
'100.64.1.10'
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = Path.home() / '.ssh' / 'config'
|
||||
|
||||
if not config_path.exists():
|
||||
logger.warning(f"SSH config not found: {config_path}")
|
||||
return {}
|
||||
|
||||
hosts = {}
|
||||
current_host = None
|
||||
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
|
||||
# Skip comments and empty lines
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Host directive
|
||||
if line.lower().startswith('host '):
|
||||
host_alias = line.split(maxsplit=1)[1]
|
||||
# Skip wildcards
|
||||
if '*' not in host_alias and '?' not in host_alias:
|
||||
current_host = host_alias
|
||||
hosts[current_host] = {}
|
||||
|
||||
# Configuration directives
|
||||
elif current_host:
|
||||
parts = line.split(maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
key, value = parts
|
||||
hosts[current_host][key.lower()] = value
|
||||
|
||||
return hosts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing SSH config: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def parse_sshsync_config(config_path: Optional[Path] = None) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Parse sshsync config file for group definitions.
|
||||
|
||||
Args:
|
||||
config_path: Path to sshsync config (default: ~/.config/sshsync/config.yaml)
|
||||
|
||||
Returns:
|
||||
Dict mapping group names to list of hosts:
|
||||
{
|
||||
'production': ['prod-web-01', 'prod-db-01'],
|
||||
'development': ['dev-laptop', 'dev-desktop']
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> groups = parse_sshsync_config()
|
||||
>>> groups['production']
|
||||
['prod-web-01', 'prod-db-01']
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = Path.home() / '.config' / 'sshsync' / 'config.yaml'
|
||||
|
||||
if not config_path.exists():
|
||||
logger.warning(f"sshsync config not found: {config_path}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
return config.get('groups', {})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing sshsync config: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_timestamp(iso: bool = True) -> str:
|
||||
"""
|
||||
Get current timestamp.
|
||||
|
||||
Args:
|
||||
iso: If True, return ISO format; otherwise human-readable
|
||||
|
||||
Returns:
|
||||
Timestamp string
|
||||
|
||||
Example:
|
||||
>>> get_timestamp(iso=True)
|
||||
"2025-10-19T19:43:41Z"
|
||||
>>> get_timestamp(iso=False)
|
||||
"2025-10-19 19:43:41"
|
||||
"""
|
||||
now = datetime.now()
|
||||
if iso:
|
||||
return now.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
else:
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def safe_execute(func, *args, default=None, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with error handling.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Positional arguments
|
||||
default: Value to return on error
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result or default on error
|
||||
|
||||
Example:
|
||||
>>> safe_execute(int, "not_a_number", default=0)
|
||||
0
|
||||
>>> safe_execute(int, "42")
|
||||
42
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {func.__name__}: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def validate_path(path: str, must_exist: bool = True) -> bool:
|
||||
"""
|
||||
Check if path is valid and accessible.
|
||||
|
||||
Args:
|
||||
path: Path to validate
|
||||
must_exist: If True, path must exist
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
|
||||
Example:
|
||||
>>> validate_path("/tmp")
|
||||
True
|
||||
>>> validate_path("/nonexistent", must_exist=True)
|
||||
False
|
||||
"""
|
||||
p = Path(path).expanduser()
|
||||
|
||||
if must_exist:
|
||||
return p.exists()
|
||||
else:
|
||||
# Check if parent directory exists (for paths that will be created)
|
||||
return p.parent.exists()
|
||||
|
||||
|
||||
def parse_disk_usage(df_output: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse 'df' command output.
|
||||
|
||||
Args:
|
||||
df_output: Output from 'df -h' command
|
||||
|
||||
Returns:
|
||||
Dict with disk usage info:
|
||||
{
|
||||
'filesystem': '/dev/sda1',
|
||||
'size': '100G',
|
||||
'used': '45G',
|
||||
'available': '50G',
|
||||
'use_pct': 45,
|
||||
'mount': '/'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> output = "Filesystem Size Used Avail Use% Mounted on\\n/dev/sda1 100G 45G 50G 45% /"
|
||||
>>> parse_disk_usage(output)
|
||||
{'filesystem': '/dev/sda1', 'size': '100G', ...}
|
||||
"""
|
||||
lines = df_output.strip().split('\n')
|
||||
if len(lines) < 2:
|
||||
return {}
|
||||
|
||||
# Parse last line (actual data, not header)
|
||||
data_line = lines[-1]
|
||||
parts = data_line.split()
|
||||
|
||||
if len(parts) < 6:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return {
|
||||
'filesystem': parts[0],
|
||||
'size': parts[1],
|
||||
'used': parts[2],
|
||||
'available': parts[3],
|
||||
'use_pct': int(parts[4].rstrip('%')),
|
||||
'mount': parts[5]
|
||||
}
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error(f"Error parsing disk usage: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def parse_memory_usage(free_output: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse 'free' command output (Linux).
|
||||
|
||||
Args:
|
||||
free_output: Output from 'free -m' command
|
||||
|
||||
Returns:
|
||||
Dict with memory info:
|
||||
{
|
||||
'total': 16384, # MB
|
||||
'used': 8192,
|
||||
'free': 8192,
|
||||
'use_pct': 50.0
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> output = "Mem: 16384 8192 8192 0 0 0"
|
||||
>>> parse_memory_usage(output)
|
||||
{'total': 16384, 'used': 8192, ...}
|
||||
"""
|
||||
lines = free_output.strip().split('\n')
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('Mem:'):
|
||||
parts = line.split()
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
total = int(parts[1])
|
||||
used = int(parts[2])
|
||||
free = int(parts[3]) if len(parts) > 3 else (total - used)
|
||||
|
||||
return {
|
||||
'total': total,
|
||||
'used': used,
|
||||
'free': free,
|
||||
'use_pct': (used / total * 100) if total > 0 else 0
|
||||
}
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error(f"Error parsing memory usage: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def parse_cpu_load(uptime_output: str) -> Dict[str, float]:
|
||||
"""
|
||||
Parse 'uptime' command output for load averages.
|
||||
|
||||
Args:
|
||||
uptime_output: Output from 'uptime' command
|
||||
|
||||
Returns:
|
||||
Dict with load averages:
|
||||
{
|
||||
'load_1min': 0.45,
|
||||
'load_5min': 0.38,
|
||||
'load_15min': 0.32
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> output = "19:43:41 up 5 days, 2:15, 3 users, load average: 0.45, 0.38, 0.32"
|
||||
>>> parse_cpu_load(output)
|
||||
{'load_1min': 0.45, 'load_5min': 0.38, 'load_15min': 0.32}
|
||||
"""
|
||||
# Find "load average:" part
|
||||
match = re.search(r'load average:\s+([\d.]+),\s+([\d.]+),\s+([\d.]+)', uptime_output)
|
||||
|
||||
if match:
|
||||
try:
|
||||
return {
|
||||
'load_1min': float(match.group(1)),
|
||||
'load_5min': float(match.group(2)),
|
||||
'load_15min': float(match.group(3))
|
||||
}
|
||||
except ValueError as e:
|
||||
logger.error(f"Error parsing CPU load: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def format_host_status(host: str, online: bool, groups: List[str],
|
||||
latency: Optional[int] = None,
|
||||
tailscale_connected: bool = False) -> str:
|
||||
"""
|
||||
Format host status as display string.
|
||||
|
||||
Args:
|
||||
host: Host name
|
||||
online: Whether host is online
|
||||
groups: List of groups host belongs to
|
||||
latency: Latency in ms (optional)
|
||||
tailscale_connected: Tailscale connection status
|
||||
|
||||
Returns:
|
||||
Formatted status string
|
||||
|
||||
Example:
|
||||
>>> format_host_status("web-01", True, ["production", "web"], 25, True)
|
||||
"🟢 web-01 (production, web) - Online - Tailscale: Connected | Latency: 25ms"
|
||||
"""
|
||||
icon = "🟢" if online else "🔴"
|
||||
status = "Online" if online else "Offline"
|
||||
group_str = ", ".join(groups) if groups else "no group"
|
||||
|
||||
parts = [f"{icon} {host} ({group_str}) - {status}"]
|
||||
|
||||
if tailscale_connected:
|
||||
parts.append("Tailscale: Connected")
|
||||
|
||||
if latency is not None and online:
|
||||
parts.append(f"Latency: {latency}ms")
|
||||
|
||||
return " - ".join(parts)
|
||||
|
||||
|
||||
def calculate_load_score(cpu_pct: float, mem_pct: float, disk_pct: float) -> float:
|
||||
"""
|
||||
Calculate composite load score for a machine.
|
||||
|
||||
Args:
|
||||
cpu_pct: CPU usage percentage (0-100)
|
||||
mem_pct: Memory usage percentage (0-100)
|
||||
disk_pct: Disk usage percentage (0-100)
|
||||
|
||||
Returns:
|
||||
Load score (0-1, lower is better)
|
||||
|
||||
Formula:
|
||||
score = (cpu * 0.4) + (mem * 0.3) + (disk * 0.3)
|
||||
|
||||
Example:
|
||||
>>> calculate_load_score(45, 60, 40)
|
||||
0.48 # (0.45*0.4 + 0.60*0.3 + 0.40*0.3)
|
||||
"""
|
||||
return (cpu_pct * 0.4 + mem_pct * 0.3 + disk_pct * 0.3) / 100
|
||||
|
||||
|
||||
def classify_load_status(score: float) -> str:
|
||||
"""
|
||||
Classify load score into status category.
|
||||
|
||||
Args:
|
||||
score: Load score (0-1)
|
||||
|
||||
Returns:
|
||||
Status string: "low", "moderate", or "high"
|
||||
|
||||
Example:
|
||||
>>> classify_load_status(0.28)
|
||||
"low"
|
||||
>>> classify_load_status(0.55)
|
||||
"moderate"
|
||||
>>> classify_load_status(0.82)
|
||||
"high"
|
||||
"""
|
||||
if score < 0.4:
|
||||
return "low"
|
||||
elif score < 0.7:
|
||||
return "moderate"
|
||||
else:
|
||||
return "high"
|
||||
|
||||
|
||||
def classify_latency(latency_ms: int) -> Tuple[str, str]:
|
||||
"""
|
||||
Classify network latency.
|
||||
|
||||
Args:
|
||||
latency_ms: Latency in milliseconds
|
||||
|
||||
Returns:
|
||||
Tuple of (status, description)
|
||||
|
||||
Example:
|
||||
>>> classify_latency(25)
|
||||
("excellent", "Ideal for interactive tasks")
|
||||
>>> classify_latency(150)
|
||||
("fair", "May impact interactive workflows")
|
||||
"""
|
||||
if latency_ms < 50:
|
||||
return ("excellent", "Ideal for interactive tasks")
|
||||
elif latency_ms < 100:
|
||||
return ("good", "Suitable for most operations")
|
||||
elif latency_ms < 200:
|
||||
return ("fair", "May impact interactive workflows")
|
||||
else:
|
||||
return ("poor", "Investigate network issues")
|
||||
|
||||
|
||||
def get_hosts_from_groups(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
Get list of hosts in a group.
|
||||
|
||||
Args:
|
||||
group: Group name
|
||||
groups_config: Groups configuration dict
|
||||
|
||||
Returns:
|
||||
List of host names in group
|
||||
|
||||
Example:
|
||||
>>> groups = {'production': ['web-01', 'db-01']}
|
||||
>>> get_hosts_from_groups('production', groups)
|
||||
['web-01', 'db-01']
|
||||
"""
|
||||
return groups_config.get(group, [])
|
||||
|
||||
|
||||
def get_groups_for_host(host: str, groups_config: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
Get list of groups a host belongs to.
|
||||
|
||||
Args:
|
||||
host: Host name
|
||||
groups_config: Groups configuration dict
|
||||
|
||||
Returns:
|
||||
List of group names
|
||||
|
||||
Example:
|
||||
>>> groups = {'production': ['web-01'], 'web': ['web-01', 'web-02']}
|
||||
>>> get_groups_for_host('web-01', groups)
|
||||
['production', 'web']
|
||||
"""
|
||||
return [group for group, hosts in groups_config.items() if host in hosts]
|
||||
|
||||
|
||||
def run_command(command: str, timeout: int = 10) -> Tuple[bool, str, str]:
|
||||
"""
|
||||
Run shell command with timeout.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (success, stdout, stderr)
|
||||
|
||||
Example:
|
||||
>>> success, stdout, stderr = run_command("echo hello")
|
||||
>>> success
|
||||
True
|
||||
>>> stdout.strip()
|
||||
"hello"
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return (
|
||||
result.returncode == 0,
|
||||
result.stdout,
|
||||
result.stderr
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return (False, "", f"Command timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
return (False, "", str(e))
|
||||
|
||||
|
||||
def main():
|
||||
"""Test helper functions."""
|
||||
print("Testing helper functions...\n")
|
||||
|
||||
# Test formatting
|
||||
print("1. Format bytes:")
|
||||
print(f" 12582912 bytes = {format_bytes(12582912)}")
|
||||
print(f" 1610612736 bytes = {format_bytes(1610612736)}")
|
||||
|
||||
print("\n2. Format duration:")
|
||||
print(f" 135 seconds = {format_duration(135)}")
|
||||
print(f" 5430 seconds = {format_duration(5430)}")
|
||||
|
||||
print("\n3. Format percentage:")
|
||||
print(f" 45.567 = {format_percentage(45.567)}")
|
||||
|
||||
print("\n4. Calculate load score:")
|
||||
score = calculate_load_score(45, 60, 40)
|
||||
print(f" CPU 45%, Mem 60%, Disk 40% = {score:.2f}")
|
||||
print(f" Status: {classify_load_status(score)}")
|
||||
|
||||
print("\n5. Classify latency:")
|
||||
latencies = [25, 75, 150, 250]
|
||||
for lat in latencies:
|
||||
status, desc = classify_latency(lat)
|
||||
print(f" {lat}ms: {status} - {desc}")
|
||||
|
||||
print("\n6. Parse SSH config:")
|
||||
ssh_hosts = parse_ssh_config()
|
||||
print(f" Found {len(ssh_hosts)} hosts")
|
||||
|
||||
print("\n7. Parse sshsync config:")
|
||||
groups = parse_sshsync_config()
|
||||
print(f" Found {len(groups)} groups")
|
||||
for group, hosts in groups.items():
|
||||
print(f" - {group}: {len(hosts)} hosts")
|
||||
|
||||
print("\n✅ All helpers tested successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
scripts/utils/validators/__init__.py
Normal file
43
scripts/utils/validators/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Validators package for Tailscale SSH Sync Agent.
|
||||
"""
|
||||
|
||||
from .parameter_validator import (
|
||||
ValidationError,
|
||||
validate_host,
|
||||
validate_group,
|
||||
validate_path_exists,
|
||||
validate_timeout,
|
||||
validate_command
|
||||
)
|
||||
|
||||
from .host_validator import (
|
||||
validate_ssh_config,
|
||||
validate_host_reachable,
|
||||
validate_group_members,
|
||||
get_invalid_hosts
|
||||
)
|
||||
|
||||
from .connection_validator import (
|
||||
validate_ssh_connection,
|
||||
validate_tailscale_connection,
|
||||
validate_ssh_key,
|
||||
get_connection_diagnostics
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'ValidationError',
|
||||
'validate_host',
|
||||
'validate_group',
|
||||
'validate_path_exists',
|
||||
'validate_timeout',
|
||||
'validate_command',
|
||||
'validate_ssh_config',
|
||||
'validate_host_reachable',
|
||||
'validate_group_members',
|
||||
'get_invalid_hosts',
|
||||
'validate_ssh_connection',
|
||||
'validate_tailscale_connection',
|
||||
'validate_ssh_key',
|
||||
'get_connection_diagnostics',
|
||||
]
|
||||
275
scripts/utils/validators/connection_validator.py
Normal file
275
scripts/utils/validators/connection_validator.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Connection validators for Tailscale SSH Sync Agent.
|
||||
Validates SSH and Tailscale connections.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
|
||||
from .parameter_validator import ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_ssh_connection(host: str, timeout: int = 10) -> bool:
|
||||
"""
|
||||
Test SSH connection works.
|
||||
|
||||
Args:
|
||||
host: Host to connect to
|
||||
timeout: Connection timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if SSH connection successful
|
||||
|
||||
Raises:
|
||||
ValidationError: If connection fails
|
||||
|
||||
Example:
|
||||
>>> validate_ssh_connection("web-01")
|
||||
True
|
||||
"""
|
||||
try:
|
||||
# Try to execute a simple command via SSH
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "ConnectTimeout={}".format(timeout),
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
host, "echo", "test"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout + 5
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
else:
|
||||
# Parse error message
|
||||
error_msg = result.stderr.strip()
|
||||
|
||||
if "Permission denied" in error_msg:
|
||||
raise ValidationError(
|
||||
f"SSH authentication failed for '{host}'\n"
|
||||
"Check:\n"
|
||||
"1. SSH key is added: ssh-add -l\n"
|
||||
"2. Public key is on remote: cat ~/.ssh/authorized_keys\n"
|
||||
"3. User/key in SSH config is correct"
|
||||
)
|
||||
elif "Connection refused" in error_msg:
|
||||
raise ValidationError(
|
||||
f"SSH connection refused for '{host}'\n"
|
||||
"Check:\n"
|
||||
"1. SSH server is running on remote\n"
|
||||
"2. Port 22 is not blocked by firewall"
|
||||
)
|
||||
elif "Connection timed out" in error_msg or "timeout" in error_msg.lower():
|
||||
raise ValidationError(
|
||||
f"SSH connection timed out for '{host}'\n"
|
||||
"Check:\n"
|
||||
"1. Host is reachable (ping test)\n"
|
||||
"2. Tailscale is connected\n"
|
||||
"3. Network connectivity"
|
||||
)
|
||||
else:
|
||||
raise ValidationError(
|
||||
f"SSH connection failed for '{host}': {error_msg}"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise ValidationError(
|
||||
f"SSH connection timed out for '{host}' (>{timeout}s)"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Error testing SSH connection to '{host}': {e}")
|
||||
|
||||
|
||||
def validate_tailscale_connection(host: str) -> bool:
|
||||
"""
|
||||
Test Tailscale connectivity to host.
|
||||
|
||||
Args:
|
||||
host: Host to check
|
||||
|
||||
Returns:
|
||||
True if Tailscale connection active
|
||||
|
||||
Raises:
|
||||
ValidationError: If Tailscale not connected
|
||||
|
||||
Example:
|
||||
>>> validate_tailscale_connection("web-01")
|
||||
True
|
||||
"""
|
||||
try:
|
||||
# Check if tailscale is running
|
||||
result = subprocess.run(
|
||||
["tailscale", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ValidationError(
|
||||
"Tailscale is not running\n"
|
||||
"Start Tailscale: sudo tailscale up"
|
||||
)
|
||||
|
||||
# Check if specific host is in the network
|
||||
if host in result.stdout or host.replace('-', '.') in result.stdout:
|
||||
return True
|
||||
else:
|
||||
raise ValidationError(
|
||||
f"Host '{host}' not found in Tailscale network\n"
|
||||
"Ensure host is:\n"
|
||||
"1. Connected to Tailscale\n"
|
||||
"2. In the same tailnet\n"
|
||||
"3. Not expired/offline"
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise ValidationError(
|
||||
"Tailscale not installed\n"
|
||||
"Install: https://tailscale.com/download"
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise ValidationError("Timeout checking Tailscale status")
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Error checking Tailscale connection: {e}")
|
||||
|
||||
|
||||
def validate_ssh_key(host: str) -> bool:
|
||||
"""
|
||||
Check SSH key authentication is working.
|
||||
|
||||
Args:
|
||||
host: Host to check
|
||||
|
||||
Returns:
|
||||
True if SSH key auth works
|
||||
|
||||
Raises:
|
||||
ValidationError: If key auth fails
|
||||
|
||||
Example:
|
||||
>>> validate_ssh_key("web-01")
|
||||
True
|
||||
"""
|
||||
try:
|
||||
# Test connection with explicit key-only auth
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "BatchMode=yes",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
"-o", "ConnectTimeout=5",
|
||||
host, "echo", "test"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
else:
|
||||
error_msg = result.stderr.strip()
|
||||
|
||||
if "Permission denied" in error_msg:
|
||||
raise ValidationError(
|
||||
f"SSH key authentication failed for '{host}'\n"
|
||||
"Fix:\n"
|
||||
"1. Add your SSH key: ssh-add ~/.ssh/id_ed25519\n"
|
||||
"2. Copy public key to remote: ssh-copy-id {}\n"
|
||||
"3. Verify: ssh -v {} 2>&1 | grep -i 'offering public key'".format(host, host)
|
||||
)
|
||||
else:
|
||||
raise ValidationError(
|
||||
f"SSH key validation failed for '{host}': {error_msg}"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise ValidationError(f"Timeout validating SSH key for '{host}'")
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Error validating SSH key for '{host}': {e}")
|
||||
|
||||
|
||||
def get_connection_diagnostics(host: str) -> Dict[str, any]:
|
||||
"""
|
||||
Comprehensive connection testing.
|
||||
|
||||
Args:
|
||||
host: Host to diagnose
|
||||
|
||||
Returns:
|
||||
Dict with diagnostic results:
|
||||
{
|
||||
'ping': {'success': bool, 'message': str},
|
||||
'ssh': {'success': bool, 'message': str},
|
||||
'tailscale': {'success': bool, 'message': str},
|
||||
'ssh_key': {'success': bool, 'message': str}
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> diag = get_connection_diagnostics("web-01")
|
||||
>>> diag['ssh']['success']
|
||||
True
|
||||
"""
|
||||
diagnostics = {}
|
||||
|
||||
# Test 1: Ping
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ping", "-c", "1", "-W", "2", host],
|
||||
capture_output=True,
|
||||
timeout=3
|
||||
)
|
||||
diagnostics['ping'] = {
|
||||
'success': result.returncode == 0,
|
||||
'message': 'Host is reachable' if result.returncode == 0 else 'Host not reachable'
|
||||
}
|
||||
except Exception as e:
|
||||
diagnostics['ping'] = {'success': False, 'message': str(e)}
|
||||
|
||||
# Test 2: SSH connection
|
||||
try:
|
||||
validate_ssh_connection(host, timeout=5)
|
||||
diagnostics['ssh'] = {'success': True, 'message': 'SSH connection works'}
|
||||
except ValidationError as e:
|
||||
diagnostics['ssh'] = {'success': False, 'message': str(e).split('\n')[0]}
|
||||
|
||||
# Test 3: Tailscale
|
||||
try:
|
||||
validate_tailscale_connection(host)
|
||||
diagnostics['tailscale'] = {'success': True, 'message': 'Tailscale connected'}
|
||||
except ValidationError as e:
|
||||
diagnostics['tailscale'] = {'success': False, 'message': str(e).split('\n')[0]}
|
||||
|
||||
# Test 4: SSH key
|
||||
try:
|
||||
validate_ssh_key(host)
|
||||
diagnostics['ssh_key'] = {'success': True, 'message': 'SSH key authentication works'}
|
||||
except ValidationError as e:
|
||||
diagnostics['ssh_key'] = {'success': False, 'message': str(e).split('\n')[0]}
|
||||
|
||||
return diagnostics
|
||||
|
||||
|
||||
def main():
|
||||
"""Test connection validators."""
|
||||
print("Testing connection validators...\n")
|
||||
|
||||
print("1. Testing connection diagnostics:")
|
||||
try:
|
||||
diag = get_connection_diagnostics("localhost")
|
||||
print(" Results:")
|
||||
for test, result in diag.items():
|
||||
status = "✓" if result['success'] else "✗"
|
||||
print(f" {status} {test}: {result['message']}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print("\n✅ Connection validators tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
232
scripts/utils/validators/host_validator.py
Normal file
232
scripts/utils/validators/host_validator.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Host validators for Tailscale SSH Sync Agent.
|
||||
Validates host configuration and availability.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from .parameter_validator import ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_ssh_config(host: str, config_path: Optional[Path] = None) -> bool:
|
||||
"""
|
||||
Check if host has SSH config entry.
|
||||
|
||||
Args:
|
||||
host: Host name to check
|
||||
config_path: Path to SSH config (default: ~/.ssh/config)
|
||||
|
||||
Returns:
|
||||
True if host is in SSH config
|
||||
|
||||
Raises:
|
||||
ValidationError: If host not found in config
|
||||
|
||||
Example:
|
||||
>>> validate_ssh_config("web-01")
|
||||
True
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = Path.home() / '.ssh' / 'config'
|
||||
|
||||
if not config_path.exists():
|
||||
raise ValidationError(
|
||||
f"SSH config file not found: {config_path}\n"
|
||||
"Create ~/.ssh/config with your host definitions"
|
||||
)
|
||||
|
||||
# Parse SSH config for this host
|
||||
host_found = False
|
||||
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.lower().startswith('host ') and host in line:
|
||||
host_found = True
|
||||
break
|
||||
|
||||
if not host_found:
|
||||
raise ValidationError(
|
||||
f"Host '{host}' not found in SSH config: {config_path}\n"
|
||||
"Add host to SSH config:\n"
|
||||
f"Host {host}\n"
|
||||
f" HostName <IP_ADDRESS>\n"
|
||||
f" User <USERNAME>"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except IOError as e:
|
||||
raise ValidationError(f"Error reading SSH config: {e}")
|
||||
|
||||
|
||||
def validate_host_reachable(host: str, timeout: int = 5) -> bool:
|
||||
"""
|
||||
Check if host is reachable via ping.
|
||||
|
||||
Args:
|
||||
host: Host name to check
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if host is reachable
|
||||
|
||||
Raises:
|
||||
ValidationError: If host is not reachable
|
||||
|
||||
Example:
|
||||
>>> validate_host_reachable("web-01", timeout=5)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
# Try to resolve via SSH config first
|
||||
result = subprocess.run(
|
||||
["ssh", "-G", host],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Extract hostname from SSH config
|
||||
for line in result.stdout.split('\n'):
|
||||
if line.startswith('hostname '):
|
||||
actual_host = line.split()[1]
|
||||
break
|
||||
else:
|
||||
actual_host = host
|
||||
else:
|
||||
actual_host = host
|
||||
|
||||
# Ping the host
|
||||
ping_result = subprocess.run(
|
||||
["ping", "-c", "1", "-W", str(timeout), actual_host],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout + 1
|
||||
)
|
||||
|
||||
if ping_result.returncode == 0:
|
||||
return True
|
||||
else:
|
||||
raise ValidationError(
|
||||
f"Host '{host}' ({actual_host}) is not reachable\n"
|
||||
"Check:\n"
|
||||
"1. Host is powered on\n"
|
||||
"2. Tailscale is connected\n"
|
||||
"3. Network connectivity"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise ValidationError(f"Timeout checking host '{host}' (>{timeout}s)")
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Error checking host '{host}': {e}")
|
||||
|
||||
|
||||
def validate_group_members(group: str, groups_config: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
Ensure group has valid members.
|
||||
|
||||
Args:
|
||||
group: Group name
|
||||
groups_config: Groups configuration dict
|
||||
|
||||
Returns:
|
||||
List of valid hosts in group
|
||||
|
||||
Raises:
|
||||
ValidationError: If group is empty or has no valid members
|
||||
|
||||
Example:
|
||||
>>> groups = {'production': ['web-01', 'db-01']}
|
||||
>>> validate_group_members('production', groups)
|
||||
['web-01', 'db-01']
|
||||
"""
|
||||
if group not in groups_config:
|
||||
raise ValidationError(
|
||||
f"Group '{group}' not found in configuration\n"
|
||||
f"Available groups: {', '.join(groups_config.keys())}"
|
||||
)
|
||||
|
||||
members = groups_config[group]
|
||||
|
||||
if not members:
|
||||
raise ValidationError(
|
||||
f"Group '{group}' has no members\n"
|
||||
f"Add hosts to group with: sshsync gadd {group}"
|
||||
)
|
||||
|
||||
if not isinstance(members, list):
|
||||
raise ValidationError(
|
||||
f"Invalid group configuration for '{group}': members must be a list"
|
||||
)
|
||||
|
||||
return members
|
||||
|
||||
|
||||
def get_invalid_hosts(hosts: List[str], config_path: Optional[Path] = None) -> List[str]:
|
||||
"""
|
||||
Find hosts without valid SSH config.
|
||||
|
||||
Args:
|
||||
hosts: List of host names
|
||||
config_path: Path to SSH config
|
||||
|
||||
Returns:
|
||||
List of hosts without valid config
|
||||
|
||||
Example:
|
||||
>>> get_invalid_hosts(["web-01", "nonexistent"])
|
||||
["nonexistent"]
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = Path.home() / '.ssh' / 'config'
|
||||
|
||||
if not config_path.exists():
|
||||
return hosts # All invalid if no config
|
||||
|
||||
# Parse SSH config
|
||||
valid_hosts = set()
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.lower().startswith('host '):
|
||||
host_alias = line.split(maxsplit=1)[1]
|
||||
if '*' not in host_alias and '?' not in host_alias:
|
||||
valid_hosts.add(host_alias)
|
||||
except IOError:
|
||||
return hosts
|
||||
|
||||
# Find invalid hosts
|
||||
return [h for h in hosts if h not in valid_hosts]
|
||||
|
||||
|
||||
def main():
|
||||
"""Test host validators."""
|
||||
print("Testing host validators...\n")
|
||||
|
||||
print("1. Testing validate_ssh_config():")
|
||||
try:
|
||||
validate_ssh_config("localhost")
|
||||
print(" ✓ localhost has SSH config")
|
||||
except ValidationError as e:
|
||||
print(f" Note: {e.args[0].split(chr(10))[0]}")
|
||||
|
||||
print("\n2. Testing get_invalid_hosts():")
|
||||
test_hosts = ["localhost", "nonexistent-host-12345"]
|
||||
invalid = get_invalid_hosts(test_hosts)
|
||||
print(f" Invalid hosts: {invalid}")
|
||||
|
||||
print("\n✅ Host validators tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
363
scripts/utils/validators/parameter_validator.py
Normal file
363
scripts/utils/validators/parameter_validator.py
Normal file
@@ -0,0 +1,363 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Parameter validators for Tailscale SSH Sync Agent.
|
||||
Validates user inputs before making operations.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Raised when validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
def validate_host(host: str, valid_hosts: Optional[List[str]] = None) -> str:
|
||||
"""
|
||||
Validate host parameter.
|
||||
|
||||
Args:
|
||||
host: Host name or alias
|
||||
valid_hosts: List of valid hosts (None to skip check)
|
||||
|
||||
Returns:
|
||||
str: Validated and normalized host name
|
||||
|
||||
Raises:
|
||||
ValidationError: If host is invalid
|
||||
|
||||
Example:
|
||||
>>> validate_host("web-01")
|
||||
"web-01"
|
||||
>>> validate_host("web-01", ["web-01", "web-02"])
|
||||
"web-01"
|
||||
"""
|
||||
if not host:
|
||||
raise ValidationError("Host cannot be empty")
|
||||
|
||||
if not isinstance(host, str):
|
||||
raise ValidationError(f"Host must be string, got {type(host)}")
|
||||
|
||||
# Normalize (strip whitespace, lowercase for comparison)
|
||||
host = host.strip()
|
||||
|
||||
# Basic validation: alphanumeric, dash, underscore, dot
|
||||
if not re.match(r'^[a-zA-Z0-9._-]+$', host):
|
||||
raise ValidationError(
|
||||
f"Invalid host name format: {host}\n"
|
||||
"Host names must contain only letters, numbers, dots, dashes, and underscores"
|
||||
)
|
||||
|
||||
# Check if valid (if list provided)
|
||||
if valid_hosts:
|
||||
# Try exact match first
|
||||
if host in valid_hosts:
|
||||
return host
|
||||
|
||||
# Try case-insensitive match
|
||||
for valid_host in valid_hosts:
|
||||
if host.lower() == valid_host.lower():
|
||||
return valid_host
|
||||
|
||||
# Not found - provide suggestions
|
||||
suggestions = [h for h in valid_hosts if host[:3].lower() in h.lower()]
|
||||
raise ValidationError(
|
||||
f"Invalid host: {host}\n"
|
||||
f"Valid options: {', '.join(valid_hosts[:10])}\n"
|
||||
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
|
||||
)
|
||||
|
||||
return host
|
||||
|
||||
|
||||
def validate_group(group: str, valid_groups: Optional[List[str]] = None) -> str:
|
||||
"""
|
||||
Validate group parameter.
|
||||
|
||||
Args:
|
||||
group: Group name
|
||||
valid_groups: List of valid groups (None to skip check)
|
||||
|
||||
Returns:
|
||||
str: Validated group name
|
||||
|
||||
Raises:
|
||||
ValidationError: If group is invalid
|
||||
|
||||
Example:
|
||||
>>> validate_group("production")
|
||||
"production"
|
||||
>>> validate_group("prod", ["production", "development"])
|
||||
ValidationError: Invalid group: prod
|
||||
"""
|
||||
if not group:
|
||||
raise ValidationError("Group cannot be empty")
|
||||
|
||||
if not isinstance(group, str):
|
||||
raise ValidationError(f"Group must be string, got {type(group)}")
|
||||
|
||||
# Normalize
|
||||
group = group.strip().lower()
|
||||
|
||||
# Basic validation
|
||||
if not re.match(r'^[a-z0-9_-]+$', group):
|
||||
raise ValidationError(
|
||||
f"Invalid group name format: {group}\n"
|
||||
"Group names must contain only lowercase letters, numbers, dashes, and underscores"
|
||||
)
|
||||
|
||||
# Check if valid (if list provided)
|
||||
if valid_groups:
|
||||
if group not in valid_groups:
|
||||
suggestions = [g for g in valid_groups if group[:3] in g]
|
||||
raise ValidationError(
|
||||
f"Invalid group: {group}\n"
|
||||
f"Valid groups: {', '.join(valid_groups)}\n"
|
||||
+ (f"Did you mean: {', '.join(suggestions[:3])}?" if suggestions else "")
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def validate_path_exists(path: str, must_be_file: bool = False,
|
||||
must_be_dir: bool = False) -> Path:
|
||||
"""
|
||||
Validate path exists and is accessible.
|
||||
|
||||
Args:
|
||||
path: Path to validate
|
||||
must_be_file: If True, path must be a file
|
||||
must_be_dir: If True, path must be a directory
|
||||
|
||||
Returns:
|
||||
Path: Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If path is invalid
|
||||
|
||||
Example:
|
||||
>>> validate_path_exists("/tmp", must_be_dir=True)
|
||||
Path('/tmp')
|
||||
>>> validate_path_exists("/nonexistent")
|
||||
ValidationError: Path does not exist: /nonexistent
|
||||
"""
|
||||
if not path:
|
||||
raise ValidationError("Path cannot be empty")
|
||||
|
||||
p = Path(path).expanduser().resolve()
|
||||
|
||||
if not p.exists():
|
||||
raise ValidationError(
|
||||
f"Path does not exist: {path}\n"
|
||||
f"Resolved to: {p}"
|
||||
)
|
||||
|
||||
if must_be_file and not p.is_file():
|
||||
raise ValidationError(f"Path must be a file: {path}")
|
||||
|
||||
if must_be_dir and not p.is_dir():
|
||||
raise ValidationError(f"Path must be a directory: {path}")
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def validate_timeout(timeout: int, min_timeout: int = 1,
|
||||
max_timeout: int = 600) -> int:
|
||||
"""
|
||||
Validate timeout parameter.
|
||||
|
||||
Args:
|
||||
timeout: Timeout in seconds
|
||||
min_timeout: Minimum allowed timeout
|
||||
max_timeout: Maximum allowed timeout
|
||||
|
||||
Returns:
|
||||
int: Validated timeout
|
||||
|
||||
Raises:
|
||||
ValidationError: If timeout is invalid
|
||||
|
||||
Example:
|
||||
>>> validate_timeout(10)
|
||||
10
|
||||
>>> validate_timeout(0)
|
||||
ValidationError: Timeout must be between 1 and 600 seconds
|
||||
"""
|
||||
if not isinstance(timeout, int):
|
||||
raise ValidationError(f"Timeout must be integer, got {type(timeout)}")
|
||||
|
||||
if timeout < min_timeout:
|
||||
raise ValidationError(
|
||||
f"Timeout too low: {timeout}s (minimum: {min_timeout}s)"
|
||||
)
|
||||
|
||||
if timeout > max_timeout:
|
||||
raise ValidationError(
|
||||
f"Timeout too high: {timeout}s (maximum: {max_timeout}s)"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
|
||||
def validate_command(command: str, allow_dangerous: bool = False) -> str:
|
||||
"""
|
||||
Basic command safety validation.
|
||||
|
||||
Args:
|
||||
command: Command to validate
|
||||
allow_dangerous: If False, block potentially dangerous commands
|
||||
|
||||
Returns:
|
||||
str: Validated command
|
||||
|
||||
Raises:
|
||||
ValidationError: If command is invalid or dangerous
|
||||
|
||||
Example:
|
||||
>>> validate_command("ls -la")
|
||||
"ls -la"
|
||||
>>> validate_command("rm -rf /", allow_dangerous=False)
|
||||
ValidationError: Potentially dangerous command blocked: rm -rf
|
||||
"""
|
||||
if not command:
|
||||
raise ValidationError("Command cannot be empty")
|
||||
|
||||
if not isinstance(command, str):
|
||||
raise ValidationError(f"Command must be string, got {type(command)}")
|
||||
|
||||
command = command.strip()
|
||||
|
||||
if not allow_dangerous:
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = [
|
||||
(r'\brm\s+-rf\s+/', "rm -rf on root directory"),
|
||||
(r'\bmkfs\.', "filesystem formatting"),
|
||||
(r'\bdd\s+.*of=/dev/', "disk writing with dd"),
|
||||
(r':(){:|:&};:', "fork bomb"),
|
||||
(r'>\s*/dev/sd[a-z]', "direct disk writing"),
|
||||
]
|
||||
|
||||
for pattern, description in dangerous_patterns:
|
||||
if re.search(pattern, command, re.IGNORECASE):
|
||||
raise ValidationError(
|
||||
f"Potentially dangerous command blocked: {description}\n"
|
||||
f"Command: {command}\n"
|
||||
"Use allow_dangerous=True if you really want to execute this"
|
||||
)
|
||||
|
||||
return command
|
||||
|
||||
|
||||
def validate_hosts_list(hosts: List[str], valid_hosts: Optional[List[str]] = None) -> List[str]:
|
||||
"""
|
||||
Validate a list of hosts.
|
||||
|
||||
Args:
|
||||
hosts: List of host names
|
||||
valid_hosts: List of valid hosts (None to skip check)
|
||||
|
||||
Returns:
|
||||
List[str]: Validated host names
|
||||
|
||||
Raises:
|
||||
ValidationError: If any host is invalid
|
||||
|
||||
Example:
|
||||
>>> validate_hosts_list(["web-01", "web-02"])
|
||||
["web-01", "web-02"]
|
||||
"""
|
||||
if not hosts:
|
||||
raise ValidationError("Hosts list cannot be empty")
|
||||
|
||||
if not isinstance(hosts, list):
|
||||
raise ValidationError(f"Hosts must be list, got {type(hosts)}")
|
||||
|
||||
validated = []
|
||||
errors = []
|
||||
|
||||
for host in hosts:
|
||||
try:
|
||||
validated.append(validate_host(host, valid_hosts))
|
||||
except ValidationError as e:
|
||||
errors.append(str(e))
|
||||
|
||||
if errors:
|
||||
raise ValidationError(
|
||||
f"Invalid hosts in list:\n" + "\n".join(errors)
|
||||
)
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def main():
|
||||
"""Test validators."""
|
||||
print("Testing parameter validators...\n")
|
||||
|
||||
# Test host validation
|
||||
print("1. Testing validate_host():")
|
||||
try:
|
||||
host = validate_host("web-01", ["web-01", "web-02", "db-01"])
|
||||
print(f" ✓ Valid host: {host}")
|
||||
except ValidationError as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
|
||||
try:
|
||||
host = validate_host("invalid-host", ["web-01", "web-02"])
|
||||
print(f" ✗ Should have failed!")
|
||||
except ValidationError as e:
|
||||
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
|
||||
|
||||
# Test group validation
|
||||
print("\n2. Testing validate_group():")
|
||||
try:
|
||||
group = validate_group("production", ["production", "development"])
|
||||
print(f" ✓ Valid group: {group}")
|
||||
except ValidationError as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
|
||||
# Test path validation
|
||||
print("\n3. Testing validate_path_exists():")
|
||||
try:
|
||||
path = validate_path_exists("/tmp", must_be_dir=True)
|
||||
print(f" ✓ Valid path: {path}")
|
||||
except ValidationError as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
|
||||
# Test timeout validation
|
||||
print("\n4. Testing validate_timeout():")
|
||||
try:
|
||||
timeout = validate_timeout(10)
|
||||
print(f" ✓ Valid timeout: {timeout}s")
|
||||
except ValidationError as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
|
||||
try:
|
||||
timeout = validate_timeout(0)
|
||||
print(f" ✗ Should have failed!")
|
||||
except ValidationError as e:
|
||||
print(f" ✓ Correctly rejected: {e.args[0].split(chr(10))[0]}")
|
||||
|
||||
# Test command validation
|
||||
print("\n5. Testing validate_command():")
|
||||
try:
|
||||
cmd = validate_command("ls -la")
|
||||
print(f" ✓ Safe command: {cmd}")
|
||||
except ValidationError as e:
|
||||
print(f" ✗ Error: {e}")
|
||||
|
||||
try:
|
||||
cmd = validate_command("rm -rf /", allow_dangerous=False)
|
||||
print(f" ✗ Should have failed!")
|
||||
except ValidationError as e:
|
||||
print(f" ✓ Correctly blocked: {e.args[0].split(chr(10))[0]}")
|
||||
|
||||
print("\n✅ All parameter validators tested")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
445
scripts/workflow_executor.py
Normal file
445
scripts/workflow_executor.py
Normal file
@@ -0,0 +1,445 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Workflow executor for Tailscale SSH Sync Agent.
|
||||
Common multi-machine workflow automation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Add utils to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from utils.helpers import format_duration, get_timestamp
|
||||
from sshsync_wrapper import execute_on_group, execute_on_host, push_to_hosts
|
||||
from load_balancer import get_group_capacity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def deploy_workflow(code_path: str,
|
||||
staging_group: str,
|
||||
prod_group: str,
|
||||
run_tests: bool = True) -> Dict:
|
||||
"""
|
||||
Full deployment pipeline: staging → test → production.
|
||||
|
||||
Args:
|
||||
code_path: Path to code to deploy
|
||||
staging_group: Staging server group
|
||||
prod_group: Production server group
|
||||
run_tests: Whether to run tests on staging
|
||||
|
||||
Returns:
|
||||
Dict with deployment results
|
||||
|
||||
Example:
|
||||
>>> result = deploy_workflow("./dist", "staging", "production")
|
||||
>>> result['success']
|
||||
True
|
||||
>>> result['duration']
|
||||
"12m 45s"
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = {
|
||||
'stages': {},
|
||||
'success': False,
|
||||
'start_time': get_timestamp()
|
||||
}
|
||||
|
||||
try:
|
||||
# Stage 1: Deploy to staging
|
||||
logger.info("Stage 1: Deploying to staging...")
|
||||
stage1 = push_to_hosts(
|
||||
local_path=code_path,
|
||||
remote_path="/var/www/app",
|
||||
group=staging_group,
|
||||
recurse=True
|
||||
)
|
||||
|
||||
results['stages']['staging_deploy'] = stage1
|
||||
|
||||
if not stage1.get('success'):
|
||||
results['error'] = 'Staging deployment failed'
|
||||
return results
|
||||
|
||||
# Build on staging
|
||||
logger.info("Building on staging...")
|
||||
build_result = execute_on_group(
|
||||
staging_group,
|
||||
"cd /var/www/app && npm run build",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
results['stages']['staging_build'] = build_result
|
||||
|
||||
if not build_result.get('success'):
|
||||
results['error'] = 'Staging build failed'
|
||||
return results
|
||||
|
||||
# Stage 2: Run tests (if enabled)
|
||||
if run_tests:
|
||||
logger.info("Stage 2: Running tests...")
|
||||
test_result = execute_on_group(
|
||||
staging_group,
|
||||
"cd /var/www/app && npm test",
|
||||
timeout=600
|
||||
)
|
||||
|
||||
results['stages']['tests'] = test_result
|
||||
|
||||
if not test_result.get('success'):
|
||||
results['error'] = 'Tests failed on staging'
|
||||
return results
|
||||
|
||||
# Stage 3: Validation
|
||||
logger.info("Stage 3: Validating staging...")
|
||||
health_result = execute_on_group(
|
||||
staging_group,
|
||||
"curl -f http://localhost:3000/health || echo 'Health check failed'",
|
||||
timeout=10
|
||||
)
|
||||
|
||||
results['stages']['staging_validation'] = health_result
|
||||
|
||||
# Stage 4: Deploy to production
|
||||
logger.info("Stage 4: Deploying to production...")
|
||||
prod_deploy = push_to_hosts(
|
||||
local_path=code_path,
|
||||
remote_path="/var/www/app",
|
||||
group=prod_group,
|
||||
recurse=True
|
||||
)
|
||||
|
||||
results['stages']['production_deploy'] = prod_deploy
|
||||
|
||||
if not prod_deploy.get('success'):
|
||||
results['error'] = 'Production deployment failed'
|
||||
return results
|
||||
|
||||
# Build and restart on production
|
||||
logger.info("Building and restarting production...")
|
||||
prod_build = execute_on_group(
|
||||
prod_group,
|
||||
"cd /var/www/app && npm run build && pm2 restart app",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
results['stages']['production_build'] = prod_build
|
||||
|
||||
# Stage 5: Production verification
|
||||
logger.info("Stage 5: Verifying production...")
|
||||
prod_health = execute_on_group(
|
||||
prod_group,
|
||||
"curl -f http://localhost:3000/health",
|
||||
timeout=15
|
||||
)
|
||||
|
||||
results['stages']['production_verification'] = prod_health
|
||||
|
||||
# Success!
|
||||
results['success'] = True
|
||||
results['duration'] = format_duration(time.time() - start_time)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Deployment workflow error: {e}")
|
||||
results['error'] = str(e)
|
||||
results['duration'] = format_duration(time.time() - start_time)
|
||||
return results
|
||||
|
||||
|
||||
def backup_workflow(hosts: List[str],
|
||||
backup_paths: List[str],
|
||||
destination: str) -> Dict:
|
||||
"""
|
||||
Backup files from multiple hosts.
|
||||
|
||||
Args:
|
||||
hosts: List of hosts to backup from
|
||||
backup_paths: Paths to backup on each host
|
||||
destination: Local destination directory
|
||||
|
||||
Returns:
|
||||
Dict with backup results
|
||||
|
||||
Example:
|
||||
>>> result = backup_workflow(
|
||||
... ["db-01", "db-02"],
|
||||
... ["/var/lib/mysql"],
|
||||
... "./backups"
|
||||
... )
|
||||
>>> result['backed_up_hosts']
|
||||
2
|
||||
"""
|
||||
from sshsync_wrapper import pull_from_host
|
||||
|
||||
start_time = time.time()
|
||||
results = {
|
||||
'hosts': {},
|
||||
'success': True,
|
||||
'backed_up_hosts': 0
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
host_results = []
|
||||
|
||||
for backup_path in backup_paths:
|
||||
# Create timestamped backup directory
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
host_dest = f"{destination}/{host}_{timestamp}"
|
||||
|
||||
result = pull_from_host(
|
||||
host=host,
|
||||
remote_path=backup_path,
|
||||
local_path=host_dest,
|
||||
recurse=True
|
||||
)
|
||||
|
||||
host_results.append(result)
|
||||
|
||||
if not result.get('success'):
|
||||
results['success'] = False
|
||||
|
||||
results['hosts'][host] = host_results
|
||||
|
||||
if all(r.get('success') for r in host_results):
|
||||
results['backed_up_hosts'] += 1
|
||||
|
||||
results['duration'] = format_duration(time.time() - start_time)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_workflow(source_host: str,
|
||||
target_group: str,
|
||||
paths: List[str]) -> Dict:
|
||||
"""
|
||||
Sync files from one host to many.
|
||||
|
||||
Args:
|
||||
source_host: Host to pull from
|
||||
target_group: Group to push to
|
||||
paths: Paths to sync
|
||||
|
||||
Returns:
|
||||
Dict with sync results
|
||||
|
||||
Example:
|
||||
>>> result = sync_workflow(
|
||||
... "master-db",
|
||||
... "replica-dbs",
|
||||
... ["/var/lib/mysql/config"]
|
||||
... )
|
||||
>>> result['success']
|
||||
True
|
||||
"""
|
||||
from sshsync_wrapper import pull_from_host, push_to_hosts
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
start_time = time.time()
|
||||
results = {'paths': {}, 'success': True}
|
||||
|
||||
# Create temp directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
for path in paths:
|
||||
# Pull from source
|
||||
pull_result = pull_from_host(
|
||||
host=source_host,
|
||||
remote_path=path,
|
||||
local_path=f"{temp_dir}/{Path(path).name}",
|
||||
recurse=True
|
||||
)
|
||||
|
||||
if not pull_result.get('success'):
|
||||
results['paths'][path] = {
|
||||
'success': False,
|
||||
'error': 'Pull from source failed'
|
||||
}
|
||||
results['success'] = False
|
||||
continue
|
||||
|
||||
# Push to targets
|
||||
push_result = push_to_hosts(
|
||||
local_path=f"{temp_dir}/{Path(path).name}",
|
||||
remote_path=path,
|
||||
group=target_group,
|
||||
recurse=True
|
||||
)
|
||||
|
||||
results['paths'][path] = {
|
||||
'pull': pull_result,
|
||||
'push': push_result,
|
||||
'success': push_result.get('success', False)
|
||||
}
|
||||
|
||||
if not push_result.get('success'):
|
||||
results['success'] = False
|
||||
|
||||
results['duration'] = format_duration(time.time() - start_time)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rolling_restart(group: str,
|
||||
service_name: str,
|
||||
wait_between: int = 30) -> Dict:
|
||||
"""
|
||||
Zero-downtime rolling restart of a service across group.
|
||||
|
||||
Args:
|
||||
group: Group to restart
|
||||
service_name: Service name (e.g., "nginx", "app")
|
||||
wait_between: Seconds to wait between restarts
|
||||
|
||||
Returns:
|
||||
Dict with restart results
|
||||
|
||||
Example:
|
||||
>>> result = rolling_restart("web-servers", "nginx")
|
||||
>>> result['restarted_count']
|
||||
3
|
||||
"""
|
||||
from utils.helpers import parse_sshsync_config
|
||||
|
||||
start_time = time.time()
|
||||
groups_config = parse_sshsync_config()
|
||||
hosts = groups_config.get(group, [])
|
||||
|
||||
if not hosts:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'Group {group} not found or empty'
|
||||
}
|
||||
|
||||
results = {
|
||||
'hosts': {},
|
||||
'restarted_count': 0,
|
||||
'failed_count': 0,
|
||||
'success': True
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
logger.info(f"Restarting {service_name} on {host}...")
|
||||
|
||||
# Restart service
|
||||
restart_result = execute_on_host(
|
||||
host,
|
||||
f"sudo systemctl restart {service_name} || sudo service {service_name} restart",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Health check
|
||||
time.sleep(5) # Wait for service to start
|
||||
|
||||
health_result = execute_on_host(
|
||||
host,
|
||||
f"sudo systemctl is-active {service_name} || sudo service {service_name} status",
|
||||
timeout=10
|
||||
)
|
||||
|
||||
success = restart_result.get('success') and health_result.get('success')
|
||||
|
||||
results['hosts'][host] = {
|
||||
'restart': restart_result,
|
||||
'health': health_result,
|
||||
'success': success
|
||||
}
|
||||
|
||||
if success:
|
||||
results['restarted_count'] += 1
|
||||
logger.info(f"✓ {host} restarted successfully")
|
||||
else:
|
||||
results['failed_count'] += 1
|
||||
results['success'] = False
|
||||
logger.error(f"✗ {host} restart failed")
|
||||
|
||||
# Wait before next restart (except last)
|
||||
if host != hosts[-1]:
|
||||
time.sleep(wait_between)
|
||||
|
||||
results['duration'] = format_duration(time.time() - start_time)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def health_check_workflow(group: str,
|
||||
endpoint: str = "/health",
|
||||
timeout: int = 10) -> Dict:
|
||||
"""
|
||||
Check health endpoint across group.
|
||||
|
||||
Args:
|
||||
group: Group to check
|
||||
endpoint: Health endpoint path
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
Dict with health check results
|
||||
|
||||
Example:
|
||||
>>> result = health_check_workflow("production", "/health")
|
||||
>>> result['healthy_count']
|
||||
3
|
||||
"""
|
||||
from utils.helpers import parse_sshsync_config
|
||||
|
||||
groups_config = parse_sshsync_config()
|
||||
hosts = groups_config.get(group, [])
|
||||
|
||||
if not hosts:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'Group {group} not found or empty'
|
||||
}
|
||||
|
||||
results = {
|
||||
'hosts': {},
|
||||
'healthy_count': 0,
|
||||
'unhealthy_count': 0
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
health_result = execute_on_host(
|
||||
host,
|
||||
f"curl -f -s -o /dev/null -w '%{{http_code}}' http://localhost:3000{endpoint}",
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
is_healthy = (
|
||||
health_result.get('success') and
|
||||
'200' in health_result.get('stdout', '')
|
||||
)
|
||||
|
||||
results['hosts'][host] = {
|
||||
'healthy': is_healthy,
|
||||
'response': health_result.get('stdout', '').strip()
|
||||
}
|
||||
|
||||
if is_healthy:
|
||||
results['healthy_count'] += 1
|
||||
else:
|
||||
results['unhealthy_count'] += 1
|
||||
|
||||
results['success'] = results['unhealthy_count'] == 0
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Test workflow executor functions."""
|
||||
print("Testing workflow executor...\n")
|
||||
|
||||
print("Note: Workflow executor requires configured hosts and groups.")
|
||||
print("Tests would execute real operations, so showing dry-run simulations.\n")
|
||||
|
||||
print("✅ Workflow executor ready")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user