import threading
import queue
import time
import logging
import urllib.request
import urllib.error
import json
import os
import itertools
from typing import Any, Optional, Callable, Union
from ..models import agent_pb2, agent_pb2_grpc
from .base import IMeshTransport, IMeshListener
logger = logging.getLogger(__name__)
class GrpcMeshTransport(IMeshTransport):
"""
gRPC implementation of the Mesh Transport.
Encapsulates the bidirectional stream and auto-reconnection logic.
"""
def __init__(self, node_id: str, stub_factory: Callable[[], tuple], auth_token: str = "",
hub_http_url: str = "", secret_key: str = ""):
self.node_id = node_id
self.stub_factory = stub_factory # Callable returning (stub, channel)
self.auth_token = auth_token
self.hub_http_url = hub_http_url.rstrip("/")
self.secret_key = secret_key
self.listener = None
self.stub = None
self.channel = None
self.send_queue = queue.PriorityQueue()
self.health_queue = queue.Queue()
self._stop_event = threading.Event()
self._connected = False
self._health_thread_started = False
self._health_thread = None
self.last_activity = 0
self._send_counter = itertools.count() # thread-safe atomic counter; avoids protobuf comparison in heapq
def handshake(self) -> Any:
self._refresh_stub()
try:
req = agent_pb2.RegistrationRequest(
node_id=self.node_id,
auth_token=self.auth_token,
node_description="Portable Mesh Node"
)
res = self.stub.SyncConfiguration(req)
if res.success:
logger.info(f"[Mesh] Handshake successful for {self.node_id}")
return res.policy
else:
logger.error(f"[Mesh] Handshake REJECTED for {self.node_id}: {res.error_message}")
if self.hub_http_url and self.secret_key:
logger.info(f"[Mesh] Attempting token self-recovery for {self.node_id}...")
if self._try_token_recovery():
# We need to retry handshake to get the policy
return self.handshake()
return False
except Exception as e:
logger.error(f"[Mesh] Handshake FAILED for {self.node_id}: {e}")
return False
def _try_token_recovery(self) -> bool:
"""Fetches a fresh invite_token from the hub using the stable secret_key."""
try:
url = f"{self.hub_http_url}/api/v1/agent/token-sync?node_id={self.node_id}"
req = urllib.request.Request(url, headers={"X-Agent-Token": self.secret_key})
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read().decode())
new_token = data.get("invite_token", "")
if not new_token:
logger.error("[Mesh] Token recovery: hub returned empty token")
return False
logger.info(f"[Mesh] Token recovery successful — updating auth_token")
self.auth_token = new_token
self._persist_token(new_token)
# Retry handshake with fresh token
req2 = agent_pb2.RegistrationRequest(
node_id=self.node_id,
auth_token=self.auth_token,
node_description="Portable Mesh Node"
)
res2 = self.stub.SyncConfiguration(req2)
if res2.success:
logger.info(f"[Mesh] Handshake successful after token recovery for {self.node_id}")
return True
logger.error(f"[Mesh] Handshake still rejected after token recovery: {res2.error_message}")
return False
except Exception as e:
logger.error(f"[Mesh] Token recovery failed: {e}")
return False
def _persist_token(self, new_token: str):
"""Writes the recovered token to all known config file locations."""
candidates = []
# Authoritative path: next to src/ (two levels above this install)
try:
src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
candidates.append(os.path.join(src_dir, "agent_config.yaml"))
except Exception:
pass
# Ghost/user path
candidates.append(os.path.expanduser("~/.cortex/agent-node/agent_config.yaml"))
candidates.append(os.path.expanduser("~/.cortex/agent.yaml"))
import re
token_re = re.compile(r'^(\s*(?:auth_token|invite_token)\s*:\s*).*$', re.MULTILINE)
for path in candidates:
if not os.path.exists(path):
continue
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
updated = token_re.sub(lambda m: m.group(1) + new_token, content)
with open(path, "w", encoding="utf-8") as f:
f.write(updated)
logger.info(f"[Mesh] Persisted new token to {path}")
except Exception as e:
logger.warning(f"[Mesh] Could not persist token to {path}: {e}")
def connect(self):
self._stop_event.clear()
self._refresh_stub()
threading.Thread(target=self._run_stream, daemon=True, name="GrpcTransportStream").start()
def set_listener(self, listener: IMeshListener):
self.listener = listener
def _refresh_stub(self):
if self.channel:
try: self.channel.close()
except: pass
self.stub, self.channel = self.stub_factory()
def _run_stream(self):
retry_count = 0
while not self._stop_event.is_set():
try:
def _gen():
# Initial announcement
yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
last_heartbeat = time.time()
while not self._stop_event.is_set():
try:
# PriorityQueue returns (priority, seq, msg)
item = self.send_queue.get(timeout=1.0)
yield item[2]
except queue.Empty:
pass
# Transport-level KeepAlive
if time.time() - last_heartbeat >= 10.0:
yield agent_pb2.ClientTaskMessage(
skill_event=agent_pb2.SkillEvent(keep_alive=True)
)
last_heartbeat = time.time()
logger.info(f"[*] Opening gRPC TaskStream for {self.node_id}...")
responses = self.stub.TaskStream(_gen())
self._connected = True
self.last_activity = time.time()
logger.info(f"[✅] gRPC Mesh Transport Online for {self.node_id}")
retry_count = 0
for msg in responses:
self.last_activity = time.time()
if self.listener:
self.listener.on_message(msg)
self._connected = False
logger.warning(f"[📶] gRPC Stream closed by server for {self.node_id}")
except Exception as e:
self._connected = False
if self.listener:
self.listener.on_error(e)
if not self._stop_event.is_set():
retry_count += 1
backoff = min(30, 2 * retry_count)
logger.error(f"[❌] gRPC Stream Error ({type(e).__name__}): {e}. Reconnecting in {backoff}s... (Attempt {retry_count})")
self._refresh_stub()
time.sleep(backoff)
def send(self, message: Any, priority: int = 1):
self.send_queue.put((priority, next(self._send_counter), message))
def send_health(self, heartbeat: Any):
"""Sends a heartbeat via the dedicated health stream."""
self._start_health_stream()
self.health_queue.put(heartbeat)
def _start_health_stream(self):
if self._health_thread is not None and self._health_thread.is_alive():
return
self._health_thread_started = True
self._health_thread = threading.Thread(target=self._run_health_stream, daemon=True, name="GrpcHealthStream")
self._health_thread.start()
def _run_health_stream(self):
while not self._stop_event.is_set():
try:
if not self.stub:
time.sleep(1)
continue
def _gen():
while not self._stop_event.is_set():
try:
hb = self.health_queue.get(timeout=1.0)
yield hb
except queue.Empty:
pass
logger.info("[*] Opening gRPC HealthStream...")
responses = self.stub.ReportHealth(_gen())
for res in responses:
# Optional: Handle HealthCheckResponse (server_time_ms)
pass
except Exception as e:
logger.error(f"[Mesh] Health Stream Error: {e}")
time.sleep(5)
self._health_thread_started = False
def close(self):
self._stop_event.set()
self._health_thread_started = False
if self.channel:
self.channel.close()
self._connected = False
if self.listener:
self.listener.on_close()
def is_connected(self) -> bool:
return self._connected
class GrpcServerTransport(IMeshTransport):
"""
gRPC implementation of IMeshTransport for the server side.
Wraps a single bi-directional stream context.
"""
def __init__(self, context: Any, signer: Optional[Callable[[bytes], str]] = None):
self.context = context
self.send_queue = queue.PriorityQueue()
self.listener = None
self.signer = signer
self._send_counter = itertools.count()
def handshake(self) -> bool:
return True # Handshake handled by Servicer
def connect(self):
pass
def set_listener(self, listener: IMeshListener):
self.listener = listener
def send(self, message: Any, priority: int = 1):
if self.signer:
if hasattr(message, 'signature'):
message.signature = ""
msg_bytes = message.SerializeToString(deterministic=True)
message.signature = self.signer(msg_bytes)
self.send_queue.put((priority, next(self._send_counter), message))
def close(self):
try:
import grpc
self.context.abort(grpc.StatusCode.CANCELLED, "Transport closed")
except:
pass
def is_connected(self) -> bool:
return self.context.is_active()