import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from typing import List from datetime import datetime import dspy from app.core.services.rag import RAGService from app.db import models from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.mock import MockEmbedder from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever from app.core.retrievers.base_retriever import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline from app.core.providers.base import LLMProvider @pytest.fixture def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. It includes a mock FaissDBRetriever and a mock generic Retriever to test conditional loading. """ # Create a mock vector store to provide a mock retriever mock_vector_store = MagicMock(spec=FaissVectorStore) mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) mock_web_retriever = MagicMock(spec=Retriever) return RAGService( retrievers=[mock_web_retriever, mock_faiss_retriever] ) # --- Session Management Tests --- def test_create_session(rag_service: RAGService): """Tests that the create_session method correctly creates a new session.""" mock_db = MagicMock(spec=Session) rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") mock_db.add.assert_called_once() added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" assert added_object.model_name == "gemini" @patch('app.core.services.rag.get_llm_provider') @patch('app.core.services.rag.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session using the default model and with the retriever loading parameter explicitly set to False. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=42, prompt="Test prompt", model="deepseek", load_faiss_retriever=False ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") mock_dspy_pipeline.assert_called_once_with(retrievers=[]) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response" assert model_name == "deepseek" def test_chat_with_rag_model_switch(rag_service: RAGService): """ Tests that chat_with_rag correctly switches the model based on the 'model' argument, while still using the default retriever setting. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=43, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ patch('dspy.configure'): mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini", load_faiss_retriever=False ) ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("gemini") mock_dspy_pipeline.assert_called_once_with(retrievers=[]) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response from Gemini" assert model_name == "gemini" def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): """ Tests that the chat_with_rag method correctly initializes the DspyRagPipeline with the FaissDBRetriever when `load_faiss_retriever` is True. """ # --- Arrange --- mock_db = MagicMock(spec=Session) mock_session = models.Session(id=44, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ patch('dspy.configure'): mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=44, prompt="Test prompt with FAISS", model="deepseek", load_faiss_retriever=True ) ) # --- Assert --- expected_retrievers = [rag_service.faiss_retriever] mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt with FAISS", history=mock_session.messages, db=mock_db ) assert answer == "Response with FAISS context" assert model_name == "deepseek" def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) mock_session = models.Session(id=1, messages=[ models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act messages = rag_service.get_message_history(db=mock_db, session_id=1) # Assert assert len(messages) == 2 assert messages[0].created_at < messages[1].created_at mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): """Tests retrieving history for a non-existent session.""" # Arrange mock_db = MagicMock(spec=Session) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # Act messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert assert messages is None