import asyncio from unittest.mock import patch, MagicMock, AsyncMock, call from sqlalchemy.orm import Session import dspy # Import what you are testing from app.core.rag_service import RAGService, RAGPipeline, DSPyLLMProvider # Import dependencies that need to be referenced from app.core.retrievers import Retriever from app.core.llm_providers import LLMProvider # For type checks if needed # --- RAGService Unit Tests --- # ... (Your successful add_document tests are fine and don't need changes) ... # NOTE: The patch target for get_llm_provider has been corrected. @patch('app.core.rag_service.get_llm_provider') @patch('app.core.rag_service.RAGPipeline') @patch('dspy.configure') def test_rag_service_chat_with_rag_with_context(mock_configure, mock_rag_pipeline, mock_get_llm_provider): """ Test the RAGService.chat_with_rag method when context is retrieved. """ # --- Arrange --- mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_db = MagicMock(spec=Session) mock_retriever = MagicMock(spec=Retriever) mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] mock_rag_pipeline_instance = MagicMock(spec=RAGPipeline) mock_rag_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") mock_rag_pipeline.return_value = mock_rag_pipeline_instance rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) prompt = "Test prompt." # --- Act --- response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) # --- Assert --- mock_get_llm_provider.assert_called_once_with("deepseek") mock_configure.assert_called_once() lm_instance = mock_configure.call_args.kwargs['lm'] # FIX 1: Assert it's an instance of the correct wrapper class. assert isinstance(lm_instance, DSPyLLMProvider) # This assertion will now pass because the patch target is correct. assert lm_instance.provider == mock_llm_provider mock_rag_pipeline.assert_called_once_with(retrievers=[mock_retriever]) mock_rag_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db) assert response_text == "LLM response with context" # NOTE: The patch target for get_llm_provider has been corrected. @patch('app.core.rag_service.get_llm_provider') @patch('app.core.rag_service.RAGPipeline') @patch('dspy.configure') def test_rag_service_chat_with_rag_without_context(mock_configure, mock_rag_pipeline, mock_get_llm_provider): """ Test the RAGService.chat_with_rag method when no context is retrieved. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_retriever = MagicMock(spec=Retriever) mock_retriever.retrieve_context.return_value = [] mock_rag_pipeline_instance = MagicMock(spec=RAGPipeline) mock_rag_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") mock_rag_pipeline.return_value = mock_rag_pipeline_instance rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) prompt = "Test prompt without context." # --- Act --- response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) # --- Assert --- mock_get_llm_provider.assert_called_once_with("deepseek") mock_configure.assert_called_once() lm_instance = mock_configure.call_args.kwargs['lm'] assert isinstance(lm_instance, DSPyLLMProvider) # This assertion will now pass because the patch target is correct. assert lm_instance.provider == mock_llm_provider mock_rag_pipeline.assert_called_once_with(retrievers=[mock_retriever]) mock_rag_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db) assert response_text == "LLM response without context"