Newer
Older
cortex-hub / ai-hub / tests / test_app.py
import os
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from datetime import datetime # Import datetime for models.Session

# Import the factory function directly to get a fresh app instance for testing
from app.app import create_app
# The get_db function is now in app.api.dependencies.py, so we must update the import path.
from app.api.dependencies import get_db
from app.db import models # Import your SQLAlchemy models

# --- Dependency Override for Testing ---
# This is a mock database session that will be used in our tests.
mock_db = MagicMock(spec=Session)

def override_get_db():
    """Returns the mock database session for tests."""
    try:
        yield mock_db
    finally:
        pass


# --- API Endpoint Tests ---
# We patch the RAGService class itself, as the instance is created inside create_app().

def test_read_root():
    """Test the root endpoint to ensure it's running."""
    # Create app and client here to be sure no mocking interferes
    app = create_app()
    client = TestClient(app)
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

@patch('app.app.RAGService')
def test_create_session_success(mock_rag_service_class):
    """
    Tests successfully creating a new chat session via the POST /sessions endpoint.
    """
    # Arrange
    mock_rag_service_instance = mock_rag_service_class.return_value
    # The service should return a SQLAlchemy Session object
    mock_session_obj = models.Session(
        id=1, 
        user_id="test_user", 
        model_name="gemini", 
        title="New Chat Session", 
        created_at=datetime.now()
    )
    mock_rag_service_instance.create_session.return_value = mock_session_obj
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)
    
    # Act
    response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

    # Assert
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["id"] == 1
    assert response_data["user_id"] == "test_user"
    mock_rag_service_instance.create_session.assert_called_once_with(
        db=mock_db, user_id="test_user", model="gemini"
    )

@patch('app.app.RAGService')
def test_chat_in_session_success(mock_rag_service_class):
    """
    Test the session-based chat endpoint with a successful, mocked response.
    It should default to 'deepseek' if no model is specified.
    """
    # Arrange
    mock_rag_service_instance = mock_rag_service_class.return_value
    # The service now returns a tuple: (answer_text, model_used)
    mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek"))
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)
    
    # Act
    response = client.post("/sessions/123/chat", json={"prompt": "Hello there"})

    # Assert
    assert response.status_code == 200
    assert response.json()["answer"] == "This is a mock response."
    assert response.json()["model_used"] == "deepseek"
    # The fix: Include the default 'model' parameter in the assertion
    mock_rag_service_instance.chat_with_rag.assert_called_once_with(
        db=mock_db, session_id=123, prompt="Hello there", model="deepseek"
    )

@patch('app.app.RAGService')
def test_chat_in_session_with_model_switch(mock_rag_service_class):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    test_client = TestClient(create_app()) # Create client within test to ensure fresh mock
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})
    
    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}
    # Verify that chat_with_rag was called with the specified model 'gemini'
    mock_rag_service_instance.chat_with_rag.assert_called_once_with(
        db=mock_db,
        session_id=42,
        prompt="Hello there, Gemini!",
        model="gemini"
    )

@patch('app.app.RAGService')
def test_get_session_messages_success(mock_rag_service_class):
    """Tests retrieving the message history for a session."""
    mock_rag_service_instance = mock_rag_service_class.return_value
    # Arrange: Mock the service to return a list of message objects
    mock_history = [
        models.Message(sender="user", content="Hello", created_at=datetime.now()),
        models.Message(sender="assistant", content="Hi there!", created_at=datetime.now())
    ]
    mock_rag_service_instance.get_message_history.return_value = mock_history
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)
    
    # Act
    response = client.get("/sessions/123/messages")
    
    # Assert
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["session_id"] == 123
    assert len(response_data["messages"]) == 2
    assert response_data["messages"][0]["sender"] == "user"
    assert response_data["messages"][1]["content"] == "Hi there!"
    mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123)

@patch('app.app.RAGService')
def test_get_session_messages_not_found(mock_rag_service_class):
    """Tests retrieving messages for a session that does not exist."""
    mock_rag_service_instance = mock_rag_service_class.return_value
    # Arrange: Mock the service to return None, indicating the session wasn't found
    mock_rag_service_instance.get_message_history.return_value = None
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    # Act
    response = client.get("/sessions/999/messages")
    
    # Assert
    assert response.status_code == 404
    assert response.json()["detail"] == "Session with ID 999 not found."

@patch('app.app.RAGService')
def test_add_document_success(mock_rag_service_class):
    """
    Test the /document endpoint with a successful, mocked RAG service response.
    """
    # Create a mock instance of RAGService
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.add_document.return_value = 1
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }
    
    response = client.post("/documents", json=doc_data) # Changed to /documents as per routes.py
    
    assert response.status_code == 200
    assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1"
    
    # Verify that the mocked method was called with the correct arguments,
    # including the default values added by Pydantic.
    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)


@patch('app.app.RAGService')
def test_add_document_api_failure(mock_rag_service_class):
    """
    Test the /document endpoint when the RAG service encounters an error.
    """
    # Create a mock instance of RAGService
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.add_document.side_effect = Exception("Service failed")
    
    # Now create the app and client, so the patch takes effect.
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    doc_data = {
        "title": "Test Document",
        "text": "This is a test document.",
        "source_url": "http://example.com/test"
    }
    
    response = client.post("/documents", json=doc_data) # Changed to /documents
    
    assert response.status_code == 500
    assert "An error occurred: Service failed" in response.json()["detail"]

    # Verify that the mocked method was called with the correct arguments,
    # including the default values added by Pydantic.
    expected_doc_data = doc_data.copy()
    expected_doc_data.update({"author": None, "user_id": "default_user"})
    mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)

@patch('app.app.RAGService')
def test_get_documents_success(mock_rag_service_class):
    """
    Tests the /documents endpoint for successful retrieval of documents.
    """
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_docs = [
        models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
        models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
    ]
    mock_rag_service_instance.get_all_documents.return_value = mock_docs
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.get("/documents")
    assert response.status_code == 200
    assert len(response.json()["documents"]) == 2
    assert response.json()["documents"][0]["title"] == "Doc One"
    mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db)

@patch('app.app.RAGService')
def test_delete_document_success(mock_rag_service_class):
    """
    Tests the DELETE /documents/{document_id} endpoint for successful deletion.
    """
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.delete_document.return_value = 42
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.delete("/documents/42")
    assert response.status_code == 200
    assert response.json()["message"] == "Document deleted successfully"
    assert response.json()["document_id"] == 42
    mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42)

@patch('app.app.RAGService')
def test_delete_document_not_found(mock_rag_service_class):
    """
    Tests the DELETE /documents/{document_id} endpoint when the document is not found.
    """
    mock_rag_service_instance = mock_rag_service_class.return_value
    mock_rag_service_instance.delete_document.return_value = None
    
    app = create_app()
    app.dependency_overrides[get_db] = override_get_db
    client = TestClient(app)

    response = client.delete("/documents/999")
    assert response.status_code == 404
    assert response.json()["detail"] == "Document with ID 999 not found."
    mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999)