Newer
Older
cortex-hub / mesh-sdk / mesh_core / engines / node.py
import threading
import time
import logging
from typing import Any, Optional, Dict, Callable
from ..transport import IMeshTransport, IMeshListener

logger = logging.getLogger(__name__)

class MeshNodeCore(IMeshListener):
    """
    Portable state machine for a Mesh Node.
    Handles the lifecycle of a node (Handshake, Heartbeat, Reconnection)
    without being coupled to gRPC or specific application logic.
    """

    # Single source of truth: proto field name → (callback_attr, extractor).
    # To add a new message kind: add one entry here. Missing entries emit a
    # WARNING in production immediately — no silent drops possible.
    _DISPATCH = {
        'policy_update':    ('on_policy',    lambda m: m.policy_update),
        'policy':           ('on_policy',    lambda m: m.policy),
        'task_request':     ('on_task',      lambda m: m.task_request),
        'task_cancel':      ('on_cancel',    lambda m: m.task_cancel),
        'file_sync':        ('on_sync',      lambda m: m.file_sync),
        'work_pool_update': ('on_work_pool', lambda m: m.work_pool_update),
    }

    def __init__(self, node_id: str, transport: IMeshTransport):
        self.node_id = node_id
        self.transport = transport
        self.transport.set_listener(self)

        self._stop_event = threading.Event()
        self._is_ready = False

        # Callbacks to be hooked by the application (e.g., AgentNode)
        self.on_task = None  # Callable[[Any], None]
        self.on_cancel = None
        self.on_policy = None
        self.on_sync = None
        self.on_ready = None
        self.on_disconnect = None
        self.on_work_pool = None

    def start(self):
        """Starts the node and its management loops."""
        logger.info(f"[MeshCore] Starting Node {self.node_id}...")
        
        # 1. Perform Handshake
        policy = self.transport.handshake()
        if not policy:
            logger.error(f"[MeshCore] Handshake failed. Node {self.node_id} cannot start.")
            return False
            
        if self.on_policy and hasattr(policy, "mode"):
            self.on_policy(policy)

        # 2. Connect TaskStream
        self.transport.connect()
        threading.Thread(target=self._management_loop, daemon=True, name="MeshNodeMgmt").start()
        return True

    def _management_loop(self):
        """Background loop for health and state monitoring."""
        while not self._stop_event.is_set():
            if not self.transport.is_connected():
                if self._is_ready:
                    self._is_ready = False
                    if self.on_disconnect: self.on_disconnect()
            
            time.sleep(1)

    def send_message(self, message: Any, priority: int = 1):
        """High-level method to send a message via the transport."""
        if self.transport.is_connected():
            self.transport.send(message, priority=priority)
        else:
            logger.warning(f"[MeshCore] Dropped message: Transport disconnected.")

    # IMeshListener Implementation
    def on_message(self, message: Any):
        """Routes inbound server messages via _DISPATCH. To handle a new proto
        field, add one entry to _DISPATCH — no other changes needed."""
        try:
            kind = message.WhichOneof('payload')
            print(f"    [📥] MeshNodeCore Inbound: {kind}")

            entry = self._DISPATCH.get(kind)
            if entry is None:
                logger.warning(f"[MeshCore] Unhandled message kind: '{kind}' — add entry to _DISPATCH")
                return

            cb_attr, extractor = entry
            payload = extractor(message)

            # Lifecycle hook: first policy message marks the node as ready.
            if cb_attr == 'on_policy' and not self._is_ready:
                self._is_ready = True
                if self.on_ready:
                    self.on_ready(payload)

            cb = getattr(self, cb_attr)
            if cb:
                cb(payload)

            logger.debug(f"[MeshCore] Dispatched: {kind} → {cb_attr}")
        except Exception as e:
            logger.error(f"[MeshCore] Dispatch Error ({kind}): {e}")

    def on_error(self, error: Exception):
        logger.error(f"[MeshCore] Transport Error: {error}")

    def on_close(self):
        logger.warning(f"[MeshCore] Transport Closed.")
        self._is_ready = False
        if self.on_disconnect: self.on_disconnect()

    def stop(self):
        """Graceful shutdown."""
        self._stop_event.set()
        self.transport.close()