import asyncio
import logging
import json
import sys
from typing import List, Dict, Any, AsyncGenerator

class ToolExecutor:
    """Handles parallel tool dispatching and event drainage."""
    
    def __init__(self, tool_service: Any, user_id: str, db: Any, sync_workspace_id: str):
        self.tool_service = tool_service
        self.user_id = user_id
        self.db = db
        self.sync_workspace_id = sync_workspace_id
        self.event_queue = asyncio.Queue()

    async def _subagent_event_handler(self, event):
        await self.event_queue.put(event)

    async def run_tools(self, tool_calls: List[Any], safety_guard: Any, mesh_bridge: Any) -> AsyncGenerator[Dict[str, Any], None]:
        """Dispatches and monitors tools until all are complete or cancelled."""
        tool_tasks = []
        for tc in tool_calls:
            func_name = tc.function.name
            func_args = self._parse_args(tc)
            
            # Surface tool call for UI
            for ev in self._yield_tool_details(func_name, func_args):
                yield ev
            yield {"type": "tool_start", "name": func_name, "args": func_args}

            # Create async task
            task = asyncio.create_task(
                self.tool_service.call_tool(
                    func_name, 
                    func_args, 
                    db=self.db, 
                    user_id=self.user_id, 
                    session_id=self.sync_workspace_id,
                    on_event=self._subagent_event_handler
                )
            )
            tool_tasks.append((tc, task))

        # --- Wait & Monitor loop ---
        while True:
            all_done = all(item[1].done() for item in tool_tasks)

            # Drain UI events (Thoughts, Mesh Observations)
            while not self.event_queue.empty():
                ev = await self.event_queue.get()
                if ev["type"] == "subagent_thought":
                    yield {
                        "type": "reasoning",
                        "content": f"\n\n> **🧠 Sub-Agent [{ev.get('node_id', 'Swarm')}]:** {ev.get('content')}\n\n"
                    }

            if mesh_bridge:
                while not mesh_bridge.empty():
                    try:
                        ev = mesh_bridge.get_nowait()
                        if ev["event"] == "mesh_observation":
                            yield {
                                "type": "reasoning",
                                "content": f"\n\n> **📡 Mesh Observation:** {ev.get('data', {}).get('message', 'Unspecified drift observed.')}\n\n"
                            }
                    except: break

            # Cancellation check
            if safety_guard.check_cancellation():
                yield {"type": "status", "content": "Cancellation requested. Interrupting..."}
                return # Should handle graceful task cancellation in future

            if all_done:
                break
            await asyncio.sleep(0.1)

        # Yield results for AI history
        for tc, task in tool_tasks:
            func_name = tc.function.name
            try:
                result = await task
            except Exception as e:
                result = {"success": False, "error": f"Tool crashed: {str(e)}"}
            
            yield {"type": "tool_result", "name": func_name, "result": result}
            yield {
                "role": "tool", 
                "tool_call_id": tc.id, 
                "name": func_name, 
                "content": self._truncate_result(result)
            }

    def _parse_args(self, tc) -> Dict[str, Any]:
        try:
            args = json.loads(tc.function.arguments)
        except:
            args = {}
        # Parallel PTY Optimization preserved
        if tc.function.name == "mesh_terminal_control" and "session_id" not in args:
            args["session_id"] = f"subagent-{tc.id[:8]}"
        return args

    def _yield_tool_details(self, name, args):
        lines = [f"🔧 **Tool Call: `{name}`**"]
        if args.get("command"): lines.append(f"- Command: `{args['command']}`")
        if args.get("node_id"): lines.append(f"- Node: `{args['node_id']}`")
        if args.get("node_ids"): lines.append(f"- Nodes: `{', '.join(args['node_ids'])}`")
        yield {"type": "reasoning", "content": "\n" + "\n".join(lines) + "\n"}

    def _truncate_result(self, result: Any) -> str:
        s = json.dumps(result) if isinstance(result, dict) else str(result)
        limit = 8000
        if len(s) > limit:
            return s[:limit] + f"\n...[truncated {len(s)-limit} chars]"
        return s
