Newer
Older
cortex-hub / ai-hub / app / core / pipelines / rag_pipeline.py
import logging
from typing import List, Dict, Any, Optional, Callable
from sqlalchemy.orm import Session

from app.db import models

# Define a default prompt template outside the class or as a class constant
# This is inferred from the usage in the provided diff.
PROMPT_TEMPLATE = """You are the Cortex AI Assistant, a powerful orchestrator of decentralized agent nodes. 

## Architecture Highlights:
- You operate within a secure, gRPC-based mesh of Agent Nodes.
- You can execute shell commands, browse the web, and manage files on these nodes.
- You use 'skills' to interact with the physical world.

{mesh_context}

## Task:
Generate a natural and context-aware answer using the provided knowledge, conversation history, and available tools.

Relevant excerpts from the knowledge base:
{context}

Conversation History:
{chat_history}

User Question: {question}

Answer:"""

VOICE_PROMPT_TEMPLATE = """You are a conversational voice assistant. 
Keep your responses short, natural, and helpful. 
Avoid using technical jargon or listing technical infrastructure details unless specifically asked.
Focus on being a friendly companion.

Conversation History:
{chat_history}

User Question: {question}

Answer:"""

class RagPipeline:
    """
    A flexible and extensible RAG pipeline updated to remove DSPy dependency.
    """

    def __init__(
        self,
        context_postprocessor: Optional[Callable[[List[str]], str]] = None,
        history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
        response_postprocessor: Optional[Callable[[str], str]] = None,
    ):
        self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
        self.history_formatter = history_formatter or self._default_history_formatter
        self.response_postprocessor = response_postprocessor

    async def forward(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider = None,
        prompt_service = None,
        tool_service = None,
        tools: List[Dict[str, Any]] = None,
        mesh_context: str = "",
        db: Optional[Session] = None,
        user_id: Optional[str] = None,
        feature_name: str = "chat",
        prompt_slug: str = "rag-pipeline"
    ):
        logging.debug(f"[RagPipeline.forward] Received question: '{question}'")

        if not llm_provider:
            raise ValueError("LLM Provider is required.")

        history_text = self.history_formatter(history)
        context_text = self.context_postprocessor(context_chunks)

        template = PROMPT_TEMPLATE
        if feature_name == "voice":
            template = VOICE_PROMPT_TEMPLATE

        if prompt_service and db and user_id:
            db_prompt = prompt_service.get_prompt_by_slug(db, prompt_slug, user_id)
            if db_prompt:
                template = db_prompt.content

        system_prompt = template.format(
            question=question,
            context=context_text,
            chat_history=history_text,
            mesh_context=mesh_context
        )

        # 1. Prepare initial messages
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ]

        import asyncio
        import time

        # 2. Agentic Tool Loop (Max 5 turns to prevent infinite loops)
        for turn in range(5):
            request_kwargs = {"stream": True}
            if tools:
                request_kwargs["tools"] = tools
                request_kwargs["tool_choice"] = "auto"

            model = getattr(llm_provider, "model_name", "unknown")
            msg_lens = []
            for m in messages:
                content = ""
                if hasattr(m, "content") and m.content is not None:
                    content = m.content
                elif isinstance(m, dict):
                    content = m.get("content") or ""
                msg_lens.append(len(content))
            
            total_chars = sum(msg_lens)
            
            logging.info(f"[RagPipeline] Turn {turn+1} starting (STREAMING). Model: {model}, Messages: {len(messages)}, Total Chars: {total_chars}")

            # LiteLLM streaming call
            prediction = await llm_provider.acompletion(messages=messages, **request_kwargs)
            
            accumulated_content = ""
            accumulated_reasoning = ""
            tool_calls_map = {} # index -> tc object
            
            async for chunk in prediction:
                if not chunk.choices: continue
                delta = chunk.choices[0].delta
                
                # A. Handle Reasoning (Thinking)
                # Some models use 'reasoning_content' (OpenAI-compatible / DeepSeek)
                reasoning = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
                if reasoning:
                    accumulated_reasoning += reasoning
                    yield {"type": "reasoning", "content": reasoning}
                
                # B. Handle Content
                content = getattr(delta, "content", None) or delta.get("content")
                if content:
                    accumulated_content += content
                    yield {"type": "content", "content": content}
                
                # C. Handle Tool Calls
                tool_calls = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
                if tool_calls:
                    for tc_delta in tool_calls:
                        idx = tc_delta.index
                        if idx not in tool_calls_map:
                            tool_calls_map[idx] = tc_delta
                        else:
                            # Accumulate arguments
                            if tc_delta.function.arguments:
                                tool_calls_map[idx].function.arguments += tc_delta.function.arguments

            # Process completed turn
            if not tool_calls_map:
                # If no tools, this is the final answer for this forward pass.
                return

            # 3. Parallel dispatch logic for tools
            processed_tool_calls = list(tool_calls_map.values())
            
            # Reconstruct the tool call list and message object for the next turn
            assistant_msg = {
                "role": "assistant",
                "content": accumulated_content or None,
                "tool_calls": processed_tool_calls
            }
            if accumulated_reasoning:
                assistant_msg["reasoning_content"] = accumulated_reasoning
            
            messages.append(assistant_msg)
            
            # A. Dispatch all tool calls simultaneously
            tool_tasks = []
            for tc in processed_tool_calls:
                func_name = tc.function.name
                func_args = {}
                try:
                    import json
                    func_args = json.loads(tc.function.arguments)
                except: pass

                # --- M7 Parallel PTY Optimization ---
                # If the tool is terminal control and no session is provided, 
                # use a unique session ID per SUBAGENT task to avoid PTY SERIALIZATION.
                if func_name == "mesh_terminal_control" and "session_id" not in func_args:
                    func_args["session_id"] = f"subagent-{tc.id[:8]}"

                yield {"type": "status", "content": f"AI decided to use tool: {func_name}"}
                logging.info(f"[🔧] Agent calling tool (PARALLEL): {func_name} with {func_args}")
                
                if tool_service:
                    # Notify UI about tool execution start
                    yield {"type": "tool_start", "name": func_name, "args": func_args}
                    
                    # Create an async task for each tool call
                    tool_tasks.append(asyncio.create_task(
                        tool_service.call_tool(func_name, func_args, db=db, user_id=user_id)
                    ))
                else:
                    # Treat as failure immediately if no service
                    tool_tasks.append(asyncio.sleep(0, result={"success": False, "error": "Tool service not available"}))

            # B. HEARTBEAT WAIT: Wait for all sub-agent tasks to fulfill in parallel
            wait_start = time.time()
            if tool_tasks:
                while not all(t.done() for t in tool_tasks):
                    elapsed = int(time.time() - wait_start)
                    # This status fulfills the requirement: "internal wait seconds (showing this wait seconds in chat)"
                    yield {"type": "status", "content": f"Waiting for nodes result... ({elapsed}s)"}
                    await asyncio.sleep(1)

            # C. Collect results and populate history turn
            for i, task in enumerate(tool_tasks):
                tc = processed_tool_calls[i]
                func_name = tc.function.name
                result = await task
                
                # Stream the result back so UI can see "behind the scenes"
                yield {"type": "tool_result", "name": func_name, "result": result}
                
                messages.append({
                    "role": "tool",
                    "tool_call_id": tc.id,
                    "name": func_name,
                    "content": json.dumps(result) if isinstance(result, dict) else str(result)
                })

        # --- Loop finished without return ---
        yield {"type": "error", "content": "Agent loop reached maximum turns (5) without a final response."}


    def _build_prompt(self, context, history, question):
        return f"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.

Relevant excerpts from the knowledge base:
{context}

Conversation History:
{history}

User Question: {question}

Answer:"""

    # Default context processor: concatenate chunks
    def _default_context_postprocessor(self, contexts: List[str]) -> str:
        return "\n\n".join(contexts) or "No context provided."

    # Default history formatter: simple speaker prefix
    def _default_history_formatter(self, history: List[models.Message]) -> str:
        return "\n".join(
            f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
            for msg in history
        )