import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from typing import List from datetime import datetime import dspy from app.core.services.rag import RAGService from app.db import models from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.mock import MockEmbedder from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever from app.core.retrievers.base_retriever import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline from app.core.providers.base import LLMProvider @pytest.fixture def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. It includes a mock FaissDBRetriever and a mock generic Retriever to test conditional loading. """ # Create a mock vector store to provide a mock retriever mock_vector_store = MagicMock(spec=FaissVectorStore) mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) mock_web_retriever = MagicMock(spec=Retriever) return RAGService( retrievers=[mock_web_retriever, mock_faiss_retriever] ) # --- Session Management Tests --- # def test_create_session(rag_service: RAGService): # """Tests that the create_session method correctly creates a new session.""" # mock_db = MagicMock(spec=Session) # # rag_service.create_session(db=mock_db, user_id="test_user", provider_name="gemini") # mock_db.add.assert_called_once() # added_object = mock_db.add.call_args[0][0] # assert isinstance(added_object, models.Session) # assert added_object.user_id == "test_user" # assert added_object.provider_name == "gemini" @patch('app.core.services.rag.get_llm_provider') @patch('app.core.services.rag.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session using the default model and with the retriever loading parameter explicitly set to False. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=42, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=42, prompt="Test prompt", provider_name="deepseek", load_faiss_retriever=False ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") # Assert that DspyRagPipeline was called without any arguments mock_dspy_pipeline.assert_called_once_with() # Assert that the forward method received the correct arguments mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, context_chunks=[] # It was called with an empty list ) assert answer == "Final RAG response" assert provider_name == "deepseek" def test_chat_with_rag_model_switch(rag_service: RAGService): """ Tests that chat_with_rag correctly switches the model based on the 'model' argument, while still using the default retriever setting. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=43, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ patch('dspy.configure'): mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=43, prompt="Test prompt for Gemini", provider_name="gemini", load_faiss_retriever=False ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("gemini") # Assert that DspyRagPipeline was called without any arguments mock_dspy_pipeline.assert_called_once_with() # Assert that the forward method received the correct arguments mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, context_chunks=[] # It was called with an empty list ) assert answer == "Final RAG response from Gemini" assert provider_name == "gemini" def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): """ Tests that the chat_with_rag method correctly initializes the DspyRagPipeline with the FaissDBRetriever when `load_faiss_retriever` is True. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=44, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Mock FaissDBRetriever to return some chunks rag_service.faiss_retriever.retrieve_context.return_value = ["faiss_chunk_1", "faiss_chunk_2"] with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ patch('dspy.configure'): mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=44, prompt="Test prompt with FAISS", provider_name="deepseek", load_faiss_retriever=True ) ) # --- Assert --- # The DspyRagPipeline is still called without arguments mock_dspy_pipeline.assert_called_once_with() # The retriever's context method is now called rag_service.faiss_retriever.retrieve_context.assert_called_once_with(query="Test prompt with FAISS", db=mock_db) # The forward method receives the retrieved chunks mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt with FAISS", history=mock_session.messages, context_chunks=["faiss_chunk_1", "faiss_chunk_2"] ) assert answer == "Response with FAISS context" assert provider_name == "deepseek" def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) mock_session = models.Session(id=1, messages=[ models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act messages = rag_service.get_message_history(db=mock_db, session_id=1) # Assert assert len(messages) == 2 assert messages[0].created_at < messages[1].created_at mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): """Tests retrieving history for a non-existent session.""" # Arrange mock_db = MagicMock(spec=Session) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # Act messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert assert messages is None