Newer
Older
cortex-hub / ai-hub / tests / test_app.py
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from datetime import datetime

from app.app import create_app
from app.api.dependencies import get_db
from app.db import models # Import your SQLAlchemy models

# --- Test Setup ---

# A mock DB session that can be used across tests
mock_db = MagicMock(spec=Session)
def override_get_db():
    """Dependency override to replace the real database with a mock."""
    try:
        yield mock_db
    finally:
        pass

# --- API Endpoint Tests ---

def test_read_root():
    """Test the root endpoint to ensure it's running."""
    app = create_app()
    client = TestClient(app)
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

@patch('app.app.RAGService')
def test_create_session_success(mock_rag_service_class):
    """
    Tests successfully creating a new chat session via the POST /sessions endpoint.
    """
    # Arrange
    mock_rag_service_instance = mock_rag_service_class.return_value
    # The service should return a SQLAlchemy Session object
    mock_session_obj = models.Session(
        id=1, 
        user_id="test_user", 
        model_name="gemini", 
        title="New Chat Session", 
        created_at=datetime.now()
    )
    mock_rag_service_instance.create_session.return_value = mock_session_obj
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)
    
    # Act
    response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

    # Assert
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["id"] == 1
    assert response_data["user_id"] == "test_user"
    mock_rag_service_instance.create_session.assert_called_once_with(
        db=mock_db, user_id="test_user", model="gemini"
    )

@patch('app.app.RAGService')
def test_chat_in_session_success(mock_rag_service_class):
    """
    Test the session-based chat endpoint with a successful, mocked response.
    """
    # Arrange
    mock_rag_service_instance = mock_rag_service_class.return_value
    # The service now returns a tuple: (answer_text, model_used)
    mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "gemini"))
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)
    
    # Act
    response = client.post("/sessions/123/chat", json={"prompt": "Hello there"})

    # Assert
    assert response.status_code == 200
    assert response.json()["answer"] == "This is a mock response."
    assert response.json()["model_used"] == "gemini"
    mock_rag_service_instance.chat_with_rag.assert_called_once_with(
        db=mock_db, session_id=123, prompt="Hello there"
    )