import threading
import queue
import time
import json
import re
from playwright.sync_api import sync_playwright
from agent_node.skills.base import BaseSkill
from protos import agent_pb2

# ============================================================
#  Role-Ref Registry
#  Inspired by Openclaw's pw-role-snapshot.ts
#  Maps `ref=eN` shorthand -> (role, name, nth) for every
#  interactive / content element on the last snapshotted page.
# ============================================================

INTERACTIVE_ROLES = {
    "button", "link", "textbox", "checkbox", "radio", "combobox",
    "listbox", "menuitem", "menuitemcheckbox", "menuitemradio",
    "option", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem",
}
CONTENT_ROLES = {
    "heading", "cell", "gridcell", "columnheader", "rowheader",
    "listitem", "article", "region", "main", "navigation",
}
STRUCTURAL_ROLES = {
    "generic", "group", "list", "table", "row", "rowgroup", "grid",
    "treegrid", "menu", "menubar", "toolbar", "tablist", "tree",
    "directory", "document", "application", "presentation", "none",
}


def _build_aria_snapshot(aria_text: str) -> tuple[str, dict]:
    """
    Parse Playwright's ariaSnapshot() output and annotate interactive/content
    elements with stable [ref=eN] labels that the AI can refer back to.
    Returns (annotated_snapshot, ref_map).
    """
    lines = aria_text.split("\n")
    refs = {}
    counter = [0]
    role_counts = {}     # (role, name) -> count (for nth disambiguation)
    output_lines = []

    def next_ref():
        counter[0] += 1
        return f"e{counter[0]}"

    for line in lines:
        m = re.match(r'^(\s*-\s*)(\w+)(?:\s+"([^"]*)")?(.*)$', line)
        if not m:
            output_lines.append(line)
            continue

        prefix, role_raw, name, suffix = m.group(1), m.group(2), m.group(3), m.group(4)
        role = role_raw.lower()

        is_interactive = role in INTERACTIVE_ROLES
        is_content_with_name = role in CONTENT_ROLES and name

        if not (is_interactive or is_content_with_name):
            output_lines.append(line)
            continue

        # assign ref
        ref = next_ref()
        key = (role, name)
        nth = role_counts.get(key, 0)
        role_counts[key] = nth + 1

        refs[ref] = {"role": role, "name": name, "nth": nth if nth > 0 else None}

        enhanced = f"{prefix}{role_raw}"
        if name:
            enhanced += f' "{name}"'
        enhanced += f" [ref={ref}]"
        if nth > 0:
            enhanced += f" [nth={nth}]"
        if suffix:
            enhanced += suffix
        output_lines.append(enhanced)

    return "\n".join(output_lines), refs


def _resolve_ref(page, ref: str, role_refs: dict):
    """Resolve a [ref=eN] string to a Playwright Locator."""
    info = role_refs.get(ref)
    if not info:
        raise ValueError(f"Unknown ref '{ref}'. Run aria_snapshot first and use a ref from that output.")
    role = info["role"]
    name = info.get("name")
    nth  = info.get("nth") or 0
    if name:
        loc = page.get_by_role(role, name=name, exact=True)
    else:
        loc = page.get_by_role(role)
    if nth:
        loc = loc.nth(nth)
    return loc


class BrowserSkill(BaseSkill):
    """
    Persistent Browser Skill — OpenClaw-inspired role-snapshot architecture.
    
    Key innovation over the prior version:
      - `aria_snapshot` action returns a compact semantic role tree with [ref=eN] labels.
      - All `click`, `type`, `hover` actions accept either a CSS/XPath selector OR a
        ref string like 'e3', enabling the AI to address elements without fragile selectors.
      - Page errors and console output are tracked per-session and included in results.
    """
    def __init__(self, sync_mgr=None):
        self.task_queue = queue.Queue()
        # session_id -> { "context", "page", "role_refs", "console", "errors", "download_dir" }
        self.sessions = {}
        self.sync_mgr = sync_mgr
        self.lock = threading.Lock()
        threading.Thread(target=self._browser_actor, daemon=True, name="BrowserActor").start()

    # ------------------------------------------------------------------
    # Session Management
    # ------------------------------------------------------------------

    def _get_or_create_session(self, browser, sid, task, on_event):
        """Return existing session dict or create a new one."""
        with self.lock:
            if sid in self.sessions:
                return self.sessions[sid]

            download_dir = None
            if self.sync_mgr and task.session_id:
                download_dir = self.sync_mgr.get_session_dir(task.session_id)
                print(f"    [🌐📁] Mapping Browser Context to: {download_dir}")

            ctx = browser.new_context(accept_downloads=True)
            page = ctx.new_page()

            sess = {
                "context": ctx,
                "page": page,
                "role_refs": {},   # ref -> {role, name, nth}
                "console": [],
                "errors": [],
                "download_dir": download_dir,
            }
            self.sessions[sid] = sess

            # Listeners
            self._attach_listeners(sid, page, on_event, sess)
            return sess

    def _attach_listeners(self, sid, page, on_event, sess):
        # Console log capture
        def _on_console(msg):
            entry = {"level": msg.type, "text": msg.text, "ts": int(time.time() * 1000)}
            sess["console"].append(entry)
            if len(sess["console"]) > 200:
                sess["console"].pop(0)
            if on_event:
                on_event(agent_pb2.BrowserEvent(
                    session_id=sid,
                    console_msg=agent_pb2.ConsoleMessage(
                        level=msg.type, text=msg.text, timestamp_ms=entry["ts"]
                    )
                ))

        def _on_page_error(err):
            sess["errors"].append({"message": str(err), "ts": int(time.time() * 1000)})
            if len(sess["errors"]) > 100:
                sess["errors"].pop(0)

        def _on_network(req):
            resp = req.response()
            if on_event:
                on_event(agent_pb2.BrowserEvent(
                    session_id=sid,
                    network_req=agent_pb2.NetworkRequest(
                        method=req.method, url=req.url,
                        status=resp.status if resp else 0,
                        resource_type=req.resource_type, latency_ms=0
                    )
                ))

        def _on_download(dl):
            import os
            with self.lock:
                s = self.sessions.get(sid)
                if s and s.get("download_dir"):
                    os.makedirs(s["download_dir"], exist_ok=True)
                    target = os.path.join(s["download_dir"], dl.suggested_filename)
                    print(f"    [🌐📥] Download: {dl.suggested_filename} -> {target}")
                    dl.save_as(target)

        page.on("console", _on_console)
        page.on("pageerror", _on_page_error)
        page.on("requestfinished", _on_network)
        page.on("download", _on_download)

    # ------------------------------------------------------------------
    # Browser Actor Loop
    # ------------------------------------------------------------------

    def _browser_actor(self):
        print("[🌐] Browser Actor Starting...", flush=True)
        pw = browser = None
        try:
            try:
                pw = sync_playwright().start()
            except Exception as pe:
                print(f"[!] Playwright failed to start: {pe}", flush=True)
                return

            try:
                browser = pw.chromium.launch(headless=True, args=[
                    '--no-sandbox', '--disable-setuid-sandbox',
                    '--disable-dev-shm-usage', '--disable-gpu'
                ])
                print("[🌐] Browser Engine Online.", flush=True)
            except Exception as be:
                print(f"[!] Chromium launch failed: {be}", flush=True)
                if pw: pw.stop()
                return
        except Exception as e:
            print(f"[!] Browser Actor critical failure: {e}", flush=True)
            if pw: pw.stop()
            return

        while True:
            try:
                item = self.task_queue.get()
                if item is None:
                    print("[🌐] Browser Actor Shutting Down...", flush=True)
                    break

                task, sandbox, on_complete, on_event = item
                action = task.browser_action
                sid = action.session_id or "default"
                action_name = agent_pb2.BrowserAction.ActionType.Name(action.action)
                print(f"    [🌐] {action_name} | Session: {sid}", flush=True)

                sess = self._get_or_create_session(browser, sid, task, on_event)
                page = sess["page"]

                res_data = {}
                try:
                    self._dispatch_action(action, page, sess, res_data)
                except Exception as e:
                    on_complete(task.task_id, {"stderr": str(e), "status": 2}, task.trace_id)
                    continue

                # Build BrowserResponse — include aria_snapshot result in eval_result
                br_res = agent_pb2.BrowserResponse(
                    url=page.url,
                    title=page.title(),
                    snapshot=res_data.get("snapshot", b""),
                    dom_content=res_data.get("dom_content", ""),
                    a11y_tree=res_data.get("a11y_tree", ""),
                    eval_result=res_data.get("eval_result", ""),
                )
                on_complete(task.task_id, {"status": 1, "browser_result": br_res}, task.trace_id)

            except Exception as e:
                print(f"    [!] Browser Actor Error: {e}", flush=True)
                try:
                    on_complete(task.task_id, {"stderr": str(e), "status": 2}, task.trace_id)
                except Exception:
                    pass

        # Cleanup
        print("[🌐] Cleaning up Browser Engine...", flush=True)
        with self.lock:
            for s in self.sessions.values():
                try: s["context"].close()
                except: pass
            self.sessions.clear()
        if browser: browser.close()
        if pw: pw.stop()

    # ------------------------------------------------------------------
    # Action Dispatcher
    # ------------------------------------------------------------------

    def _dispatch_action(self, action, page, sess, res_data):
        A = agent_pb2.BrowserAction
        role_refs = sess["role_refs"]

        def resolve(selector_or_ref: str):
            """Accept either a CSS selector or a ref like 'e3'."""
            s = (selector_or_ref or "").strip()
            if re.match(r'^e\d+$', s):
                return _resolve_ref(page, s, role_refs)
            return page.locator(s)

        if action.action == A.NAVIGATE:
            page.goto(action.url, wait_until="domcontentloaded", timeout=25000)
            # Auto-snapshot after every navigation: give AI page context immediately
            aria_raw = page.locator(":root").aria_snapshot()
            snap, refs = _build_aria_snapshot(aria_raw)
            sess["role_refs"] = refs
            # Trim to 8000 chars to avoid bloating the grpc response
            trimmed = snap[:8000] + ("\n\n[...snapshot truncated...]" if len(snap) > 8000 else "")
            stats = f"refs={len(refs)}"
            res_data["a11y_tree"] = trimmed
            res_data["eval_result"] = stats

        elif action.action == A.CLICK:
            target = action.selector or ""
            resolve(target).click(timeout=8000)

        elif action.action == A.TYPE:
            target = action.selector or ""
            resolve(target).fill(action.text, timeout=8000)

        elif action.action == A.SCREENSHOT:
            res_data["snapshot"] = page.screenshot(full_page=False)

        elif action.action == A.GET_DOM:
            res_data["dom_content"] = page.content()

        elif action.action == A.HOVER:
            target = action.selector or ""
            resolve(target).hover(timeout=5000)

        elif action.action == A.SCROLL:
            page.mouse.wheel(x=0, y=action.y or 400)

        elif action.action == A.EVAL:
            result = page.evaluate(action.text)
            res_data["eval_result"] = str(result)

        elif action.action == A.GET_A11Y:
            # OpenClaw-style role snapshot with ref labels — the key feature!
            aria_raw = page.locator(":root").aria_snapshot()
            snap, refs = _build_aria_snapshot(aria_raw)
            sess["role_refs"] = refs   # remember refs for subsequent click/type calls

            # Trim large snapshots (news pages can be huge)
            MAX = 10000
            if len(snap) > MAX:
                snap = snap[:MAX] + "\n\n[...snapshot truncated - use eval/scroll to see more...]"

            stats = {
                "total_refs": len(refs),
                "interactive": sum(1 for r in refs.values() if r["role"] in INTERACTIVE_ROLES),
                "url": page.url,
                "title": page.title(),
            }
            res_data["a11y_tree"] = snap
            res_data["eval_result"] = json.dumps(stats)

        elif action.action == A.CLOSE:
            with self.lock:
                s = self.sessions.pop(action.session_id or "default", None)
                if s:
                    s["context"].close()

    # ------------------------------------------------------------------
    # Public Interface
    # ------------------------------------------------------------------

    def execute(self, task, sandbox, on_complete, on_event=None):
        self.task_queue.put((task, sandbox, on_complete, on_event))

    def cancel(self, task_id):
        return False

    def shutdown(self):
        self.task_queue.put(None)
