import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from typing import List from datetime import datetime import dspy # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore, MockEmbedder # Import FaissDBRetriever and a mock WebRetriever for testing different cases from app.core.retrievers import FaissDBRetriever, Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @pytest.fixture def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. It includes a mock FaissDBRetriever and a mock generic Retriever to test conditional loading. """ # Create a mock embedder to be attached to the vector store mock mock_embedder = MagicMock(spec=MockEmbedder) mock_vector_store = MagicMock(spec=FaissVectorStore) mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) mock_web_retriever = MagicMock(spec=Retriever) return RAGService( vector_store=mock_vector_store, retrievers=[mock_web_retriever, mock_faiss_retriever] ) # --- Session Management Tests --- def test_create_session(rag_service: RAGService): """Tests that the create_session method correctly creates a new session.""" mock_db = MagicMock(spec=Session) rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") mock_db.add.assert_called_once() added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session using the default model and with the retriever loading parameter explicitly set to False. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=42, prompt="Test prompt", model="deepseek", load_faiss_retriever=False # Explicitly pass the default value ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") # Assert that DspyRagPipeline was initialized with an empty list of retrievers mock_dspy_pipeline.assert_called_once_with(retrievers=[]) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response" assert model_name == "deepseek" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests that chat_with_rag correctly switches the model based on the 'model' argument, while still using the default retriever setting. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=43, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini", load_faiss_retriever=False # Explicitly pass the default value ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("gemini") # Assert that DspyRagPipeline was initialized with an empty list of retrievers mock_dspy_pipeline.assert_called_once_with(retrievers=[]) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response from Gemini" assert model_name == "gemini" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests that the chat_with_rag method correctly initializes the DspyRagPipeline with the FaissDBRetriever when `load_faiss_retriever` is True. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=44, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- # Explicitly enable the FAISS retriever answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=44, prompt="Test prompt with FAISS", model="deepseek", load_faiss_retriever=True ) ) # --- Assert --- # The crucial part is to verify that the pipeline was called with the correct retriever expected_retrievers = [rag_service.faiss_retriever] mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt with FAISS", history=mock_session.messages, db=mock_db ) assert answer == "Response with FAISS context" assert model_name == "deepseek" def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) # Ensure mocked messages have created_at for sorting mock_session = models.Session(id=1, messages=[ models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act messages = rag_service.get_message_history(db=mock_db, session_id=1) # Assert assert len(messages) == 2 assert messages[0].created_at < messages[1].created_at # Verify sorting mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): """Tests retrieving history for a non-existent session.""" # Arrange mock_db = MagicMock(spec=Session) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # Act messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert assert messages is None # --- Document Management Tests --- @patch('app.db.models.VectorMetadata') @patch('app.db.models.Document') @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): """ Test the RAGService.add_document method for a successful run. Verifies that the method correctly calls db.add(), db.commit(), and the vector store. """ # Setup mocks mock_db = MagicMock(spec=Session) mock_new_document_instance = MagicMock() mock_document_model.return_value = mock_new_document_instance mock_new_document_instance.id = 1 mock_new_document_instance.text = "Test text." mock_new_document_instance.title = "Test Title" mock_vector_store_instance = mock_vector_store.return_value # Fix: Manually set the embedder on the mock vector store instance mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) mock_vector_store_instance.add_document.return_value = 123 # Instantiate the service correctly rag_service = RAGService( vector_store=mock_vector_store_instance, retrievers=[] ) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) # Assertions assert document_id == 1 from unittest.mock import call expected_calls = [ call(mock_new_document_instance), call(mock_vector_metadata_model.return_value) ] mock_db.add.assert_has_calls(expected_calls) mock_db.commit.assert_called() mock_db.refresh.assert_called_with(mock_new_document_instance) mock_vector_store_instance.add_document.assert_called_once_with("Test text.") # Assert that VectorMetadata was instantiated with the correct arguments mock_vector_metadata_model.assert_called_once_with( document_id=mock_new_document_instance.id, faiss_index=mock_vector_store_instance.add_document.return_value, embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder ) @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_error_handling(mock_vector_store): """ Test the RAGService.add_document method's error handling. Verifies that the transaction is rolled back on an exception. """ # Setup mocks mock_db = MagicMock(spec=Session) # Configure the mock db.add to raise the specific SQLAlchemyError. mock_db.add.side_effect = SQLAlchemyError("Database error") mock_vector_store_instance = mock_vector_store.return_value # Instantiate the service correctly rag_service = RAGService( vector_store=mock_vector_store_instance, retrievers=[] ) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test and expect an exception with pytest.raises(SQLAlchemyError, match="Database error"): rag_service.add_document(db=mock_db, doc_data=doc_data) # Assertions mock_db.add.assert_called_once() mock_db.commit.assert_not_called() mock_db.rollback.assert_called_once()