173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
"""State management for ADW composable architecture.
|
|
|
|
Provides persistent state management via file storage and
|
|
transient state passing between scripts via stdin/stdout.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from adw_modules.data_types import ADWStateData
|
|
|
|
|
|
class ADWState:
|
|
"""Container for ADW workflow state with file persistence."""
|
|
|
|
STATE_FILENAME = "adw_state.json"
|
|
|
|
def __init__(self, adw_id: str):
|
|
"""Initialize ADWState with a required ADW ID.
|
|
|
|
Args:
|
|
adw_id: The ADW ID for this state (required)
|
|
"""
|
|
if not adw_id:
|
|
raise ValueError("adw_id is required for ADWState")
|
|
|
|
self.adw_id = adw_id
|
|
# Start with minimal state
|
|
self.data: Dict[str, Any] = {"adw_id": self.adw_id}
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def update(self, **kwargs):
|
|
"""Update state with new key-value pairs."""
|
|
# Filter to only our core fields
|
|
core_fields = {"adw_id", "issue_number", "branch_name", "plan_file", "issue_class", "worktree_path", "backend_port", "frontend_port", "model_set", "all_adws"}
|
|
for key, value in kwargs.items():
|
|
if key in core_fields:
|
|
self.data[key] = value
|
|
|
|
def get(self, key: str, default=None):
|
|
"""Get value from state by key."""
|
|
return self.data.get(key, default)
|
|
|
|
def append_adw_id(self, adw_id: str):
|
|
"""Append an ADW ID to the all_adws list if not already present."""
|
|
all_adws = self.data.get("all_adws", [])
|
|
if adw_id not in all_adws:
|
|
all_adws.append(adw_id)
|
|
self.data["all_adws"] = all_adws
|
|
|
|
def get_working_directory(self) -> str:
|
|
"""Get the working directory for this ADW instance.
|
|
|
|
Returns worktree_path if set (for isolated workflows),
|
|
otherwise returns the main repo path.
|
|
"""
|
|
worktree_path = self.data.get("worktree_path")
|
|
if worktree_path:
|
|
return worktree_path
|
|
|
|
# Return main repo path (parent of adws directory)
|
|
return os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
)
|
|
|
|
def get_state_path(self) -> str:
|
|
"""Get path to state file."""
|
|
project_root = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
)
|
|
return os.path.join(project_root, "agents", self.adw_id, self.STATE_FILENAME)
|
|
|
|
def save(self, workflow_step: Optional[str] = None) -> None:
|
|
"""Save state to file in agents/{adw_id}/adw_state.json."""
|
|
state_path = self.get_state_path()
|
|
os.makedirs(os.path.dirname(state_path), exist_ok=True)
|
|
|
|
# Create ADWStateData for validation
|
|
state_data = ADWStateData(
|
|
adw_id=self.data.get("adw_id"),
|
|
issue_number=self.data.get("issue_number"),
|
|
branch_name=self.data.get("branch_name"),
|
|
plan_file=self.data.get("plan_file"),
|
|
issue_class=self.data.get("issue_class"),
|
|
worktree_path=self.data.get("worktree_path"),
|
|
backend_port=self.data.get("backend_port"),
|
|
frontend_port=self.data.get("frontend_port"),
|
|
model_set=self.data.get("model_set", "base"),
|
|
all_adws=self.data.get("all_adws", []),
|
|
)
|
|
|
|
# Save as JSON
|
|
with open(state_path, "w") as f:
|
|
json.dump(state_data.model_dump(), f, indent=2)
|
|
|
|
self.logger.info(f"Saved state to {state_path}")
|
|
if workflow_step:
|
|
self.logger.info(f"State updated by: {workflow_step}")
|
|
|
|
@classmethod
|
|
def load(
|
|
cls, adw_id: str, logger: Optional[logging.Logger] = None
|
|
) -> Optional["ADWState"]:
|
|
"""Load state from file if it exists."""
|
|
project_root = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
)
|
|
state_path = os.path.join(project_root, "agents", adw_id, cls.STATE_FILENAME)
|
|
|
|
if not os.path.exists(state_path):
|
|
return None
|
|
|
|
try:
|
|
with open(state_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
# Validate with ADWStateData
|
|
state_data = ADWStateData(**data)
|
|
|
|
# Create ADWState instance
|
|
state = cls(state_data.adw_id)
|
|
state.data = state_data.model_dump()
|
|
|
|
if logger:
|
|
logger.info(f"🔍 Found existing state from {state_path}")
|
|
logger.info(f"State: {json.dumps(state_data.model_dump(), indent=2)}")
|
|
|
|
return state
|
|
except Exception as e:
|
|
if logger:
|
|
logger.error(f"Failed to load state from {state_path}: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def from_stdin(cls) -> Optional["ADWState"]:
|
|
"""Read state from stdin if available (for piped input).
|
|
|
|
Returns None if no piped input is available (stdin is a tty).
|
|
"""
|
|
if sys.stdin.isatty():
|
|
return None
|
|
try:
|
|
input_data = sys.stdin.read()
|
|
if not input_data.strip():
|
|
return None
|
|
data = json.loads(input_data)
|
|
adw_id = data.get("adw_id")
|
|
if not adw_id:
|
|
return None # No valid state without adw_id
|
|
state = cls(adw_id)
|
|
state.data = data
|
|
return state
|
|
except (json.JSONDecodeError, EOFError):
|
|
return None
|
|
|
|
def to_stdout(self):
|
|
"""Write state to stdout as JSON (for piping to next script)."""
|
|
# Only output core fields
|
|
output_data = {
|
|
"adw_id": self.data.get("adw_id"),
|
|
"issue_number": self.data.get("issue_number"),
|
|
"branch_name": self.data.get("branch_name"),
|
|
"plan_file": self.data.get("plan_file"),
|
|
"issue_class": self.data.get("issue_class"),
|
|
"worktree_path": self.data.get("worktree_path"),
|
|
"backend_port": self.data.get("backend_port"),
|
|
"frontend_port": self.data.get("frontend_port"),
|
|
"all_adws": self.data.get("all_adws", []),
|
|
}
|
|
print(json.dumps(output_data, indent=2))
|