Newer
Older
cortex-hub / ai-hub / app / core / pipelines / rag_pipeline.py
import logging
from typing import List, Dict, Any, Optional, Callable
from sqlalchemy.orm import Session

from app.db import models

# Define a default prompt template outside the class or as a class constant
# This is inferred from the usage in the provided diff.
PROMPT_TEMPLATE = """Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.

Relevant excerpts from the knowledge base:
{context}

Conversation History:
{chat_history}

User Question: {question}

Answer:"""

class RagPipeline:
    """
    A flexible and extensible RAG pipeline updated to remove DSPy dependency.
    """

    def __init__(
        self,
        context_postprocessor: Optional[Callable[[List[str]], str]] = None,
        history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
        response_postprocessor: Optional[Callable[[str], str]] = None,
    ):
        self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
        self.history_formatter = history_formatter or self._default_history_formatter
        self.response_postprocessor = response_postprocessor

    async def forward(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider = None,
        prompt_service = None,
        db: Optional[Session] = None,
        user_id: Optional[str] = None,
        prompt_slug: str = "rag-pipeline"
    ) -> str:
        logging.debug(f"[RagPipeline.forward] Received question: '{question}'")

        if not llm_provider:
            raise ValueError("LLM Provider is required.")

        history_text = self.history_formatter(history)
        context_text = self.context_postprocessor(context_chunks)

        template = PROMPT_TEMPLATE
        if prompt_service and db and user_id:
            db_prompt = prompt_service.get_prompt_by_slug(db, prompt_slug, user_id)
            if db_prompt:
                template = db_prompt.content

        prompt = template.format(
            question=question,
            context=context_text,
            chat_history=history_text
        )
        
        prediction = await llm_provider.acompletion(prompt=prompt)
        raw_response = prediction.choices[0].message.content

        # Step 4: Optional response postprocessing
        if self.response_postprocessor:
            return self.response_postprocessor(raw_response)

        return raw_response

    def _build_prompt(self, context, history, question):
        return f"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.

Relevant excerpts from the knowledge base:
{context}

Conversation History:
{history}

User Question: {question}

Answer:"""

    # Default context processor: concatenate chunks
    def _default_context_postprocessor(self, contexts: List[str]) -> str:
        return "\n\n".join(contexts) or "No context provided."

    # Default history formatter: simple speaker prefix
    def _default_history_formatter(self, history: List[models.Message]) -> str:
        return "\n".join(
            f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
            for msg in history
        )