import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @pytest.fixture def rag_service(): """Pytest fixture to create a RAGService instance with mocked dependencies.""" mock_vector_store = MagicMock(spec=FaissVectorStore) mock_retriever = MagicMock(spec=Retriever) return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) # --- Session Management Tests --- def test_create_session(rag_service: RAGService): """Tests that the create_session method correctly creates a new session.""" # Arrange mock_db = MagicMock(spec=Session) # Act session = rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") # Assert mock_db.add.assert_called_once() mock_db.commit.assert_called_once() mock_db.refresh.assert_called_once() # Check that the object passed to db.add was a Session instance added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session. """ # --- Arrange --- # Mock the database to return a session when queried mock_db = MagicMock(spec=Session) mock_session = models.Session(id=42, model_name="deepseek") mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Mock the LLM provider and the DSPy pipeline mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) # --- Assert --- # 1. Assert the session was fetched correctly mock_db.query.assert_called_once_with(models.Session) # 2. Assert the user and assistant messages were saved assert mock_db.add.call_count == 2 assert mock_db.commit.call_count == 2 # 3. Assert the RAG pipeline was orchestrated correctly mock_get_llm_provider.assert_called_once_with("deepseek") mock_dspy_pipeline.assert_called_once() mock_pipeline_instance.forward.assert_called_once() # 4. Assert the correct response was returned assert answer == "Final RAG response" assert model_name == "deepseek"