Newer
Older
cortex-hub / ai-hub / tests / core / test_services.py
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session

# Import the service and its dependencies
from app.core.services import RAGService
from app.db import models
from app.core.vector_store import FaissVectorStore
from app.core.retrievers import Retriever
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider
from app.core.llm_providers import LLMProvider

@pytest.fixture
def rag_service():
    """Pytest fixture to create a RAGService instance with mocked dependencies."""
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    mock_retriever = MagicMock(spec=Retriever)
    return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever])

# --- Session Management Tests ---

def test_create_session(rag_service: RAGService):
    """Tests that the create_session method correctly creates a new session."""
    mock_db = MagicMock(spec=Session)
    
    rag_service.create_session(db=mock_db, user_id="test_user", model="gemini")
    
    mock_db.add.assert_called_once()
    added_object = mock_db.add.call_args[0][0]
    assert isinstance(added_object, models.Session)
    assert added_object.user_id == "test_user"
    assert added_object.model_name == "gemini"

@patch('app.core.services.get_llm_provider')
@patch('app.core.services.DspyRagPipeline')
@patch('dspy.configure')
def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService):
    """
    Tests the full orchestration of a chat message within a session.
    """
    # --- Arrange ---
    mock_db = MagicMock(spec=Session)
    # **FIX**: The mock session now needs a 'messages' attribute for the history
    mock_session = models.Session(id=42, model_name="deepseek", messages=[])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    mock_llm_provider = MagicMock(spec=LLMProvider)
    mock_get_llm_provider.return_value = mock_llm_provider
    mock_pipeline_instance = MagicMock(spec=DspyRagPipeline)
    mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response")
    mock_dspy_pipeline.return_value = mock_pipeline_instance

    # --- Act ---
    answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt"))

    # --- Assert ---
    mock_db.query.assert_called_once_with(models.Session)
    assert mock_db.add.call_count == 2
    mock_get_llm_provider.assert_called_once_with("deepseek")
    
    # **FIX**: Assert that the pipeline was called with the history argument
    mock_pipeline_instance.forward.assert_called_once_with(
        question="Test prompt",
        history=mock_session.messages,
        db=mock_db
    )
    
    assert answer == "Final RAG response"
    assert model_name == "deepseek"

def test_get_message_history_success(rag_service: RAGService):
    """Tests successfully retrieving message history for an existing session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_session = models.Session(id=1, messages=[models.Message(), models.Message()])
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=1)

    # Assert
    assert len(messages) == 2
    mock_db.query.assert_called_once_with(models.Session)

def test_get_message_history_not_found(rag_service: RAGService):
    """Tests retrieving history for a non-existent session."""
    # Arrange
    mock_db = MagicMock(spec=Session)
    mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None

    # Act
    messages = rag_service.get_message_history(db=mock_db, session_id=999)

    # Assert
    assert messages is None