Newer
Older
cortex-hub / ai-hub / tests / api / test_routes.py
# tests/api/test_routes.py
import pytest
from unittest.mock import MagicMock, AsyncMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from datetime import datetime
from httpx import AsyncClient, ASGITransport


# Import the dependencies and router factory
from app.api.dependencies import get_db, ServiceContainer
from app.core.services.rag import RAGService
from app.core.services.document import DocumentService
from app.core.services.tts import TTSService 
from app.api.routes import create_api_router
from app.db import models # Import your SQLAlchemy models

@pytest.fixture
def client():
    """
    Pytest fixture to create a TestClient with a fully mocked environment,
    including a mock ServiceContainer.
    """
    test_app = FastAPI()
    
    # Mock individual services
    mock_rag_service = MagicMock(spec=RAGService)
    mock_document_service = MagicMock(spec=DocumentService)
    
    # Use AsyncMock for the TTS service since its methods are async
    mock_tts_service = MagicMock(spec=TTSService)
    
    # Create a mock ServiceContainer that holds the mocked services
    mock_services = MagicMock(spec=ServiceContainer)
    mock_services.rag_service = mock_rag_service
    mock_services.document_service = mock_document_service
    mock_services.tts_service = mock_tts_service
    
    # Mock the database session
    mock_db_session = MagicMock(spec=Session)

    def override_get_db():
        yield mock_db_session

    # Pass the mock ServiceContainer to the router factory
    api_router = create_api_router(services=mock_services)
    test_app.dependency_overrides[get_db] = override_get_db
    test_app.include_router(api_router)

    # Return the test client and the mock services for assertion
    yield TestClient(test_app), mock_services

# --- General Endpoint ---

def test_read_root(client):
    """Tests the root endpoint."""
    test_client, _ = client
    response = test_client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

# --- Session and Chat Endpoints ---

def test_create_session_success(client):
    """Tests successfully creating a new chat session."""
    test_client, mock_services = client
    mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now())
    mock_services.rag_service.create_session.return_value = mock_session
    
    response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})
    
    assert response.status_code == 200
    assert response.json()["id"] == 1
    mock_services.rag_service.create_session.assert_called_once()

def test_chat_in_session_success(client):
    """
    Tests sending a message in an existing session without specifying a model
    or retriever. It should default to 'deepseek' and 'False'.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek"))
    
    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})
    
    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"}
    
    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there",
        model="deepseek",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_model_switch(client):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
    
    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})
    
    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}
    
    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there, Gemini!",
        model="gemini",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_faiss_retriever(client):
    """
    Tests sending a message and explicitly enabling the FAISS retriever.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek"))
    
    response = test_client.post(
        "/sessions/42/chat",
        json={"prompt": "What is RAG?", "load_faiss_retriever": True}
    )
    
    assert response.status_code == 200
    assert response.json() == {"answer": "Response with context", "model_used": "deepseek"}
    
    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="What is RAG?",
        model="deepseek",
        load_faiss_retriever=True
    )

def test_get_session_messages_success(client):
    """Tests retrieving the message history for a session."""
    test_client, mock_services = client
    mock_history = [
        models.Message(sender="user", content="Hello", created_at=datetime.now()),
        models.Message(sender="assistant", content="Hi there!", created_at=datetime.now())
    ]
    mock_services.rag_service.get_message_history.return_value = mock_history
    
    response = test_client.get("/sessions/123/messages")
    
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["session_id"] == 123
    assert len(response_data["messages"]) == 2
    assert response_data["messages"][0]["sender"] == "user"
    assert response_data["messages"][1]["content"] == "Hi there!"
    mock_services.rag_service.get_message_history.assert_called_once_with(
        db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], 
        session_id=123
    )

def test_get_session_messages_not_found(client):
    """Tests retrieving messages for a session that does not exist."""
    test_client, mock_services = client
    mock_services.rag_service.get_message_history.return_value = None
    
    response = test_client.get("/sessions/999/messages")
    
    assert response.status_code == 404
    assert response.json()["detail"] == "Session with ID 999 not found."

# --- Document Endpoints ---

def test_add_document_success(client):
    """Tests the /documents endpoint for adding a new document."""
    test_client, mock_services = client
    mock_services.document_service.add_document.return_value = 123
    doc_payload = {"title": "Test Doc", "text": "Content here"}
    response = test_client.post("/documents", json=doc_payload)
    assert response.status_code == 200
    assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123"

def test_get_documents_success(client):
    """Tests the /documents endpoint for retrieving all documents."""
    test_client, mock_services = client
    mock_docs = [
        models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
        models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
    ]
    mock_services.document_service.get_all_documents.return_value = mock_docs
    response = test_client.get("/documents")
    assert response.status_code == 200
    assert len(response.json()["documents"]) == 2

def test_delete_document_success(client):
    """Tests the DELETE /documents/{document_id} endpoint for successful deletion."""
    test_client, mock_services = client
    mock_services.document_service.delete_document.return_value = 42
    response = test_client.delete("/documents/42")
    assert response.status_code == 200
    assert response.json()["document_id"] == 42

def test_delete_document_not_found(client):
    """Tests the DELETE /documents/{document_id} endpoint when the document is not found."""
    test_client, mock_services = client
    mock_services.document_service.delete_document.return_value = None
    response = test_client.delete("/documents/999")
    assert response.status_code == 404
    
# --- TTS Endpoint ---

@pytest.mark.anyio
async def test_create_speech_stream_success(client):
    """
    Tests the /speech endpoint to ensure it can successfully generate an audio stream.
    """
    test_client, mock_services = client
    app = test_client.app  # Get the FastAPI app from the TestClient
    
    # Arrange: Define the text to convert and mock the service's response.
    text_to_speak = "Hello, world!"
    
    # Define the async generator
    async def mock_audio_generator():
        yield b'chunk1'
        yield b'chunk2'
        yield b'chunk3'
    
    # Properly mock the method to return the generator
    mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator()
    
    # Use AsyncClient with ASGITransport to send request to the FastAPI app
    async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
        response = await ac.post("/speech", json={"text": text_to_speak})

    # Assert: Check status code and content
    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"
    assert response.content == b"chunk1chunk2chunk3"