Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
# import pytest
# import asyncio
# from unittest.mock import patch, MagicMock, AsyncMock
# from sqlalchemy.orm import Session
# from sqlalchemy.exc import SQLAlchemyError
# from typing import List
# from datetime import datetime
# import dspy

# # Import the service and its dependencies
# from app.core.services import RAGService
# from app.db import models
# from app.core.vector_store.faiss_store import FaissVectorStore
# from app.core.vector_store.embedder.mock import MockEmbedder
# # Import FaissDBRetriever and a mock WebRetriever for testing different cases
# from app.core.retrievers import FaissDBRetriever, Retriever 
# from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider 
# from app.core.llm_providers import LLMProvider


# @pytest.fixture
# def rag_service():
#     """
#     Pytest fixture to create a RAGService instance with mocked dependencies.
#     It includes a mock FaissDBRetriever and a mock generic Retriever to test
#     conditional loading.
#     """
#     # Create a mock embedder to be attached to the vector store mock
#     mock_embedder = MagicMock(spec=MockEmbedder)
#     mock_vector_store = MagicMock(spec=FaissVectorStore)
#     mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute

#     mock_faiss_retriever = MagicMock(spec=FaissDBRetriever)
#     mock_web_retriever = MagicMock(spec=Retriever)
#     return RAGService(
#         vector_store=mock_vector_store,
#         retrievers=[mock_web_retriever, mock_faiss_retriever]
#     )

# # --- Session Management Tests ---

# def test_create_session(rag_service: RAGService):
#     """Tests that the create_session method correctly creates a new session."""
#     mock_db = MagicMock(spec=Session)
    
#     rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
#     mock_db.add.assert_called_once()
#     added_object = mock_db.add.call_args[0][0]
#     assert isinstance(added_object, models.Session)
#     assert added_object.user_id == "test_user"
#     assert added_object.model_name == "gemini"

# @patch('app.core.services.get_llm_provider')
# @patch('app.core.services.DspyRagPipeline')
# @patch('dspy.configure')
# def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
#     """
#     Tests the full orchestration of a chat message within a session using the default model
#     and with the retriever loading parameter explicitly set to False.
#     """
#     # --- Arrange ---
#     mock_db = MagicMock(spec=Session)
#     mock_session = models.Session(id=42, model_name="deepseek", messages=[])
#     mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

#     mock_llm_provider = MagicMock(spec=LLMProvider)
#     mock_get_llm_provider.return_value = mock_llm_provider
#     mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
#     mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
#     mock_dspy_pipeline.return_value = mock_pipeline_instance

#     # --- Act ---
#     answer, model_name = asyncio.run(
#         rag_service.chat_with_rag(
#             db=mock_db,
#             session_id=42,
#             prompt="Test prompt",
#             model="deepseek",
#             load_faiss_retriever=False # Explicitly pass the default value
#         )
#     )

#     # --- Assert ---
#     mock_db.query.assert_called_once_with(models.Session)
#     assert mock_db.add.call_count == 2
#     mock_get_llm_provider.assert_called_once_with("deepseek")
    
#     # Assert that DspyRagPipeline was initialized with an empty list of retrievers
#     mock_dspy_pipeline.assert_called_once_with(retrievers=[])

#     mock_pipeline_instance.forward.assert_called_once_with(
#         question="Test prompt",
#         history=mock_session.messages,
#         db=mock_db
#     )
    
#     assert answer == "Final RAG response"
#     assert model_name == "deepseek"


# @patch('app.core.services.get_llm_provider')
# @patch('app.core.services.DspyRagPipeline')
# @patch('dspy.configure')
# def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
#     """
#     Tests that chat_with_rag correctly switches the model based on the 'model' argument,
#     while still using the default retriever setting.
#     """
#     # --- Arrange ---
#     mock_db = MagicMock(spec=Session)
#     mock_session = models.Session(id=43, model_name="deepseek", messages=[]) 
#     mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

#     mock_llm_provider = MagicMock(spec=LLMProvider)
#     mock_get_llm_provider.return_value = mock_llm_provider
#     mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
#     mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini")
#     mock_dspy_pipeline.return_value = mock_pipeline_instance

#     # --- Act ---
#     answer, model_name = asyncio.run(
#         rag_service.chat_with_rag(
#             db=mock_db,
#             session_id=43,
#             prompt="Test prompt for Gemini",
#             model="gemini",
#             load_faiss_retriever=False # Explicitly pass the default value
#         )
#     )

#     # --- Assert ---
#     mock_db.query.assert_called_once_with(models.Session)
#     assert mock_db.add.call_count == 2
#     mock_get_llm_provider.assert_called_once_with("gemini") 
    
#     # Assert that DspyRagPipeline was initialized with an empty list of retrievers
#     mock_dspy_pipeline.assert_called_once_with(retrievers=[])

#     mock_pipeline_instance.forward.assert_called_once_with(
#         question="Test prompt for Gemini",
#         history=mock_session.messages,
#         db=mock_db
#     )
    
#     assert answer == "Final RAG response from Gemini"
#     assert model_name == "gemini"


# @patch('app.core.services.get_llm_provider')
# @patch('app.core.services.DspyRagPipeline')
# @patch('dspy.configure')
# def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
#     """
#     Tests that the chat_with_rag method correctly initializes the DspyRagPipeline
#     with the FaissDBRetriever when `load_faiss_retriever` is True.
#     """
#     # --- Arrange ---
#     mock_db = MagicMock(spec=Session)
#     mock_session = models.Session(id=44, model_name="deepseek", messages=[]) 
#     mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

#     mock_llm_provider = MagicMock(spec=LLMProvider)
#     mock_get_llm_provider.return_value = mock_llm_provider
#     mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
#     mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context")
#     mock_dspy_pipeline.return_value = mock_pipeline_instance
    
#     # --- Act ---
#     # Explicitly enable the FAISS retriever
#     answer, model_name = asyncio.run(
#         rag_service.chat_with_rag(
#             db=mock_db,
#             session_id=44,
#             prompt="Test prompt with FAISS",
#             model="deepseek",
#             load_faiss_retriever=True
#         )
#     )
    
#     # --- Assert ---
#     # The crucial part is to verify that the pipeline was called with the correct retriever
#     expected_retrievers = [rag_service.faiss_retriever]
#     mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers)
    
#     mock_pipeline_instance.forward.assert_called_once_with(
#         question="Test prompt with FAISS",
#         history=mock_session.messages,
#         db=mock_db
#     )
    
#     assert answer == "Response with FAISS context"
#     assert model_name == "deepseek"


# def test_get_message_history_success(rag_service: RAGService):
#     """Tests successfully retrieving message history for an existing session."""
#     # Arrange
#     mock_db = MagicMock(spec=Session)
#     # Ensure mocked messages have created_at for sorting
#     mock_session = models.Session(id=1, messages=[
#         models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)),
#         models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0))
#     ])
#     mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

#     # Act
#     messages = rag_service.get_message_history(db=mock_db, session_id=1)

#     # Assert
#     assert len(messages) == 2
#     assert messages[0].created_at < messages[1].created_at # Verify sorting
#     mock_db.query.assert_called_once_with(models.Session)

# def test_get_message_history_not_found(rag_service: RAGService):
#     """Tests retrieving history for a non-existent session."""
#     # Arrange
#     mock_db = MagicMock(spec=Session)
#     mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None

#     # Act
#     messages = rag_service.get_message_history(db=mock_db, session_id=999)

#     # Assert
#     assert messages is None

# # --- Document Management Tests ---

# @patch('app.db.models.VectorMetadata')
# @patch('app.db.models.Document')
# @patch('app.core.vector_store.faiss_store.FaissVectorStore')
# def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model):
#     """
#     Test the RAGService.add_document method for a successful run.
#     Verifies that the method correctly calls db.add(), db.commit(), and the vector store.
#     """
#     # Setup mocks
#     mock_db = MagicMock(spec=Session)
#     mock_new_document_instance = MagicMock()
#     mock_document_model.return_value = mock_new_document_instance
#     mock_new_document_instance.id = 1
#     mock_new_document_instance.text = "Test text."
#     mock_new_document_instance.title = "Test Title"

#     mock_vector_store_instance = mock_vector_store.return_value
#     # Fix: Manually set the embedder on the mock vector store instance
#     mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) 
#     mock_vector_store_instance.add_document.return_value = 123
    
#     # Instantiate the service correctly
#     rag_service = RAGService(
#         vector_store=mock_vector_store_instance,
#         retrievers=[]
#     )

#     doc_data = {
#         "title": "Test Title",
#         "text": "Test text.",
#         "source_url": "http://test.com"
#     }
    
#     # Call the method under test
#     document_id = rag_service.add_document(db=mock_db, doc_data=doc_data)

#     # Assertions
#     assert document_id == 1
    
#     from unittest.mock import call
#     expected_calls = [
#         call(mock_new_document_instance),
#         call(mock_vector_metadata_model.return_value)
#     ]
#     mock_db.add.assert_has_calls(expected_calls)

#     mock_db.commit.assert_called()
#     mock_db.refresh.assert_called_with(mock_new_document_instance)
#     mock_vector_store_instance.add_document.assert_called_once_with("Test text.")
    
#     # Assert that VectorMetadata was instantiated with the correct arguments
#     mock_vector_metadata_model.assert_called_once_with(
#         document_id=mock_new_document_instance.id,
#         faiss_index=mock_vector_store_instance.add_document.return_value,
#         embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder
#     )

# @patch('app.core.vector_store.faiss_store.FaissVectorStore')
# def test_rag_service_add_document_error_handling(mock_vector_store):
#     """
#     Test the RAGService.add_document method's error handling.
#     Verifies that the transaction is rolled back on an exception.
#     """
#     # Setup mocks
#     mock_db = MagicMock(spec=Session)
    
#     # Configure the mock db.add to raise the specific SQLAlchemyError.
#     mock_db.add.side_effect = SQLAlchemyError("Database error")
    
#     mock_vector_store_instance = mock_vector_store.return_value
    
#     # Instantiate the service correctly
#     rag_service = RAGService(
#         vector_store=mock_vector_store_instance,
#         retrievers=[]
#     )

#     doc_data = {
#         "title": "Test Title",
#         "text": "Test text.",
#         "source_url": "http://test.com"
#     }
    
#     # Call the method under test and expect an exception
#     with pytest.raises(SQLAlchemyError, match="Database error"):
#         rag_service.add_document(db=mock_db, doc_data=doc_data)
        
#     # Assertions
#     mock_db.add.assert_called_once()
#     mock_db.commit.assert_not_called()
#     mock_db.rollback.assert_called_once()