Newer
Older
cortex-hub / mesh-sdk / tests / test_core.py
@yangyang xie yangyang xie 17 days ago 2 KB refactor done
import pytest
import time
from mesh_core.server_engine import MeshServerCore
from mesh_core.node_engine import MeshNodeCore
from mesh_core.transport_mock import MockMeshTransport
from mesh_core import agent_pb2

class MockListener:
    def __init__(self):
        self.messages = []
    def on_message(self, node_id, msg):
        self.messages.append((node_id, msg))

def test_server_registration():
    server = MeshServerCore()
    online_nodes = []
    server.on_node_online = lambda record: online_nodes.append(record.node_id)
    
    transport = MockMeshTransport("test-node")
    server.register_node("test-node", "user-1", {"desc": "test"}, transport)
    
    assert "test-node" in online_nodes
    assert len(server.list_nodes()) == 1
    assert server.get_node("test-node").user_id == "user-1"

def test_server_dispatch():
    server = MeshServerCore()
    transport = MockMeshTransport("test-node")
    server.register_node("test-node", "user-1", {}, transport)
    
    # Mock transport.send to capture messages
    sent_messages = []
    transport.send = lambda msg, priority=1: sent_messages.append(msg)
    
    msg = agent_pb2.ServerTaskMessage(work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=["task-1"]))
    server.dispatch("test-node", msg)
    
    assert len(sent_messages) == 1
    assert sent_messages[0].work_pool_update.available_task_ids == ["task-1"]

def test_node_lifecycle():
    transport = MockMeshTransport("node-1")
    node = MeshNodeCore("node-1", transport)
    
    received_tasks = []
    node.on_task = lambda task: received_tasks.append(task)
    
    # Mock handshake to always succeed
    transport.handshake = lambda: True
    
    node.start()
    assert transport.is_connected()
    
    # Simulate inbound message from server
    msg = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(task_id="t1", payload_json="{}"))
    transport.simulate_server_message(msg)
    
    assert len(received_tasks) == 1
    assert received_tasks[0].task_id == "t1"
    
    node.stop()
    assert not transport.is_connected()

def test_server_inbound_handling():
    server = MeshServerCore()
    received = []
    server.on_message_received = lambda node_id, msg: received.append((node_id, msg))
    
    msg = agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id="node-1"))
    server.handle_inbound("node-1", msg)
    
    assert len(received) == 1
    assert received[0][0] == "node-1"
    assert received[0][1].announce.node_id == "node-1"