Newer
Older
cortex-hub / ai-hub / tests / core / services / test_rag.py
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from typing import List
from datetime import datetime
import dspy

from app.core.services.rag import RAGService
from app.db import models
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.vector_store.embedder.mock import MockEmbedder
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever 
from app.core.retrievers.base_retriever import Retriever 
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.llm_providers import LLMProvider

@pytest.fixture
def rag_service():
    """
    Pytest fixture to create a RAGService instance with mocked dependencies.
    It includes a mock FaissDBRetriever and a mock generic Retriever to test
    conditional loading.
    """
    # Create a mock vector store to provide a mock retriever
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    
    mock_faiss_retriever = MagicMock(spec=FaissDBRetriever)
    mock_web_retriever = MagicMock(spec=Retriever)
    
    return RAGService(
        retrievers=[mock_web_retriever, mock_faiss_retriever]
    )

# --- Session Management Tests ---

def test_create_session(rag_service: RAGService):
    """Tests that the create_session method correctly creates a new session."""
    mock_db = MagicMock(spec=Session)
    
    rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
    mock_db.add.assert_called_once()
    added_object = mock_db.add.call_args[0][0]
    assert isinstance(added_object, models.Session)
    assert added_object.user_id == "test_user"
    assert added_object.model_name == "gemini"

@patch('app.core.services.rag.get_llm_provider')
@patch('app.core.services.rag.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests the full orchestration of a chat message within a session using the default model
    and with the retriever loading parameter explicitly set to False.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=42, model_name="deepseek", messages=[])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    answer, model_name = asyncio.run(
        rag_service.chat_with_rag(
            db=mock_db,
            session_id=42,
            prompt="Test prompt",
            model="deepseek",
            load_faiss_retriever=False
        )
    )

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    mock_dspy_pipeline.assert_called_once_with(retrievers=[])

    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response"
    assert model_name == "deepseek"

def test_chat_with_rag_model_switch(rag_service: RAGService):
    """
    Tests that chat_with_rag correctly switches the model based on the 'model' argument,
    while still using the default retriever setting.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=43, model_name="deepseek", messages=[]) 
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \
         patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \
         patch('dspy.configure'):

        mock_llm_provider = MagicMock(spec=LLMProvider)
        mock_get_llm_provider.return_value = mock_llm_provider
        mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
        mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini")
        mock_dspy_pipeline.return_value = mock_pipeline_instance

        # --- Act ---
        answer, model_name = asyncio.run(
            rag_service.chat_with_rag(
                db=mock_db,
                session_id=43,
                prompt="Test prompt for Gemini",
                model="gemini",
                load_faiss_retriever=False
            )
        )

        # --- Assert ---
        mock_db.query.assert_called_once_with(models.Session)
        assert mock_db.add.call_count == 2
        mock_get_llm_provider.assert_called_once_with("gemini") 
        
        mock_dspy_pipeline.assert_called_once_with(retrievers=[])

        mock_pipeline_instance.forward.assert_called_once_with(
            question="Test prompt for Gemini",
            history=mock_session.messages,
            db=mock_db
        )
        
        assert answer == "Final RAG response from Gemini"
        assert model_name == "gemini"


def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService):
    """
    Tests that the chat_with_rag method correctly initializes the DspyRagPipeline
    with the FaissDBRetriever when `load_faiss_retriever` is True.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=44, model_name="deepseek", messages=[]) 
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \
         patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \
         patch('dspy.configure'):

        mock_llm_provider = MagicMock(spec=LLMProvider)
        mock_get_llm_provider.return_value = mock_llm_provider
        mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
        mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context")
        mock_dspy_pipeline.return_value = mock_pipeline_instance
        
        # --- Act ---
        answer, model_name = asyncio.run(
            rag_service.chat_with_rag(
                db=mock_db,
                session_id=44,
                prompt="Test prompt with FAISS",
                model="deepseek",
                load_faiss_retriever=True
            )
        )
        
        # --- Assert ---
        expected_retrievers = [rag_service.faiss_retriever]
        mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers)
        
        mock_pipeline_instance.forward.assert_called_once_with(
            question="Test prompt with FAISS",
            history=mock_session.messages,
            db=mock_db
        )
        
        assert answer == "Response with FAISS context"
        assert model_name == "deepseek"


def test_get_message_history_success(rag_service: RAGService):
    """Tests successfully retrieving message history for an existing session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=1, messages=[
        models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)),
        models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0))
    ])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=1)

    # Assert
    assert len(messages) == 2
    assert messages[0].created_at < messages[1].created_at
    mock_db.query.assert_called_once_with(models.Session)

def test_get_message_history_not_found(rag_service: RAGService):
    """Tests retrieving history for a non-existent session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=999)

    # Assert
    assert messages is None