import grpc
from concurrent import futures
import time
import agent_pb2
import agent_pb2_grpc
import queue
import threading
import jwt
import hmac
import hashlib
# In production, these would be in .env
SECRET_KEY = "cortex-secret-shared-key"
class AgentOrchestratorServicer(agent_pb2_grpc.AgentOrchestratorServicer):
def __init__(self):
self.nodes = {} # node_id -> queue for messages to node
def Connect(self, request_iterator, context):
node_id = None
# We need a way to send messages to the client from another thread/input
# In a real app, this would be triggered by an AI task planner
send_queue = queue.Queue()
def stream_messages():
while context.is_active():
try:
msg = send_queue.get(timeout=1.0)
yield msg
except queue.Empty:
continue
# Start a thread to handle incoming messages from the client
incoming_thread = threading.Thread(target=self._handle_incoming, args=(request_iterator, context, send_queue))
incoming_thread.start()
# Yield messages from the queue to the client
for msg in stream_messages():
yield msg
def _handle_incoming(self, request_iterator, context, send_queue):
try:
for message in request_iterator:
payload_type = message.WhichOneof('payload')
if payload_type == 'registration':
reg = message.registration
print(f"[*] Node {reg.node_id} registration request. Verifying token...")
try:
# Verify JWT
# In real app, we check if node_id matches token subject
decoded = jwt.decode(reg.auth_token, SECRET_KEY, algorithms=["HS256"])
print(f" [OK] Token verified for workspace: {decoded.get('workspace_id')}")
# Send ACK
ack = agent_pb2.ServerMessage(
registration_ack=agent_pb2.RegistrationResponse(
success=True,
session_id="session-secure-123"
)
)
send_queue.put(ack)
# Test 1: Allowed Command
t1_payload = '{"command": "whoami"}'
t1_sig = hmac.new(SECRET_KEY.encode(), t1_payload.encode(), hashlib.sha256).hexdigest()
send_queue.put(agent_pb2.ServerMessage(
task_request=agent_pb2.TaskRequest(
task_id="task-001-ALLOWED",
task_type="shell",
payload_json=t1_payload,
trace_id="trace-001",
signature=t1_sig
)
))
# Test 2: Sensitive Command (Consent Required)
t2_payload = '{"command": "rm -rf /tmp/node-test"}'
t2_sig = hmac.new(SECRET_KEY.encode(), t2_payload.encode(), hashlib.sha256).hexdigest()
send_queue.put(agent_pb2.ServerMessage(
task_request=agent_pb2.TaskRequest(
task_id="task-002-SENSITIVE",
task_type="shell",
payload_json=t2_payload,
trace_id="trace-002",
signature=t2_sig
)
))
# Test 3: Path Traversal Attempt (JAILBREAK)
t3_payload = '{"command": "cat ../.env.gitbucket"}'
t3_sig = hmac.new(SECRET_KEY.encode(), t3_payload.encode(), hashlib.sha256).hexdigest()
send_queue.put(agent_pb2.ServerMessage(
task_request=agent_pb2.TaskRequest(
task_id="task-003-TRAVERSAL",
task_type="shell",
payload_json=t3_payload,
trace_id="trace-003",
signature=t3_sig
)
))
# Test 4: STRICTLY Forbidden
t4_payload = '{"command": "sudo apt update"}'
t4_sig = hmac.new(SECRET_KEY.encode(), t4_payload.encode(), hashlib.sha256).hexdigest()
send_queue.put(agent_pb2.ServerMessage(
task_request=agent_pb2.TaskRequest(
task_id="task-004-FORBIDDEN",
task_type="shell",
payload_json=t4_payload,
trace_id="trace-004",
signature=t4_sig
)
))
# Test 5: Non-whitelisted but Allowed in PERMISSIVE
t5_payload = '{"command": "df -h"}'
t5_sig = hmac.new(SECRET_KEY.encode(), t5_payload.encode(), hashlib.sha256).hexdigest()
send_queue.put(agent_pb2.ServerMessage(
task_request=agent_pb2.TaskRequest(
task_id="task-005-PERMISSIVE",
task_type="shell",
payload_json=t5_payload,
trace_id="trace-005",
signature=t5_sig
)
))
print("[*] Sequence of 5 test tasks dispatched to verify Sandbox Policy (PERMISSIVE mode).")
except jwt.ExpiredSignatureError:
print(f" [FAIL] Token for {reg.node_id} expired.")
ack = agent_pb2.ServerMessage(
registration_ack=agent_pb2.RegistrationResponse(
success=False,
error_message="Authentication token expired."
)
)
send_queue.put(ack)
except jwt.InvalidTokenError as e:
print(f" [FAIL] Invalid token for {reg.node_id}: {e}")
ack = agent_pb2.ServerMessage(
registration_ack=agent_pb2.RegistrationResponse(
success=False,
error_message=f"Invalid authentication token: {e}"
)
)
send_queue.put(ack)
elif payload_type == 'heartbeat':
hb = message.heartbeat
# print(f"[+] Heartbeat from {hb.node_id}: CPU {hb.cpu_usage_percent}%")
pass
elif payload_type == 'task_response':
res = message.task_response
print(f"[!] Task Finished: {res.task_id} | Status: {res.status}")
print(f" Stdout: {res.stdout.strip()}")
except Exception as e:
print(f"[!] Error handling incoming stream: {e}")
def serve():
# Load certificates for mTLS
print("[🔐] Loading mTLS certificates...")
with open('certs/server.key', 'rb') as f:
private_key = f.read()
with open('certs/server.crt', 'rb') as f:
certificate_chain = f.read()
with open('certs/ca.crt', 'rb') as f:
root_certificates = f.read()
# Create server credentials
# require_client_auth=True enforces bidirectional verification (mTLS)
server_credentials = grpc.ssl_server_credentials(
[(private_key, certificate_chain)],
root_certificates=root_certificates,
require_client_auth=True
)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
agent_pb2_grpc.add_AgentOrchestratorServicer_to_server(AgentOrchestratorServicer(), server)
# Use secure_port instead of insecure_port
server.add_secure_port('[::]:50051', server_credentials)
print("[*] Cortex Secure Server POC listening on port 50051 (mTLS enabled)...")
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()