Newer
Older
cortex-hub / ai-hub / tests / test_main.py
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock

# Import the FastAPI app instance to create a test client
from app.main import app

# Create a TestClient instance based on our FastAPI app
client = TestClient(app)

def test_read_root():
    """Test the root endpoint to ensure it's running."""
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

@patch('app.main.get_llm_provider')
def test_chat_handler_success(mock_get_llm_provider):
    """
    Test the /chat endpoint with a successful, mocked LLM response.
    
    We patch the get_llm_provider factory function to control the
    behavior of the LLM provider instance it returns.
    """
    # Configure a mock LLM provider instance with an async method
    mock_provider = MagicMock()
    mock_provider.generate_response = AsyncMock(return_value="This is a mock response from a provider.")
    
    # Configure our mocked factory function to return our mock provider
    mock_get_llm_provider.return_value = mock_provider
    
    # Make the request to our app
    response = client.post("/chat", json={"prompt": "Hello there"})
    
    # Assert our app behaved as expected
    assert response.status_code == 200
    assert response.json()["response"] == "This is a mock response from a provider."
    
    # Verify that the mocked factory and its method were called
    mock_get_llm_provider.assert_called_once_with("deepseek")
    mock_provider.generate_response.assert_called_once_with("Hello there")

@patch('app.main.get_llm_provider')
def test_chat_handler_api_failure(mock_get_llm_provider):
    """
    Test the /chat endpoint when the external LLM API fails.
    
    We configure the mocked provider's generate_response method to raise an exception.
    """
    # Configure a mock LLM provider instance with an async method that raises an exception
    mock_provider = MagicMock()
    mock_provider.generate_response = AsyncMock(side_effect=Exception("API connection error"))
    
    # Configure our mocked factory function to return our mock provider
    mock_get_llm_provider.return_value = mock_provider
    
    # Make the request to our app
    response = client.post("/chat", json={"prompt": "This request will fail"})
    
    # Assert our app handles the error gracefully
    assert response.status_code == 500
    assert "An error occurred with the deepseek API" in response.json()["detail"]
    
    # Verify that the mocked factory and its method were called
    mock_get_llm_provider.assert_called_once_with("deepseek")
    mock_provider.generate_response.assert_called_once_with("This request will fail")