Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError # Import the specific error type
from typing import List
from datetime import datetime
import dspy

# Import the service and its dependencies
from app.core.services import RAGService
from app.db import models
from app.core.vector_store import FaissVectorStore
from app.core.retrievers import Retriever
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider 
from app.core.llm_providers import LLMProvider


@pytest.fixture
def rag_service():
    """
    Pytest fixture to create a RAGService instance with mocked dependencies.
    Correctly instantiates RAGService with only the required arguments.
    """
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    mock_retriever = MagicMock(spec=Retriever)
    return RAGService(
        vector_store=mock_vector_store,
        retrievers=[mock_retriever]
    )

# --- Session Management Tests ---

def test_create_session(rag_service: RAGService):
    """Tests that the create_session method correctly creates a new session."""
    mock_db = MagicMock(spec=Session)
    
    rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
    mock_db.add.assert_called_once()
    added_object = mock_db.add.call_args[0][0]
    assert isinstance(added_object, models.Session)
    assert added_object.user_id == "test_user"
    assert added_object.model_name == "gemini"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests the full orchestration of a chat message within a session using the default model.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    # The mock session now needs a 'messages' attribute for the history
    mock_session = models.Session(id=42, model_name="deepseek", messages=[])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    # Pass the 'model' argument, defaulting to "deepseek" for this test case
    answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek"))

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    # Assert that the pipeline was called with the history argument
    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response"
    assert model_name == "deepseek"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests that chat_with_rag correctly switches the model based on the 'model' argument.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    # Explicitly request the "gemini" model for this chat turn
    answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini"))

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    # Verify that get_llm_provider was called with "gemini"
    mock_get_llm_provider.assert_called_once_with("gemini") 
    
    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt for Gemini",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response from Gemini"
    assert model_name == "gemini"

def test_get_message_history_success(rag_service: RAGService):
    """Tests successfully retrieving message history for an existing session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    # Ensure mocked messages have created_at for sorting
    mock_session = models.Session(id=1, messages=[
        models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)),
        models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0))
    ])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=1)

    # Assert
    assert len(messages) == 2
    assert messages[0].created_at < messages[1].created_at # Verify sorting
    mock_db.query.assert_called_once_with(models.Session)

def test_get_message_history_not_found(rag_service: RAGService):
    """Tests retrieving history for a non-existent session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=999)

    # Assert
    assert messages is None

# --- Document Management Tests ---

@patch('app.db.models.VectorMetadata')
@patch('app.db.models.Document')
@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model):
    """
    Test the RAGService.add_document method for a successful run.
    Verifies that the method correctly calls db.add(), db.commit(), and the vector store.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    mock_new_document_instance = MagicMock()
    mock_document_model.return_value = mock_new_document_instance
    mock_new_document_instance.id = 1
    mock_new_document_instance.text = "Test text."
    mock_new_document_instance.title = "Test Title"

    mock_vector_store_instance = mock_vector_store.return_value
    mock_vector_store_instance.add_document.return_value = 123
    
    # Instantiate the service correctly
    rag_service = RAGService(
        vector_store=mock_vector_store_instance,
        retrievers=[]
    )

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test
    document_id = rag_service.add_document(db=mock_db, doc_data=doc_data)

    # Assertions
    assert document_id == 1
    
    from unittest.mock import call
    expected_calls = [
        call(mock_new_document_instance),
        call(mock_vector_metadata_model.return_value)
    ]
    mock_db.add.assert_has_calls(expected_calls)

    mock_db.commit.assert_called()
    mock_db.refresh.assert_called_with(mock_new_document_instance)
    mock_vector_store_instance.add_document.assert_called_once_with("Test text.")
    
    # Assert that VectorMetadata was instantiated with the correct arguments
    mock_vector_metadata_model.assert_called_once_with(
        document_id=mock_new_document_instance.id,
        faiss_index=mock_vector_store_instance.add_document.return_value,
        embedding_model="mock_embedder"
    )

@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_error_handling(mock_vector_store):
    """
    Test the RAGService.add_document method's error handling.
    Verifies that the transaction is rolled back on an exception.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    
    # Configure the mock db.add to raise the specific SQLAlchemyError.
    mock_db.add.side_effect = SQLAlchemyError("Database error")
    
    mock_vector_store_instance = mock_vector_store.return_value
    
    # Instantiate the service correctly
    rag_service = RAGService(
        vector_store=mock_vector_store_instance,
        retrievers=[]
    )

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test and expect an exception
    with pytest.raises(SQLAlchemyError, match="Database error"):
        rag_service.add_document(db=mock_db, doc_data=doc_data)
        
    # Assertions
    mock_db.add.assert_called_once()
    mock_db.commit.assert_not_called()
    mock_db.rollback.assert_called_once()

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Test the RAGService.chat_with_rag method when context is retrieved.
    Verifies that the RAG prompt is correctly constructed.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=1, model_name="deepseek", messages=[
        models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)),
        models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0))
    ])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context")
    mock_dspy_pipeline.return_value = mock_pipeline_instance
    
    prompt = "Test prompt."
    expected_context = "Context text 1.\n\nContext text 2."
    mock_retriever = rag_service.retrievers[0]
    mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."])

    # --- Act ---
    response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek"))

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    mock_pipeline_instance.forward.assert_called_once_with(
        question=prompt,
        history=mock_session.messages,
        db=mock_db
    )

    assert response_text == "LLM response with context"
    assert model_used == "deepseek"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Test the RAGService.chat_with_rag method when no context is retrieved.
    Verifies that the original prompt is sent to the LLM.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=1, model_name="deepseek", messages=[])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    prompt = "Test prompt without context."
    mock_retriever = rag_service.retrievers[0]
    mock_retriever.retrieve_context = AsyncMock(return_value=[])

    # --- Act ---
    response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek"))

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    mock_pipeline_instance.forward.assert_called_once_with(
        question=prompt,
        history=mock_session.messages,
        db=mock_db
    )
    
    assert response_text == "LLM response without context"
    assert model_used == "deepseek"