import logging
from typing import List, Dict, Any, Optional, Callable
from sqlalchemy.orm import Session
from app.db import models
# Define a default prompt template outside the class or as a class constant
# This is inferred from the usage in the provided diff.
PROMPT_TEMPLATE = """You are the Cortex AI Assistant, a powerful orchestrator of decentralized agent nodes.
## Architecture Highlights:
- You operate within a secure, gRPC-based mesh of Agent Nodes.
- You can execute shell commands, browse the web, and manage files on these nodes.
- You use 'skills' to interact with the physical world.
{mesh_context}
## Task:
Generate a natural and context-aware answer using the provided knowledge, conversation history, and available tools.
Relevant excerpts from the knowledge base:
{context}
Conversation History:
{chat_history}
User Question: {question}
Answer:"""
VOICE_PROMPT_TEMPLATE = """You are a conversational voice assistant.
Keep your responses short, natural, and helpful.
Avoid using technical jargon or listing technical infrastructure details unless specifically asked.
Focus on being a friendly companion.
Conversation History:
{chat_history}
User Question: {question}
Answer:"""
class RagPipeline:
"""
A flexible and extensible RAG pipeline updated to remove DSPy dependency.
"""
def __init__(
self,
context_postprocessor: Optional[Callable[[List[str]], str]] = None,
history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
response_postprocessor: Optional[Callable[[str], str]] = None,
):
self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
self.history_formatter = history_formatter or self._default_history_formatter
self.response_postprocessor = response_postprocessor
async def forward(
self,
question: str,
context_chunks: List[Dict[str, Any]],
history: List[models.Message],
llm_provider = None,
prompt_service = None,
tool_service = None,
tools: List[Dict[str, Any]] = None,
mesh_context: str = "",
db: Optional[Session] = None,
user_id: Optional[str] = None,
feature_name: str = "chat",
prompt_slug: str = "rag-pipeline"
) -> str:
logging.debug(f"[RagPipeline.forward] Received question: '{question}'")
if not llm_provider:
raise ValueError("LLM Provider is required.")
history_text = self.history_formatter(history)
context_text = self.context_postprocessor(context_chunks)
template = PROMPT_TEMPLATE
if feature_name == "voice":
template = VOICE_PROMPT_TEMPLATE
if prompt_service and db and user_id:
db_prompt = prompt_service.get_prompt_by_slug(db, prompt_slug, user_id)
if db_prompt:
template = db_prompt.content
system_prompt = template.format(
question=question,
context=context_text,
chat_history=history_text,
mesh_context=mesh_context
)
# 1. Prepare initial messages
# We put the 'question' as the user message and use 'system_prompt' for instructions/context
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
]
# 2. Agentic Tool Loop (Max 5 turns to prevent infinite loops)
for turn in range(5):
request_kwargs = {}
if tools:
request_kwargs["tools"] = tools
request_kwargs["tool_choice"] = "auto"
model = getattr(llm_provider, "model_name", "unknown")
# Safely calculate total characters in messages, handling None content
msg_lens = []
for m in messages:
content = ""
if hasattr(m, "content") and m.content is not None:
content = m.content
elif isinstance(m, dict):
content = m.get("content") or ""
msg_lens.append(len(content))
total_chars = sum(msg_lens)
tool_count = len(tools) if tools else 0
logging.info(f"[RagPipeline] Turn {turn+1} starting. Model: {model}, Messages: {len(messages)}, Total Chars: {total_chars}, Tools: {tool_count}")
# Log the specific turn metadata for diagnostics
logging.debug(f"[RagPipeline] Turn {turn+1} Payload Metadata: {messages}")
prediction = await llm_provider.acompletion(messages=messages, **request_kwargs)
choices = getattr(prediction, "choices", None)
if not choices or len(choices) == 0:
finish_reason = getattr(prediction, "finish_reason", "unknown")
if choices is not None and len(choices) == 0:
# Some providers return empty list for safety filters
logging.warning(f"[RagPipeline.forward] LLM ({model}) returned 0 choices. Turn: {turn+1}. Filter/Safety Trigger likely.")
logging.error(f"[RagPipeline.forward] LLM ({model}) failed at Turn {turn+1}. Reason: {finish_reason}. Full Response: {prediction}")
return (f"The AI provider ({model}) returned an empty response (Turn {turn+1}). "
f"Reason: {finish_reason}. Context: {total_chars} chars, {tool_count} tools. "
"This is often a safety filter blocking the prompt.")
message = prediction.choices[0].message
# If no tool calls, we are done
if not getattr(message, "tool_calls", None):
raw_response = message.content or ""
if self.response_postprocessor:
return self.response_postprocessor(raw_response)
return raw_response
# Process tool calls
messages.append(message) # Add assistant message with tool_calls
for tool_call in message.tool_calls:
func_name = tool_call.function.name
func_args = {}
try:
import json
func_args = json.loads(tool_call.function.arguments)
except: pass
logging.info(f"[🔧] Agent calling tool: {func_name} with {func_args}")
if tool_service:
result = await tool_service.call_tool(func_name, func_args, db=db, user_id=user_id)
else:
result = {"success": False, "error": "Tool service not available"}
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": func_name,
"content": str(result)
})
return "Agent loop reached maximum turns without a final response."
def _build_prompt(self, context, history, question):
return f"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.
Relevant excerpts from the knowledge base:
{context}
Conversation History:
{history}
User Question: {question}
Answer:"""
# Default context processor: concatenate chunks
def _default_context_postprocessor(self, contexts: List[str]) -> str:
return "\n\n".join(contexts) or "No context provided."
# Default history formatter: simple speaker prefix
def _default_history_formatter(self, history: List[models.Message]) -> str:
return "\n".join(
f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
for msg in history
)