# import pytest # import asyncio # from unittest.mock import patch, MagicMock, AsyncMock # from sqlalchemy.orm import Session # from sqlalchemy.exc import SQLAlchemyError # from typing import List # from datetime import datetime # import dspy # # Import the service and its dependencies # from app.core.services import RAGService # from app.db import models # from app.core.vector_store.faiss_store import FaissVectorStore # from app.core.vector_store.embedder.mock import MockEmbedder # # Import FaissDBRetriever and a mock WebRetriever for testing different cases # from app.core.retrievers import FaissDBRetriever, Retriever # from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider # from app.core.llm_providers import LLMProvider # @pytest.fixture # def rag_service(): # """ # Pytest fixture to create a RAGService instance with mocked dependencies. # It includes a mock FaissDBRetriever and a mock generic Retriever to test # conditional loading. # """ # # Create a mock embedder to be attached to the vector store mock # mock_embedder = MagicMock(spec=MockEmbedder) # mock_vector_store = MagicMock(spec=FaissVectorStore) # mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute # mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) # mock_web_retriever = MagicMock(spec=Retriever) # return RAGService( # vector_store=mock_vector_store, # retrievers=[mock_web_retriever, mock_faiss_retriever] # ) # # --- Session Management Tests --- # def test_create_session(rag_service: RAGService): # """Tests that the create_session method correctly creates a new session.""" # mock_db = MagicMock(spec=Session) # rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") # mock_db.add.assert_called_once() # added_object = mock_db.add.call_args[0][0] # assert isinstance(added_object, models.Session) # assert added_object.user_id == "test_user" # assert added_object.model_name == "gemini" # @patch('app.core.services.get_llm_provider') # @patch('app.core.services.DspyRagPipeline') # @patch('dspy.configure') # def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): # """ # Tests the full orchestration of a chat message within a session using the default model # and with the retriever loading parameter explicitly set to False. # """ # # --- Arrange --- # mock_db = MagicMock(spec=Session) # mock_session = models.Session(id=42, model_name="deepseek", messages=[]) # mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # mock_llm_provider = MagicMock(spec=LLMProvider) # mock_get_llm_provider.return_value = mock_llm_provider # mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) # mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") # mock_dspy_pipeline.return_value = mock_pipeline_instance # # --- Act --- # answer, model_name = asyncio.run( # rag_service.chat_with_rag( # db=mock_db, # session_id=42, # prompt="Test prompt", # model="deepseek", # load_faiss_retriever=False # Explicitly pass the default value # ) # ) # # --- Assert --- # mock_db.query.assert_called_once_with(models.Session) # assert mock_db.add.call_count == 2 # mock_get_llm_provider.assert_called_once_with("deepseek") # # Assert that DspyRagPipeline was initialized with an empty list of retrievers # mock_dspy_pipeline.assert_called_once_with(retrievers=[]) # mock_pipeline_instance.forward.assert_called_once_with( # question="Test prompt", # history=mock_session.messages, # db=mock_db # ) # assert answer == "Final RAG response" # assert model_name == "deepseek" # @patch('app.core.services.get_llm_provider') # @patch('app.core.services.DspyRagPipeline') # @patch('dspy.configure') # def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): # """ # Tests that chat_with_rag correctly switches the model based on the 'model' argument, # while still using the default retriever setting. # """ # # --- Arrange --- # mock_db = MagicMock(spec=Session) # mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # mock_llm_provider = MagicMock(spec=LLMProvider) # mock_get_llm_provider.return_value = mock_llm_provider # mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) # mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") # mock_dspy_pipeline.return_value = mock_pipeline_instance # # --- Act --- # answer, model_name = asyncio.run( # rag_service.chat_with_rag( # db=mock_db, # session_id=43, # prompt="Test prompt for Gemini", # model="gemini", # load_faiss_retriever=False # Explicitly pass the default value # ) # ) # # --- Assert --- # mock_db.query.assert_called_once_with(models.Session) # assert mock_db.add.call_count == 2 # mock_get_llm_provider.assert_called_once_with("gemini") # # Assert that DspyRagPipeline was initialized with an empty list of retrievers # mock_dspy_pipeline.assert_called_once_with(retrievers=[]) # mock_pipeline_instance.forward.assert_called_once_with( # question="Test prompt for Gemini", # history=mock_session.messages, # db=mock_db # ) # assert answer == "Final RAG response from Gemini" # assert model_name == "gemini" # @patch('app.core.services.get_llm_provider') # @patch('app.core.services.DspyRagPipeline') # @patch('dspy.configure') # def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): # """ # Tests that the chat_with_rag method correctly initializes the DspyRagPipeline # with the FaissDBRetriever when `load_faiss_retriever` is True. # """ # # --- Arrange --- # mock_db = MagicMock(spec=Session) # mock_session = models.Session(id=44, model_name="deepseek", messages=[]) # mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # mock_llm_provider = MagicMock(spec=LLMProvider) # mock_get_llm_provider.return_value = mock_llm_provider # mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) # mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") # mock_dspy_pipeline.return_value = mock_pipeline_instance # # --- Act --- # # Explicitly enable the FAISS retriever # answer, model_name = asyncio.run( # rag_service.chat_with_rag( # db=mock_db, # session_id=44, # prompt="Test prompt with FAISS", # model="deepseek", # load_faiss_retriever=True # ) # ) # # --- Assert --- # # The crucial part is to verify that the pipeline was called with the correct retriever # expected_retrievers = [rag_service.faiss_retriever] # mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) # mock_pipeline_instance.forward.assert_called_once_with( # question="Test prompt with FAISS", # history=mock_session.messages, # db=mock_db # ) # assert answer == "Response with FAISS context" # assert model_name == "deepseek" # def test_get_message_history_success(rag_service: RAGService): # """Tests successfully retrieving message history for an existing session.""" # # Arrange # mock_db = MagicMock(spec=Session) # # Ensure mocked messages have created_at for sorting # mock_session = models.Session(id=1, messages=[ # models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), # models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) # ]) # mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # # Act # messages = rag_service.get_message_history(db=mock_db, session_id=1) # # Assert # assert len(messages) == 2 # assert messages[0].created_at < messages[1].created_at # Verify sorting # mock_db.query.assert_called_once_with(models.Session) # def test_get_message_history_not_found(rag_service: RAGService): # """Tests retrieving history for a non-existent session.""" # # Arrange # mock_db = MagicMock(spec=Session) # mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # # Act # messages = rag_service.get_message_history(db=mock_db, session_id=999) # # Assert # assert messages is None # # --- Document Management Tests --- # @patch('app.db.models.VectorMetadata') # @patch('app.db.models.Document') # @patch('app.core.vector_store.faiss_store.FaissVectorStore') # def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): # """ # Test the RAGService.add_document method for a successful run. # Verifies that the method correctly calls db.add(), db.commit(), and the vector store. # """ # # Setup mocks # mock_db = MagicMock(spec=Session) # mock_new_document_instance = MagicMock() # mock_document_model.return_value = mock_new_document_instance # mock_new_document_instance.id = 1 # mock_new_document_instance.text = "Test text." # mock_new_document_instance.title = "Test Title" # mock_vector_store_instance = mock_vector_store.return_value # # Fix: Manually set the embedder on the mock vector store instance # mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) # mock_vector_store_instance.add_document.return_value = 123 # # Instantiate the service correctly # rag_service = RAGService( # vector_store=mock_vector_store_instance, # retrievers=[] # ) # doc_data = { # "title": "Test Title", # "text": "Test text.", # "source_url": "http://test.com" # } # # Call the method under test # document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) # # Assertions # assert document_id == 1 # from unittest.mock import call # expected_calls = [ # call(mock_new_document_instance), # call(mock_vector_metadata_model.return_value) # ] # mock_db.add.assert_has_calls(expected_calls) # mock_db.commit.assert_called() # mock_db.refresh.assert_called_with(mock_new_document_instance) # mock_vector_store_instance.add_document.assert_called_once_with("Test text.") # # Assert that VectorMetadata was instantiated with the correct arguments # mock_vector_metadata_model.assert_called_once_with( # document_id=mock_new_document_instance.id, # faiss_index=mock_vector_store_instance.add_document.return_value, # embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder # ) # @patch('app.core.vector_store.faiss_store.FaissVectorStore') # def test_rag_service_add_document_error_handling(mock_vector_store): # """ # Test the RAGService.add_document method's error handling. # Verifies that the transaction is rolled back on an exception. # """ # # Setup mocks # mock_db = MagicMock(spec=Session) # # Configure the mock db.add to raise the specific SQLAlchemyError. # mock_db.add.side_effect = SQLAlchemyError("Database error") # mock_vector_store_instance = mock_vector_store.return_value # # Instantiate the service correctly # rag_service = RAGService( # vector_store=mock_vector_store_instance, # retrievers=[] # ) # doc_data = { # "title": "Test Title", # "text": "Test text.", # "source_url": "http://test.com" # } # # Call the method under test and expect an exception # with pytest.raises(SQLAlchemyError, match="Database error"): # rag_service.add_document(db=mock_db, doc_data=doc_data) # # Assertions # mock_db.add.assert_called_once() # mock_db.commit.assert_not_called() # mock_db.rollback.assert_called_once()