Newer
Older
cortex-hub / ai-hub / app / core / grpc / services / assistant.py
import time
import json
import os
import hashlib
import zlib
import logging
import shutil
import threading
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2
from app.db.session import get_db_session
from app.db.models import Session

logger = logging.getLogger(__name__)

class TaskAssistant:
    """The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
    def __init__(self, registry, journal, pool, mirror=None):
        self.registry = registry
        self.journal = journal
        self.pool = pool
        self.mirror = mirror
        self.memberships = {} # session_id -> list(node_id)
        self.membership_lock = threading.Lock()

    def push_workspace(self, node_id, session_id):
        """Initial unidirectional push from server ghost mirror to a node."""
        if not self.mirror: return
        
        # 1. Ensure Server Mirror exists immediately
        manifest = self.mirror.generate_manifest(session_id)

        # 2. Track relationship for recovery/reconciliation
        with self.membership_lock:
            if session_id not in self.memberships:
                self.memberships[session_id] = []
            if node_id not in self.memberships[session_id]:
                self.memberships[session_id].append(node_id)

        # 3. If node is online, push actual data
        node = self.registry.get_node(node_id)
        if not node:
            logger.info(f"[๐Ÿ“๐Ÿ“ค] Workspace {session_id} prepared on server for offline node {node_id}")
            return
            
        print(f"[๐Ÿ“๐Ÿ“ค] Initiating Workspace Push for Session {session_id} to {node_id}")
        
        # Send Manifest to Node. The node will compare this with its local state
        # and send back RECONCILE_REQUIRED for any files it is missing.
        # This prevents the "Double Push" race where the server blasts data
        # while the node is still trying to decide what it needs.
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                manifest=manifest
            )
        ), priority=1)
        
        # NOTE: Proactive parallel push removed. Manifest-driven reactive sync is cleaner.

    def push_file(self, node_id, session_id, rel_path):
        """Pushes a specific file to a node (used for drift recovery)."""
        node = self.registry.get_node(node_id)
        if not node: return
        
        workspace = self.mirror.get_workspace_path(session_id)
        abs_path = os.path.join(workspace, rel_path)
        
        if not os.path.exists(abs_path):
            print(f"    [๐Ÿ“โ“] Requested file {rel_path} not found in mirror")
            return

        # Line-rate Optimization: 4MB chunks + No Software Throttling
        hasher = hashlib.sha256()
        file_size = os.path.getsize(abs_path)
        
        try:
            with open(abs_path, "rb") as f:
                index = 0
                while True:
                    chunk = f.read(4 * 1024 * 1024) # 4MB chunks (optimal for gRPC)
                    if not chunk: break
                    
                    hasher.update(chunk)
                    offset = f.tell() - len(chunk)
                    is_final = f.tell() >= file_size
                    
                    # Compress Chunk for transit
                    compressed_chunk = zlib.compress(chunk)

                    # Put into priority dispatcher (priority 2 for sync data)
                    node.send_message(agent_pb2.ServerTaskMessage(
                        file_sync=agent_pb2.FileSyncMessage(
                            session_id=session_id,
                            file_data=agent_pb2.FilePayload(
                                path=rel_path,
                                chunk=compressed_chunk,
                                chunk_index=index,
                                is_final=is_final,
                                hash=hasher.hexdigest() if is_final else "",
                                offset=offset,
                                compressed=True
                            )
                        )
                    ), priority=2)
                    
                    if is_final: break
                    index += 1
        except Exception as e:
            logger.error(f"[๐Ÿ“๐Ÿ“ค] Line-rate push error for {rel_path}: {e}")

    def clear_workspace(self, node_id, session_id):
        """Sends a SyncControl command to purge the local sync directory on a node, and removes from active mesh."""
        print(f"    [๐Ÿ“๐Ÿงน] Instructing node {node_id} to purge workspace for session {session_id}")
        if session_id in self.memberships and node_id in self.memberships[session_id]:
            self.memberships[session_id].remove(node_id)
        
        node = self.registry.get_node(node_id)
        if not node: return
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE, path=".")
            )
        ), priority=1)

    def reconcile_node(self, node_id):
        """Forces a re-sync check for all sessions this node belongs to and purges dead sessions."""
        print(f"    [๐Ÿ“๐Ÿ”„] Triggering Resync Check for {node_id}...")
        
        active_sessions = []
        try:
            with get_db_session() as db:
                sessions = db.query(Session).filter(
                    Session.is_archived == False,
                    Session.sync_workspace_id.isnot(None)
                ).all()
                
                with self.membership_lock:
                    for s in sessions:
                        attached = s.attached_node_ids or []
                        if node_id in attached:
                            active_sessions.append(s.sync_workspace_id)
                            if s.sync_workspace_id not in self.memberships:
                                self.memberships[s.sync_workspace_id] = []
                            if node_id not in self.memberships[s.sync_workspace_id]:
                                self.memberships[s.sync_workspace_id].append(node_id)
                
                # Aggressive memory cleanup: Purge orphaned session memberships
                current_active_workspace_ids = {s.sync_workspace_id for s in sessions}
                with self.membership_lock:
                    to_purge = [sid for sid in self.memberships.keys() if sid not in current_active_workspace_ids]
                    for sid in to_purge:
                        del self.memberships[sid]
        except Exception as e:
            print(f"    [๐Ÿ“โš ๏ธ] Failed to fetch active sessions for node reconciliation: {e}")
            # Fallback to in-memory if DB fails
            with self.membership_lock:
                for sid, nodes in self.memberships.items():
                    if node_id in nodes:
                        active_sessions.append(sid)
                
        # Send proactive cleanup payload with the active sessions whitelist
        node = self.registry.get_node(node_id)
        if node:
            node.send_message(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id="global",
                    control=agent_pb2.SyncControl(
                        action=agent_pb2.SyncControl.CLEANUP,
                        request_paths=active_sessions
                    )
                )
            ), priority=0)
            
        for sid in active_sessions:
            # Re-push manifest to trigger node-side drift check
            self.push_workspace(node_id, sid)
            # Add a small delay to prevent saturating the gRPC stream for multiple sessions
            time.sleep(0.5)

    def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
        """Broadcasts a file chunk received from one node to all other nodes in the mesh."""
        with self.membership_lock:
            session_members = self.memberships.get(session_id, [])
            destinations = [n for n in session_members if n != sender_node_id]
        
        if destinations:
            print(f"    [๐Ÿ“๐Ÿ“ข] Broadcasting {file_payload.path} from {sender_node_id} to: {', '.join(destinations)}")
        
        def _send_to_node(nid):
            node = self.registry.get_node(nid)
            if node:
                # Forward the exact same FileSyncMessage (Priority 2 for Sync Data)
                node.send_message(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        file_data=file_payload
                    )
                ), priority=2)

        # M6: Use registry executor if available for parallel mesh broadcast
        if self.registry.executor:
            for nid in destinations:
                self.registry.executor.submit(_send_to_node, nid)
        else:
            for nid in destinations:
                _send_to_node(nid)

    def lock_workspace(self, node_id, session_id):
        """Disables user-side synchronization from a node during AI refactors."""
        self.control_sync(node_id, session_id, action="LOCK")

    def unlock_workspace(self, node_id, session_id):
        """Re-enables user-side synchronization from a node."""
        self.control_sync(node_id, session_id, action="UNLOCK")

    def request_manifest(self, node_id, session_id, path="."):
        """Requests a full directory manifest from a node for drift checking."""
        node = self.registry.get_node(node_id)
        if not node: return
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
            )
        ), priority=1)

    def control_sync(self, node_id, session_id, action="START", path="."):
        """Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
        node = self.registry.get_node(node_id)
        if not node: return

        action_map = {
            "START": agent_pb2.SyncControl.START_WATCHING,
            "STOP": agent_pb2.SyncControl.STOP_WATCHING,
            "LOCK": agent_pb2.SyncControl.LOCK,
            "UNLOCK": agent_pb2.SyncControl.UNLOCK,
            "RESYNC": agent_pb2.SyncControl.RESYNC
        }
        proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)

        # Track for recovery & broadcast
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)
        
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=proto_action, path=path)
            )
        ), priority=1)

    # ==================================================================
    #  Modular FS Explorer / Mesh Navigation
    # ==================================================================

    def ls(self, node_id: str, path: str = ".", timeout=10, session_id="__fs_explorer__", force_remote: bool = False):
        """Requests a directory listing from a node (waits for response)."""
        # Phase 1: Local Mirror Fast-Path
        if session_id != "__fs_explorer__" and self.mirror and not force_remote:
            workspace = self.mirror.get_workspace_path(session_id)
            abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
            if os.path.exists(abs_path) and os.path.isdir(abs_path):
                files = []
                try:
                    for entry in os.scandir(abs_path):
                        rel = os.path.relpath(entry.path, workspace)
                        files.append({
                            "path": rel,
                            "name": entry.name,
                            "is_dir": entry.is_dir(),
                            "size": entry.stat().st_size if entry.is_file() else 0,
                            "is_synced": True
                        })
                    return {"files": files, "path": path}
                except Exception as e:
                    logger.error(f"[๐Ÿ“๐Ÿ“‚] Local ls error for {session_id}/{path}: {e}")

        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"fs-ls-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.LIST, path=path)
            )
        ), priority=1)

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            
            # Proactive Mirroring: start fetching content so dots turn green
            if res and "files" in res and session_id != "__fs_explorer__":
                self._proactive_explorer_sync(node_id, res["files"], session_id)

            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def _proactive_explorer_sync(self, node_id, files, session_id):
        """Starts background tasks to mirror files to Hub so dots turn green."""
        for f in files:
            if f.get("is_dir"): continue
            if not f.get("is_synced") and f.get("size", 0) < 1024 * 512: # Skip large files
                # M6: Use shared registry executor instead of spawning loose threads
                if self.registry.executor:
                    self.registry.executor.submit(self.cat, node_id, f["path"], 15, session_id)

    def cat(self, node_id: str, path: str, timeout=15, session_id="__fs_explorer__", force_remote: bool = False):
        """Requests file content from a node (waits for result)."""
        # Phase 1: Local Mirror Fast-Path
        if session_id != "__fs_explorer__" and self.mirror and not force_remote:
            workspace = self.mirror.get_workspace_path(session_id)
            abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
            if os.path.exists(abs_path) and os.path.isfile(abs_path):
                try:
                    # Try reading as text
                    with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f:
                        content = f.read()
                    return {"content": content, "path": path}
                except Exception as e:
                    logger.error(f"[๐Ÿ“๐Ÿ“„] Local cat error for {session_id}/{path}: {e}")

        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # For 'cat', we might get multiple chunks, but TaskJournal fulfill 
        # usually happens on the final chunk. We'll handle chunking in server.
        tid = f"fs-cat-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.READ, path=path)
            )
        ), priority=1)

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            # res usually contains {content, path}. grpc_server already writes it to mirror.
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def write(self, node_id: str, path: str, content: bytes = b"", is_dir: bool = False, timeout=10, session_id="__fs_explorer__"):
        """Creates or updates a file/directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # Phase 1: Sync local mirror ON HUB instantly (Zero Latency)
        if self.mirror and session_id != "__fs_explorer__":
            workspace_mirror = self.mirror.get_workspace_path(session_id)
            dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
            try:
                if is_dir:
                    os.makedirs(dest, exist_ok=True)
                else:
                    os.makedirs(os.path.dirname(dest), exist_ok=True)
                    with open(dest, "wb") as f:
                        f.write(content)
                
                # Multi-node broadcast for sessions
                targets = []
                if session_id != "__fs_explorer__":
                    targets = self.memberships.get(session_id, [node_id])
                else:
                    targets = [node_id]

                print(f"[๐Ÿ“โœ๏ธ] AI Write: {path} (Session: {session_id}) -> Dispatching to {len(targets)} nodes")
                
                for target_nid in targets:
                    target_node = self.registry.get_node(target_nid)
                    if not target_node: continue
                    
                    tid = f"fs-write-{int(time.time()*1000)}"
                    target_node.send_message(agent_pb2.ServerTaskMessage(
                        file_sync=agent_pb2.FileSyncMessage(
                            session_id=session_id,
                            task_id=tid,
                            control=agent_pb2.SyncControl(
                                action=agent_pb2.SyncControl.WRITE, 
                                path=path, 
                                content=content,
                                is_dir=is_dir
                            )
                        )
                    ), priority=2)
                
                return {"success": True, "message": f"Synchronized to local mirror and dispatched to {len(targets)} nodes"}
            except Exception as e:
                logger.error(f"[๐Ÿ“โœ๏ธ] Local mirror write error: {e}")
                return {"error": str(e)}

        # Legacy/Explorer path: await node confirmation
        tid = f"fs-write-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(
                    action=agent_pb2.SyncControl.WRITE, 
                    path=path, 
                    content=content,
                    is_dir=is_dir
                )
            )
        ), priority=2)

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def inspect_drift(self, node_id: str, path: str, session_id: str):
        """Returns a unified diff between Hub local mirror and Node's actual file."""
        if not self.mirror: return {"error": "Mirror not available"}
        
        # 1. Get Local Content
        workspace = self.mirror.get_workspace_path(session_id)
        local_abs = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
        local_content = ""
        if os.path.exists(local_abs) and os.path.isfile(local_abs):
            try:
                with open(local_abs, 'r', encoding='utf-8', errors='ignore') as f:
                    local_content = f.read()
            except: pass

        # 2. Get Remote Content (Force Bypass Fast-Path)
        print(f"    [๐Ÿ“๐Ÿ”] Inspecting Drift: Fetching remote content for {path} on {node_id}")
        remote_res = self.cat(node_id, path, session_id=session_id, force_remote=True)
        if "error" in remote_res:
             return {"error": f"Failed to fetch remote content: {remote_res['error']}"}
        
        remote_content = remote_res.get("content", "")

        # 3. Create Diff
        import difflib
        diff = difflib.unified_diff(
            local_content.splitlines(keepends=True),
            remote_content.splitlines(keepends=True),
            fromfile=f"hub://{session_id}/{path}",
            tofile=f"node://{node_id}/{path}"
        )
        
        diff_text = "".join(diff)
        return {
            "path": path,
            "has_drift": local_content != remote_content,
            "diff": diff_text,
            "local_size": len(local_content),
            "remote_size": len(remote_content)
        }

    def rm(self, node_id: str, path: str, timeout=10, session_id="__fs_explorer__"):
        """Deletes a file or directory on a node (waits for status)."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        # Phase 1: Sync local mirror ON HUB instantly
        if self.mirror and session_id != "__fs_explorer__":
            workspace_mirror = self.mirror.get_workspace_path(session_id)
            dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
            try:
                if os.path.isdir(dest):
                    shutil.rmtree(dest)
                elif os.path.exists(dest):
                    os.remove(dest)
                
                # Multi-node broadcast for sessions
                targets = []
                if session_id != "__fs_explorer__":
                    targets = self.memberships.get(session_id, [node_id])
                else:
                    targets = [node_id]

                print(f"[๐Ÿ“๐Ÿ—‘๏ธ] AI Remove: {path} (Session: {session_id}) -> Dispatching to {len(targets)} nodes")

                for target_nid in targets:
                    target_node = self.registry.get_node(target_nid)
                    if not target_node: continue
                    
                    tid = f"fs-rm-{int(time.time()*1000)}"
                    target_node.send_message(agent_pb2.ServerTaskMessage(
                        file_sync=agent_pb2.FileSyncMessage(
                            session_id=session_id,
                            task_id=tid,
                            control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
                        )
                    ), priority=2)
                return {"success": True, "message": f"Removed from local mirror and dispatched delete to {len(targets)} nodes"}
            except Exception as e:
                logger.error(f"[๐Ÿ“๐Ÿ—‘๏ธ] Local mirror rm error: {e}")
                return {"error": str(e)}

        # Legacy/Explorer path: await node confirmation
        tid = f"fs-rm-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        node.send_message(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                task_id=tid,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
            )
        ), priority=2)

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_swarm(self, node_ids, cmd, timeout=120, session_id=None, no_abort=False):
        """Dispatches a command to multiple nodes in parallel and waits for all results."""
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        results = {}
        with ThreadPoolExecutor(max_workers=max(1, len(node_ids))) as executor:
            future_to_node = {
                executor.submit(self.dispatch_single, nid, cmd, timeout, session_id, no_abort): nid 
                for nid in node_ids
            }
            # Use as_completed to avoid blocking on a slow node when others are finished
            for future in as_completed(future_to_node):
                node_id = future_to_node[future]
                try:
                    results[node_id] = future.result()
                except Exception as exc:
                    results[node_id] = {"error": str(exc)}
        
        return results

    def dispatch_single(self, node_id, cmd, timeout=120, session_id=None, no_abort=False):
        """Dispatches a shell command to a specific node."""
        import uuid
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        # Use UUID to prevent timestamp collisions in high-speed swarm dispatch
        tid = f"task-{uuid.uuid4().hex[:12]}"
        event = self.journal.register(tid, node_id)
        
        # 12-Factor Signing Logic
        sig = sign_payload(cmd)
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, task_type="shell", payload_json=cmd, signature=sig, session_id=session_id,
            timeout_ms=timeout * 1000))
        
        logger.info(f"[๐Ÿ“ค] Dispatching shell {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"command": cmd, "session_id": session_id}, task_id=tid)
        node.send_message(req, priority=1)
        self.registry.emit(node_id, "task_start", {"command": cmd}, task_id=tid)
        
        # Immediate peek if timeout is 0
        if timeout == 0:
            return {"status": "RUNNING", "stdout": "", "task_id": tid}

        if event.wait(timeout):
            res = self.journal.get_result(tid)
            # pop only if fully done
            if res.get("status") != "RUNNING":
                self.journal.pop(tid)
            return res
        
        # M6: Timeout recovery.
        if no_abort:
            logger.info(f"[โณ] Shell task {tid} TIMEOUT (no_abort=True). Leaving alive on {node_id}.")
            res = self.journal.get_result(tid) or {}
            res["task_id"] = tid
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[โš ๏ธ] Shell task {tid} TIMEOUT after {timeout}s on {node_id}. Sending ABORT.")
        try:
            node.send_message(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)), priority=0)
        except: pass
        
        # Return partial result captured in buffer before popping
        res = self.journal.get_result(tid)
        self.journal.pop(tid)
        return res if res else {"error": "Timeout", "stdout": "", "stderr": "", "status": "TIMEOUT", "task_id": tid}

    def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
        """Dispatches a browser action to a directed session node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Secure Browser Signing
        sig = sign_browser_action(
            agent_pb2.BrowserAction.ActionType.Name(action.action), 
            action.url, 
            action.session_id
        )

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig, session_id=session_id))
        
        logger.info(f"[๐ŸŒ๐Ÿ“ค] Dispatching browser {tid} to {node_id}")
        self.registry.emit(node_id, "task_assigned", {"browser_action": action.action, "url": action.url}, task_id=tid)
        node.send_message(req, priority=1)
        self.registry.emit(node_id, "task_start", {"browser_action": action.action}, task_id=tid)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def wait_for_swarm(self, task_map, timeout=30, no_abort=False):
        """Waits for multiple tasks (map of node_id -> task_id) in parallel."""
        from concurrent.futures import ThreadPoolExecutor
        
        results = {}
        with ThreadPoolExecutor(max_workers=max(1, len(task_map))) as executor:
            # item = (node_id, task_id)
            future_to_node = {
                executor.submit(self.wait_for_task, nid, tid, timeout, no_abort): nid 
                for nid, tid in task_map.items()
            }
            for fut in future_to_node:
                nid = future_to_node[fut]
                try: results[nid] = fut.result()
                except Exception as e: results[nid] = {"error": str(e)}
        return results

    def wait_for_task(self, node_id, task_id, timeout=30, no_abort=False):
        """Waits for an existing task in the journal."""
        # Check journal first
        with self.journal.lock:
            data = self.journal.tasks.get(task_id)
            if not data: 
                return {"error": f"Task {task_id} not found in journal (finished or expired)", "status": "NOT_FOUND"}
            event = data["event"]
        
        # Immediate peek if timeout is 0 or event is already set
        if timeout == 0 or event.is_set():
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res

        logger.info(f"[โณ] Re-waiting for task {task_id} on {node_id} for {timeout}s")
        if event.wait(timeout):
            res = self.journal.get_result(task_id)
            if res.get("status") != "RUNNING":
                self.journal.pop(task_id)
            return res
        
        if no_abort:
            res = self.journal.get_result(task_id) or {}
            res["task_id"] = task_id
            res["status"] = "TIMEOUT_PENDING"
            return res

        logger.warning(f"[โš ๏ธ] Wait for task {task_id} TIMEOUT again. Sending ABORT.")
        node = self.registry.get_node(node_id)
        if node:
            try: node.send_message(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=task_id)), priority=0)
            except: pass
        
        res = self.journal.get_result(task_id)
        self.journal.pop(task_id)
        return res if res else {"error": "Timeout", "status": "TIMEOUT", "task_id": task_id}