import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @pytest.fixture def rag_service(): """Pytest fixture to create a RAGService instance with mocked dependencies.""" mock_vector_store = MagicMock(spec=FaissVectorStore) mock_retriever = MagicMock(spec=Retriever) return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) # --- Session Management Tests --- def test_create_session(rag_service: RAGService): """Tests that the create_session method correctly creates a new session.""" mock_db = MagicMock(spec=Session) rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") mock_db.add.assert_called_once() added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ Tests the full orchestration of a chat message within a session. """ # --- Arrange --- mock_db = MagicMock(spec=Session) # **FIX**: The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") # **FIX**: Assert that the pipeline was called with the history argument mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, db=mock_db ) assert answer == "Final RAG response" assert model_name == "deepseek" def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) mock_session = models.Session(id=1, messages=[models.Message(), models.Message()]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act messages = rag_service.get_message_history(db=mock_db, session_id=1) # Assert assert len(messages) == 2 mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): """Tests retrieving history for a non-existent session.""" # Arrange mock_db = MagicMock(spec=Session) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None # Act messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert assert messages is None