import asyncio from unittest.mock import patch, MagicMock, AsyncMock, call from sqlalchemy.orm import Session from typing import List # Import the RAGService class and its dependencies from app.core.rag_service import RAGService from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.db import models # --- RAGService Unit Tests --- # These tests directly target the methods of the RAGService class # to verify their internal logic in isolation. @patch('app.db.models.VectorMetadata') @patch('app.db.models.Document') @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): """ Test the RAGService.add_document method for a successful run. Verifies that the method correctly calls db.add(), db.commit(), and the vector store. """ # Setup mocks mock_db = MagicMock(spec=Session) mock_new_document_instance = MagicMock() mock_document_model.return_value = mock_new_document_instance mock_new_document_instance.id = 1 mock_new_document_instance.text = "Test text." mock_new_document_instance.title = "Test Title" mock_vector_store_instance = mock_vector_store.return_value mock_vector_store_instance.add_document.return_value = 123 # Instantiate the service with the mock dependencies rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) # Assertions assert document_id == 1 # Use mock.call to check for both calls to db.add in the correct order. # We must mock the VectorMetadata model to check its constructor call expected_calls = [ call(mock_new_document_instance), call(mock_vector_metadata_model.return_value) ] mock_db.add.assert_has_calls(expected_calls) mock_db.commit.assert_called() mock_db.refresh.assert_called_with(mock_new_document_instance) mock_vector_store_instance.add_document.assert_called_once_with("Test text.") # Assert that VectorMetadata was instantiated with the correct arguments mock_vector_metadata_model.assert_called_once_with( document_id=mock_new_document_instance.id, faiss_index=mock_vector_store_instance.add_document.return_value, embedding_model="mock_embedder" ) @patch('app.core.vector_store.FaissVectorStore') def test_rag_service_add_document_error_handling(mock_vector_store): """ Test the RAGService.add_document method's error handling. Verifies that the transaction is rolled back on an exception. """ # Setup mocks mock_db = MagicMock(spec=Session) # Configure the mock db.add to raise an exception mock_db.add.side_effect = Exception("Database error") mock_vector_store_instance = mock_vector_store.return_value # Instantiate the service rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) doc_data = { "title": "Test Title", "text": "Test text.", "source_url": "http://test.com" } # Call the method under test and expect an exception try: rag_service.add_document(db=mock_db, doc_data=doc_data) assert False, "Expected an exception to be raised" except Exception as e: assert str(e) == "Database error" # Assertions # The first db.add was called mock_db.add.assert_called_once() # No commit should have occurred mock_db.commit.assert_not_called() # The transaction should have been rolled back mock_db.rollback.assert_called_once() @patch('app.core.rag_service.get_llm_provider') def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider): """ Test the RAGService.chat_with_rag method when context is retrieved. Verifies that the RAG prompt is correctly constructed. """ # Setup mocks mock_db = MagicMock(spec=Session) mock_llm_provider = MagicMock() mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context") mock_get_llm_provider.return_value = mock_llm_provider mock_retriever = MagicMock(spec=Retriever) mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] # Instantiate the service with the mock retriever rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) prompt = "Test prompt." # Call the method under test and run the async function response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) # Assertions expected_context = "Context text 1.\n\nContext text 2." mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) mock_llm_provider.generate_response.assert_called_once() actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0] # Check if the generated prompt contains the expected context and question assert expected_context in actual_llm_prompt assert prompt in actual_llm_prompt assert response_text == "LLM response with context" @patch('app.core.rag_service.get_llm_provider') def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider): """ Test the RAGService.chat_with_rag method when no context is retrieved. Verifies that the original prompt is sent to the LLM. """ # Setup mocks mock_db = MagicMock(spec=Session) mock_llm_provider = MagicMock() mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context") mock_get_llm_provider.return_value = mock_llm_provider mock_retriever = MagicMock(spec=Retriever) mock_retriever.retrieve_context.return_value = [] # Instantiate the service with the mock retriever rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) prompt = "Test prompt without context." # Call the method under test and run the async function response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) # Assertions mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) mock_llm_provider.generate_response.assert_called_once_with(prompt) assert response_text == "LLM response without context"