import time
import json
import os
import hashlib
import logging
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2

logger = logging.getLogger(__name__)

class TaskAssistant:
    """The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
    def __init__(self, registry, journal, pool, mirror=None):
        self.registry = registry
        self.journal = journal
        self.pool = pool
        self.mirror = mirror
        self.memberships = {} # session_id -> list(node_id)

    def push_workspace(self, node_id, session_id):
        """Initial unidirectional push from server ghost mirror to a node."""
        node = self.registry.get_node(node_id)
        if not node or not self.mirror: return
        
        print(f"[📁📤] Initiating Workspace Push for Session {session_id} to {node_id}")
        
        # Track for recovery
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)

        manifest = self.mirror.generate_manifest(session_id)
        
        # 1. Send Manifest
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                manifest=manifest
            )
        ))
        
        # 2. Send File Data
        for file_info in manifest.files:
            if not file_info.is_dir:
                self.push_file(node_id, session_id, file_info.path)

    def push_file(self, node_id, session_id, rel_path):
        """Pushes a specific file to a node (used for drift recovery)."""
        node = self.registry.get_node(node_id)
        if not node: return
        
        workspace = self.mirror.get_workspace_path(session_id)
        abs_path = os.path.join(workspace, rel_path)
        
        if not os.path.exists(abs_path):
            print(f"    [📁❓] Requested file {rel_path} not found in mirror")
            return

        with open(abs_path, "rb") as f:
            full_data = f.read()
            full_hash = hashlib.sha256(full_data).hexdigest()
            f.seek(0)
            
            index = 0
            while True:
                chunk = f.read(1024 * 1024) # 1MB chunks
                is_final = len(chunk) < 1024 * 1024
                
                node.queue.put(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        file_data=agent_pb2.FilePayload(
                            path=rel_path,
                            chunk=chunk,
                            chunk_index=index,
                            is_final=is_final,
                            hash=full_hash if is_final else ""
                        )
                    )
                ))
                
                if is_final or not chunk:
                    break
                index += 1

    def clear_workspace(self, node_id, session_id):
        """Sends a SyncControl command to purge the local sync directory on a node, and removes from active mesh."""
        print(f"    [📁🧹] Instructing node {node_id} to purge workspace for session {session_id}")
        if session_id in self.memberships and node_id in self.memberships[session_id]:
            self.memberships[session_id].remove(node_id)
        
        node = self.registry.get_node(node_id)
        if not node: return
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE, path=".")
            )
        ))

    def reconcile_node(self, node_id):
        """Forces a re-sync check for all sessions this node belongs to and purges dead sessions."""
        print(f"    [📁🔄] Triggering Resync Check for {node_id}...")
        
        active_sessions = []
        for sid, nodes in self.memberships.items():
            if node_id in nodes:
                active_sessions.append(sid)
                
        # Send proactive cleanup payload with the active sessions whitelist
        node = self.registry.get_node(node_id)
        if node:
            node.queue.put(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id="global",
                    control=agent_pb2.SyncControl(
                        action=agent_pb2.SyncControl.CLEANUP,
                        request_paths=active_sessions
                    )
                )
            ))
            
        for sid in active_sessions:
            # Re-push manifest to trigger node-side drift check
            self.push_workspace(node_id, sid)

    def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
        """Broadcasts a file chunk received from one node to all other nodes in the mesh."""
        session_members = self.memberships.get(session_id, [])
        destinations = [n for n in session_members if n != sender_node_id]
        
        if destinations:
            print(f"    [📁📢] Broadcasting {file_payload.path} from {sender_node_id} to: {', '.join(destinations)}")
        
        for node_id in destinations:
            node = self.registry.get_node(node_id)
            if not node:
                continue
            
            # Forward the exact same FileSyncMessage
            node.queue.put(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id=session_id,
                    file_data=file_payload
                )
            ))

    def lock_workspace(self, node_id, session_id):
        """Disables user-side synchronization from a node during AI refactors."""
        self.control_sync(node_id, session_id, action="LOCK")

    def unlock_workspace(self, node_id, session_id):
        """Re-enables user-side synchronization from a node."""
        self.control_sync(node_id, session_id, action="UNLOCK")

    def request_manifest(self, node_id, session_id, path="."):
        """Requests a full directory manifest from a node for drift checking."""
        node = self.registry.get_node(node_id)
        if not node: return
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
            )
        ))

    def control_sync(self, node_id, session_id, action="START", path="."):
        """Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
        node = self.registry.get_node(node_id)
        if not node: return

        action_map = {
            "START": agent_pb2.SyncControl.START_WATCHING,
            "STOP": agent_pb2.SyncControl.STOP_WATCHING,
            "LOCK": agent_pb2.SyncControl.LOCK,
            "UNLOCK": agent_pb2.SyncControl.UNLOCK
        }
        proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)

        # Track for recovery & broadcast
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=proto_action, path=path)
            )
        ))

    # ==================================================================
    #  Modular FS Explorer / Mesh Navigation
    # ==================================================================

    def ls(self, node_id: str, path: str = ".", timeout=10, session_id="__fs_explorer__"):
        """Requests a directory listing from a node (waits for response)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"fs-ls-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.LIST, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            
            # Proactive Mirroring: start fetching content so dots turn green
            # (Only for user sessions, not for node management explorer)
            if res and "files" in res and session_id != "__fs_explorer__":
                self._proactive_explorer_sync(node_id, res["files"], session_id)

            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def _proactive_explorer_sync(self, node_id, files, session_id):
        """Starts background tasks to mirror files to Hub so dots turn green."""
        import threading
        for f in files:
            if f.get("is_dir"): continue
            if not f.get("is_synced") and f.get("size", 0) < 1024 * 512: # Skip large files
                threading.Thread(target=self.cat, args=(node_id, f["path"], 15, session_id), daemon=True).start()

    def cat(self, node_id: str, path: str, timeout=15, session_id="__fs_explorer__"):
        """Requests file content from a node (waits for result)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # For 'cat', we might get multiple chunks, but TaskJournal fulfill 
        # usually happens on the final chunk. We'll handle chunking in server.
        tid = f"fs-cat-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.READ, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            # res usually contains {content, path}. grpc_server already writes it to mirror.
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def write(self, node_id: str, path: str, content: bytes = b"", is_dir: bool = False, timeout=10, session_id="__fs_explorer__"):
        """Creates or updates a file/directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"fs-write-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(
                    action=agent_pb2.SyncControl.WRITE, 
                    path=path, 
                    content=content,
                    is_dir=is_dir
                )
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            
            # M6: Update mirror locally on hub so ls sees it as synced (Only for real sessions)
            if self.mirror and res.get("success") and session_id != "__fs_explorer__":
                workspace_mirror = self.mirror.get_workspace_path(session_id)
                dest = os.path.join(workspace_mirror, path)
                if is_dir:
                    os.makedirs(dest, exist_ok=True)
                else:
                    os.makedirs(os.path.dirname(dest), exist_ok=True)
                    with open(dest, "wb") as f:
                        f.write(content)

            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def rm(self, node_id: str, path: str, timeout=10, session_id="__fs_explorer__"):
        """Deletes a file or directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"fs-rm-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            
            # M6: remove from mirror if successful (Only for real sessions)
            if self.mirror and res.get("success") and session_id != "__fs_explorer__":
                import shutil
                dest = os.path.join(self.mirror.get_workspace_path(session_id), path)
                if os.path.isdir(dest): shutil.rmtree(dest)
                elif os.path.exists(dest): os.remove(dest)

            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_swarm(self, node_ids, cmd, timeout=30, session_id=None, no_abort=False):
        """Dispatches a command to multiple nodes in parallel and waits for all results."""
        from concurrent.futures import ThreadPoolExecutor
        
        results = {}
        with ThreadPoolExecutor(max_workers=len(node_ids)) as executor:
            future_to_node = {
                executor.submit(self.dispatch_single, nid, cmd, timeout, session_id, no_abort): nid 
                for nid in node_ids
            }
            for future in future_to_node:
                node_id = future_to_node[future]
                try:
                    results[node_id] = future.result()
                except Exception as exc:
                    results[node_id] = {"error": str(exc)}
        
        return results

    def dispatch_single(self, node_id, cmd, timeout=30, session_id=None, no_abort=False):
        """Dispatches a shell command to a specific node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"task-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        # 12-Factor Signing Logic
        sig = sign_payload(cmd)
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, payload_json=cmd, signature=sig, session_id=session_id,
            timeout_ms=timeout * 1000))
        
        logger.info(f"[📤] Dispatching shell {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"command": cmd, "session_id": session_id}, task_id=tid)
        node.queue.put(req)
        self.registry.emit(node_id, "task_start", {"command": cmd}, task_id=tid)
        
        # Immediate peek if timeout is 0
        if timeout == 0:
            return {"status": "RUNNING", "stdout": "", "task_id": tid}

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            # pop only if fully done
            if res.get("status") != "RUNNING":
                self.journal.pop(tid)
            return res
        
        # M6: Timeout recovery.
        if no_abort:
            logger.info(f"[⏳] Shell task {tid} TIMEOUT (no_abort=True). Leaving alive on {node_id}.")
            res = self.journal.get_result(tid) or {}
            res["task_id"] = tid
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[⚠️] Shell task {tid} TIMEOUT after {timeout}s on {node_id}. Sending ABORT.")
        try:
            node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)))
        except: pass
        
        # Return partial result captured in buffer before popping
        res = self.journal.get_result(tid)
        self.journal.pop(tid)
        return res if res else {"error": "Timeout", "stdout": "", "stderr": "", "status": "TIMEOUT", "task_id": tid}

    def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
        """Dispatches a browser action to a directed session node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Secure Browser Signing
        sig = sign_browser_action(
            agent_pb2.BrowserAction.ActionType.Name(action.action), 
            action.url, 
            action.session_id
        )

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig, session_id=session_id))
        
        logger.info(f"[🌐📤] Dispatching browser {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"browser_action": action.action, "url": action.url}, task_id=tid)
        node.queue.put(req)
        self.registry.emit(node_id, "task_start", {"browser_action": action.action}, task_id=tid)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def wait_for_swarm(self, task_map, timeout=30, no_abort=False):
        """Waits for multiple tasks (map of node_id -> task_id) in parallel."""
        from concurrent.futures import ThreadPoolExecutor
        
        results = {}
        with ThreadPoolExecutor(max_workers=max(1, len(task_map))) as executor:
            # item = (node_id, task_id)
            future_to_node = {
                executor.submit(self.wait_for_task, nid, tid, timeout, no_abort): nid 
                for nid, tid in task_map.items()
            }
            for fut in future_to_node:
                nid = future_to_node[fut]
                try: results[nid] = fut.result()
                except Exception as e: results[nid] = {"error": str(e)}
        return results

    def wait_for_task(self, node_id, task_id, timeout=30, no_abort=False):
        """Waits for an existing task in the journal."""
        # Check journal first
        with self.journal.lock:
            data = self.journal.tasks.get(task_id)
            if not data: 
                return {"error": f"Task {task_id} not found in journal (finished or expired)", "status": "NOT_FOUND"}
            event = data["event"]
        
        # Immediate peek if timeout is 0 or event is already set
        if timeout == 0 or event.is_set():
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res

        logger.info(f"[⏳] Re-waiting for task {task_id} on {node_id} for {timeout}s")
        if event.wait(timeout):
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res
        
        if no_abort:
            res = self.journal.get_result(task_id) or {}
            res["task_id"] = task_id
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[⚠️] Wait for task {task_id} TIMEOUT again. Sending ABORT.")
        node = self.registry.get_node(node_id)
        if node:
            try: node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=task_id)))
            except: pass
        
        res = self.journal.get_result(task_id)
        self.journal.pop(task_id)
        return res if res else {"error": "Timeout", "status": "TIMEOUT", "task_id": task_id}
