Newer
Older
cortex-hub / ai-hub / tests / core / test_rag_service.py
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock, call
from sqlalchemy.orm import Session
import dspy

# Import what you are testing
from app.core.rag_service import RAGService, RAGPipeline, DSPyLLMProvider
# Import dependencies that need to be referenced
from app.core.retrievers import Retriever
from app.core.llm_providers import LLMProvider # For type checks if needed

# --- RAGService Unit Tests ---

# ... (Your successful add_document tests are fine and don't need changes) ...

# NOTE: The patch target for get_llm_provider has been corrected.
@patch('app.core.rag_service.get_llm_provider')
@patch('app.core.rag_service.RAGPipeline')
@patch('dspy.configure')
def test_rag_service_chat_with_rag_with_context(mock_configure, mock_rag_pipeline, mock_get_llm_provider):
    """
    Test the RAGService.chat_with_rag method when context is retrieved.
    """
    # --- Arrange ---
    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_db = MagicMock(spec=Session)

    mock_retriever = MagicMock(spec=Retriever)
    mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."]

    mock_rag_pipeline_instance = MagicMock(spec=RAGPipeline)
    mock_rag_pipeline_instance.forward = AsyncMock(return_value="LLM response with context")
    mock_rag_pipeline.return_value = mock_rag_pipeline_instance

    rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever])
    prompt = "Test prompt."

    # --- Act ---
    response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek"))

    # --- Assert ---
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    mock_configure.assert_called_once()
    lm_instance = mock_configure.call_args.kwargs['lm']
    
    # FIX 1: Assert it's an instance of the correct wrapper class.
    assert isinstance(lm_instance, DSPyLLMProvider)
    # This assertion will now pass because the patch target is correct.
    assert lm_instance.provider == mock_llm_provider

    mock_rag_pipeline.assert_called_once_with(retrievers=[mock_retriever])
    mock_rag_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db)
    assert response_text == "LLM response with context"


# NOTE: The patch target for get_llm_provider has been corrected.
@patch('app.core.rag_service.get_llm_provider')
@patch('app.core.rag_service.RAGPipeline')
@patch('dspy.configure')
def test_rag_service_chat_with_rag_without_context(mock_configure, mock_rag_pipeline, mock_get_llm_provider):
    """
    Test the RAGService.chat_with_rag method when no context is retrieved.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider

    mock_retriever = MagicMock(spec=Retriever)
    mock_retriever.retrieve_context.return_value = []

    mock_rag_pipeline_instance = MagicMock(spec=RAGPipeline)
    mock_rag_pipeline_instance.forward = AsyncMock(return_value="LLM response without context")
    mock_rag_pipeline.return_value = mock_rag_pipeline_instance

    rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever])
    prompt = "Test prompt without context."

    # --- Act ---
    response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek"))

    # --- Assert ---
    mock_get_llm_provider.assert_called_once_with("deepseek")

    mock_configure.assert_called_once()
    lm_instance = mock_configure.call_args.kwargs['lm']

    assert isinstance(lm_instance, DSPyLLMProvider)
    # This assertion will now pass because the patch target is correct.
    assert lm_instance.provider == mock_llm_provider

    mock_rag_pipeline.assert_called_once_with(retrievers=[mock_retriever])
    mock_rag_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db)
    assert response_text == "LLM response without context"