import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from typing import List
from datetime import datetime
import dspy
from app.core.services.rag import RAGService
from app.db import models
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.vector_store.embedder.mock import MockEmbedder
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever, Retriever
from app.core.retrievers.base_retriever import Retriever
from app.core.pipelines.dspy_rag import DspyRagPipeline
from app.core.providers.base import LLMProvider
@pytest.fixture
def rag_service():
"""
Pytest fixture to create a RAGService instance with mocked dependencies.
It includes a mock FaissDBRetriever and a mock generic Retriever to test
conditional loading.
"""
# Create a mock vector store to provide a mock retriever
mock_vector_store = MagicMock(spec=FaissVectorStore)
mock_faiss_retriever = MagicMock(spec=FaissDBRetriever)
mock_web_retriever = MagicMock(spec=Retriever)
return RAGService(
retrievers=[mock_web_retriever, mock_faiss_retriever]
)
# --- Session Management Tests ---
def test_create_session(rag_service: RAGService):
"""Tests that the create_session method correctly creates a new session."""
mock_db = MagicMock(spec=Session)
rag_service.create_session(db=mock_db, user_id="test_user", provider_name="gemini")
mock_db.add.assert_called_once()
added_object = mock_db.add.call_args[0][0]
assert isinstance(added_object, models.Session)
assert added_object.user_id == "test_user"
assert added_object.provider_name == "gemini"
@patch('app.core.services.rag.get_llm_provider')
@patch('app.core.services.rag.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
"""
Tests the full orchestration of a chat message within a session using the default model
and with the retriever loading parameter explicitly set to False.
"""
# --- Arrange ---
mock_db = MagicMock(spec=Session)
mock_session = models.Session(id=42, provider_name="deepseek", messages=[])
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session
mock_llm_provider = MagicMock(spec=LLMProvider)
mock_get_llm_provider.return_value = mock_llm_provider
mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
mock_dspy_pipeline.return_value = mock_pipeline_instance
# --- Act ---
answer, provider_name = asyncio.run(
rag_service.chat_with_rag(
db=mock_db,
session_id=42,
prompt="Test prompt",
provider_name="deepseek",
load_faiss_retriever=False
)
)
# --- Assert ---
mock_db.query.assert_called_once_with(models.Session)
assert mock_db.add.call_count == 2
mock_get_llm_provider.assert_called_once_with("deepseek")
mock_dspy_pipeline.assert_called_once_with(retrievers=[])
mock_pipeline_instance.forward.assert_called_once_with(
question="Test prompt",
history=mock_session.messages,
db=mock_db
)
assert answer == "Final RAG response"
assert provider_name == "deepseek"
def test_chat_with_rag_model_switch(rag_service: RAGService):
"""
Tests that chat_with_rag correctly switches the model based on the 'model' argument,
while still using the default retriever setting.
"""
# --- Arrange ---
mock_db = MagicMock(spec=Session)
mock_session = models.Session(id=43, provider_name="deepseek", messages=[])
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session
with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \
patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \
patch('dspy.configure'):
mock_llm_provider = MagicMock(spec=LLMProvider)
mock_get_llm_provider.return_value = mock_llm_provider
mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini")
mock_dspy_pipeline.return_value = mock_pipeline_instance
# --- Act ---
answer, provider_name = asyncio.run(
rag_service.chat_with_rag(
db=mock_db,
session_id=43,
prompt="Test prompt for Gemini",
provider_name="gemini",
load_faiss_retriever=False
)
)
# --- Assert ---
mock_db.query.assert_called_once_with(models.Session)
assert mock_db.add.call_count == 2
mock_get_llm_provider.assert_called_once_with("gemini")
mock_dspy_pipeline.assert_called_once_with(retrievers=[])
mock_pipeline_instance.forward.assert_called_once_with(
question="Test prompt for Gemini",
history=mock_session.messages,
db=mock_db
)
assert answer == "Final RAG response from Gemini"
assert provider_name == "gemini"
def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService):
"""
Tests that the chat_with_rag method correctly initializes the DspyRagPipeline
with the FaissDBRetriever when `load_faiss_retriever` is True.
"""
# --- Arrange ---
mock_db = MagicMock(spec=Session)
mock_session = models.Session(id=44, provider_name="deepseek", messages=[])
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session
with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \
patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \
patch('dspy.configure'):
mock_llm_provider = MagicMock(spec=LLMProvider)
mock_get_llm_provider.return_value = mock_llm_provider
mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context")
mock_dspy_pipeline.return_value = mock_pipeline_instance
# --- Act ---
answer, provider_name = asyncio.run(
rag_service.chat_with_rag(
db=mock_db,
session_id=44,
prompt="Test prompt with FAISS",
provider_name="deepseek",
load_faiss_retriever=True
)
)
# --- Assert ---
expected_retrievers = [rag_service.faiss_retriever]
mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers)
mock_pipeline_instance.forward.assert_called_once_with(
question="Test prompt with FAISS",
history=mock_session.messages,
db=mock_db
)
assert answer == "Response with FAISS context"
assert provider_name == "deepseek"
def test_get_message_history_success(rag_service: RAGService):
"""Tests successfully retrieving message history for an existing session."""
# Arrange
mock_db = MagicMock(spec=Session)
mock_session = models.Session(id=1, messages=[
models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)),
models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0))
])
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session
# Act
messages = rag_service.get_message_history(db=mock_db, session_id=1)
# Assert
assert len(messages) == 2
assert messages[0].created_at < messages[1].created_at
mock_db.query.assert_called_once_with(models.Session)
def test_get_message_history_not_found(rag_service: RAGService):
"""Tests retrieving history for a non-existent session."""
# Arrange
mock_db = MagicMock(spec=Session)
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None
# Act
messages = rag_service.get_message_history(db=mock_db, session_id=999)
# Assert
assert messages is None