import logging
from typing import List, Dict, Any, Optional, Callable
from sqlalchemy.orm import Session
from app.db import models
# Define a default prompt template outside the class or as a class constant
# This is inferred from the usage in the provided diff.
PROMPT_TEMPLATE = """You are the Cortex AI Assistant, a powerful orchestrator of decentralized agent nodes.
## Architecture Highlights:
- You operate within a secure, gRPC-based mesh of Agent Nodes.
- You can execute shell commands, browse the web, and manage files on these nodes.
- You use 'skills' to interact with the physical world.
{mesh_context}
## Task:
Generate a natural and context-aware answer using the provided knowledge, conversation history, and available tools.
Relevant excerpts from the knowledge base:
{context}
Conversation History:
{chat_history}
User Question: {question}
Answer:"""
VOICE_PROMPT_TEMPLATE = """You are a conversational voice assistant.
Keep your responses short, natural, and helpful.
Avoid using technical jargon or listing technical infrastructure details unless specifically asked.
Focus on being a friendly companion.
Conversation History:
{chat_history}
User Question: {question}
Answer:"""
class RagPipeline:
"""
A flexible and extensible RAG pipeline updated to remove DSPy dependency.
"""
def __init__(
self,
context_postprocessor: Optional[Callable[[List[str]], str]] = None,
history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
response_postprocessor: Optional[Callable[[str], str]] = None,
):
self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
self.history_formatter = history_formatter or self._default_history_formatter
self.response_postprocessor = response_postprocessor
async def forward(
self,
question: str,
context_chunks: List[Dict[str, Any]],
history: List[models.Message],
llm_provider = None,
prompt_service = None,
tool_service = None,
tools: List[Dict[str, Any]] = None,
mesh_context: str = "",
db: Optional[Session] = None,
user_id: Optional[str] = None,
feature_name: str = "chat",
prompt_slug: str = "rag-pipeline"
):
logging.debug(f"[RagPipeline.forward] Received question: '{question}'")
if not llm_provider:
raise ValueError("LLM Provider is required.")
history_text = self.history_formatter(history)
context_text = self.context_postprocessor(context_chunks)
template = PROMPT_TEMPLATE
if feature_name == "voice":
template = VOICE_PROMPT_TEMPLATE
if prompt_service and db and user_id:
db_prompt = prompt_service.get_prompt_by_slug(db, prompt_slug, user_id)
if db_prompt:
template = db_prompt.content
system_prompt = template.format(
question=question,
context=context_text,
chat_history=history_text,
mesh_context=mesh_context
)
# 1. Prepare initial messages
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
]
import asyncio
import time
# 2. Agentic Tool Loop (Max 5 turns to prevent infinite loops)
for turn in range(5):
request_kwargs = {"stream": True}
if tools:
request_kwargs["tools"] = tools
request_kwargs["tool_choice"] = "auto"
model = getattr(llm_provider, "model_name", "unknown")
msg_lens = []
for m in messages:
content = ""
if hasattr(m, "content") and m.content is not None:
content = m.content
elif isinstance(m, dict):
content = m.get("content") or ""
msg_lens.append(len(content))
total_chars = sum(msg_lens)
logging.info(f"[RagPipeline] Turn {turn+1} starting (STREAMING). Model: {model}, Messages: {len(messages)}, Total Chars: {total_chars}")
# LiteLLM streaming call
prediction = await llm_provider.acompletion(messages=messages, **request_kwargs)
accumulated_content = ""
accumulated_reasoning = ""
tool_calls_map = {} # index -> tc object
async for chunk in prediction:
if not chunk.choices: continue
delta = chunk.choices[0].delta
# A. Handle Reasoning (Thinking)
# Some models use 'reasoning_content' (OpenAI-compatible / DeepSeek)
reasoning = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
if reasoning:
accumulated_reasoning += reasoning
yield {"type": "reasoning", "content": reasoning}
# B. Handle Content
content = getattr(delta, "content", None) or delta.get("content")
if content:
accumulated_content += content
yield {"type": "content", "content": content}
# C. Handle Tool Calls
tool_calls = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
if tool_calls:
for tc_delta in tool_calls:
idx = tc_delta.index
if idx not in tool_calls_map:
tool_calls_map[idx] = tc_delta
else:
# Accumulate arguments
if tc_delta.function.arguments:
tool_calls_map[idx].function.arguments += tc_delta.function.arguments
# Process completed turn
if not tool_calls_map:
# If no tools, this is the final answer for this forward pass.
return
# 3. Parallel dispatch logic for tools
processed_tool_calls = list(tool_calls_map.values())
# Reconstruct the tool call list and message object for the next turn
assistant_msg = {
"role": "assistant",
"content": accumulated_content or None,
"tool_calls": processed_tool_calls
}
if accumulated_reasoning:
assistant_msg["reasoning_content"] = accumulated_reasoning
messages.append(assistant_msg)
# A. Dispatch all tool calls simultaneously
tool_tasks = []
for tc in processed_tool_calls:
func_name = tc.function.name
func_args = {}
try:
import json
func_args = json.loads(tc.function.arguments)
except: pass
# --- M7 Parallel PTY Optimization ---
# If the tool is terminal control and no session is provided,
# use a unique session ID per SUBAGENT task to avoid PTY SERIALIZATION.
if func_name == "mesh_terminal_control" and "session_id" not in func_args:
func_args["session_id"] = f"subagent-{tc.id[:8]}"
yield {"type": "status", "content": f"AI decided to use tool: {func_name}"}
logging.info(f"[🔧] Agent calling tool (PARALLEL): {func_name} with {func_args}")
if tool_service:
# Notify UI about tool execution start
yield {"type": "tool_start", "name": func_name, "args": func_args}
# Create an async task for each tool call
tool_tasks.append(asyncio.create_task(
tool_service.call_tool(func_name, func_args, db=db, user_id=user_id)
))
else:
# Treat as failure immediately if no service
tool_tasks.append(asyncio.sleep(0, result={"success": False, "error": "Tool service not available"}))
# B. HEARTBEAT WAIT: Wait for all sub-agent tasks to fulfill in parallel
wait_start = time.time()
if tool_tasks:
while not all(t.done() for t in tool_tasks):
elapsed = int(time.time() - wait_start)
# This status fulfills the requirement: "internal wait seconds (showing this wait seconds in chat)"
yield {"type": "status", "content": f"Waiting for nodes result... ({elapsed}s)"}
await asyncio.sleep(1)
# C. Collect results and populate history turn
for i, task in enumerate(tool_tasks):
tc = processed_tool_calls[i]
func_name = tc.function.name
result = await task
# Stream the result back so UI can see "behind the scenes"
yield {"type": "tool_result", "name": func_name, "result": result}
messages.append({
"role": "tool",
"tool_call_id": tc.id,
"name": func_name,
"content": json.dumps(result) if isinstance(result, dict) else str(result)
})
# --- Loop finished without return ---
yield {"type": "error", "content": "Agent loop reached maximum turns (5) without a final response."}
def _build_prompt(self, context, history, question):
return f"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.
Relevant excerpts from the knowledge base:
{context}
Conversation History:
{history}
User Question: {question}
Answer:"""
# Default context processor: concatenate chunks
def _default_context_postprocessor(self, contexts: List[str]) -> str:
return "\n\n".join(contexts) or "No context provided."
# Default history formatter: simple speaker prefix
def _default_history_formatter(self, history: List[models.Message]) -> str:
return "\n".join(
f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
for msg in history
)