Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session

# Import the service being tested
from app.core.services import RAGService

# Import dependencies that need to be referenced in mocks
from app.core.retrievers import Retriever
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider
from app.core.llm_providers import LLMProvider


@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline') # Patched the new class name
@patch('dspy.configure')
def test_rag_service_orchestration(mock_configure, mock_dspy_pipeline, mock_get_llm_provider):
    """
    Tests that RAGService.chat_with_rag correctly orchestrates its dependencies.
    It should:
    1. Get the correct LLM provider.
    2. Configure DSPy with a wrapped provider.
    3. Instantiate and call the pipeline with the correct arguments.
    """
    # --- Arrange ---
    # Mock the dependencies that RAGService uses
    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_db = MagicMock(spec=Session)
    mock_retriever = MagicMock(spec=Retriever)

    # Mock the pipeline instance and its return value
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # Instantiate the service class we are testing
    rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever])
    prompt = "Test prompt."
    model = "deepseek"

    # --- Act ---
    response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model=model))

    # --- Assert ---
    # 1. Assert that the correct LLM provider was requested
    mock_get_llm_provider.assert_called_once_with(model)

    # 2. Assert that dspy was configured with a correctly wrapped provider
    mock_configure.assert_called_once()
    lm_instance = mock_configure.call_args.kwargs['lm']
    assert isinstance(lm_instance, DSPyLLMProvider)
    assert lm_instance.provider == mock_llm_provider

    # 3. Assert that the pipeline was instantiated and called correctly
    mock_dspy_pipeline.assert_called_once_with(retrievers=[mock_retriever])
    mock_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db)
    
    # 4. Assert the final response is returned
    assert response_text == "Final RAG response"