Newer
Older
cortex-hub / mesh-sdk / mesh_core / transport / grpc.py
import threading
import queue
import time
import logging
import urllib.request
import urllib.error
import json
import os
import itertools
from typing import Any, Optional, Callable, Union
from ..models import agent_pb2, agent_pb2_grpc
from .base import IMeshTransport, IMeshListener

logger = logging.getLogger(__name__)

class GrpcMeshTransport(IMeshTransport):
    """
    gRPC implementation of the Mesh Transport.
    Encapsulates the bidirectional stream and auto-reconnection logic.
    """
    def __init__(self, node_id: str, stub_factory: Callable[[], tuple], auth_token: str = "",
                 hub_http_url: str = "", secret_key: str = ""):
        self.node_id = node_id
        self.stub_factory = stub_factory # Callable returning (stub, channel)
        self.auth_token = auth_token
        self.hub_http_url = hub_http_url.rstrip("/")
        self.secret_key = secret_key
        self.listener = None
        self.stub = None
        self.channel = None
        self.send_queue = queue.PriorityQueue()
        self.health_queue = queue.Queue()
        self._stop_event = threading.Event()
        self._connected = False
        self._health_thread_started = False
        self.last_activity = 0
        self._send_counter = itertools.count()  # thread-safe atomic counter; avoids protobuf comparison in heapq

    def handshake(self) -> bool:
        self._refresh_stub()
        try:
            req = agent_pb2.RegistrationRequest(
                node_id=self.node_id,
                auth_token=self.auth_token,
                node_description="Portable Mesh Node"
            )
            res = self.stub.SyncConfiguration(req)
            if res.success:
                logger.info(f"[Mesh] Handshake successful for {self.node_id}")
                # Optional: Handle policy res.policy
                return True
            else:
                logger.error(f"[Mesh] Handshake REJECTED for {self.node_id}: {res.error_message}")
                if self.hub_http_url and self.secret_key:
                    logger.info(f"[Mesh] Attempting token self-recovery for {self.node_id}...")
                    if self._try_token_recovery():
                        return True
                return False
        except Exception as e:
            logger.error(f"[Mesh] Handshake FAILED for {self.node_id}: {e}")
            return False

    def _try_token_recovery(self) -> bool:
        """Fetches a fresh invite_token from the hub using the stable secret_key."""
        try:
            url = f"{self.hub_http_url}/api/v1/agent/token-sync?node_id={self.node_id}"
            req = urllib.request.Request(url, headers={"X-Agent-Token": self.secret_key})
            with urllib.request.urlopen(req, timeout=10) as resp:
                data = json.loads(resp.read().decode())
                new_token = data.get("invite_token", "")
            if not new_token:
                logger.error("[Mesh] Token recovery: hub returned empty token")
                return False
            logger.info(f"[Mesh] Token recovery successful — updating auth_token")
            self.auth_token = new_token
            self._persist_token(new_token)
            # Retry handshake with fresh token
            req2 = agent_pb2.RegistrationRequest(
                node_id=self.node_id,
                auth_token=self.auth_token,
                node_description="Portable Mesh Node"
            )
            res2 = self.stub.SyncConfiguration(req2)
            if res2.success:
                logger.info(f"[Mesh] Handshake successful after token recovery for {self.node_id}")
                return True
            logger.error(f"[Mesh] Handshake still rejected after token recovery: {res2.error_message}")
            return False
        except Exception as e:
            logger.error(f"[Mesh] Token recovery failed: {e}")
            return False

    def _persist_token(self, new_token: str):
        """Writes the recovered token to all known config file locations."""
        candidates = []
        # Authoritative path: next to src/ (two levels above this install)
        try:
            src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
            candidates.append(os.path.join(src_dir, "agent_config.yaml"))
        except Exception:
            pass
        # Ghost/user path
        candidates.append(os.path.expanduser("~/.cortex/agent-node/agent_config.yaml"))
        candidates.append(os.path.expanduser("~/.cortex/agent.yaml"))

        import re
        token_re = re.compile(r'^(\s*(?:auth_token|invite_token)\s*:\s*).*$', re.MULTILINE)
        for path in candidates:
            if not os.path.exists(path):
                continue
            try:
                with open(path, "r", encoding="utf-8") as f:
                    content = f.read()
                updated = token_re.sub(lambda m: m.group(1) + new_token, content)
                with open(path, "w", encoding="utf-8") as f:
                    f.write(updated)
                logger.info(f"[Mesh] Persisted new token to {path}")
            except Exception as e:
                logger.warning(f"[Mesh] Could not persist token to {path}: {e}")

    def connect(self):
        self._stop_event.clear()
        self._refresh_stub()
        threading.Thread(target=self._run_stream, daemon=True, name="GrpcTransportStream").start()

    def set_listener(self, listener: IMeshListener):
        self.listener = listener

    def _refresh_stub(self):
        if self.channel:
            try: self.channel.close()
            except: pass
        self.stub, self.channel = self.stub_factory()

    def _run_stream(self):
        retry_count = 0
        while not self._stop_event.is_set():
            try:
                def _gen():
                    # Initial announcement
                    yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
                    
                    last_heartbeat = time.time()
                    while not self._stop_event.is_set():
                        try:
                            # PriorityQueue returns (priority, seq, msg)
                            item = self.send_queue.get(timeout=1.0)
                            yield item[2]
                        except queue.Empty:
                            pass
                        
                        # Transport-level KeepAlive
                        if time.time() - last_heartbeat >= 10.0:
                            yield agent_pb2.ClientTaskMessage(
                                skill_event=agent_pb2.SkillEvent(keep_alive=True)
                            )
                            last_heartbeat = time.time()

                logger.info(f"[*] Opening gRPC TaskStream for {self.node_id}...")
                responses = self.stub.TaskStream(_gen())
                self._connected = True
                self.last_activity = time.time()
                logger.info(f"[✅] gRPC Mesh Transport Online for {self.node_id}")
                retry_count = 0
                
                for msg in responses:
                    self.last_activity = time.time()
                    if self.listener:
                        self.listener.on_message(msg)
                
                self._connected = False
                logger.warning(f"[📶] gRPC Stream closed by server for {self.node_id}")
            except Exception as e:
                self._connected = False
                if self.listener:
                    self.listener.on_error(e)
                
                if not self._stop_event.is_set():
                    retry_count += 1
                    backoff = min(30, 2 * retry_count)
                    logger.error(f"[❌] gRPC Stream Error ({type(e).__name__}): {e}. Reconnecting in {backoff}s... (Attempt {retry_count})")
                    self._refresh_stub()
                    time.sleep(backoff)

    def send(self, message: Any, priority: int = 1):
        self.send_queue.put((priority, next(self._send_counter), message))

    def send_health(self, heartbeat: Any):
        """Sends a heartbeat via the dedicated health stream."""
        if not self._health_thread_started:
            self._start_health_stream()
        self.health_queue.put(heartbeat)

    def _start_health_stream(self):
        if self._health_thread_started: return
        self._health_thread_started = True
        threading.Thread(target=self._run_health_stream, daemon=True, name="GrpcHealthStream").start()

    def _run_health_stream(self):
        while not self._stop_event.is_set():
            try:
                if not self.stub:
                    time.sleep(1)
                    continue
                
                def _gen():
                    while not self._stop_event.is_set():
                        try:
                            hb = self.health_queue.get(timeout=1.0)
                            yield hb
                        except queue.Empty:
                            pass
                
                logger.info("[*] Opening gRPC HealthStream...")
                responses = self.stub.ReportHealth(_gen())
                for res in responses:
                    # Optional: Handle HealthCheckResponse (server_time_ms)
                    pass
            except Exception as e:
                logger.error(f"[Mesh] Health Stream Error: {e}")
                time.sleep(5)

    def close(self):
        self._stop_event.set()
        if self.channel:
            self.channel.close()
        self._connected = False
        if self.listener:
            self.listener.on_close()

    def is_connected(self) -> bool:
        return self._connected

class GrpcServerTransport(IMeshTransport):
    """
    gRPC implementation of IMeshTransport for the server side.
    Wraps a single bi-directional stream context.
    """
    def __init__(self, context: Any, signer: Optional[Callable[[bytes], str]] = None):
        self.context = context
        self.send_queue = queue.PriorityQueue()
        self.listener = None
        self.signer = signer
        self._send_counter = itertools.count()

    def handshake(self) -> bool:
        return True # Handshake handled by Servicer

    def connect(self):
        pass

    def set_listener(self, listener: IMeshListener):
        self.listener = listener

    def send(self, message: Any, priority: int = 1):
        if self.signer:
            if hasattr(message, 'signature'):
                message.signature = ""
                msg_bytes = message.SerializeToString(deterministic=True)
                message.signature = self.signer(msg_bytes)
        
        self.send_queue.put((priority, next(self._send_counter), message))

    def close(self):
        try:
            import grpc
            self.context.abort(grpc.StatusCode.CANCELLED, "Transport closed")
        except:
            pass

    def is_connected(self) -> bool:
        return self.context.is_active()