Newer
Older
cortex-hub / browser-service / src / api / servicer.py
import logging
import os
import uuid
import json
import asyncio
import traceback
from protos import browser_pb2, browser_pb2_grpc

from core.browser import BrowserManager
from core.resolver import SelectorResolver
from extraction.a11y import A11yProcessor
from utils.responses import ResponseBuilder

logger = logging.getLogger(__name__)

class BrowserServiceServicer(browser_pb2_grpc.BrowserServiceServicer):
    def __init__(self):
        self.shm_base = os.getenv("SHM_PATH", "/dev/shm/cortex_browser")
        os.makedirs(self.shm_base, exist_ok=True)
        self.browser = BrowserManager()
        self.responses = ResponseBuilder(self.shm_base)
        self.a11y_maps = {} # session_id -> {ref -> node_data}

    async def init_playwright(self):
        await self.browser.init()

    async def Click(self, request, context):
        session_id = request.session_id
        logger.info(f"Clicking in session {session_id}")
        page = None
        try:
            page = await self.browser.get_page(session_id)
            if request.selector:
                resolver = SelectorResolver(page, session_id, self.a11y_maps)
                locator = await resolver.resolve(request.selector)
                if locator == page.mouse:
                    ref_map = self.a11y_maps.get(session_id, {})
                    node = ref_map.get(request.selector)
                    rect = node.get("rect")
                    if rect:
                        await page.mouse.click(rect["x"] + rect["width"]/2, rect["y"] + rect["height"]/2)
                else:
                    try:
                        await locator.scroll_into_view_if_needed(timeout=5000)
                        await locator.click(timeout=8000)
                    except Exception as ce:
                        logger.warning(f"Standard click failed ({ce}), attempting force-center click...")
                        await locator.click(force=True, timeout=5000, position={"x": 5, "y": 5})
            else:
                await page.mouse.click(request.x, request.y)
            return await self.responses.build(page, session_id)
        except Exception as e:
            msg = self.responses.error_to_ai_message(e, request.selector)
            logger.warning(f"Click failed: {msg}")
            if page:
                return await self.responses.build(page, session_id, status="error", error_message=msg)
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=msg)

    async def Type(self, request, context):
        session_id = request.session_id
        logger.info(f"Typing in session {session_id}")
        page = None
        try:
            page = await self.browser.get_page(session_id)
            if request.selector:
                resolver = SelectorResolver(page, session_id, self.a11y_maps)
                locator = await resolver.resolve(request.selector)
                
                if locator == page.mouse:
                    ref_map = self.a11y_maps.get(session_id, {})
                    node = ref_map.get(request.selector)
                    rect = node["rect"]
                    await page.mouse.click(rect["x"] + rect["width"]/2, rect["y"] + rect["height"]/2)
                    await page.keyboard.type(request.text, delay=50)
                else:
                    try:
                        await locator.click(timeout=5000, force=True)
                        await locator.fill(request.text, timeout=10000)
                    except Exception as fe:
                        logger.info(f"Fill failed ({fe}), trying inner input search...")
                        inner_target = locator.locator("input, textarea, [role='textbox'], [role='searchbox'], [contenteditable='true']").first
                        if await inner_target.count() > 0:
                            await inner_target.click(timeout=3000, force=True)
                            await inner_target.fill(request.text, timeout=5000)
                        else:
                            await page.keyboard.type(request.text, delay=30)
            else:
                await page.keyboard.type(request.text, delay=30)
            
            if request.press_enter:
                await asyncio.sleep(0.5)
                await page.keyboard.press("Enter")
                
            return await self.responses.build(page, session_id)
        except Exception as e:
            msg = self.responses.error_to_ai_message(e, request.selector)
            logger.warning(f"Type failed: {msg}")
            if page:
                return await self.responses.build(page, session_id, status="error", error_message=msg)
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=msg)

    async def Scroll(self, request, context):
        session_id = request.session_id
        page = None
        try:
            page = await self.browser.get_page(session_id)
            if request.selector:
                await page.locator(request.selector).evaluate(f"el => el.scrollBy({request.delta_x}, {request.delta_y})")
            else:
                await page.mouse.wheel(request.delta_x, request.delta_y)
            return await self.responses.build(page, session_id)
        except Exception as e:
            if page: return await self.responses.build(page, session_id, status="error", error_message=str(e))
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=str(e))

    async def Evaluate(self, request, context):
        session_id = request.session_id
        page = None
        try:
            page = await self.browser.get_page(session_id)
            result = await page.evaluate(request.script)
            return await self.responses.build(page, session_id, eval_result=str(result))
        except Exception as e:
            if page: return await self.responses.build(page, session_id, status="error", error_message=str(e))
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=str(e))

    async def _capture_snapshot(self, page, session_id, include_dom=False, include_a11y=True, include_screenshot=True):
        """Internal helper to capture page state into SHM."""
        resp_kwargs = {}
        
        # 1. DOM
        if include_dom:
            try:
                dom_content = await page.content()
                dom_file = os.path.join(self.shm_base, f"{session_id}_dom.html")
                with open(dom_file, "w", encoding="utf-8") as f:
                    f.write(dom_content)
                resp_kwargs["dom_path"] = dom_file
            except Exception as de:
                logger.warning(f"DOM content capture failed: {de}")

        # 2. A11y Tree
        if include_a11y:
            try:
                processor = A11yProcessor(page, session_id)
                # This now automatically handles main frame + iframes
                flat_a11y = await processor.get_all_elements()
                
                # Fallback to Native if JS walking failed completely (rare)
                if not flat_a11y:
                    tracker = {}
                    if hasattr(page, "accessibility"):
                        try:
                            a11y_tree = await page.accessibility.snapshot()
                            if a11y_tree:
                                processor.flatten_tree(a11y_tree, flat_a11y, tracker)
                        except: pass
                    
                    if not flat_a11y:
                        flat_a11y = await processor.get_cdp_tree(tracker)

                if flat_a11y:
                    self.a11y_maps[session_id] = {node["ref"]: node for node in flat_a11y}
                    a11y_file = os.path.join(self.shm_base, f"{session_id}_a11y.json")
                    with open(a11y_file, "w", encoding="utf-8") as f:
                        json.dump(flat_a11y, f, indent=2)
                    resp_kwargs["a11y_path"] = a11y_file
            except Exception as ae:
                logger.warning(f"Element discovery failed: {ae}")

        # 3. Screenshot
        if include_screenshot:
            try:
                screenshot_file = os.path.join(self.shm_base, f"{session_id}_screen.png")
                await page.screenshot(path=screenshot_file, full_page=False)
                resp_kwargs["screenshot_path"] = screenshot_file
            except: pass
            
        return resp_kwargs

    async def Navigate(self, request, context):
        session_id = request.session_id or str(uuid.uuid4())
        page = None
        try:
            page = await self.browser.get_page(session_id)
            try:
                await page.goto(request.url, wait_until="domcontentloaded", timeout=25000)
            except Exception as te:
                logger.warning(f"Navigation timed out: {te}")
            
            await asyncio.sleep(2.0)
            # PROACTIVE: Auto-capture state for immediate agent context
            snap_kwargs = await self._capture_snapshot(page, session_id, include_a11y=True, include_screenshot=True)
            return await self.responses.build(page, session_id, **snap_kwargs)
        except Exception as e:
            if page: return await self.responses.build(page, session_id, status="error", error_message=str(e))
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=str(e))

    async def Hover(self, request, context):
        session_id = request.session_id
        page = None
        try:
            page = await self.browser.get_page(session_id)
            resolver = SelectorResolver(page, session_id, self.a11y_maps)
            locator = await resolver.resolve(request.selector)
            await locator.hover(timeout=30000)
            return await self.responses.build(page, session_id)
        except Exception as e:
            if page: return await self.responses.build(page, session_id, status="error", error_message=str(e))
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=str(e))

    async def GetSnapshot(self, request, context):
        session_id = request.session_id
        try:
            page = await self.browser.get_page(session_id)
            if not page:
                 return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message="No page found")
            
            await asyncio.sleep(0.5)
            snap_kwargs = await self._capture_snapshot(
                page, session_id, 
                include_dom=request.include_dom, 
                include_a11y=request.include_a11y,
                include_screenshot=request.include_screenshot
            )
            return await self.responses.build(page, session_id, **snap_kwargs)
        except Exception as e:
            if page: return await self.responses.build(page, session_id, status="error", error_message=str(e))
            return browser_pb2.BrowserResponse(session_id=session_id, status="error", error_message=str(e))

    async def CloseSession(self, request, context):
        await self.browser.close_session(request.session_id)
        return browser_pb2.CloseResponse(success=True)

    async def ParallelFetch(self, request, context):
        urls = list(request.urls)
        max_concurrent = request.max_concurrent or 5
        extract_markdown = request.extract_markdown
        
        logger.info(f"Parallel fetching {len(urls)} URLs (max_concurrent={max_concurrent})")
        
        results = await self.browser.parallel_fetch(
            urls, 
            max_concurrent=max_concurrent, 
            extract_markdown=extract_markdown
        )
        
        proto_results = []
        for r in results:
            proto_results.append(browser_pb2.ParallelFetchResponse.FetchResult(
                url=r["url"],
                title=r.get("title", ""),
                content_markdown=r.get("content_markdown", ""),
                success=r["success"],
                error=r.get("error", "")
            ))
            
        return browser_pb2.ParallelFetchResponse(results=proto_results)