Newer
Older
cortex-hub / mesh-sdk / mesh_core / transport / grpc.py
@yangyang xie yangyang xie 17 days ago 7 KB refactor done
import threading
import queue
import time
import logging
from typing import Any, Optional, Callable, Union
from ..models import agent_pb2, agent_pb2_grpc
from .base import IMeshTransport, IMeshListener

logger = logging.getLogger(__name__)

class GrpcMeshTransport(IMeshTransport):
    """
    gRPC implementation of the Mesh Transport.
    Encapsulates the bidirectional stream and auto-reconnection logic.
    """
    def __init__(self, node_id: str, stub_factory: Callable[[], tuple], auth_token: str = ""):
        self.node_id = node_id
        self.stub_factory = stub_factory # Callable returning (stub, channel)
        self.auth_token = auth_token
        self.listener = None
        self.stub = None
        self.channel = None
        self.send_queue = queue.PriorityQueue()
        self.health_queue = queue.Queue()
        self._stop_event = threading.Event()
        self._connected = False
        self._health_thread_started = False
        self.last_activity = 0
    def handshake(self) -> bool:
        self._refresh_stub()
        try:
            req = agent_pb2.RegistrationRequest(
                node_id=self.node_id,
                auth_token=self.auth_token,
                node_description="Portable Mesh Node"
            )
            res = self.stub.SyncConfiguration(req)
            if res.success:
                logger.info(f"[Mesh] Handshake successful for {self.node_id}")
                # Optional: Handle policy res.policy
                return True
            else:
                logger.error(f"[Mesh] Handshake REJECTED for {self.node_id}: {res.error_message}")
                return False
        except Exception as e:
            logger.error(f"[Mesh] Handshake FAILED for {self.node_id}: {e}")
            return False

    def connect(self):
        self._stop_event.clear()
        self._refresh_stub()
        threading.Thread(target=self._run_stream, daemon=True, name="GrpcTransportStream").start()

    def set_listener(self, listener: IMeshListener):
        self.listener = listener

    def _refresh_stub(self):
        if self.channel:
            try: self.channel.close()
            except: pass
        self.stub, self.channel = self.stub_factory()

    def _run_stream(self):
        retry_count = 0
        while not self._stop_event.is_set():
            try:
                def _gen():
                    # Initial announcement
                    yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
                    
                    last_heartbeat = time.time()
                    while not self._stop_event.is_set():
                        try:
                            # PriorityQueue returns (priority, ts, msg)
                            item = self.send_queue.get(timeout=1.0)
                            yield item[2]
                        except queue.Empty:
                            pass
                        
                        # Transport-level KeepAlive
                        if time.time() - last_heartbeat >= 10.0:
                            yield agent_pb2.ClientTaskMessage(
                                skill_event=agent_pb2.SkillEvent(keep_alive=True)
                            )
                            last_heartbeat = time.time()

                logger.info(f"[*] Opening gRPC TaskStream for {self.node_id}...")
                responses = self.stub.TaskStream(_gen())
                self._connected = True
                self.last_activity = time.time()
                logger.info(f"[✅] gRPC Mesh Transport Online for {self.node_id}")
                retry_count = 0
                
                for msg in responses:
                    self.last_activity = time.time()
                    if self.listener:
                        self.listener.on_message(msg)
                
                self._connected = False
                logger.warning(f"[📶] gRPC Stream closed by server for {self.node_id}")
            except Exception as e:
                self._connected = False
                if self.listener:
                    self.listener.on_error(e)
                
                if not self._stop_event.is_set():
                    retry_count += 1
                    backoff = min(30, 2 * retry_count)
                    logger.error(f"[❌] gRPC Stream Error ({type(e).__name__}): {e}. Reconnecting in {backoff}s... (Attempt {retry_count})")
                    self._refresh_stub()
                    time.sleep(backoff)

    def send(self, message: Any, priority: int = 1):
        # PriorityQueue expects (priority, timestamp, item)
        # Note: Client messages are not currently signed by the SDK, 
        # but this is where we would add it if needed.
        self.send_queue.put((priority, time.time(), message))

    def send_health(self, heartbeat: Any):
        """Sends a heartbeat via the dedicated health stream."""
        if not self._health_thread_started:
            self._start_health_stream()
        self.health_queue.put(heartbeat)

    def _start_health_stream(self):
        if self._health_thread_started: return
        self._health_thread_started = True
        threading.Thread(target=self._run_health_stream, daemon=True, name="GrpcHealthStream").start()

    def _run_health_stream(self):
        while not self._stop_event.is_set():
            try:
                if not self.stub:
                    time.sleep(1)
                    continue
                
                def _gen():
                    while not self._stop_event.is_set():
                        try:
                            hb = self.health_queue.get(timeout=1.0)
                            yield hb
                        except queue.Empty:
                            pass
                
                logger.info("[*] Opening gRPC HealthStream...")
                responses = self.stub.ReportHealth(_gen())
                for res in responses:
                    # Optional: Handle HealthCheckResponse (server_time_ms)
                    pass
            except Exception as e:
                logger.error(f"[Mesh] Health Stream Error: {e}")
                time.sleep(5)

    def close(self):
        self._stop_event.set()
        if self.channel:
            self.channel.close()
        self._connected = False
        if self.listener:
            self.listener.on_close()

    def is_connected(self) -> bool:
        return self._connected

class GrpcServerTransport(IMeshTransport):
    """
    gRPC implementation of IMeshTransport for the server side.
    Wraps a single bi-directional stream context.
    """
    def __init__(self, context: Any, signer: Optional[Callable[[bytes], str]] = None):
        self.context = context
        self.send_queue = queue.PriorityQueue()
        self.listener = None
        self.signer = signer

    def handshake(self) -> bool:
        return True # Handshake handled by Servicer

    def connect(self):
        pass

    def set_listener(self, listener: IMeshListener):
        self.listener = listener

    def send(self, message: Any, priority: int = 1):
        if self.signer:
            if hasattr(message, 'signature'):
                message.signature = ""
                msg_bytes = message.SerializeToString(deterministic=True)
                message.signature = self.signer(msg_bytes)
        
        # PriorityQueue expects (priority, timestamp, item) to ensure stable ordering
        self.send_queue.put((priority, time.time(), message))

    def close(self):
        try:
            import grpc
            self.context.abort(grpc.StatusCode.CANCELLED, "Transport closed")
        except:
            pass

    def is_connected(self) -> bool:
        return self.context.is_active()