import threading
import queue
import time
import logging
from typing import Any, Optional, Callable, Union
from ..models import agent_pb2, agent_pb2_grpc
from .base import IMeshTransport, IMeshListener
logger = logging.getLogger(__name__)
class GrpcMeshTransport(IMeshTransport):
"""
gRPC implementation of the Mesh Transport.
Encapsulates the bidirectional stream and auto-reconnection logic.
"""
def __init__(self, node_id: str, stub_factory: Callable[[], tuple], auth_token: str = ""):
self.node_id = node_id
self.stub_factory = stub_factory # Callable returning (stub, channel)
self.auth_token = auth_token
self.listener = None
self.stub = None
self.channel = None
self.send_queue = queue.PriorityQueue()
self.health_queue = queue.Queue()
self._stop_event = threading.Event()
self._connected = False
self._health_thread_started = False
self.last_activity = 0
def handshake(self) -> bool:
self._refresh_stub()
try:
req = agent_pb2.RegistrationRequest(
node_id=self.node_id,
auth_token=self.auth_token,
node_description="Portable Mesh Node"
)
res = self.stub.SyncConfiguration(req)
if res.success:
logger.info(f"[Mesh] Handshake successful for {self.node_id}")
# Optional: Handle policy res.policy
return True
else:
logger.error(f"[Mesh] Handshake REJECTED for {self.node_id}: {res.error_message}")
return False
except Exception as e:
logger.error(f"[Mesh] Handshake FAILED for {self.node_id}: {e}")
return False
def connect(self):
self._stop_event.clear()
self._refresh_stub()
threading.Thread(target=self._run_stream, daemon=True, name="GrpcTransportStream").start()
def set_listener(self, listener: IMeshListener):
self.listener = listener
def _refresh_stub(self):
if self.channel:
try: self.channel.close()
except: pass
self.stub, self.channel = self.stub_factory()
def _run_stream(self):
retry_count = 0
while not self._stop_event.is_set():
try:
def _gen():
# Initial announcement
yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
last_heartbeat = time.time()
while not self._stop_event.is_set():
try:
# PriorityQueue returns (priority, ts, msg)
item = self.send_queue.get(timeout=1.0)
yield item[2]
except queue.Empty:
pass
# Transport-level KeepAlive
if time.time() - last_heartbeat >= 10.0:
yield agent_pb2.ClientTaskMessage(
skill_event=agent_pb2.SkillEvent(keep_alive=True)
)
last_heartbeat = time.time()
logger.info(f"[*] Opening gRPC TaskStream for {self.node_id}...")
responses = self.stub.TaskStream(_gen())
self._connected = True
self.last_activity = time.time()
logger.info(f"[✅] gRPC Mesh Transport Online for {self.node_id}")
retry_count = 0
for msg in responses:
self.last_activity = time.time()
if self.listener:
self.listener.on_message(msg)
self._connected = False
logger.warning(f"[📶] gRPC Stream closed by server for {self.node_id}")
except Exception as e:
self._connected = False
if self.listener:
self.listener.on_error(e)
if not self._stop_event.is_set():
retry_count += 1
backoff = min(30, 2 * retry_count)
logger.error(f"[❌] gRPC Stream Error ({type(e).__name__}): {e}. Reconnecting in {backoff}s... (Attempt {retry_count})")
self._refresh_stub()
time.sleep(backoff)
def send(self, message: Any, priority: int = 1):
# PriorityQueue expects (priority, timestamp, item)
# Note: Client messages are not currently signed by the SDK,
# but this is where we would add it if needed.
self.send_queue.put((priority, time.time(), message))
def send_health(self, heartbeat: Any):
"""Sends a heartbeat via the dedicated health stream."""
if not self._health_thread_started:
self._start_health_stream()
self.health_queue.put(heartbeat)
def _start_health_stream(self):
if self._health_thread_started: return
self._health_thread_started = True
threading.Thread(target=self._run_health_stream, daemon=True, name="GrpcHealthStream").start()
def _run_health_stream(self):
while not self._stop_event.is_set():
try:
if not self.stub:
time.sleep(1)
continue
def _gen():
while not self._stop_event.is_set():
try:
hb = self.health_queue.get(timeout=1.0)
yield hb
except queue.Empty:
pass
logger.info("[*] Opening gRPC HealthStream...")
responses = self.stub.ReportHealth(_gen())
for res in responses:
# Optional: Handle HealthCheckResponse (server_time_ms)
pass
except Exception as e:
logger.error(f"[Mesh] Health Stream Error: {e}")
time.sleep(5)
def close(self):
self._stop_event.set()
if self.channel:
self.channel.close()
self._connected = False
if self.listener:
self.listener.on_close()
def is_connected(self) -> bool:
return self._connected
class GrpcServerTransport(IMeshTransport):
"""
gRPC implementation of IMeshTransport for the server side.
Wraps a single bi-directional stream context.
"""
def __init__(self, context: Any, signer: Optional[Callable[[bytes], str]] = None):
self.context = context
self.send_queue = queue.PriorityQueue()
self.listener = None
self.signer = signer
def handshake(self) -> bool:
return True # Handshake handled by Servicer
def connect(self):
pass
def set_listener(self, listener: IMeshListener):
self.listener = listener
def send(self, message: Any, priority: int = 1):
if self.signer:
if hasattr(message, 'signature'):
message.signature = ""
msg_bytes = message.SerializeToString(deterministic=True)
message.signature = self.signer(msg_bytes)
# PriorityQueue expects (priority, timestamp, item) to ensure stable ordering
self.send_queue.put((priority, time.time(), message))
def close(self):
try:
import grpc
self.context.abort(grpc.StatusCode.CANCELLED, "Transport closed")
except:
pass
def is_connected(self) -> bool:
return self.context.is_active()