Newer
Older
cortex-hub / ai-hub / tests / test_app.py
# tests/app/test_app.py
import os
import asyncio
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from datetime import datetime
import numpy as np

# Import the factory function directly and dependencies
from app.app import create_app
from app.api.dependencies import get_db, ServiceContainer
from app.db import models
from app.core.retrievers import Retriever

# Define a constant for the dimension to ensure consistency
TEST_DIMENSION = 768

# --- Dependency Override for Testing ---
# This is a mock database session that will be used in our tests.
mock_db = MagicMock(spec=Session)

def override_get_db():
    """Returns the mock database session for tests."""
    try:
        yield mock_db
    finally:
        pass

# We patch ServiceContainer directly to control its instantiation in create_app
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.app.FaissVectorStore.save_index')
@patch('app.app.print_config')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index') # This patch is for the FaissVectorStore initialization
def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_save_index, mock_create_db, mock_service_container):
    """Test the root endpoint to ensure it's running."""
    # Arrange: We patch the embedder and faiss calls to prevent real logic
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # The mock_service_container is a mock of the ServiceContainer class.
    # We create an instance of it (mock_services) and configure it.
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services

    app = create_app()
    client = TestClient(app)
    response = client.get("/")

    # Assert
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}


# We patch ServiceContainer directly to control its instantiation in create_app
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_create_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Tests successfully creating a new chat session via the POST /sessions endpoint.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services

    # Configure the mock rag_service to return a mocked session object
    mock_session_obj = models.Session(
        id=1,
        user_id="test_user",
        model_name="gemini",
        title="New Chat Session",
        created_at=datetime.now()
    )
    mock_services.rag_service.create_session.return_value = mock_session_obj

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Act
    response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

    # Assert
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["id"] == 1
    assert response_data["user_id"] == "test_user"
    mock_services.rag_service.create_session.assert_called_once_with(
        db=mock_db, user_id="test_user", model="gemini"
    )

@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_chat_in_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Test the session-based chat endpoint with a successful, mocked response.
    It should default to 'deepseek' if no model is specified.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services
    
    # Correctly mock the async method using AsyncMock
    mock_chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek"))
    mock_services.rag_service.chat_with_rag = mock_chat_with_rag

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Act
    response = client.post("/sessions/123/chat", json={"prompt": "Hello there"})

    # Assert
    assert response.status_code == 200
    assert response.json()["answer"] == "This is a mock response."
    assert response.json()["model_used"] == "deepseek"
    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False
    )

@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_chat_in_session_with_model_switch(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services

    # Correctly mock the async method using AsyncMock
    mock_chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
    mock_services.rag_service.chat_with_rag = mock_chat_with_rag

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})

    # Assert
    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}
    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_db,
        session_id=42,
        prompt="Hello there, Gemini!",
        model="gemini",
        load_faiss_retriever=False
    )

@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_add_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Test the /document endpoint with a successful, mocked RAG service response.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services
    mock_services.document_service.add_document.return_value = 1

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }

    # Act
    response = client.post("/documents", json=doc_data)

    # Assert
    assert response.status_code == 200
    assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1"

    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)


@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_add_document_api_failure(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Test the /document endpoint when the RAG service encounters an error.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services
    mock_services.document_service.add_document.side_effect = Exception("Service failed")

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }

    # Act
    response = client.post("/documents", json=doc_data)

    # Assert
    assert response.status_code == 500
    assert "An error occurred: Service failed" in response.json()["detail"]

    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)


@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_get_documents_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Tests the /documents endpoint for successful retrieval of documents.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services

    mock_docs = [
        models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
        models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
    ]
    mock_services.document_service.get_all_documents.return_value = mock_docs

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Act
    response = client.get("/documents")
    assert response.status_code == 200
    assert len(response.json()["documents"]) == 2
    assert response.json()["documents"][0]["title"] == "Doc One"
    mock_services.document_service.get_all_documents.assert_called_once_with(db=mock_db)


@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_delete_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Tests the DELETE /documents/{document_id} endpoint for successful deletion.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services
    mock_services.document_service.delete_document.return_value = 42

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.delete("/documents/42")
    assert response.status_code == 200
    assert response.json()["message"] == "Document deleted successfully"
    assert response.json()["document_id"] == 42
    mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=42)

@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_delete_document_not_found(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
    """
    Tests the DELETE /documents/{document_id} endpoint when the document is not found.
    """
    # Arrange
    mock_read_index.return_value = MagicMock()
    mock_get_embedder.return_value = MagicMock()

    # Create a mock instance of ServiceContainer and its services
    mock_services = MagicMock()
    mock_service_container.return_value = mock_services
    mock_services.document_service.delete_document.return_value = None

    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.delete("/documents/999")
    assert response.status_code == 404
    assert response.json()["detail"] == "Document with ID 999 not found."
    mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999)