import time
import json
import os
import hashlib
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2
class TaskAssistant:
"""The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
def __init__(self, registry, journal, pool, mirror=None):
self.registry = registry
self.journal = journal
self.pool = pool
self.mirror = mirror
self.memberships = {} # session_id -> list(node_id)
def push_workspace(self, node_id, session_id):
"""Initial unidirectional push from server ghost mirror to a node."""
node = self.registry.get_node(node_id)
if not node or not self.mirror: return
print(f"[📁📤] Initiating Workspace Push for Session {session_id} to {node_id}")
# Track for recovery
if session_id not in self.memberships:
self.memberships[session_id] = []
if node_id not in self.memberships[session_id]:
self.memberships[session_id].append(node_id)
manifest = self.mirror.generate_manifest(session_id)
# 1. Send Manifest
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
manifest=manifest
)
))
# 2. Send File Data
for file_info in manifest.files:
if not file_info.is_dir:
self.push_file(node_id, session_id, file_info.path)
def push_file(self, node_id, session_id, rel_path):
"""Pushes a specific file to a node (used for drift recovery)."""
node = self.registry.get_node(node_id)
if not node: return
workspace = self.mirror.get_workspace_path(session_id)
abs_path = os.path.join(workspace, rel_path)
if not os.path.exists(abs_path):
print(f" [📁❓] Requested file {rel_path} not found in mirror")
return
with open(abs_path, "rb") as f:
full_data = f.read()
full_hash = hashlib.sha256(full_data).hexdigest()
f.seek(0)
index = 0
while True:
chunk = f.read(1024 * 1024) # 1MB chunks
is_final = len(chunk) < 1024 * 1024
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=agent_pb2.FilePayload(
path=rel_path,
chunk=chunk,
chunk_index=index,
is_final=is_final,
hash=full_hash if is_final else ""
)
)
))
if is_final or not chunk:
break
index += 1
def reconcile_node(self, node_id):
"""Forces a re-sync check for all sessions this node belongs to."""
print(f" [📁🔄] Triggering Resync Check for {node_id}...")
for sid, nodes in self.memberships.items():
if node_id in nodes:
# Re-push manifest to trigger node-side drift check
self.push_workspace(node_id, sid)
def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
"""Broadcasts a file chunk received from one node to all other nodes in the mesh."""
print(f" [📁📢] Broadcasting {file_payload.path} from {sender_node_id} to other nodes...")
for node_id in self.registry.list_nodes():
if node_id == sender_node_id:
continue
node = self.registry.get_node(node_id)
if not node:
continue
# Forward the exact same FileSyncMessage
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=file_payload
)
))
def lock_workspace(self, node_id, session_id):
"""Disables user-side synchronization from a node during AI refactors."""
self.control_sync(node_id, session_id, action="LOCK")
def unlock_workspace(self, node_id, session_id):
"""Re-enables user-side synchronization from a node."""
self.control_sync(node_id, session_id, action="UNLOCK")
def request_manifest(self, node_id, session_id, path="."):
"""Requests a full directory manifest from a node for drift checking."""
node = self.registry.get_node(node_id)
if not node: return
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
)
))
def control_sync(self, node_id, session_id, action="START", path="."):
"""Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
node = self.registry.get_node(node_id)
if not node: return
action_map = {
"START": agent_pb2.SyncControl.START_WATCHING,
"STOP": agent_pb2.SyncControl.STOP_WATCHING,
"LOCK": agent_pb2.SyncControl.LOCK,
"UNLOCK": agent_pb2.SyncControl.UNLOCK
}
proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=proto_action, path=path)
)
))
def dispatch_single(self, node_id, cmd, timeout=30, session_id=None):
"""Dispatches a shell command to a specific node."""
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
tid = f"task-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
# 12-Factor Signing Logic
sig = sign_payload(cmd)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, payload_json=cmd, signature=sig, session_id=session_id))
print(f"[📤] Dispatching shell {tid} to {node_id}")
node["queue"].put(req)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
"""Dispatches a browser action to a directed session node."""
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
tid = f"br-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
# Secure Browser Signing
sig = sign_browser_action(
agent_pb2.BrowserAction.ActionType.Name(action.action),
action.url,
action.session_id
)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, browser_action=action, signature=sig, session_id=session_id))
print(f"[🌐📤] Dispatching browser {tid} to {node_id}")
node["queue"].put(req)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}