Newer
Older
cortex-hub / ai-hub / app / core / grpc / services / assistant.py
import time
import json
import os
import hashlib
import logging
import shutil
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2

logger = logging.getLogger(__name__)

class TaskAssistant:
    """The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
    def __init__(self, registry, journal, pool, mirror=None):
        self.registry = registry
        self.journal = journal
        self.pool = pool
        self.mirror = mirror
        self.memberships = {} # session_id -> list(node_id)

    def push_workspace(self, node_id, session_id):
        """Initial unidirectional push from server ghost mirror to a node."""
        node = self.registry.get_node(node_id)
        if not node or not self.mirror: return
        
        print(f"[๐Ÿ“๐Ÿ“ค] Initiating Workspace Push for Session {session_id} to {node_id}")
        
        # Track for recovery
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)

        manifest = self.mirror.generate_manifest(session_id)
        
        # 1. Send Manifest
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                manifest=manifest
            )
        ))
        
        # 2. Send File Data
        for file_info in manifest.files:
            if not file_info.is_dir:
                self.push_file(node_id, session_id, file_info.path)

    def push_file(self, node_id, session_id, rel_path):
        """Pushes a specific file to a node (used for drift recovery)."""
        node = self.registry.get_node(node_id)
        if not node: return
        
        workspace = self.mirror.get_workspace_path(session_id)
        abs_path = os.path.join(workspace, rel_path)
        
        if not os.path.exists(abs_path):
            print(f"    [๐Ÿ“โ“] Requested file {rel_path} not found in mirror")
            return

        with open(abs_path, "rb") as f:
            full_data = f.read()
            full_hash = hashlib.sha256(full_data).hexdigest()
            f.seek(0)
            
            index = 0
            while True:
                chunk = f.read(1024 * 1024) # 1MB chunks
                is_final = len(chunk) < 1024 * 1024
                
                node.queue.put(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        file_data=agent_pb2.FilePayload(
                            path=rel_path,
                            chunk=chunk,
                            chunk_index=index,
                            is_final=is_final,
                            hash=full_hash if is_final else ""
                        )
                    )
                ))
                
                if is_final or not chunk:
                    break
                index += 1

    def clear_workspace(self, node_id, session_id):
        """Sends a SyncControl command to purge the local sync directory on a node, and removes from active mesh."""
        print(f"    [๐Ÿ“๐Ÿงน] Instructing node {node_id} to purge workspace for session {session_id}")
        if session_id in self.memberships and node_id in self.memberships[session_id]:
            self.memberships[session_id].remove(node_id)
        
        node = self.registry.get_node(node_id)
        if not node: return
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE, path=".")
            )
        ))

    def reconcile_node(self, node_id):
        """Forces a re-sync check for all sessions this node belongs to and purges dead sessions."""
        print(f"    [๐Ÿ“๐Ÿ”„] Triggering Resync Check for {node_id}...")
        
        active_sessions = []
        for sid, nodes in self.memberships.items():
            if node_id in nodes:
                active_sessions.append(sid)
                
        # Send proactive cleanup payload with the active sessions whitelist
        node = self.registry.get_node(node_id)
        if node:
            node.queue.put(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id="global",
                    control=agent_pb2.SyncControl(
                        action=agent_pb2.SyncControl.CLEANUP,
                        request_paths=active_sessions
                    )
                )
            ))
            
        for sid in active_sessions:
            # Re-push manifest to trigger node-side drift check
            self.push_workspace(node_id, sid)

    def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
        """Broadcasts a file chunk received from one node to all other nodes in the mesh."""
        session_members = self.memberships.get(session_id, [])
        destinations = [n for n in session_members if n != sender_node_id]
        
        if destinations:
            print(f"    [๐Ÿ“๐Ÿ“ข] Broadcasting {file_payload.path} from {sender_node_id} to: {', '.join(destinations)}")
        
        for node_id in destinations:
            node = self.registry.get_node(node_id)
            if not node:
                continue
            
            # Forward the exact same FileSyncMessage
            node.queue.put(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id=session_id,
                    file_data=file_payload
                )
            ))

    def lock_workspace(self, node_id, session_id):
        """Disables user-side synchronization from a node during AI refactors."""
        self.control_sync(node_id, session_id, action="LOCK")

    def unlock_workspace(self, node_id, session_id):
        """Re-enables user-side synchronization from a node."""
        self.control_sync(node_id, session_id, action="UNLOCK")

    def request_manifest(self, node_id, session_id, path="."):
        """Requests a full directory manifest from a node for drift checking."""
        node = self.registry.get_node(node_id)
        if not node: return
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
            )
        ))

    def control_sync(self, node_id, session_id, action="START", path="."):
        """Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
        node = self.registry.get_node(node_id)
        if not node: return

        action_map = {
            "START": agent_pb2.SyncControl.START_WATCHING,
            "STOP": agent_pb2.SyncControl.STOP_WATCHING,
            "LOCK": agent_pb2.SyncControl.LOCK,
            "UNLOCK": agent_pb2.SyncControl.UNLOCK,
            "RESYNC": agent_pb2.SyncControl.RESYNC
        }
        proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)

        # Track for recovery & broadcast
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=proto_action, path=path)
            )
        ))

    # ==================================================================
    #  Modular FS Explorer / Mesh Navigation
    # ==================================================================

    def ls(self, node_id: str, path: str = ".", timeout=10, session_id="__fs_explorer__", force_remote: bool = False):
        """Requests a directory listing from a node (waits for response)."""
        # Phase 1: Local Mirror Fast-Path
        if session_id != "__fs_explorer__" and self.mirror and not force_remote:
            workspace = self.mirror.get_workspace_path(session_id)
            abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
            if os.path.exists(abs_path) and os.path.isdir(abs_path):
                files = []
                try:
                    for entry in os.scandir(abs_path):
                        rel = os.path.relpath(entry.path, workspace)
                        files.append({
                            "path": rel,
                            "name": entry.name,
                            "is_dir": entry.is_dir(),
                            "size": entry.stat().st_size if entry.is_file() else 0,
                            "is_synced": True
                        })
                    return {"files": files, "path": path}
                except Exception as e:
                    logger.error(f"[๐Ÿ“๐Ÿ“‚] Local ls error for {session_id}/{path}: {e}")

        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"fs-ls-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.LIST, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            
            # Proactive Mirroring: start fetching content so dots turn green
            if res and "files" in res and session_id != "__fs_explorer__":
                self._proactive_explorer_sync(node_id, res["files"], session_id)

            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def _proactive_explorer_sync(self, node_id, files, session_id):
        """Starts background tasks to mirror files to Hub so dots turn green."""
        import threading
        for f in files:
            if f.get("is_dir"): continue
            if not f.get("is_synced") and f.get("size", 0) < 1024 * 512: # Skip large files
                threading.Thread(target=self.cat, args=(node_id, f["path"], 15, session_id), daemon=True).start()

    def cat(self, node_id: str, path: str, timeout=15, session_id="__fs_explorer__", force_remote: bool = False):
        """Requests file content from a node (waits for result)."""
        # Phase 1: Local Mirror Fast-Path
        if session_id != "__fs_explorer__" and self.mirror and not force_remote:
            workspace = self.mirror.get_workspace_path(session_id)
            abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
            if os.path.exists(abs_path) and os.path.isfile(abs_path):
                try:
                    # Try reading as text
                    with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f:
                        content = f.read()
                    return {"content": content, "path": path}
                except Exception as e:
                    logger.error(f"[๐Ÿ“๐Ÿ“„] Local cat error for {session_id}/{path}: {e}")

        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # For 'cat', we might get multiple chunks, but TaskJournal fulfill 
        # usually happens on the final chunk. We'll handle chunking in server.
        tid = f"fs-cat-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.READ, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            # res usually contains {content, path}. grpc_server already writes it to mirror.
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def write(self, node_id: str, path: str, content: bytes = b"", is_dir: bool = False, timeout=10, session_id="__fs_explorer__"):
        """Creates or updates a file/directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # Phase 1: Sync local mirror ON HUB instantly (Zero Latency)
        if self.mirror and session_id != "__fs_explorer__":
            workspace_mirror = self.mirror.get_workspace_path(session_id)
            dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
            try:
                if is_dir:
                    os.makedirs(dest, exist_ok=True)
                else:
                    os.makedirs(os.path.dirname(dest), exist_ok=True)
                    with open(dest, "wb") as f:
                        f.write(content)
                
                # Fire and forget synchronization to the edge node
                tid = f"fs-write-{int(time.time()*1000)}"
                node.queue.put(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        task_id=tid,
                        control=agent_pb2.SyncControl(
                            action=agent_pb2.SyncControl.WRITE, 
                            path=path, 
                            content=content,
                            is_dir=is_dir
                        )
                    )
                ))
                return {"success": True, "message": "Synchronized to local mirror and dispatched to node"}
            except Exception as e:
                logger.error(f"[๐Ÿ“โœ๏ธ] Local mirror write error: {e}")
                return {"error": str(e)}

        # Legacy/Explorer path: await node confirmation
        tid = f"fs-write-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(
                    action=agent_pb2.SyncControl.WRITE, 
                    path=path, 
                    content=content,
                    is_dir=is_dir
                )
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def inspect_drift(self, node_id: str, path: str, session_id: str):
        """Returns a unified diff between Hub local mirror and Node's actual file."""
        if not self.mirror: return {"error": "Mirror not available"}
        
        # 1. Get Local Content
        workspace = self.mirror.get_workspace_path(session_id)
        local_abs = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
        local_content = ""
        if os.path.exists(local_abs) and os.path.isfile(local_abs):
            try:
                with open(local_abs, 'r', encoding='utf-8', errors='ignore') as f:
                    local_content = f.read()
            except: pass

        # 2. Get Remote Content (Force Bypass Fast-Path)
        print(f"    [๐Ÿ“๐Ÿ”] Inspecting Drift: Fetching remote content for {path} on {node_id}")
        remote_res = self.cat(node_id, path, session_id=session_id, force_remote=True)
        if "error" in remote_res:
             return {"error": f"Failed to fetch remote content: {remote_res['error']}"}
        
        remote_content = remote_res.get("content", "")

        # 3. Create Diff
        import difflib
        diff = difflib.unified_diff(
            local_content.splitlines(keepends=True),
            remote_content.splitlines(keepends=True),
            fromfile=f"hub://{session_id}/{path}",
            tofile=f"node://{node_id}/{path}"
        )
        
        diff_text = "".join(diff)
        return {
            "path": path,
            "has_drift": local_content != remote_content,
            "diff": diff_text,
            "local_size": len(local_content),
            "remote_size": len(remote_content)
        }

    def rm(self, node_id: str, path: str, timeout=10, session_id="__fs_explorer__"):
        """Deletes a file or directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # Phase 1: Sync local mirror ON HUB instantly
        if self.mirror and session_id != "__fs_explorer__":
            workspace_mirror = self.mirror.get_workspace_path(session_id)
            dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
            try:
                if os.path.isdir(dest):
                    shutil.rmtree(dest)
                elif os.path.exists(dest):
                    os.remove(dest)
                
                # Fire and forget to edge node
                tid = f"fs-rm-{int(time.time()*1000)}"
                node.queue.put(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        task_id=tid,
                        control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
                    )
                ))
                return {"success": True, "message": "Removed from local mirror and dispatched delete to node"}
            except Exception as e:
                logger.error(f"[๐Ÿ“๐Ÿ—‘๏ธ] Local mirror rm error: {e}")
                return {"error": str(e)}

        # Legacy/Explorer path: await node confirmation
        tid = f"fs-rm-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
            )
        ))

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_swarm(self, node_ids, cmd, timeout=30, session_id=None, no_abort=False):
        """Dispatches a command to multiple nodes in parallel and waits for all results."""
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        results = {}
        with ThreadPoolExecutor(max_workers=max(1, len(node_ids))) as executor:
            future_to_node = {
                executor.submit(self.dispatch_single, nid, cmd, timeout, session_id, no_abort): nid 
                for nid in node_ids
            }
            # Use as_completed to avoid blocking on a slow node when others are finished
            for future in as_completed(future_to_node):
                node_id = future_to_node[future]
                try:
                    results[node_id] = future.result()
                except Exception as exc:
                    results[node_id] = {"error": str(exc)}
        
        return results

    def dispatch_single(self, node_id, cmd, timeout=30, session_id=None, no_abort=False):
        """Dispatches a shell command to a specific node."""
        import uuid
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        # Use UUID to prevent timestamp collisions in high-speed swarm dispatch
        tid = f"task-{uuid.uuid4().hex[:12]}"
        event = self.journal.register(tid, node_id)
        
        # 12-Factor Signing Logic
        sig = sign_payload(cmd)
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, task_type="shell", payload_json=cmd, signature=sig, session_id=session_id,
            timeout_ms=timeout * 1000))
        
        logger.info(f"[๐Ÿ“ค] Dispatching shell {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"command": cmd, "session_id": session_id}, task_id=tid)
        node.queue.put(req)
        self.registry.emit(node_id, "task_start", {"command": cmd}, task_id=tid)
        
        # Immediate peek if timeout is 0
        if timeout == 0:
            return {"status": "RUNNING", "stdout": "", "task_id": tid}

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            # pop only if fully done
            if res.get("status") != "RUNNING":
                self.journal.pop(tid)
            return res
        
        # M6: Timeout recovery.
        if no_abort:
            logger.info(f"[โณ] Shell task {tid} TIMEOUT (no_abort=True). Leaving alive on {node_id}.")
            res = self.journal.get_result(tid) or {}
            res["task_id"] = tid
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[โš ๏ธ] Shell task {tid} TIMEOUT after {timeout}s on {node_id}. Sending ABORT.")
        try:
            node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)))
        except: pass
        
        # Return partial result captured in buffer before popping
        res = self.journal.get_result(tid)
        self.journal.pop(tid)
        return res if res else {"error": "Timeout", "stdout": "", "stderr": "", "status": "TIMEOUT", "task_id": tid}

    def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
        """Dispatches a browser action to a directed session node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Secure Browser Signing
        sig = sign_browser_action(
            agent_pb2.BrowserAction.ActionType.Name(action.action), 
            action.url, 
            action.session_id
        )

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig, session_id=session_id))
        
        logger.info(f"[๐ŸŒ๐Ÿ“ค] Dispatching browser {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"browser_action": action.action, "url": action.url}, task_id=tid)
        node.queue.put(req)
        self.registry.emit(node_id, "task_start", {"browser_action": action.action}, task_id=tid)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def wait_for_swarm(self, task_map, timeout=30, no_abort=False):
        """Waits for multiple tasks (map of node_id -> task_id) in parallel."""
        from concurrent.futures import ThreadPoolExecutor
        
        results = {}
        with ThreadPoolExecutor(max_workers=max(1, len(task_map))) as executor:
            # item = (node_id, task_id)
            future_to_node = {
                executor.submit(self.wait_for_task, nid, tid, timeout, no_abort): nid 
                for nid, tid in task_map.items()
            }
            for fut in future_to_node:
                nid = future_to_node[fut]
                try: results[nid] = fut.result()
                except Exception as e: results[nid] = {"error": str(e)}
        return results

    def wait_for_task(self, node_id, task_id, timeout=30, no_abort=False):
        """Waits for an existing task in the journal."""
        # Check journal first
        with self.journal.lock:
            data = self.journal.tasks.get(task_id)
            if not data: 
                return {"error": f"Task {task_id} not found in journal (finished or expired)", "status": "NOT_FOUND"}
            event = data["event"]
        
        # Immediate peek if timeout is 0 or event is already set
        if timeout == 0 or event.is_set():
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res

        logger.info(f"[โณ] Re-waiting for task {task_id} on {node_id} for {timeout}s")
        if event.wait(timeout):
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res
        
        if no_abort:
            res = self.journal.get_result(task_id) or {}
            res["task_id"] = task_id
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[โš ๏ธ] Wait for task {task_id} TIMEOUT again. Sending ABORT.")
        node = self.registry.get_node(node_id)
        if node:
            try: node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=task_id)))
            except: pass
        
        res = self.journal.get_result(task_id)
        self.journal.pop(task_id)
        return res if res else {"error": "Timeout", "status": "TIMEOUT", "task_id": task_id}