Newer
Older
cortex-hub / ai-hub / app / core / grpc / services / assistant.py
import time
import json
import os
import hashlib
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2

class TaskAssistant:
    """The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
    def __init__(self, registry, journal, pool, mirror=None):
        self.registry = registry
        self.journal = journal
        self.pool = pool
        self.mirror = mirror
        self.memberships = {} # session_id -> list(node_id)

    def push_workspace(self, node_id, session_id):
        """Initial unidirectional push from server ghost mirror to a node."""
        node = self.registry.get_node(node_id)
        if not node or not self.mirror: return
        
        print(f"[📁📤] Initiating Workspace Push for Session {session_id} to {node_id}")
        
        # Track for recovery
        if session_id not in self.memberships:
            self.memberships[session_id] = []
        if node_id not in self.memberships[session_id]:
            self.memberships[session_id].append(node_id)

        manifest = self.mirror.generate_manifest(session_id)
        
        # 1. Send Manifest
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                manifest=manifest
            )
        ))
        
        # 2. Send File Data
        for file_info in manifest.files:
            if not file_info.is_dir:
                self.push_file(node_id, session_id, file_info.path)

    def push_file(self, node_id, session_id, rel_path):
        """Pushes a specific file to a node (used for drift recovery)."""
        node = self.registry.get_node(node_id)
        if not node: return
        
        workspace = self.mirror.get_workspace_path(session_id)
        abs_path = os.path.join(workspace, rel_path)
        
        if not os.path.exists(abs_path):
            print(f"    [📁❓] Requested file {rel_path} not found in mirror")
            return

        with open(abs_path, "rb") as f:
            full_data = f.read()
            full_hash = hashlib.sha256(full_data).hexdigest()
            f.seek(0)
            
            index = 0
            while True:
                chunk = f.read(1024 * 1024) # 1MB chunks
                is_final = len(chunk) < 1024 * 1024
                
                node.queue.put(agent_pb2.ServerTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        file_data=agent_pb2.FilePayload(
                            path=rel_path,
                            chunk=chunk,
                            chunk_index=index,
                            is_final=is_final,
                            hash=full_hash if is_final else ""
                        )
                    )
                ))
                
                if is_final or not chunk:
                    break
                index += 1

    def reconcile_node(self, node_id):
        """Forces a re-sync check for all sessions this node belongs to."""
        print(f"    [📁🔄] Triggering Resync Check for {node_id}...")
        for sid, nodes in self.memberships.items():
            if node_id in nodes:
                # Re-push manifest to trigger node-side drift check
                self.push_workspace(node_id, sid)

    def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
        """Broadcasts a file chunk received from one node to all other nodes in the mesh."""
        print(f"    [📁📢] Broadcasting {file_payload.path} from {sender_node_id} to other nodes...")
        for node_id in self.registry.list_nodes():
            if node_id == sender_node_id:
                continue
            
            node = self.registry.get_node(node_id)
            if not node:
                continue
            
            # Forward the exact same FileSyncMessage
            node.queue.put(agent_pb2.ServerTaskMessage(
                file_sync=agent_pb2.FileSyncMessage(
                    session_id=session_id,
                    file_data=file_payload
                )
            ))

    def lock_workspace(self, node_id, session_id):
        """Disables user-side synchronization from a node during AI refactors."""
        self.control_sync(node_id, session_id, action="LOCK")

    def unlock_workspace(self, node_id, session_id):
        """Re-enables user-side synchronization from a node."""
        self.control_sync(node_id, session_id, action="UNLOCK")

    def request_manifest(self, node_id, session_id, path="."):
        """Requests a full directory manifest from a node for drift checking."""
        node = self.registry.get_node(node_id)
        if not node: return
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
            )
        ))

    def control_sync(self, node_id, session_id, action="START", path="."):
        """Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
        node = self.registry.get_node(node_id)
        if not node: return

        action_map = {
            "START": agent_pb2.SyncControl.START_WATCHING,
            "STOP": agent_pb2.SyncControl.STOP_WATCHING,
            "LOCK": agent_pb2.SyncControl.LOCK,
            "UNLOCK": agent_pb2.SyncControl.UNLOCK
        }
        proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)
        
        node.queue.put(agent_pb2.ServerTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                control=agent_pb2.SyncControl(action=proto_action, path=path)
            )
        ))

    def dispatch_single(self, node_id, cmd, timeout=30, session_id=None):
        """Dispatches a shell command to a specific node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"task-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        # 12-Factor Signing Logic
        sig = sign_payload(cmd)
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, payload_json=cmd, signature=sig, session_id=session_id))
        
        print(f"[📤] Dispatching shell {tid} to {node_id}")
        node["queue"].put(req)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
        """Dispatches a browser action to a directed session node."""
        node = self.registry.get_node(node_id)
        if not node: return {"error": f"Node {node_id} Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Secure Browser Signing
        sig = sign_browser_action(
            agent_pb2.BrowserAction.ActionType.Name(action.action), 
            action.url, 
            action.session_id
        )

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig, session_id=session_id))
        
        print(f"[🌐📤] Dispatching browser {tid} to {node_id}")
        node["queue"].put(req)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}