Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session

# Import the service and its dependencies
from app.core.services import RAGService
from app.db import models
from app.core.vector_store import FaissVectorStore
from app.core.retrievers import Retriever
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider
from app.core.llm_providers import LLMProvider

@pytest.fixture
def rag_service():
    """Pytest fixture to create a RAGService instance with mocked dependencies."""
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    mock_retriever = MagicMock(spec=Retriever)
    return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever])

# --- Session Management Tests ---

def test_create_session(rag_service: RAGService):
    """Tests that the create_session method correctly creates a new session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    
    # Act
    session = rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
    # Assert
    mock_db.add.assert_called_once()
    mock_db.commit.assert_called_once()
    mock_db.refresh.assert_called_once()
    
    # Check that the object passed to db.add was a Session instance
    added_object = mock_db.add.call_args[0][0]
    assert isinstance(added_object, models.Session)
    assert added_object.user_id == "test_user"
    assert added_object.model_name == "gemini"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests the full orchestration of a chat message within a session.
    """
    # --- Arrange ---
    # Mock the database to return a session when queried
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=42, model_name="deepseek")
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    # Mock the LLM provider and the DSPy pipeline
    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt"))

    # --- Assert ---
    # 1. Assert the session was fetched correctly
    mock_db.query.assert_called_once_with(models.Session)

    # 2. Assert the user and assistant messages were saved
    assert mock_db.add.call_count == 2
    assert mock_db.commit.call_count == 2

    # 3. Assert the RAG pipeline was orchestrated correctly
    mock_get_llm_provider.assert_called_once_with("deepseek")
    mock_dspy_pipeline.assert_called_once()
    mock_pipeline_instance.forward.assert_called_once()

    # 4. Assert the correct response was returned
    assert answer == "Final RAG response"
    assert model_name == "deepseek"