Newer
Older
cortex-hub / poc-grpc-agent / client.py
import grpc
import time
import agent_pb2
import agent_pb2_grpc
import threading
import subprocess
import json
import platform

class AgentNode:
    def __init__(self, node_id="agent-007"):
        self.node_id = node_id
        self.channel = grpc.insecure_channel('localhost:50051')
        self.stub = agent_pb2_grpc.AgentOrchestratorStub(self.channel)
        print(f"[*] Agent Node {self.node_id} initialized.")

    def run(self):
        # Bi-directional stream connection
        responses = self.stub.Connect(self.message_generator())
        try:
            for response in responses:
                payload_type = response.WhichOneof('payload')
                if payload_type == 'registration_ack':
                    ack = response.registration_ack
                    print(f"[*] Server ACK: Success={ack.success}, Session={ack.session_id}")
                elif payload_type == 'task_request':
                    self.execute_task(response.task_request)
        except grpc.RpcError as e:
            print(f"[!] RPC Error: {e}")

    def message_generator(self):
        # 1. Registration
        print(f"[*] Sending Registration for {self.node_id}...")
        reg = agent_pb2.NodeMessage(
            registration=agent_pb2.RegistrationRequest(
                node_id=self.node_id,
                version="1.0.0",
                platform=platform.system() + "-" + platform.machine(),
                capabilities={"shell": True, "browser": False}
            )
        )
        yield reg

        # 2. Heartbeat loop (every 30s) - usually would be a separate thread,
        # but for the POC we can just let it idle or send one more
        # In a real app we'd yield heartbeats based on a queue
        while True:
            time.sleep(30)
            hb = agent_pb2.NodeMessage(
                heartbeat=agent_pb2.Heartbeat(
                    node_id=self.node_id,
                    cpu_usage_percent=3.5,
                    active_task_count=0
                )
            )
            yield hb

    def execute_task(self, task):
        print(f"[?] Received Task: {task.task_id} ({task.task_type})")
        
        # Dispatch to execution engine
        if task.task_type == "shell":
            try:
                payload = json.loads(task.payload_json)
                cmd = payload.get("command", "echo 'No command'")
                print(f"    Executing local shell: {cmd}")
                
                start_time = time.time()
                result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
                duration = int((time.time() - start_time) * 1000)

                # Return Response (NodeMessage -> response_task)
                # Wait, NodeMessage has task_response which is TaskResponse
                # The response object from Connect is ServerMessage. ServerMessage does not have task_response.
                # NodeMessage HAS task_response.
                # Since we are inside the Connect stream, we need to YIELD the response.
                # This requires a thread-safe queue for the generator.
                
                # NOTE: For this simple sequential POC, we'll need to update run() or generator.
                # Let's use a queue for the generator to be cleaner.
                print(f"    [OK] Task {task.task_id} completed. Sending response...")
            except Exception as e:
                print(f"    [ERROR] Task {task.task_id} failed: {e}")

if __name__ == '__main__':
    # We'll use a queue-based generator for better concurrency support
    import queue
    msg_queue = queue.Queue()

    node = AgentNode()
    
    # 1. Registration
    reg = agent_pb2.NodeMessage(
        registration=agent_pb2.RegistrationRequest(
            node_id=node.node_id,
            version="1.0.0",
            platform=platform.system() + "-" + platform.machine(),
            capabilities={"shell": True, "browser": False}
        )
    )
    msg_queue.put(reg)

    def heartbeat_thread():
        while True:
            time.sleep(10)
            hb = agent_pb2.NodeMessage(
                heartbeat=agent_pb2.Heartbeat(
                    node_id=node.node_id,
                    cpu_usage_percent=1.2,
                    active_task_count=0
                )
            )
            msg_queue.put(hb)

    threading.Thread(target=heartbeat_thread, daemon=True).start()

    def generator():
        while True:
            msg = msg_queue.get()
            yield msg

    responses = node.stub.Connect(generator())
    
    for response in responses:
        payload_type = response.WhichOneof('payload')
        if payload_type == 'registration_ack':
            print(f"[*] Registered: {response.registration_ack.session_id}")
        elif payload_type == 'task_request':
            task = response.task_request
            print(f"[*] Executing {task.task_id}: {task.payload_json}")
            
            payload = json.loads(task.payload_json)
            cmd = payload.get("command")
            res = subprocess.run(cmd, shell=True, capture_output=True, text=True)
            
            # Send result back
            tr = agent_pb2.NodeMessage(
                task_response=agent_pb2.TaskResponse(
                    task_id=task.task_id,
                    status=agent_pb2.TaskResponse.SUCCESS,
                    stdout=res.stdout,
                    stderr=res.stderr,
                    duration_ms=0
                )
            )
            msg_queue.put(tr)