Newer
Older
cortex-hub / ai-hub / app / core / retrievers.py
import abc
from typing import List, Dict
from sqlalchemy.orm import Session
from app.core.vector_store.faiss_store import FaissVectorStore
from app.db import models

class Retriever(abc.ABC):
    """
    Abstract base class for a Retriever.
    
    A retriever is a pluggable component that is responsible for fetching
    relevant context for a given query from a specific data source.
    """
    @abc.abstractmethod
    def retrieve_context(self, query: str, db: Session) -> List[str]:
        """
        Fetches context for a given query.
        
        Args:
            query (str): The user's query string.
            db (Session): The database session.
            
        Returns:
            List[str]: A list of text strings representing the retrieved context.
        """
        raise NotImplementedError

class FaissDBRetriever(Retriever):
    """
    A concrete retriever that uses a FAISS index and a local database
    to find and return relevant document text.
    """
    def __init__(self, vector_store: FaissVectorStore):
        self.vector_store = vector_store

    def retrieve_context(self, query: str, db: Session) -> List[str]:
        """
        Retrieves document text by first searching the FAISS index
        and then fetching the corresponding documents from the database.
        """
        faiss_ids = self.vector_store.search_similar_documents(query, k=3)
        context_docs_text = []

        if faiss_ids:
            # Use FAISS IDs to find the corresponding document_id from the database
            document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter(
                models.VectorMetadata.faiss_index.in_(faiss_ids)
            ).all()

            document_ids = [doc_id for (doc_id,) in document_ids_from_vectors]
            
            # Retrieve the full documents from the Document table
            context_docs = db.query(models.Document).filter(
                models.Document.id.in_(document_ids)
            ).all()
            
            context_docs_text = [doc.text for doc in context_docs]

        return context_docs_text

# You could add other retriever implementations here, like:
# class RemoteServiceRetriever(Retriever):
#     def retrieve_context(self, query: str, db: Session) -> List[str]:
#         # Logic to call a remote API and return context
#         ...