Newer
Older
cortex-hub / ai-hub / app / api / routes / agents.py
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import List
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.db.models.agent import AgentTemplate, AgentInstance, AgentTrigger
from app.db.models import Message
from app.api.schemas import (
    AgentTemplateCreate, AgentTemplateResponse,
    AgentInstanceCreate, AgentInstanceResponse, AgentInstanceStatusUpdate
)
import uuid
import json
import os
import logging
from app.core.orchestration.agent_loop import AgentExecutor

from sqlalchemy.orm import joinedload

def create_agents_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter()

    def _workspace_id_from_jail(jail_path: str | None, fallback_session_id: int | None = None) -> str:
        """Derive a stable workspace ID from jail path when possible."""
        if jail_path:
            normalized = jail_path.rstrip("/")
            base = os.path.basename(normalized)
            if base:
                return base
        if fallback_session_id is not None:
            return f"session-{fallback_session_id}"
        return f"agent-{uuid.uuid4().hex[:8]}"

    def _ensure_agent_workspace_binding(instance: AgentInstance, db: Session):
        """
        Keep session sync workspace and agent jail aligned.
        This heals legacy records where sync_workspace_id is null/mismatched.
        """
        if not instance or not instance.session:
            return

        desired_workspace_id = _workspace_id_from_jail(instance.current_workspace_jail, instance.session_id)
        desired_jail = f"/tmp/cortex/{desired_workspace_id}/"

        changed = False
        if instance.session.sync_workspace_id != desired_workspace_id:
            instance.session.sync_workspace_id = desired_workspace_id
            changed = True

        if instance.current_workspace_jail != desired_jail:
            instance.current_workspace_jail = desired_jail
            changed = True

        if changed:
            db.flush()
            try:
                orchestrator = getattr(services, "orchestrator", None)
                if orchestrator and instance.mesh_node_id:
                    orchestrator.assistant.push_workspace(instance.mesh_node_id, desired_workspace_id)
                    orchestrator.assistant.control_sync(instance.mesh_node_id, desired_workspace_id, action="START")
                    orchestrator.assistant.control_sync(instance.mesh_node_id, desired_workspace_id, action="UNLOCK")
            except Exception as e:
                logging.error(f"Failed to heal workspace binding for agent {instance.id}: {e}")

    @router.get("", response_model=List[AgentInstanceResponse])
    def get_agents(db: Session = Depends(get_db)):
        agents = db.query(AgentInstance).options(joinedload(AgentInstance.template), joinedload(AgentInstance.session)).all()
        changed = False
        for instance in agents:
            before_sync = instance.session.sync_workspace_id if instance.session else None
            before_jail = instance.current_workspace_jail
            _ensure_agent_workspace_binding(instance, db)
            after_sync = instance.session.sync_workspace_id if instance.session else None
            after_jail = instance.current_workspace_jail
            if before_sync != after_sync or before_jail != after_jail:
                changed = True

        if changed:
            db.commit()
        return agents
        
    @router.post("/templates", response_model=AgentTemplateResponse)
    def create_template(request: AgentTemplateCreate, db: Session = Depends(get_db)):
        template = AgentTemplate(**request.model_dump())
        db.add(template)
        db.commit()
        db.refresh(template)
        return template

    @router.post("/instances", response_model=AgentInstanceResponse)
    def create_instance(request: AgentInstanceCreate, db: Session = Depends(get_db)):
        # Verify template exists
        template = db.query(AgentTemplate).filter(AgentTemplate.id == request.template_id).first()
        if not template:
            raise HTTPException(status_code=404, detail="Template not found")
            
        instance = AgentInstance(**request.model_dump())
        db.add(instance)
        db.commit()
        db.refresh(instance)
        return instance

    @router.patch("/{id}/status", response_model=AgentInstanceResponse)
    def update_status(id: str, request: AgentInstanceStatusUpdate, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
            
        instance.status = request.status
        db.commit()
        db.refresh(instance)
        return instance

    @router.patch("/{id}/config", response_model=AgentInstanceResponse)
    def update_config(id: str, request: schemas.AgentConfigUpdate, db: Session = Depends(get_db)):
        from app.db.models.session import Session as SessionModel
        
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
            
        template = db.query(AgentTemplate).filter(AgentTemplate.id == instance.template_id).first()
        
        if request.name is not None and template:
            template.name = request.name
        if request.system_prompt is not None and template:
            template.system_prompt_path = request.system_prompt
        if request.max_loop_iterations is not None and template:
            template.max_loop_iterations = request.max_loop_iterations
            
        if request.mesh_node_id is not None:
            instance.mesh_node_id = request.mesh_node_id
            
        # Update the Session overriding prompt so the running loop picks it up instantly!
        if instance.session_id:
            session = db.query(SessionModel).filter(SessionModel.id == instance.session_id).first()
            if session:
                if request.system_prompt is not None:
                    session.system_prompt_override = request.system_prompt
                if hasattr(request, 'provider_name') and request.provider_name is not None:
                    session.provider_name = request.provider_name
                if request.mesh_node_id is not None:
                    old_nodes = session.attached_node_ids or []
                    if not old_nodes or request.mesh_node_id not in old_nodes or len(old_nodes) > 1:
                        try:
                            services.session_service.attach_nodes(db, session.id, schemas.NodeAttachRequest(node_ids=[request.mesh_node_id] if request.mesh_node_id else []))
                        except Exception as e:
                            logging.error(f"Failed to attach session node: {e}")
                    else:
                        session.attached_node_ids = [request.mesh_node_id] if request.mesh_node_id else []
                if hasattr(request, 'restrict_skills') and request.restrict_skills is not None:
                    session.restrict_skills = request.restrict_skills
                if hasattr(request, 'allowed_skill_ids') and request.allowed_skill_ids is not None:
                    from app.db.models.asset import Skill
                    skills = db.query(Skill).filter(Skill.id.in_(request.allowed_skill_ids)).all()
                    session.skills = skills
                if hasattr(request, 'is_locked') and request.is_locked is not None:
                    session.is_locked = request.is_locked
                if hasattr(request, 'auto_clear_history') and request.auto_clear_history is not None:
                    session.auto_clear_history = request.auto_clear_history
        
        db.commit()
        db.refresh(instance)
        return instance

    @router.post("/{id}/webhook", status_code=202)
    def webhook_receiver(id: str, payload: dict, background_tasks: BackgroundTasks, token: str = None, db: Session = Depends(get_db)):
        # Validate instance
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
            
        # Pass webhook event directly to the Agent Executor to process
        prompt = f"Webhook Event: {json.dumps(payload)}"
        background_tasks.add_task(AgentExecutor.run, instance.id, prompt, services.rag_service, services.user_service)
        return {"message": "Accepted"}

    @router.post("/{id}/run", status_code=202)
    def manual_trigger(id: str, payload: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
            
        prompt = payload.get("prompt") or f"Manual triggered execution for agent {id}."
        background_tasks.add_task(AgentExecutor.run, instance.id, prompt, services.rag_service, services.user_service)
        return {"message": "Accepted"}

    @router.get("/{id}/triggers", response_model=List[schemas.AgentTriggerResponse])
    def get_agent_triggers(id: str, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
        return db.query(AgentTrigger).filter(AgentTrigger.instance_id == id).all()

    @router.post("/{id}/triggers", response_model=schemas.AgentTriggerResponse)
    def create_agent_trigger(id: str, request: schemas.AgentTriggerCreate, db: Session = Depends(get_db)):
        trigger = AgentTrigger(**request.model_dump())
        trigger.instance_id = id # Ensure it maps safely
        
        if trigger.trigger_type == "webhook" and not trigger.webhook_secret:
            import secrets
            trigger.webhook_secret = secrets.token_hex(16)
            
        db.add(trigger)
        db.commit()
        db.refresh(trigger)
        return trigger

    @router.delete("/triggers/{trigger_id}")
    def delete_agent_trigger(trigger_id: str, db: Session = Depends(get_db)):
        trigger = db.query(AgentTrigger).filter(AgentTrigger.id == trigger_id).first()
        if not trigger:
            raise HTTPException(status_code=404, detail="Trigger not found")
        db.delete(trigger)
        db.commit()
        return {"message": "Trigger deleted successfully"}


    @router.post("/{id}/metrics/reset")
    def reset_agent_metrics(id: str, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
        
        instance.total_runs = 0
        instance.successful_runs = 0
        instance.total_tokens_accumulated = 0
        instance.total_running_time_seconds = 0.0
        # By setting this to an empty dict but doing an in-place update the ORM sees it
        instance.tool_call_counts = {}
        
        db.commit()
        db.refresh(instance)
        return {"message": "Metrics reset successfully"}

    @router.get("/{id}/telemetry")
    def get_telemetry(id: str, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
        # For MVP/Area 3, return mock telemetry data (e.g. baseline or from cgroup)
        # Real cgroup-based metrics will come in Phase 2
        return {
            "cpu_usage": 2.5,
            "memory_usage": 512,
            "network_tx": 120,
            "network_rx": 450
        }

    @router.get("/{id}/dependencies")
    def get_dependencies(id: str, db: Session = Depends(get_db)):
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Instance not found")
        return {
            "dependencies": [],
            "edges": []
        }

    @router.post("/deploy")
    def deploy_agent(
        request: schemas.DeployAgentRequest,
        background_tasks: BackgroundTasks,
        db: Session = Depends(get_db)
    ):
        """
        One-click agent deployment (Design Doc CUJ 1).
        Atomically creates: Template → Session → Instance → Locks Session → Injects initial prompt → Starts loop.
        """
        from app.db import models as db_models

        # 1. Create Template
        template = AgentTemplate(
            name=request.name,
            description=request.description,
            system_prompt_path=request.system_prompt,
            max_loop_iterations=request.max_loop_iterations
        )
        db.add(template)
        db.flush()

        # Resolve default provider mapping if user didn't select one
        resolved_provider = request.provider_name
        if not resolved_provider:
            sys_prefs = services.user_service.get_system_settings(db)
            resolved_provider = sys_prefs.get('llm', {}).get('default_provider', 'gemini')

        # 2. Create a locked Session for the agent
        new_session = db_models.Session(
            user_id="agent-system",
            provider_name=resolved_provider,
            feature_name="agent_harness",
            is_locked=True,
            system_prompt_override=request.system_prompt,
            attached_node_ids=[request.mesh_node_id] if getattr(request, "mesh_node_id", None) else []
        )
        db.add(new_session)
        db.flush()

        workspace_id = f"agent_{template.id[:8]}"
        workspace_jail = f"/tmp/cortex/{workspace_id}/"
        new_session.sync_workspace_id = workspace_id
        db.flush()
        
        # 2.5: Inject node into Orchestrator to ensure mirror works locally & remotely
        try:
            orchestrator = getattr(services, "orchestrator", None)
            if orchestrator and request.mesh_node_id:
                # Same logic as session attach_nodes config.source="empty"
                orchestrator.assistant.push_workspace(request.mesh_node_id, new_session.sync_workspace_id)
                orchestrator.assistant.control_sync(request.mesh_node_id, new_session.sync_workspace_id, action="START")
                orchestrator.assistant.control_sync(request.mesh_node_id, new_session.sync_workspace_id, action="UNLOCK")
        except Exception as e:
            import logging
            logging.error(f"Failed to bootstrap Orchestrator Sync for Agent Deploy: {e}")

        # 3. Create AgentInstance
        instance = AgentInstance(
            template_id=template.id,
            session_id=new_session.id,
            mesh_node_id=request.mesh_node_id,
            status="idle",
            current_workspace_jail=workspace_jail
        )
        db.add(instance)
        db.flush()

        # 4. Create primary trigger if specified
        trigger = AgentTrigger(
            instance_id=instance.id,
            trigger_type=request.trigger_type or "manual",
            cron_expression=request.cron_expression,
            interval_seconds=request.interval_seconds,
            default_prompt=request.default_prompt
        )
        if trigger.trigger_type == "webhook":
            import secrets
            trigger.webhook_secret = secrets.token_hex(16)
        db.add(trigger)
        db.flush()

        # 5. Kick off agent loop if initial prompt was provided
        # (Message insertion is handled automatically by the RAG service execution)
        if request.initial_prompt:
            instance.status = "active"
            db.commit()
            
            async def run_wrapper():
                await AgentExecutor.run(instance.id, request.initial_prompt, services.rag_service, services.user_service)
                
            background_tasks.add_task(run_wrapper)
        else:
            db.commit()

        return {
            "template_id": template.id,
            "template_name": template.name,
            "instance_id": instance.id,
            "session_id": new_session.id,
            "status": instance.status,
            "workspace_jail": workspace_jail,
            "message": f"Agent '{request.name}' deployed successfully"
        }
    @router.delete("/{id}")
    def delete_agent(id: str, db: Session = Depends(get_db)):
        from app.db.models.agent import AgentInstance
        instance = db.query(AgentInstance).filter(AgentInstance.id == id).first()
        if not instance:
            raise HTTPException(status_code=404, detail="Agent not found")
        
        # Stop the agent loop if it was active by deleting it (the loop will hit a None instance and return)
        db.delete(instance)
        db.commit()
        return {"message": "Agent deleted successfully"}

    return router