Newer
Older
cortex-hub / ai-hub / app / core / services / mesh.py
import os
import secrets
import json
import uuid
import logging
from typing import Optional, List
from fastapi import HTTPException
from sqlalchemy.orm import Session
import jinja2

from app.db import models
from app.api import schemas
from app.api.dependencies import ServiceContainer
from app.core.grpc.utils.crypto import sign_payload
from app.protos import agent_pb2

logger = logging.getLogger(__name__)

class MeshService:
    def __init__(self, services: ServiceContainer = None):
        self.services = services
        # Setup Jinja2 templates
        self.templates_dir = os.path.join(os.path.dirname(__file__), "..", "templates", "provisioning")
        self.jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(self.templates_dir)) if os.path.exists(self.templates_dir) else None

    # --- Admin Logic ---

    def register_node(self, request: schemas.AgentNodeCreate, admin_id: str, db: Session) -> models.AgentNode:
        existing = db.query(models.AgentNode).filter(models.AgentNode.node_id == request.node_id).first()
        if existing:
            raise HTTPException(status_code=409, detail=f"Node '{request.node_id}' already exists.")

        invite_token = secrets.token_urlsafe(32)
        node = models.AgentNode(
            node_id=request.node_id,
            display_name=request.display_name,
            description=request.description,
            registered_by=admin_id,
            skill_config=request.skill_config.model_dump(),
            invite_token=invite_token,
            last_status="offline",
        )
        db.add(node)
        db.commit()
        db.refresh(node)
        return node

    def update_node(self, node_id: str, update: schemas.AgentNodeUpdate, db: Session) -> models.AgentNode:
        node = self.get_node_or_404(node_id, db)
        if update.display_name is not None: node.display_name = update.display_name
        if update.description is not None: node.description = update.description
        if update.skill_config is not None:
            node.skill_config = update.skill_config.model_dump()
            try:
                self.services.orchestrator.push_policy(node_id, node.skill_config)
            except Exception as e:
                logger.warning(f"Could not push live policy to {node_id}: {e}")

        if update.is_active is not None: node.is_active = update.is_active
        db.commit()
        db.refresh(node)
        return node

    def delete_node(self, node_id: str, db: Session):
        node = self.get_node_or_404(node_id, db)
        self.services.node_registry_service.deregister(node_id)
        db.delete(node)
        db.commit()

    def grant_access(self, node_id: str, grant: schemas.NodeAccessGrant, admin_id: str, db: Session):
        existing = db.query(models.NodeGroupAccess).filter(
            models.NodeGroupAccess.node_id == node_id,
            models.NodeGroupAccess.group_id == grant.group_id
        ).first()
        if existing:
            existing.access_level = grant.access_level
            existing.granted_by = admin_id
        else:
            access = models.NodeGroupAccess(
                node_id=node_id, group_id=grant.group_id,
                access_level=grant.access_level, granted_by=admin_id,
            )
            db.add(access)
        db.commit()

    # --- User Logic ---

    def list_accessible_nodes(self, user_id: str, db: Session) -> List[models.AgentNode]:
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found.")

        # Admin bypass is removed to allow testing of node group assignments for all users.
        # if user.role == "admin":
        #     return db.query(models.AgentNode).filter(models.AgentNode.is_active == True).all()

        accesses = db.query(models.NodeGroupAccess).filter(models.NodeGroupAccess.group_id == user.group_id).all()
        node_ids = set([a.node_id for a in accesses])
        if user.group and user.group.policy:
            for nid in user.group.policy.get("nodes", []): node_ids.add(nid)

        return db.query(models.AgentNode).filter(
            models.AgentNode.node_id.in_(list(node_ids)),
            models.AgentNode.is_active == True
        ).all()

    def dispatch_task(self, node_id: str, command: str, user_id: str, db: Session, session_id: str = "", task_id: str = None, timeout_ms: int = 30000):
        self.require_node_access(user_id, node_id, db)
        registry = self.services.node_registry_service
        live = registry.get_node(node_id)
        if not live:
            raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.")

        t_id = task_id or str(uuid.uuid4())
        registry.emit(node_id, "task_assigned", {"command": command, "session_id": session_id}, task_id=t_id)

        task_req = agent_pb2.TaskRequest(
            task_id=t_id, payload_json=command, signature=sign_payload(command),
            timeout_ms=timeout_ms, session_id=session_id or ""
        )
        live.send_message(agent_pb2.ServerTaskMessage(task_request=task_req), priority=1)
        registry.emit(node_id, "task_start", {"command": command}, task_id=t_id)
        return t_id

    # --- Utilities ---

    def get_node_or_404(self, node_id: str, db: Session) -> models.AgentNode:
        node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first()
        if not node:
            raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.")
        return node

    def node_to_admin_detail(self, node: models.AgentNode) -> schemas.AgentNodeAdminDetail:
        registry = self.services.node_registry_service
        live = registry.get_node(node.node_id)
        status = live._compute_status() if live else node.last_status or "offline"
        stats = schemas.AgentNodeStats(**live.stats) if live else schemas.AgentNodeStats()
        return schemas.AgentNodeAdminDetail(
            node_id=node.node_id,
            display_name=node.display_name,
            description=node.description,
            skill_config=node.skill_config or {},
            capabilities=node.capabilities or {},
            invite_token=node.invite_token,
            is_active=node.is_active,
            last_status=status,
            last_seen_at=node.last_seen_at,
            created_at=node.created_at,
            registered_by=node.registered_by,
            group_access=[
                schemas.NodeAccessResponse(
                    id=a.id, node_id=a.node_id, group_id=a.group_id,
                    access_level=a.access_level, granted_at=a.granted_at
                ) for a in (node.group_access or [])
            ],
            stats=stats,
        )

    def generate_node_config_yaml(self, node: models.AgentNode, skill_overrides: dict = None, is_windows: bool = False) -> str:
        from app.config import settings
        import yaml
        
        hub_url = settings.GRPC_EXTERNAL_ENDPOINT or os.getenv("HUB_PUBLIC_URL", "http://127.0.0.1:8000")
        hub_grpc = settings.GRPC_TARGET_ORIGIN or os.getenv("HUB_GRPC_ENDPOINT", "127.0.0.1:50051")
        secret_key = os.getenv("SECRET_KEY", "dev-secret-key-1337")

        skill_cfg = node.skill_config or {}
        if isinstance(skill_cfg, str):
            try: skill_cfg = json.loads(skill_cfg)
            except: skill_cfg = {}
            
        if skill_overrides:
            for skill, cfg in skill_overrides.items():
                if isinstance(cfg, dict):
                    skill_cfg.setdefault(skill, {}).update(cfg)
                else:
                    logger.warning(f"Skipping non-dict override for skill '{skill}': {cfg}")
        
        config_data = {
            "node_id": node.node_id,
            "node_description": node.display_name,
            "hub_url": hub_url,
            "grpc_endpoint": hub_grpc,
            "invite_token": node.invite_token,
            "auth_token": node.invite_token,
            "secret_key": secret_key,
            "skills": skill_cfg,
            "sync_root": settings.AGENT_SYNC_ROOT_WINDOWS if is_windows else settings.AGENT_SYNC_ROOT_LINUX,
            "fs_root": settings.AGENT_FS_ROOT_WINDOWS if is_windows else settings.AGENT_FS_ROOT_LINUX,
            "tls": settings.GRPC_TLS_ENABLED
        }
        
        header = f"# Cortex Hub - Agent Node Configuration\n# Generated for node '{node.node_id}'\n\n"
        return header + yaml.dump(config_data, sort_keys=False, default_flow_style=False)

    # Extracted from nodes.py
    def require_node_access(self, user_id: str, node_id: str, db: Session):
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found.")
        
        # Admin bypass removed to enforce assignment rules for all roles
        # if user.role == "admin":
        #    return user
            
        access = db.query(models.NodeGroupAccess).filter(
            models.NodeGroupAccess.node_id == node_id,
            models.NodeGroupAccess.group_id == user.group_id
        ).first()
        if access:
            return user
            
        if user.group and user.group.policy:
            policy_nodes = user.group.policy.get("nodes", [])
            if isinstance(policy_nodes, list) and node_id in policy_nodes:
                return user
                
        raise HTTPException(status_code=403, detail=f"Access Denied: You do not have permission to access node '{node_id}'.")

    def node_to_user_view(self, node: models.AgentNode, registry) -> schemas.AgentNodeUserView:
        live = registry.get_node(node.node_id)
        status = live._compute_status() if live else "offline"
        
        skill_cfg = node.skill_config or {}
        if isinstance(skill_cfg, str):
            try: skill_cfg = json.loads(skill_cfg)
            except: skill_cfg = {}
        available = [skill for skill, cfg in skill_cfg.items() if isinstance(cfg, dict) and cfg.get("enabled", True)]
        stats = live.stats if live else {}
        
        return schemas.AgentNodeUserView(
            node_id=node.node_id,
            display_name=node.display_name,
            description=node.description,
            capabilities=node.capabilities or {},
            available_skills=available,
            last_status=status,
            last_seen_at=node.last_seen_at,
            stats=schemas.AgentNodeStats(**stats) if stats else schemas.AgentNodeStats()
        )

    def generate_provisioning_script(self, node: models.AgentNode, config_yaml: str, base_url: str) -> str:
        return self._render_provision_template("provision.py.j2", node, config_yaml, base_url)

    def generate_provisioning_sh(self, node: models.AgentNode, config_yaml: str, base_url: str) -> str:
        return self._render_provision_template("provision.sh.j2", node, config_yaml, base_url)

    def generate_provisioning_ps1(self, node: models.AgentNode, config_yaml: str, base_url: str, grpc_url: str = "") -> str:
        params = {"grpc_url": grpc_url or base_url.replace("http://", "").replace("https://", "")}
        return self._render_provision_template("provision.ps1.j2", node, config_yaml, base_url, **params)

    def _render_provision_template(self, template_name: str, node: models.AgentNode, config_yaml: str, base_url: str, **kwargs) -> str:
        if not self.jinja_env: return "Error: Templates directory not found."
        try:
            return self.jinja_env.get_template(template_name).render(
                node_id=node.node_id, config_yaml=config_yaml,
                base_url=base_url, invite_token=node.invite_token, **kwargs
            )
        except Exception as e:
            logger.error(f"Failed to render {template_name}: {e}")
            return f"Error: {e}"

    def get_template_content(self, filename: str) -> str:
        if not self.jinja_env: return ""
        try: return self.jinja_env.get_template(filename).render()
        except: return ""