import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError # Import the specific error type from typing import List from datetime import datetime import dspy # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @pytest.fixture def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. Correctly instantiates RAGService with only the required arguments. """ mock_vector_store = MagicMock(spec=FaissVectorStore) mock_retriever = MagicMock(spec=Retriever) return RAGService( vector_store=mock_vector_store, retrievers=[mock_retriever] ) # --- Session Management Tests --- def test_create_session(rag_service: RAGService): """Tests that the create_session method correctly creates a new session.""" mock_db = MagicMock(spec=Session) rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") mock_db.add.assert_called_once() added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session using the default model. """ # --- Arrange --- mock_db = MagicMock(spec=Session) # The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- # Pass the 'model' argument, defaulting to "deepseek" for this test case answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") # Assert that the pipeline was called with the history argument mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response" assert model_name == "deepseek" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests that chat_with_rag correctly switches the model based on the 'model' argument. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- # Explicitly request the "gemini" model for this chat turn answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 # Verify that get_llm_provider was called with "gemini" mock_get_llm_provider.assert_called_once_with("gemini") mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response from Gemini" assert model_name == "gemini" def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) # Ensure mocked messages have created_at for sorting mock_session = models.Session(id=1, messages=[ models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act messages = rag_service.get_message_history(db=mock_db, session_id=1) # Assert assert len(messages) == 2 assert messages[0].created_at < messages[1].created_at # Verify sorting mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): """Tests retrieving history for a non-existent session.""" # Arrange mock_db = MagicMock(spec=Session) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # Act messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert assert messages is None # --- Document Management Tests --- @patch('app.db.models.VectorMetadata') @patch('app.db.models.Document') @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): """ Test the RAGService.add_document method for a successful run. Verifies that the method correctly calls db.add(), db.commit(), and the vector store. """ # Setup mocks mock_db = MagicMock(spec=Session) mock_new_document_instance = MagicMock() mock_document_model.return_value = mock_new_document_instance mock_new_document_instance.id = 1 mock_new_document_instance.text = "Test text." mock_new_document_instance.title = "Test Title" mock_vector_store_instance = mock_vector_store.return_value mock_vector_store_instance.add_document.return_value = 123 # Instantiate the service correctly rag_service = RAGService( vector_store=mock_vector_store_instance, retrievers=[] ) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) # Assertions assert document_id == 1 from unittest.mock import call expected_calls = [ call(mock_new_document_instance), call(mock_vector_metadata_model.return_value) ] mock_db.add.assert_has_calls(expected_calls) mock_db.commit.assert_called() mock_db.refresh.assert_called_with(mock_new_document_instance) mock_vector_store_instance.add_document.assert_called_once_with("Test text.") # Assert that VectorMetadata was instantiated with the correct arguments mock_vector_metadata_model.assert_called_once_with( document_id=mock_new_document_instance.id, faiss_index=mock_vector_store_instance.add_document.return_value, embedding_model="mock_embedder" ) @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_error_handling(mock_vector_store): """ Test the RAGService.add_document method's error handling. Verifies that the transaction is rolled back on an exception. """ # Setup mocks mock_db = MagicMock(spec=Session) # Configure the mock db.add to raise the specific SQLAlchemyError. mock_db.add.side_effect = SQLAlchemyError("Database error") mock_vector_store_instance = mock_vector_store.return_value # Instantiate the service correctly rag_service = RAGService( vector_store=mock_vector_store_instance, retrievers=[] ) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test and expect an exception with pytest.raises(SQLAlchemyError, match="Database error"): rag_service.add_document(db=mock_db, doc_data=doc_data) # Assertions mock_db.add.assert_called_once() mock_db.commit.assert_not_called() mock_db.rollback.assert_called_once() @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Test the RAGService.chat_with_rag method when context is retrieved. Verifies that the RAG prompt is correctly constructed. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=1, model_name="deepseek", messages=[ models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)), models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0)) ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") mock_dspy_pipeline.return_value = mock_pipeline_instance prompt = "Test prompt." expected_context = "Context text 1.\n\nContext text 2." mock_retriever = rag_service.retrievers[0] mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."]) # --- Act --- response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") mock_pipeline_instance.forward.assert_called_once_with( question=prompt, history=mock_session.messages, db=mock_db ) assert response_text == "LLM response with context" assert model_used == "deepseek" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Test the RAGService.chat_with_rag method when no context is retrieved. Verifies that the original prompt is sent to the LLM. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=1, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") mock_dspy_pipeline.return_value = mock_pipeline_instance prompt = "Test prompt without context." mock_retriever = rag_service.retrievers[0] mock_retriever.retrieve_context = AsyncMock(return_value=[]) # --- Act --- response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") mock_pipeline_instance.forward.assert_called_once_with( question=prompt, history=mock_session.messages, db=mock_db ) assert response_text == "LLM response without context" assert model_used == "deepseek"