Newer
Older
cortex-hub / ai-hub / tests / test_app.py
import os
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session

# Import the factory function directly to get a fresh app instance for testing
from app.app import create_app
# The get_db function is now in app/db_setup.py, so we must update the import path.
from app.db_setup import get_db

# --- Dependency Override for Testing ---
# This is a mock database session that will be used in our tests.
mock_db = MagicMock(spec=Session)

def override_get_db():
    """Returns the mock database session for tests."""
    try:
        yield mock_db
    finally:
        pass


# --- API Endpoint Tests ---
# We patch the RAGService class itself, as the instance is created inside create_app().

# This test does not require mocking, so the app can be created at the module level.
# For consistency, we can still move it inside a function if preferred.
def test_read_root():
    """Test the root endpoint to ensure it's running."""
    # Create app and client here to be sure no mocking interferes
    app = create_app()
    client = TestClient(app)
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

@patch('app.app.RAGService')
def test_chat_handler_success(mock_rag_service_class):
    """
    Test the /chat endpoint with a successful, mocked RAG service response.
    
    We patch the RAGService class and configure a mock instance
    with a controlled return value.
    """
    # Create a mock instance of RAGService that will be returned by the factory
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.chat_with_rag = AsyncMock(return_value="This is a mock response from the RAG service.")
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Make the request to our app
    response = client.post("/chat", json={"prompt": "Hello there"})
    
    # Assert our app behaved as expected
    assert response.status_code == 200
    assert response.json()["response"] == "This is a mock response from the RAG service."
    
    # Verify that the mocked method was called with the correct arguments
    mock_rag_service_instance.chat_with_rag.assert_called_once_with(
        db=mock_db, prompt="Hello there", model="deepseek"
    )

@patch('app.app.RAGService')
def test_chat_handler_api_failure(mock_rag_service_class):
    """
    Test the /chat endpoint when the RAG service encounters an error.
    
    We configure the mock RAGService instance's chat_with_rag method
    to raise an exception.
    """
    # Create a mock instance of RAGService
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.chat_with_rag = AsyncMock(side_effect=Exception("API connection error"))
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Make the request to our app
    response = client.post("/chat", json={"prompt": "This request will fail"})
    
    # Assert our app handles the error gracefully
    assert response.status_code == 500
    assert "An error occurred with the deepseek API" in response.json()["detail"]
    
    # Verify that the mocked method was called with the correct arguments
    mock_rag_service_instance.chat_with_rag.assert_called_once_with(
        db=mock_db, prompt="This request will fail", model="deepseek"
    )

@patch('app.app.RAGService')
def test_add_document_success(mock_rag_service_class):
    """
    Test the /document endpoint with a successful, mocked RAG service response.
    """
    # Create a mock instance of RAGService
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.add_document.return_value = 1
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }
    
    response = client.post("/document", json=doc_data)
    
    assert response.status_code == 200
    assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1"
    
    # Verify that the mocked method was called with the correct arguments,
    # including the default values added by Pydantic.
    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)


@patch('app.app.RAGService')
def test_add_document_api_failure(mock_rag_service_class):
    """
    Test the /document endpoint when the RAG service encounters an error.
    """
    # Create a mock instance of RAGService
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.add_document.side_effect = Exception("Service failed")
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }
    
    response = client.post("/document", json=doc_data)
    
    assert response.status_code == 500
    assert "An error occurred: Service failed" in response.json()["detail"]

    # Verify that the mocked method was called with the correct arguments,
    # including the default values added by Pydantic.
    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)