import os
import logging
import grpc
import json
from typing import Optional, Dict, Any
from app.protos import browser_pb2, browser_pb2_grpc

logger = logging.getLogger(__name__)

class BrowserServiceClient:
    """
    Client for the dedicated Browser Service.
    Uses gRPC for control and Sidecar Handoff via shared volumes (/dev/shm) for large data.
    """

    def __init__(self, endpoint: str = "browser-service:50052"):
        # Support both TCP and Unix Sockets
        if endpoint.startswith("unix:"):
            self.channel = grpc.aio.insecure_channel(endpoint)
        else:
            self.channel = grpc.aio.insecure_channel(endpoint)
        
        self.stub = browser_pb2_grpc.BrowserServiceStub(self.channel)
        # Shared memory path for sidecar handoff
        self.shm_base = os.getenv("BROWSER_SHM_PATH", "/dev/shm/cortex_browser")

    def _read_shm(self, path: str) -> Optional[str]:
        """Reads data from shared memory if a path is provided."""
        if not path:
            return None
        
        # Security: Ensure we only read from the allowed prefix
        if not path.startswith(self.shm_base):
            logger.warning(f"BLOCKED: Attempted to read browser data outside SHM: {path}")
            return None

        try:
            if not os.path.exists(path):
                logger.error(f"SHM file missing: {path}")
                return None
            
            with open(path, "r", encoding="utf-8") as f:
                return f.read()
        except Exception as e:
            logger.error(f"Failed to read browser SHM data: {e}")
            return None

    def _read_shm_bytes(self, path: str) -> Optional[bytes]:
        """Reads binary data from shared memory."""
        if not path:
            return None
        
        if not path.startswith(self.shm_base):
            return None

        try:
            if not os.path.exists(path):
                return None
            
            with open(path, "rb") as f:
                return f.read()
        except Exception as e:
            logger.error(f"Failed to read browser SHM binary data: {e}")
            return None

    def _summarize_a11y(self, a11y_json: str) -> str:
        """Creates a compact, readable summary of the accessibility tree."""
        try:
            nodes = json.loads(a11y_json)
            summary = []
            for node in nodes:
                ref = node.get("ref", "??")
                role = node.get("role", "unknown")
                name = node.get("name", "").strip()
                if name:
                    summary.append(f"- {role} \"{name}\" [ref={ref}]")
                else:
                    summary.append(f"- {role} [ref={ref}]")
            
            if not summary:
                return "No interactive or landmark elements found."
            
            # Limit summary size to avoid blowing context
            if len(summary) > 150:
                truncated = summary[:150]
                truncated.append(f"... and {len(summary) - 150} more elements. Use 'snapshot' again after scrolling or targeted extraction.")
                return "\n".join(truncated)
            
            return "\n".join(summary)
        except Exception as e:
            return f"Error summarizing a11y tree: {e}"

    def _process_response(self, resp: browser_pb2.BrowserResponse) -> Dict[str, Any]:
        """Convert gRPC response to a standard dictionary, including SHM hydration."""
        result = {
            "success": resp.status == "success",
            "session_id": resp.session_id,
            "url": resp.url,
            "title": resp.title,
            "error": resp.error_message,
            "eval_result": resp.eval_result
        }

        # Hydrate DOM
        if resp.dom_path:
            dom_content = self._read_shm(resp.dom_path)
            if dom_content:
                # Truncate DOM if excessively large to prevent LLM overflow
                if len(dom_content) > 30000:
                    result["dom"] = dom_content[:15000] + "\n... (DOM TRUNCATED) ...\n" + dom_content[-5000:]
                else:
                    result["dom"] = dom_content

        # Hydrate A11y
        if resp.a11y_path:
            a11y_content = self._read_shm(resp.a11y_path)
            if a11y_content:
                result["a11y_summary"] = self._summarize_a11y(a11y_content)
                # We still provide raw a11y if it's small, or just the summary
                if len(a11y_content) < 10000:
                    result["a11y_raw"] = a11y_content
        
        # Add internal/hidden technical metadata for post-processing but keep it away from direct AI observation if possible
        # Actually, we keep screenshot_path for the SubAgent to handle extraction
        if resp.screenshot_path:
            result["_screenshot_path"] = resp.screenshot_path
            
        return result

    async def _report_status(self, content: str, on_event: Any):
        """Helper to stream thoughts back to the UI."""
        if on_event:
            await on_event({
                "type": "subagent_thought",
                "node_id": "browser",
                "content": content
            })

    async def navigate(self, url: str, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🌐 Navigating to `{url}`...", on_event)
        try:
            req = browser_pb2.NavigateRequest(url=url, session_id=session_id)
            resp = await self.stub.Navigate(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def click(self, selector: str, session_id: str = "default", x: int = 0, y: int = 0, on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🖱️ Clicking on `{selector}`...", on_event)
        try:
            req = browser_pb2.ClickRequest(selector=selector, session_id=session_id, x=x, y=y)
            resp = await self.stub.Click(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def type(self, text: str, selector: str = "", session_id: str = "default", press_enter: bool = True, on_event: Any = None) -> Dict[str, Any]:
        target = f" on `{selector}`" if selector else ""
        await self._report_status(f"⌨️ Typing `{text}`{target}...", on_event)
        try:
            req = browser_pb2.TypeRequest(selector=selector, text=text, session_id=session_id, press_enter=press_enter)
            resp = await self.stub.Type(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def get_snapshot(self, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"📸 Capturing page snapshot...", on_event)
        try:
            req = browser_pb2.SnapshotRequest(session_id=session_id, include_dom=True, include_a11y=True, include_screenshot=True)
            resp = await self.stub.GetSnapshot(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def scroll(self, delta_x: int = 0, delta_y: int = 0, selector: str = "", session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"📜 Scrolling...", on_event)
        try:
            req = browser_pb2.ScrollRequest(delta_x=delta_x, delta_y=delta_y, selector=selector, session_id=session_id)
            resp = await self.stub.Scroll(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def eval(self, script: str, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🧪 Evaluating script...", on_event)
        try:
            req = browser_pb2.EvalRequest(script=script, session_id=session_id)
            resp = await self.stub.Evaluate(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def hover(self, selector: str, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🖱️ Hovering on `{selector}`...", on_event)
        try:
            req = browser_pb2.HoverRequest(selector=selector, session_id=session_id)
            resp = await self.stub.Hover(req)
            result = self._process_response(resp)
            if resp.screenshot_path and on_event:
                result["_screenshot_bytes"] = self._read_shm_bytes(resp.screenshot_path)
            return result
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def close(self, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🛑 Closing browser session...", on_event)
        try:
            req = browser_pb2.CloseRequest(session_id=session_id)
            resp = await self.stub.CloseSession(req)
            return {"success": resp.success, "session_id": session_id}
        except Exception as e:
            return {"success": False, "error": str(e)}

    async def parallel_fetch(self, urls: list, session_id: str = "default", max_concurrent: int = 5, extract_markdown: bool = True, on_event: Any = None) -> Dict[str, Any]:
        await self._report_status(f"🚀 Dispatching {len(urls)} research tasks to browser worker pool...", on_event)
        try:
            req = browser_pb2.ParallelFetchRequest(
                urls=urls, 
                session_id=session_id, 
                max_concurrent=max_concurrent, 
                extract_markdown=extract_markdown
            )
            resp = await self.stub.ParallelFetch(req)
            
            results = []
            for r in resp.results:
                results.append({
                    "url": r.url,
                    "title": r.title,
                    "content": r.content_markdown,
                    "success": r.success,
                    "error": r.error
                })
            
            return {"success": True, "results": results}
        except Exception as e:
            return {"success": False, "error": str(e)}

    # Alias for AI compatibility
    async def screenshot(self, session_id: str = "default", on_event: Any = None) -> Dict[str, Any]:
        return await self.get_snapshot(session_id=session_id, on_event=on_event)
