Newer
Older
cortex-hub / ai-hub / tests / core / test_rag_service.py
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock, call
from sqlalchemy.orm import Session
from typing import List

# Import the RAGService class and its dependencies
from app.core.rag_service import RAGService
from app.core.vector_store import FaissVectorStore
from app.core.retrievers import Retriever
from app.db import models

# --- RAGService Unit Tests ---
# These tests directly target the methods of the RAGService class
# to verify their internal logic in isolation.

@patch('app.db.models.VectorMetadata')
@patch('app.db.models.Document')
@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model):
    """
    Test the RAGService.add_document method for a successful run.
    Verifies that the method correctly calls db.add(), db.commit(), and the vector store.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    mock_new_document_instance = MagicMock()
    mock_document_model.return_value = mock_new_document_instance
    mock_new_document_instance.id = 1
    mock_new_document_instance.text = "Test text."
    mock_new_document_instance.title = "Test Title"

    mock_vector_store_instance = mock_vector_store.return_value
    mock_vector_store_instance.add_document.return_value = 123
    
    # Instantiate the service with the mock dependencies
    rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[])

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test
    document_id = rag_service.add_document(db=mock_db, doc_data=doc_data)

    # Assertions
    assert document_id == 1
    
    # Use mock.call to check for both calls to db.add in the correct order.
    # We must mock the VectorMetadata model to check its constructor call
    expected_calls = [
        call(mock_new_document_instance),
        call(mock_vector_metadata_model.return_value)
    ]
    mock_db.add.assert_has_calls(expected_calls)

    mock_db.commit.assert_called()
    mock_db.refresh.assert_called_with(mock_new_document_instance)
    mock_vector_store_instance.add_document.assert_called_once_with("Test text.")
    
    # Assert that VectorMetadata was instantiated with the correct arguments
    mock_vector_metadata_model.assert_called_once_with(
        document_id=mock_new_document_instance.id,
        faiss_index=mock_vector_store_instance.add_document.return_value,
        embedding_model="mock_embedder"
    )

@patch('app.core.vector_store.FaissVectorStore')
def test_rag_service_add_document_error_handling(mock_vector_store):
    """
    Test the RAGService.add_document method's error handling.
    Verifies that the transaction is rolled back on an exception.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    
    # Configure the mock db.add to raise an exception
    mock_db.add.side_effect = Exception("Database error")
    
    mock_vector_store_instance = mock_vector_store.return_value
    
    # Instantiate the service
    rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[])

    doc_data = {
        "title": "Test Title",
        "text": "Test text.",
        "source_url": "http://test.com"
    }
    
    # Call the method under test and expect an exception
    try:
        rag_service.add_document(db=mock_db, doc_data=doc_data)
        assert False, "Expected an exception to be raised"
    except Exception as e:
        assert str(e) == "Database error"
        
    # Assertions
    # The first db.add was called
    mock_db.add.assert_called_once()
    # No commit should have occurred
    mock_db.commit.assert_not_called()
    # The transaction should have been rolled back
    mock_db.rollback.assert_called_once()


@patch('app.core.rag_service.get_llm_provider')
def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider):
    """
    Test the RAGService.chat_with_rag method when context is retrieved.
    Verifies that the RAG prompt is correctly constructed.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    mock_llm_provider = MagicMock()
    mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context")
    mock_get_llm_provider.return_value = mock_llm_provider
    
    mock_retriever = MagicMock(spec=Retriever)
    mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."]
    
    # Instantiate the service with the mock retriever
    rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever])

    prompt = "Test prompt."
    
    # Call the method under test and run the async function
    response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek"))

    # Assertions
    expected_context = "Context text 1.\n\nContext text 2."
    mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db)
    
    mock_llm_provider.generate_response.assert_called_once()
    actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0]
    
    # Check if the generated prompt contains the expected context and question
    assert expected_context in actual_llm_prompt
    assert prompt in actual_llm_prompt
    assert response_text == "LLM response with context"

@patch('app.core.rag_service.get_llm_provider')
def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider):
    """
    Test the RAGService.chat_with_rag method when no context is retrieved.
    Verifies that the original prompt is sent to the LLM.
    """
    # Setup mocks
    mock_db = MagicMock(spec=Session)
    mock_llm_provider = MagicMock()
    mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context")
    mock_get_llm_provider.return_value = mock_llm_provider
    
    mock_retriever = MagicMock(spec=Retriever)
    mock_retriever.retrieve_context.return_value = []
    
    # Instantiate the service with the mock retriever
    rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever])
    
    prompt = "Test prompt without context."
    
    # Call the method under test and run the async function
    response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek"))

    # Assertions
    mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db)
    
    mock_llm_provider.generate_response.assert_called_once_with(prompt)
    assert response_text == "LLM response without context"