import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session # Import the service being tested from app.core.rag_service import RAGService # Import dependencies that need to be referenced in mocks from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @patch('app.core.rag_service.get_llm_provider') @patch('app.core.rag_service.DspyRagPipeline') # Patched the new class name @patch('dspy.configure') def test_rag_service_orchestration(mock_configure, mock_dspy_pipeline, mock_get_llm_provider): """ Tests that RAGService.chat_with_rag correctly orchestrates its dependencies. It should: 1. Get the correct LLM provider. 2. Configure DSPy with a wrapped provider. 3. Instantiate and call the pipeline with the correct arguments. """ # --- Arrange --- # Mock the dependencies that RAGService uses mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_db = MagicMock(spec=Session) mock_retriever = MagicMock(spec=Retriever) # Mock the pipeline instance and its return value mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # Instantiate the service class we are testing rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) prompt = "Test prompt." model = "deepseek" # --- Act --- response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model=model)) # --- Assert --- # 1. Assert that the correct LLM provider was requested mock_get_llm_provider.assert_called_once_with(model) # 2. Assert that dspy was configured with a correctly wrapped provider mock_configure.assert_called_once() lm_instance = mock_configure.call_args.kwargs['lm'] assert isinstance(lm_instance, DSPyLLMProvider) assert lm_instance.provider == mock_llm_provider # 3. Assert that the pipeline was instantiated and called correctly mock_dspy_pipeline.assert_called_once_with(retrievers=[mock_retriever]) mock_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db) # 4. Assert the final response is returned assert response_text == "Final RAG response"