import asyncio
import logging
import json
import sys
from typing import List, Dict, Any, AsyncGenerator, Optional

class ToolExecutor:
    """Handles parallel tool dispatching and event drainage."""
    
    def __init__(self, tool_service: Any, user_id: str, db: Any, sync_workspace_id: str, session_db_id: int, provider_name: Optional[str] = None):
        self.tool_service = tool_service
        self.user_id = user_id
        self.db = db
        self.sync_workspace_id = sync_workspace_id
        self.session_db_id = session_db_id
        self.provider_name = provider_name
        self.event_queue = asyncio.Queue()

    async def _subagent_event_handler(self, event):
        await self.event_queue.put(event)

    async def run_tools(self, tool_calls: List[Any], safety_guard: Any, mesh_bridge: Any) -> AsyncGenerator[Dict[str, Any], None]:
        """Dispatches and monitors tools until all are complete or cancelled."""
        tool_tasks = []
        for tc in tool_calls:
            name, args = tc.function.name, self._parse_args(tc)
            
            async for ev in self._prepare_and_start_tool(tc, name, args):
                yield ev
                
            task = asyncio.create_task(self.tool_service.call_tool(
                name, args, db=self.db, user_id=self.user_id, session_id=self.sync_workspace_id,
                session_db_id=self.session_db_id, on_event=self._subagent_event_handler, provider_name=self.provider_name
            ))
            tool_tasks.append((tc, task))

        # --- Wait & Monitor loop ---
        _cycles = 0
        while not all(t[1].done() for t in tool_tasks):
            async for ev in self._drain_events(mesh_bridge): yield ev
            
            if safety_guard.check_cancellation():
                yield {"type": "status", "content": "Cancellation requested. Interrupting..."}
                return 

            await asyncio.sleep(0.1)
            _cycles += 1
            if _cycles % 100 == 0: yield {"type": "status", "content": "Processing task..."}

        # --- Finalize Results ---
        async for ev in self._finalize_tool_results(tool_tasks):
            yield ev

    async def _prepare_and_start_tool(self, tc, name, args):
        """Yields visual UI details and initialization markers for a tool call."""
        lines = [f"🔧 **Tool Call: `{name}`**"]
        if args.get("command"): lines.append(f"- Command: `{args['command']}`")
        if args.get("node_id"): lines.append(f"- Node: `{args['node_id']}`")
        if args.get("node_ids"): lines.append(f"- Nodes: `{', '.join(args['node_ids'])}`")
        yield {"type": "reasoning", "content": "\n" + "\n".join(lines) + "\n"}
        yield {"type": "tool_start", "name": name, "args": args}

    async def _drain_events(self, mesh_bridge):
        """Drains subagent thought events and mesh observations for UI feedback."""
        while not self.event_queue.empty():
            ev = await self.event_queue.get()
            if ev["type"] == "subagent_thought":
                yield {"type": "reasoning", "content": f"\n\n> **🧠 Sub-Agent [{ev.get('node_id', 'Swarm').capitalize()}]:** {ev.get('content')}\n"}

        if mesh_bridge:
            while not mesh_bridge.empty():
                try:
                    ev = mesh_bridge.get_nowait()
                    if ev["event"] == "mesh_observation":
                        yield {"type": "reasoning", "content": f"\n\n> **📡 Mesh Observation:** {ev.get('data', {}).get('message', 'Unspecified drift observed.')}\n\n"}
                except: break

    async def _finalize_tool_results(self, tool_tasks):
        """Awaits all tool tasks and yields results formatted for AI history and UI."""
        for tc, task in tool_tasks:
            name = tc.function.name
            try:
                result = await task
            except Exception as e:
                result = {"success": False, "error": f"Tool crashed: {str(e)}"}
            
            if result and (not isinstance(result, dict) or result.get("success") is False):
                err = result.get("error") if isinstance(result, dict) else "Unknown failure"
                yield {"type": "reasoning", "content": f"\n> **❌ Tool Error [{name}]:** {err}\n"}
            
            yield {"type": "tool_result", "name": name, "result": result}
            yield {"role": "tool", "tool_call_id": tc.id, "name": name, "content": self._truncate_result(result)}

    def _parse_args(self, tc) -> Dict[str, Any]:
        try:
            args = json.loads(tc.function.arguments)
        except:
            args = {}
        # Parallel PTY Optimization preserved
        if tc.function.name == "mesh_terminal_control" and "session_id" not in args:
            args["session_id"] = f"subagent-{tc.id[:8]}"
        return args


    def _truncate_result(self, result: Any) -> str:
        s = json.dumps(result) if isinstance(result, dict) else str(result)
        limit = 128000 # Increased for better RAG/context
        if len(s) > limit:
            return s[:limit] + f"\n...[SYSTEM: Output Truncated at {limit} chars for safety. Use specific filters or file explorers if more detail is needed.]"
        return s
