import time
import json
import os
import hashlib
import logging
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2
logger = logging.getLogger(__name__)
class TaskAssistant:
"""The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
def __init__(self, registry, journal, pool, mirror=None):
self.registry = registry
self.journal = journal
self.pool = pool
self.mirror = mirror
self.memberships = {} # session_id -> list(node_id)
def push_workspace(self, node_id, session_id):
"""Initial unidirectional push from server ghost mirror to a node."""
node = self.registry.get_node(node_id)
if not node or not self.mirror: return
print(f"[๐๐ค] Initiating Workspace Push for Session {session_id} to {node_id}")
# Track for recovery
if session_id not in self.memberships:
self.memberships[session_id] = []
if node_id not in self.memberships[session_id]:
self.memberships[session_id].append(node_id)
manifest = self.mirror.generate_manifest(session_id)
# 1. Send Manifest
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
manifest=manifest
)
))
# 2. Send File Data
for file_info in manifest.files:
if not file_info.is_dir:
self.push_file(node_id, session_id, file_info.path)
def push_file(self, node_id, session_id, rel_path):
"""Pushes a specific file to a node (used for drift recovery)."""
node = self.registry.get_node(node_id)
if not node: return
workspace = self.mirror.get_workspace_path(session_id)
abs_path = os.path.join(workspace, rel_path)
if not os.path.exists(abs_path):
print(f" [๐โ] Requested file {rel_path} not found in mirror")
return
with open(abs_path, "rb") as f:
full_data = f.read()
full_hash = hashlib.sha256(full_data).hexdigest()
f.seek(0)
index = 0
while True:
chunk = f.read(1024 * 1024) # 1MB chunks
is_final = len(chunk) < 1024 * 1024
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=agent_pb2.FilePayload(
path=rel_path,
chunk=chunk,
chunk_index=index,
is_final=is_final,
hash=full_hash if is_final else ""
)
)
))
if is_final or not chunk:
break
index += 1
def clear_workspace(self, node_id, session_id):
"""Sends a SyncControl command to purge the local sync directory on a node, and removes from active mesh."""
print(f" [๐๐งน] Instructing node {node_id} to purge workspace for session {session_id}")
if session_id in self.memberships and node_id in self.memberships[session_id]:
self.memberships[session_id].remove(node_id)
node = self.registry.get_node(node_id)
if not node: return
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE, path=".")
)
))
def reconcile_node(self, node_id):
"""Forces a re-sync check for all sessions this node belongs to and purges dead sessions."""
print(f" [๐๐] Triggering Resync Check for {node_id}...")
active_sessions = []
for sid, nodes in self.memberships.items():
if node_id in nodes:
active_sessions.append(sid)
# Send proactive cleanup payload with the active sessions whitelist
node = self.registry.get_node(node_id)
if node:
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id="global",
control=agent_pb2.SyncControl(
action=agent_pb2.SyncControl.CLEANUP,
request_paths=active_sessions
)
)
))
for sid in active_sessions:
# Re-push manifest to trigger node-side drift check
self.push_workspace(node_id, sid)
def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
"""Broadcasts a file chunk received from one node to all other nodes in the mesh."""
session_members = self.memberships.get(session_id, [])
destinations = [n for n in session_members if n != sender_node_id]
if destinations:
print(f" [๐๐ข] Broadcasting {file_payload.path} from {sender_node_id} to: {', '.join(destinations)}")
for node_id in destinations:
node = self.registry.get_node(node_id)
if not node:
continue
# Forward the exact same FileSyncMessage
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=file_payload
)
))
def lock_workspace(self, node_id, session_id):
"""Disables user-side synchronization from a node during AI refactors."""
self.control_sync(node_id, session_id, action="LOCK")
def unlock_workspace(self, node_id, session_id):
"""Re-enables user-side synchronization from a node."""
self.control_sync(node_id, session_id, action="UNLOCK")
def request_manifest(self, node_id, session_id, path="."):
"""Requests a full directory manifest from a node for drift checking."""
node = self.registry.get_node(node_id)
if not node: return
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
)
))
def control_sync(self, node_id, session_id, action="START", path="."):
"""Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
node = self.registry.get_node(node_id)
if not node: return
action_map = {
"START": agent_pb2.SyncControl.START_WATCHING,
"STOP": agent_pb2.SyncControl.STOP_WATCHING,
"LOCK": agent_pb2.SyncControl.LOCK,
"UNLOCK": agent_pb2.SyncControl.UNLOCK
}
proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)
# Track for recovery & broadcast
if session_id not in self.memberships:
self.memberships[session_id] = []
if node_id not in self.memberships[session_id]:
self.memberships[session_id].append(node_id)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=proto_action, path=path)
)
))
# ==================================================================
# Modular FS Explorer / Mesh Navigation
# ==================================================================
def ls(self, node_id: str, path: str = ".", timeout=10, session_id="__fs_explorer__"):
"""Requests a directory listing from a node (waits for response)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
tid = f"fs-ls-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.LIST, path=path)
)
))
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# Proactive Mirroring: start fetching content so dots turn green
# (Only for user sessions, not for node management explorer)
if res and "files" in res and session_id != "__fs_explorer__":
self._proactive_explorer_sync(node_id, res["files"], session_id)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def _proactive_explorer_sync(self, node_id, files, session_id):
"""Starts background tasks to mirror files to Hub so dots turn green."""
import threading
for f in files:
if f.get("is_dir"): continue
if not f.get("is_synced") and f.get("size", 0) < 1024 * 512: # Skip large files
threading.Thread(target=self.cat, args=(node_id, f["path"], 15, session_id), daemon=True).start()
def cat(self, node_id: str, path: str, timeout=15, session_id="__fs_explorer__"):
"""Requests file content from a node (waits for result)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
# For 'cat', we might get multiple chunks, but TaskJournal fulfill
# usually happens on the final chunk. We'll handle chunking in server.
tid = f"fs-cat-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.READ, path=path)
)
))
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# res usually contains {content, path}. grpc_server already writes it to mirror.
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def write(self, node_id: str, path: str, content: bytes = b"", is_dir: bool = False, timeout=10, session_id="__fs_explorer__"):
"""Creates or updates a file/directory on a node (waits for status)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
tid = f"fs-write-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(
action=agent_pb2.SyncControl.WRITE,
path=path,
content=content,
is_dir=is_dir
)
)
))
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# M6: Update mirror locally on hub so ls sees it as synced (Only for real sessions)
if self.mirror and res.get("success") and session_id != "__fs_explorer__":
workspace_mirror = self.mirror.get_workspace_path(session_id)
dest = os.path.join(workspace_mirror, path)
if is_dir:
os.makedirs(dest, exist_ok=True)
else:
os.makedirs(os.path.dirname(dest), exist_ok=True)
with open(dest, "wb") as f:
f.write(content)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def rm(self, node_id: str, path: str, timeout=10, session_id="__fs_explorer__"):
"""Deletes a file or directory on a node (waits for status)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
tid = f"fs-rm-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.queue.put(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
)
))
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# M6: remove from mirror if successful (Only for real sessions)
if self.mirror and res.get("success") and session_id != "__fs_explorer__":
import shutil
dest = os.path.join(self.mirror.get_workspace_path(session_id), path)
if os.path.isdir(dest): shutil.rmtree(dest)
elif os.path.exists(dest): os.remove(dest)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def dispatch_swarm(self, node_ids, cmd, timeout=30, session_id=None, no_abort=False):
"""Dispatches a command to multiple nodes in parallel and waits for all results."""
from concurrent.futures import ThreadPoolExecutor
results = {}
with ThreadPoolExecutor(max_workers=len(node_ids)) as executor:
future_to_node = {
executor.submit(self.dispatch_single, nid, cmd, timeout, session_id, no_abort): nid
for nid in node_ids
}
for future in future_to_node:
node_id = future_to_node[future]
try:
results[node_id] = future.result()
except Exception as exc:
results[node_id] = {"error": str(exc)}
return results
def dispatch_single(self, node_id, cmd, timeout=30, session_id=None, no_abort=False):
"""Dispatches a shell command to a specific node."""
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
tid = f"task-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
# 12-Factor Signing Logic
sig = sign_payload(cmd)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, payload_json=cmd, signature=sig, session_id=session_id,
timeout_ms=timeout * 1000))
logger.info(f"[๐ค] Dispatching shell {tid} to {node_id}")
self.registry.emit(node_id, "task_assigned", {"command": cmd, "session_id": session_id}, task_id=tid)
node.queue.put(req)
self.registry.emit(node_id, "task_start", {"command": cmd}, task_id=tid)
# Immediate peek if timeout is 0
if timeout == 0:
return {"status": "RUNNING", "stdout": "", "task_id": tid}
if event.wait(timeout):
res = self.journal.get_result(tid)
# pop only if fully done
if res.get("status") != "RUNNING":
self.journal.pop(tid)
return res
# M6: Timeout recovery.
if no_abort:
logger.info(f"[โณ] Shell task {tid} TIMEOUT (no_abort=True). Leaving alive on {node_id}.")
res = self.journal.get_result(tid) or {}
res["task_id"] = tid
res["status"] = "TIMEOUT_PENDING"
return res
logger.warning(f"[โ ๏ธ] Shell task {tid} TIMEOUT after {timeout}s on {node_id}. Sending ABORT.")
try:
node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)))
except: pass
# Return partial result captured in buffer before popping
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res if res else {"error": "Timeout", "stdout": "", "stderr": "", "status": "TIMEOUT", "task_id": tid}
def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
"""Dispatches a browser action to a directed session node."""
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
tid = f"br-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
# Secure Browser Signing
sig = sign_browser_action(
agent_pb2.BrowserAction.ActionType.Name(action.action),
action.url,
action.session_id
)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, browser_action=action, signature=sig, session_id=session_id))
logger.info(f"[๐๐ค] Dispatching browser {tid} to {node_id}")
self.registry.emit(node_id, "task_assigned", {"browser_action": action.action, "url": action.url}, task_id=tid)
node.queue.put(req)
self.registry.emit(node_id, "task_start", {"browser_action": action.action}, task_id=tid)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def wait_for_swarm(self, task_map, timeout=30, no_abort=False):
"""Waits for multiple tasks (map of node_id -> task_id) in parallel."""
from concurrent.futures import ThreadPoolExecutor
results = {}
with ThreadPoolExecutor(max_workers=max(1, len(task_map))) as executor:
# item = (node_id, task_id)
future_to_node = {
executor.submit(self.wait_for_task, nid, tid, timeout, no_abort): nid
for nid, tid in task_map.items()
}
for fut in future_to_node:
nid = future_to_node[fut]
try: results[nid] = fut.result()
except Exception as e: results[nid] = {"error": str(e)}
return results
def wait_for_task(self, node_id, task_id, timeout=30, no_abort=False):
"""Waits for an existing task in the journal."""
# Check journal first
with self.journal.lock:
data = self.journal.tasks.get(task_id)
if not data:
return {"error": f"Task {task_id} not found in journal (finished or expired)", "status": "NOT_FOUND"}
event = data["event"]
# Immediate peek if timeout is 0 or event is already set
if timeout == 0 or event.is_set():
res = self.journal.get_result(task_id)
if res.get("status") != "RUNNING":
self.journal.pop(task_id)
return res
logger.info(f"[โณ] Re-waiting for task {task_id} on {node_id} for {timeout}s")
if event.wait(timeout):
res = self.journal.get_result(task_id)
if res.get("status") != "RUNNING":
self.journal.pop(task_id)
return res
if no_abort:
res = self.journal.get_result(task_id) or {}
res["task_id"] = task_id
res["status"] = "TIMEOUT_PENDING"
return res
logger.warning(f"[โ ๏ธ] Wait for task {task_id} TIMEOUT again. Sending ABORT.")
node = self.registry.get_node(node_id)
if node:
try: node.queue.put(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=task_id)))
except: pass
res = self.journal.get_result(task_id)
self.journal.pop(task_id)
return res if res else {"error": "Timeout", "status": "TIMEOUT", "task_id": task_id}