import asyncio
import time
import logging

logger = logging.getLogger(__name__)

class SubAgent:
    """
    A stateful watcher for a specific task on an agent node.
    Handles execution, result accumulation, and state monitoring.
    """
    def __init__(self, name: str, task_fn, args: dict, retries: int = 1):
        self.name = name
        self.task_fn = task_fn
        self.args = args
        self.retries = retries
        self.status = "PENDING"
        self.result = None
        self.start_time = None
        self.end_time = None
        self.error = None
        self.task_id = None

    async def run(self):
        self.start_time = time.time()
        self.status = "RUNNING"
        
        for attempt in range(self.retries + 1):
            try:
                # Execute the blocking assistant method (which uses TaskJournal/Event) 
                # in a worker thread to keep the async loop free.
                self.result = await asyncio.to_thread(self.task_fn, **self.args)
                
                # Basic error detection for retries (e.g. Node Offline or Timeout)
                if isinstance(self.result, dict) and self.result.get("error"):
                    err_msg = str(self.result.get("error")).lower()
                    # Only retry on potentially transient network/node issues
                    if any(x in err_msg for x in ["timeout", "offline", "disconnected"]):
                        if attempt < self.retries:
                            self.status = f"RETRYING ({attempt+1}/{self.retries})"
                            logger.info(f"[SubAgent] {self.name} retrying due to: {err_msg}")
                            await asyncio.sleep(2)
                            continue
                
                self.status = "COMPLETED"
                break
            except Exception as e:
                logger.error(f"SubAgent {self.name} execution error: {e}")
                self.error = str(e)
                if attempt < self.retries:
                    self.status = f"ERROR_RETRYING ({attempt+1}/{self.retries})"
                    await asyncio.sleep(2)
                else:
                    self.status = "FAILED"
        
        self.end_time = time.time()
        return self.result

    def get_elapsed(self) -> int:
        if not self.start_time:
            return 0
        end = self.end_time or time.time()
        return int(end - self.start_time)
