Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from typing import List
from datetime import datetime
import dspy

# Import the service and its dependencies
from app.core.services import RAGService
from app.db import models
from app.core.vector_store import FaissVectorStore
# Import FaissDBRetriever and a mock WebRetriever for testing different cases
from app.core.retrievers import FaissDBRetriever, Retriever 
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider 
from app.core.llm_providers import LLMProvider


@pytest.fixture
def rag_service():
    """
    Pytest fixture to create a RAGService instance with mocked dependencies.
    It includes a mock FaissDBRetriever and a mock generic Retriever to test
    conditional loading.
    """
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    mock_faiss_retriever = MagicMock(spec=FaissDBRetriever)
    mock_web_retriever = MagicMock(spec=Retriever)
    return RAGService(
        vector_store=mock_vector_store,
        retrievers=[mock_web_retriever, mock_faiss_retriever]
    )

# --- Session Management Tests ---

def test_create_session(rag_service: RAGService):
    """Tests that the create_session method correctly creates a new session."""
    mock_db = MagicMock(spec=Session)
    
    rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
    mock_db.add.assert_called_once()
    added_object = mock_db.add.call_args[0][0]
    assert isinstance(added_object, models.Session)
    assert added_object.user_id == "test_user"
    assert added_object.model_name == "gemini"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests the full orchestration of a chat message within a session using the default model
    and with the retriever loading parameter explicitly set to False.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=42, model_name="deepseek", messages=[])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    answer, model_name = asyncio.run(
        rag_service.chat_with_rag(
            db=mock_db,
            session_id=42,
            prompt="Test prompt",
            model="deepseek",
            load_faiss_retriever=False # Explicitly pass the default value
        )
    )

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    # Assert that DspyRagPipeline was initialized with an empty list of retrievers
    mock_dspy_pipeline.assert_called_once_with(retrievers=[])

    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response"
    assert model_name == "deepseek"


@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests that chat_with_rag correctly switches the model based on the 'model' argument,
    while still using the default retriever setting.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=43, model_name="deepseek", messages=[]) 
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    answer, model_name = asyncio.run(
        rag_service.chat_with_rag(
            db=mock_db,
            session_id=43,
            prompt="Test prompt for Gemini",
            model="gemini",
            load_faiss_retriever=False # Explicitly pass the default value
        )
    )

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("gemini") 
    
    # Assert that DspyRagPipeline was initialized with an empty list of retrievers
    mock_dspy_pipeline.assert_called_once_with(retrievers=[])

    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt for Gemini",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response from Gemini"
    assert model_name == "gemini"


@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests that the chat_with_rag method correctly initializes the DspyRagPipeline
    with the FaissDBRetriever when `load_faiss_retriever` is True.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=44, model_name="deepseek", messages=[]) 
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context")
    mock_dspy_pipeline.return_value = mock_pipeline_instance
    
    # --- Act ---
    # Explicitly enable the FAISS retriever
    answer, model_name = asyncio.run(
        rag_service.chat_with_rag(
            db=mock_db,
            session_id=44,
            prompt="Test prompt with FAISS",
            model="deepseek",
            load_faiss_retriever=True
        )
    )
    
    # --- Assert ---
    # The crucial part is to verify that the pipeline was called with the correct retriever
    expected_retrievers = [rag_service.faiss_retriever]
    mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers)
    
    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt with FAISS",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Response with FAISS context"
    assert model_name == "deepseek"


def test_get_message_history_success(rag_service: RAGService):
    """Tests successfully retrieving message history for an existing session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    # Ensure mocked messages have created_at for sorting
    mock_session = models.Session(id=1, messages=[
        models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)),
        models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0))
    ])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=1)

    # Assert
    assert len(messages) == 2
    assert messages[0].created_at < messages[1].created_at # Verify sorting
    mock_db.query.assert_called_once_with(models.Session)

def test_get_message_history_not_found(rag_service: RAGService):
    """Tests retrieving history for a non-existent session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=999)

    # Assert
    assert messages is None

# --- Document Management Tests ---

@patch('app.db.models.VectorMetadata')
@patch('app.db.models.Document')
@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model):
    """
    Test the RAGService.add_document method for a successful run.
    Verifies that the method correctly calls db.add(), db.commit(), and the vector store.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    mock_new_document_instance = MagicMock()
    mock_document_model.return_value = mock_new_document_instance
    mock_new_document_instance.id = 1
    mock_new_document_instance.text = "Test text."
    mock_new_document_instance.title = "Test Title"

    mock_vector_store_instance = mock_vector_store.return_value
    mock_vector_store_instance.add_document.return_value = 123
    
    # Instantiate the service correctly
    rag_service = RAGService(
        vector_store=mock_vector_store_instance,
        retrievers=[]
    )

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test
    document_id = rag_service.add_document(db=mock_db, doc_data=doc_data)

    # Assertions
    assert document_id == 1
    
    from unittest.mock import call
    expected_calls = [
        call(mock_new_document_instance),
        call(mock_vector_metadata_model.return_value)
    ]
    mock_db.add.assert_has_calls(expected_calls)

    mock_db.commit.assert_called()
    mock_db.refresh.assert_called_with(mock_new_document_instance)
    mock_vector_store_instance.add_document.assert_called_once_with("Test text.")
    
    # Assert that VectorMetadata was instantiated with the correct arguments
    mock_vector_metadata_model.assert_called_once_with(
        document_id=mock_new_document_instance.id,
        faiss_index=mock_vector_store_instance.add_document.return_value,
        embedding_model="mock_embedder"
    )

@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_error_handling(mock_vector_store):
    """
    Test the RAGService.add_document method's error handling.
    Verifies that the transaction is rolled back on an exception.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    
    # Configure the mock db.add to raise the specific SQLAlchemyError.
    mock_db.add.side_effect = SQLAlchemyError("Database error")
    
    mock_vector_store_instance = mock_vector_store.return_value
    
    # Instantiate the service correctly
    rag_service = RAGService(
        vector_store=mock_vector_store_instance,
        retrievers=[]
    )

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test and expect an exception
    with pytest.raises(SQLAlchemyError, match="Database error"):
        rag_service.add_document(db=mock_db, doc_data=doc_data)
        
    # Assertions
    mock_db.add.assert_called_once()
    mock_db.commit.assert_not_called()
    mock_db.rollback.assert_called_once()