import time
import json
import os
import hashlib
import zlib
import logging
import shutil
import threading
from app.core.grpc.utils.crypto import sign_payload, sign_browser_action
from app.protos import agent_pb2
from app.db.session import get_db_session
from app.db.models import Session
logger = logging.getLogger(__name__)
class TaskAssistant:
"""The 'Brain' of the Orchestrator: High-Level AI API for Dispatching Tasks."""
def __init__(self, registry, journal, pool, mirror=None):
self.registry = registry
self.journal = journal
self.pool = pool
self.mirror = mirror
self.memberships = {} # session_id -> list(node_id)
self.membership_lock = threading.Lock()
def push_workspace(self, node_id, session_id):
"""Initial unidirectional push from server ghost mirror to a node."""
if not self.mirror: return
# 1. Ensure Server Mirror exists immediately
manifest = self.mirror.generate_manifest(session_id)
# 2. Track relationship for recovery/reconciliation
with self.membership_lock:
if session_id not in self.memberships:
self.memberships[session_id] = []
if node_id not in self.memberships[session_id]:
self.memberships[session_id].append(node_id)
# 3. If node is online, push actual data
node = self.registry.get_node(node_id)
if not node:
logger.info(f"[๐๐ค] Workspace {session_id} prepared on server for offline node {node_id}")
return
print(f"[๐๐ค] Initiating Workspace Push for Session {session_id} to {node_id}")
# Send Manifest to Node. The node will compare this with its local state
# and send back RECONCILE_REQUIRED for any files it is missing.
# This prevents the "Double Push" race where the server blasts data
# while the node is still trying to decide what it needs.
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
manifest=manifest
)
), priority=1)
# NOTE: Proactive parallel push removed. Manifest-driven reactive sync is cleaner.
def push_file(self, node_id, session_id, rel_path):
"""Pushes a specific file to a node (used for drift recovery)."""
node = self.registry.get_node(node_id)
if not node: return
workspace = self.mirror.get_workspace_path(session_id)
abs_path = os.path.join(workspace, rel_path)
if not os.path.exists(abs_path):
print(f" [๐โ] Requested file {rel_path} not found in mirror")
return
# Line-rate Optimization: 4MB chunks + No Software Throttling
hasher = hashlib.sha256()
file_size = os.path.getsize(abs_path)
try:
with open(abs_path, "rb") as f:
index = 0
while True:
chunk = f.read(4 * 1024 * 1024) # 4MB chunks (optimal for gRPC)
if not chunk: break
hasher.update(chunk)
offset = f.tell() - len(chunk)
is_final = f.tell() >= file_size
# Compress Chunk for transit
compressed_chunk = zlib.compress(chunk)
# Put into priority dispatcher (priority 2 for sync data)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=agent_pb2.FilePayload(
path=rel_path,
chunk=compressed_chunk,
chunk_index=index,
is_final=is_final,
hash=hasher.hexdigest() if is_final else "",
offset=offset,
compressed=True
)
)
), priority=2)
if is_final: break
index += 1
except Exception as e:
logger.error(f"[๐๐ค] Line-rate push error for {rel_path}: {e}")
def clear_workspace(self, node_id, session_id):
"""Sends a SyncControl command to purge the local sync directory on a node, and removes from active mesh."""
print(f" [๐๐งน] Instructing node {node_id} to purge workspace for session {session_id}")
if session_id in self.memberships and node_id in self.memberships[session_id]:
self.memberships[session_id].remove(node_id)
node = self.registry.get_node(node_id)
if not node: return
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE, path=".")
)
), priority=1)
def reconcile_node(self, node_id):
"""Forces a re-sync check for all sessions this node belongs to and purges dead sessions."""
print(f" [๐๐] Triggering Resync Check for {node_id}...")
active_sessions = []
try:
with get_db_session() as db:
sessions = db.query(Session).filter(
Session.is_archived == False,
Session.sync_workspace_id.isnot(None)
).all()
with self.membership_lock:
for s in sessions:
attached = s.attached_node_ids or []
if node_id in attached:
active_sessions.append(s.sync_workspace_id)
if s.sync_workspace_id not in self.memberships:
self.memberships[s.sync_workspace_id] = []
if node_id not in self.memberships[s.sync_workspace_id]:
self.memberships[s.sync_workspace_id].append(node_id)
# Aggressive memory cleanup: Purge orphaned session memberships
current_active_workspace_ids = {s.sync_workspace_id for s in sessions}
with self.membership_lock:
to_purge = [sid for sid in self.memberships.keys() if sid not in current_active_workspace_ids]
for sid in to_purge:
del self.memberships[sid]
except Exception as e:
print(f" [๐โ ๏ธ] Failed to fetch active sessions for node reconciliation: {e}")
# Fallback to in-memory if DB fails
with self.membership_lock:
for sid, nodes in self.memberships.items():
if node_id in nodes:
active_sessions.append(sid)
# Send proactive cleanup payload with the active sessions whitelist
node = self.registry.get_node(node_id)
if node:
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id="global",
control=agent_pb2.SyncControl(
action=agent_pb2.SyncControl.CLEANUP,
request_paths=active_sessions
)
)
), priority=0)
for sid in active_sessions:
# Re-push manifest to trigger node-side drift check
self.push_workspace(node_id, sid)
# Add a small delay to prevent saturating the gRPC stream for multiple sessions
time.sleep(0.5)
def broadcast_file_chunk(self, session_id: str, sender_node_id: str, file_payload):
"""Broadcasts a file chunk received from one node to all other nodes in the mesh."""
with self.membership_lock:
session_members = self.memberships.get(session_id, [])
destinations = [n for n in session_members if n != sender_node_id]
if destinations:
print(f" [๐๐ข] Broadcasting {file_payload.path} from {sender_node_id} to: {', '.join(destinations)}")
def _send_to_node(nid):
node = self.registry.get_node(nid)
if node:
# Forward the exact same FileSyncMessage (Priority 2 for Sync Data)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
file_data=file_payload
)
), priority=2)
# M6: Use registry executor if available for parallel mesh broadcast
if self.registry.executor:
for nid in destinations:
self.registry.executor.submit(_send_to_node, nid)
else:
for nid in destinations:
_send_to_node(nid)
def lock_workspace(self, node_id, session_id):
"""Disables user-side synchronization from a node during AI refactors."""
self.control_sync(node_id, session_id, action="LOCK")
def unlock_workspace(self, node_id, session_id):
"""Re-enables user-side synchronization from a node."""
self.control_sync(node_id, session_id, action="UNLOCK")
def request_manifest(self, node_id, session_id, path="."):
"""Requests a full directory manifest from a node for drift checking."""
node = self.registry.get_node(node_id)
if not node: return
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.REFRESH_MANIFEST, path=path)
)
), priority=1)
def control_sync(self, node_id, session_id, action="START", path="."):
"""Sends a SyncControl command to a node (e.g. START_WATCHING, LOCK)."""
node = self.registry.get_node(node_id)
if not node: return
action_map = {
"START": agent_pb2.SyncControl.START_WATCHING,
"STOP": agent_pb2.SyncControl.STOP_WATCHING,
"LOCK": agent_pb2.SyncControl.LOCK,
"UNLOCK": agent_pb2.SyncControl.UNLOCK,
"RESYNC": agent_pb2.SyncControl.RESYNC
}
proto_action = action_map.get(action, agent_pb2.SyncControl.START_WATCHING)
# Track for recovery & broadcast
if session_id not in self.memberships:
self.memberships[session_id] = []
if node_id not in self.memberships[session_id]:
self.memberships[session_id].append(node_id)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
control=agent_pb2.SyncControl(action=proto_action, path=path)
)
), priority=1)
# ==================================================================
# Modular FS Explorer / Mesh Navigation
# ==================================================================
def ls(self, node_id: str, path: str = ".", timeout=10, session_id="__fs_explorer__", force_remote: bool = False):
"""Requests a directory listing from a node (waits for response)."""
# Phase 1: Local Mirror Fast-Path
if session_id != "__fs_explorer__" and self.mirror and not force_remote:
workspace = self.mirror.get_workspace_path(session_id)
abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
if os.path.exists(abs_path) and os.path.isdir(abs_path):
files = []
try:
for entry in os.scandir(abs_path):
rel = os.path.relpath(entry.path, workspace)
files.append({
"path": rel,
"name": entry.name,
"is_dir": entry.is_dir(),
"size": entry.stat().st_size if entry.is_file() else 0,
"is_synced": True
})
return {"files": files, "path": path}
except Exception as e:
logger.error(f"[๐๐] Local ls error for {session_id}/{path}: {e}")
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
tid = f"fs-ls-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.LIST, path=path)
)
), priority=1)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# Proactive Mirroring: start fetching content so dots turn green
if res and "files" in res and session_id != "__fs_explorer__":
self._proactive_explorer_sync(node_id, res["files"], session_id)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def _proactive_explorer_sync(self, node_id, files, session_id):
"""Starts background tasks to mirror files to Hub so dots turn green."""
for f in files:
if f.get("is_dir"): continue
if not f.get("is_synced") and f.get("size", 0) < 1024 * 512: # Skip large files
# M6: Use shared registry executor instead of spawning loose threads
if self.registry.executor:
self.registry.executor.submit(self.cat, node_id, f["path"], 15, session_id)
def cat(self, node_id: str, path: str, timeout=15, session_id="__fs_explorer__", force_remote: bool = False):
"""Requests file content from a node (waits for result)."""
# Phase 1: Local Mirror Fast-Path
if session_id != "__fs_explorer__" and self.mirror and not force_remote:
workspace = self.mirror.get_workspace_path(session_id)
abs_path = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
if os.path.exists(abs_path) and os.path.isfile(abs_path):
try:
# Try reading as text
with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
return {"content": content, "path": path}
except Exception as e:
logger.error(f"[๐๐] Local cat error for {session_id}/{path}: {e}")
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
# For 'cat', we might get multiple chunks, but TaskJournal fulfill
# usually happens on the final chunk. We'll handle chunking in server.
tid = f"fs-cat-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.READ, path=path)
)
), priority=1)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
# res usually contains {content, path}. grpc_server already writes it to mirror.
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def write(self, node_id: str, path: str, content: bytes = b"", is_dir: bool = False, timeout=10, session_id="__fs_explorer__"):
"""Creates or updates a file/directory on a node (waits for status)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
# Phase 1: Sync local mirror ON HUB instantly (Zero Latency)
if self.mirror and session_id != "__fs_explorer__":
workspace_mirror = self.mirror.get_workspace_path(session_id)
dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
try:
if is_dir:
os.makedirs(dest, exist_ok=True)
else:
os.makedirs(os.path.dirname(dest), exist_ok=True)
with open(dest, "wb") as f:
f.write(content)
# Multi-node broadcast for sessions
targets = []
if session_id != "__fs_explorer__":
targets = self.memberships.get(session_id, [node_id])
else:
targets = [node_id]
print(f"[๐โ๏ธ] AI Write: {path} (Session: {session_id}) -> Dispatching to {len(targets)} nodes")
for target_nid in targets:
target_node = self.registry.get_node(target_nid)
if not target_node: continue
tid = f"fs-write-{int(time.time()*1000)}"
target_node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(
action=agent_pb2.SyncControl.WRITE,
path=path,
content=content,
is_dir=is_dir
)
)
), priority=2)
return {"success": True, "message": f"Synchronized to local mirror and dispatched to {len(targets)} nodes"}
except Exception as e:
logger.error(f"[๐โ๏ธ] Local mirror write error: {e}")
return {"error": str(e)}
# Legacy/Explorer path: await node confirmation
tid = f"fs-write-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(
action=agent_pb2.SyncControl.WRITE,
path=path,
content=content,
is_dir=is_dir
)
)
), priority=2)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def inspect_drift(self, node_id: str, path: str, session_id: str):
"""Returns a unified diff between Hub local mirror and Node's actual file."""
if not self.mirror: return {"error": "Mirror not available"}
# 1. Get Local Content
workspace = self.mirror.get_workspace_path(session_id)
local_abs = os.path.normpath(os.path.join(workspace, path.lstrip("/")))
local_content = ""
if os.path.exists(local_abs) and os.path.isfile(local_abs):
try:
with open(local_abs, 'r', encoding='utf-8', errors='ignore') as f:
local_content = f.read()
except: pass
# 2. Get Remote Content (Force Bypass Fast-Path)
print(f" [๐๐] Inspecting Drift: Fetching remote content for {path} on {node_id}")
remote_res = self.cat(node_id, path, session_id=session_id, force_remote=True)
if "error" in remote_res:
return {"error": f"Failed to fetch remote content: {remote_res['error']}"}
remote_content = remote_res.get("content", "")
# 3. Create Diff
import difflib
diff = difflib.unified_diff(
local_content.splitlines(keepends=True),
remote_content.splitlines(keepends=True),
fromfile=f"hub://{session_id}/{path}",
tofile=f"node://{node_id}/{path}"
)
diff_text = "".join(diff)
return {
"path": path,
"has_drift": local_content != remote_content,
"diff": diff_text,
"local_size": len(local_content),
"remote_size": len(remote_content)
}
def rm(self, node_id: str, path: str, timeout=10, session_id="__fs_explorer__"):
"""Deletes a file or directory on a node (waits for status)."""
node = self.registry.get_node(node_id)
if not node: return {"error": "Offline"}
# Phase 1: Sync local mirror ON HUB instantly
if self.mirror and session_id != "__fs_explorer__":
workspace_mirror = self.mirror.get_workspace_path(session_id)
dest = os.path.normpath(os.path.join(workspace_mirror, path.lstrip("/")))
try:
if os.path.isdir(dest):
shutil.rmtree(dest)
elif os.path.exists(dest):
os.remove(dest)
# Multi-node broadcast for sessions
targets = []
if session_id != "__fs_explorer__":
targets = self.memberships.get(session_id, [node_id])
else:
targets = [node_id]
print(f"[๐๐๏ธ] AI Remove: {path} (Session: {session_id}) -> Dispatching to {len(targets)} nodes")
for target_nid in targets:
target_node = self.registry.get_node(target_nid)
if not target_node: continue
tid = f"fs-rm-{int(time.time()*1000)}"
target_node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
)
), priority=2)
return {"success": True, "message": f"Removed from local mirror and dispatched delete to {len(targets)} nodes"}
except Exception as e:
logger.error(f"[๐๐๏ธ] Local mirror rm error: {e}")
return {"error": str(e)}
# Legacy/Explorer path: await node confirmation
tid = f"fs-rm-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
node.send_message(agent_pb2.ServerTaskMessage(
file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=tid,
control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=path)
)
), priority=2)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def dispatch_swarm(self, node_ids, cmd, timeout=120, session_id=None, no_abort=False):
"""Dispatches a command to multiple nodes in parallel and waits for all results."""
from concurrent.futures import ThreadPoolExecutor, as_completed
results = {}
with ThreadPoolExecutor(max_workers=max(1, len(node_ids))) as executor:
future_to_node = {
executor.submit(self.dispatch_single, nid, cmd, timeout, session_id, no_abort): nid
for nid in node_ids
}
# Use as_completed to avoid blocking on a slow node when others are finished
for future in as_completed(future_to_node):
node_id = future_to_node[future]
try:
results[node_id] = future.result()
except Exception as exc:
results[node_id] = {"error": str(exc)}
return results
def dispatch_single(self, node_id, cmd, timeout=120, session_id=None, no_abort=False):
"""Dispatches a shell command to a specific node."""
import uuid
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
# Use UUID to prevent timestamp collisions in high-speed swarm dispatch
tid = f"task-{uuid.uuid4().hex[:12]}"
event = self.journal.register(tid, node_id)
# 12-Factor Signing Logic
sig = sign_payload(cmd)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, task_type="shell", payload_json=cmd, signature=sig, session_id=session_id,
timeout_ms=timeout * 1000))
logger.info(f"[๐ค] Dispatching shell {tid} to {node_id}")
self.registry.emit(node_id, "task_assigned", {"command": cmd, "session_id": session_id}, task_id=tid)
node.send_message(req, priority=1)
self.registry.emit(node_id, "task_start", {"command": cmd}, task_id=tid)
# Immediate peek if timeout is 0
if timeout == 0:
return {"status": "RUNNING", "stdout": "", "task_id": tid}
if event.wait(timeout):
res = self.journal.get_result(tid)
# pop only if fully done
if res.get("status") != "RUNNING":
self.journal.pop(tid)
return res
# M6: Timeout recovery.
if no_abort:
logger.info(f"[โณ] Shell task {tid} TIMEOUT (no_abort=True). Leaving alive on {node_id}.")
res = self.journal.get_result(tid) or {}
res["task_id"] = tid
res["status"] = "TIMEOUT_PENDING"
return res
logger.warning(f"[โ ๏ธ] Shell task {tid} TIMEOUT after {timeout}s on {node_id}. Sending ABORT.")
try:
node.send_message(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)), priority=0)
except: pass
# Return partial result captured in buffer before popping
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res if res else {"error": "Timeout", "stdout": "", "stderr": "", "status": "TIMEOUT", "task_id": tid}
def dispatch_browser(self, node_id, action, timeout=60, session_id=None):
"""Dispatches a browser action to a directed session node."""
node = self.registry.get_node(node_id)
if not node: return {"error": f"Node {node_id} Offline"}
tid = f"br-{int(time.time()*1000)}"
event = self.journal.register(tid, node_id)
# Secure Browser Signing
sig = sign_browser_action(
agent_pb2.BrowserAction.ActionType.Name(action.action),
action.url,
action.session_id
)
req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
task_id=tid, browser_action=action, signature=sig, session_id=session_id))
logger.info(f"[๐๐ค] Dispatching browser {tid} to {node_id}")
self.registry.emit(node_id, "task_assigned", {"browser_action": action.action, "url": action.url}, task_id=tid)
node.send_message(req, priority=1)
self.registry.emit(node_id, "task_start", {"browser_action": action.action}, task_id=tid)
if event.wait(timeout):
res = self.journal.get_result(tid)
self.journal.pop(tid)
return res
self.journal.pop(tid)
return {"error": "Timeout"}
def wait_for_swarm(self, task_map, timeout=30, no_abort=False):
"""Waits for multiple tasks (map of node_id -> task_id) in parallel."""
from concurrent.futures import ThreadPoolExecutor
results = {}
with ThreadPoolExecutor(max_workers=max(1, len(task_map))) as executor:
# item = (node_id, task_id)
future_to_node = {
executor.submit(self.wait_for_task, nid, tid, timeout, no_abort): nid
for nid, tid in task_map.items()
}
for fut in future_to_node:
nid = future_to_node[fut]
try: results[nid] = fut.result()
except Exception as e: results[nid] = {"error": str(e)}
return results
def wait_for_task(self, node_id, task_id, timeout=30, no_abort=False):
"""Waits for an existing task in the journal."""
# Check journal first
with self.journal.lock:
data = self.journal.tasks.get(task_id)
if not data:
return {"error": f"Task {task_id} not found in journal (finished or expired)", "status": "NOT_FOUND"}
event = data["event"]
# Immediate peek if timeout is 0 or event is already set
if timeout == 0 or event.is_set():
res = self.journal.get_result(task_id)
if res.get("status") != "RUNNING":
self.journal.pop(task_id)
return res
logger.info(f"[โณ] Re-waiting for task {task_id} on {node_id} for {timeout}s")
if event.wait(timeout):
res = self.journal.get_result(task_id)
if res.get("status") != "RUNNING":
self.journal.pop(task_id)
return res
if no_abort:
res = self.journal.get_result(task_id) or {}
res["task_id"] = task_id
res["status"] = "TIMEOUT_PENDING"
return res
logger.warning(f"[โ ๏ธ] Wait for task {task_id} TIMEOUT again. Sending ABORT.")
node = self.registry.get_node(node_id)
if node:
try: node.send_message(agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=task_id)), priority=0)
except: pass
res = self.journal.get_result(task_id)
self.journal.pop(task_id)
return res if res else {"error": "Timeout", "status": "TIMEOUT", "task_id": task_id}