Newer
Older
cortex-hub / ai-hub / tests / test_app.py
import os
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
from sqlalchemy.orm import Session
from datetime import datetime
import numpy as np

# Import the factory function directly to get a fresh app instance for testing
from app.app import create_app
from app.api.dependencies import get_db
from app.db import models

# Define a constant for the dimension to ensure consistency
TEST_DIMENSION = 768

# --- Dependency Override for Testing ---
# This is a mock database session that will be used in our tests.
mock_db = MagicMock(spec=Session)

def override_get_db():
    """Returns the mock database session for tests."""
    try:
        yield mock_db
    finally:
        pass

# --- API Endpoint Tests ---
# We patch the RAGService class itself, as the instance is created inside create_app().

def test_read_root():
    """Test the root endpoint to ensure it's running."""
    # Patch the requests.post call for the GenAIEmbedder to avoid network calls during app creation.
    # Also patch faiss.read_index to prevent file system errors.
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        app = create_app()
        client = TestClient(app)
        response = client.get("/")
        assert response.status_code == 200
        assert response.json() == {"status": "AI Model Hub is running!"}

@patch('app.app.RAGService')
def test_create_session_success(mock_rag_service_class):
    """
    Tests successfully creating a new chat session via the POST /sessions endpoint.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response
    
        # Arrange
        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_session_obj = models.Session(
            id=1, 
            user_id="test_user", 
            model_name="gemini", 
            title="New Chat Session", 
            created_at=datetime.now()
        )
        mock_rag_service_instance.create_session.return_value = mock_session_obj
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)
        
        # Act
        response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

        # Assert
        assert response.status_code == 200
        response_data = response.json()
        assert response_data["id"] == 1
        assert response_data["user_id"] == "test_user"
        mock_rag_service_instance.create_session.assert_called_once_with(
            db=mock_db, user_id="test_user", model="gemini"
        )

@patch('app.app.RAGService')
def test_chat_in_session_success(mock_rag_service_class):
    """
    Test the session-based chat endpoint with a successful, mocked response.
    It should default to 'deepseek' if no model is specified.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response
    
        # Arrange
        mock_rag_service_instance = mock_rag_service_class.return_value
        # Mock the async method correctly using a mock async function
        async def mock_chat_with_rag(*args, **kwargs):
            return "This is a mock response.", "deepseek"
        mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag)
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)
        
        # Act
        response = client.post("/sessions/123/chat", json={"prompt": "Hello there"})

        # Assert
        assert response.status_code == 200
        assert response.json()["answer"] == "This is a mock response."
        assert response.json()["model_used"] == "deepseek"
        mock_rag_service_instance.chat_with_rag.assert_called_once_with(
            db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False
        )

@patch('app.app.RAGService')
def test_chat_in_session_with_model_switch(mock_rag_service_class):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response
        
        mock_rag_service_instance = mock_rag_service_class.return_value
        # Mock the async method correctly using a mock async function
        async def mock_chat_with_rag(*args, **kwargs):
            return "Mocked response from Gemini", "gemini"
        mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag)
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})
        
        assert response.status_code == 200
        assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}
        mock_rag_service_instance.chat_with_rag.assert_called_once_with(
            db=mock_db,
            session_id=42,
            prompt="Hello there, Gemini!",
            model="gemini",
            load_faiss_retriever=False
        )

@patch('app.app.RAGService')
def test_get_session_messages_success(mock_rag_service_class):
    """Tests retrieving the message history for a session."""
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_history = [
            models.Message(sender="user", content="Hello", created_at=datetime.now()),
            models.Message(sender="assistant", content="Hi there!", created_at=datetime.now())
        ]
        mock_rag_service_instance.get_message_history.return_value = mock_history
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)
        
        # Act
        response = client.get("/sessions/123/messages")
        
        # Assert
        assert response.status_code == 200
        response_data = response.json()
        assert response_data["session_id"] == 123
        assert len(response_data["messages"]) == 2
        assert response_data["messages"][0]["sender"] == "user"
        assert response_data["messages"][1]["content"] == "Hi there!"
        mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123)

@patch('app.app.RAGService')
def test_get_session_messages_not_found(mock_rag_service_class):
    """Tests retrieving messages for a session that does not exist."""
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_rag_service_instance.get_message_history.return_value = None
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        # Act
        response = client.get("/sessions/999/messages")
        
        # Assert
        assert response.status_code == 404
        assert response.json()["detail"] == "Session with ID 999 not found."
        mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=999)

@patch('app.app.RAGService')
def test_add_document_success(mock_rag_service_class):
    """
    Test the /document endpoint with a successful, mocked RAG service response.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_rag_service_instance.add_document.return_value = 1
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        doc_data = {
            "title": "Test Document",
            "text": "This is a test document.",
            "source_url": "http://example.com/test"
        }
        
        response = client.post("/documents", json=doc_data)
        
        assert response.status_code == 200
        assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1"
        
        expected_doc_data = doc_data.copy()
        expected_doc_data.update({"author": None, "user_id": "default_user"})
        mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)


@patch('app.app.RAGService')
def test_add_document_api_failure(mock_rag_service_class):
    """
    Test the /document endpoint when the RAG service encounters an error.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_rag_service_instance.add_document.side_effect = Exception("Service failed")
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        doc_data = {
            "title": "Test Document",
            "text": "This is a test document.",
            "source_url": "http://example.com/test"
        }
        
        response = client.post("/documents", json=doc_data)
        
        assert response.status_code == 500
        assert "An error occurred: Service failed" in response.json()["detail"]

        expected_doc_data = doc_data.copy()
        expected_doc_data.update({"author": None, "user_id": "default_user"})
        mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)

@patch('app.app.RAGService')
def test_get_documents_success(mock_rag_service_class):
    """
    Tests the /documents endpoint for successful retrieval of documents.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response

        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_docs = [
            models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
            models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
        ]
        mock_rag_service_instance.get_all_documents.return_value = mock_docs
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        response = client.get("/documents")
        assert response.status_code == 200
        assert len(response.json()["documents"]) == 2
        assert response.json()["documents"][0]["title"] == "Doc One"
        mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db)

@patch('app.app.RAGService')
def test_delete_document_success(mock_rag_service_class):
    """
    Tests the DELETE /documents/{document_id} endpoint for successful deletion.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response
        
        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_rag_service_instance.delete_document.return_value = 42
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        response = client.delete("/documents/42")
        assert response.status_code == 200
        assert response.json()["message"] == "Document deleted successfully"
        assert response.json()["document_id"] == 42
        mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42)

@patch('app.app.RAGService')
def test_delete_document_not_found(mock_rag_service_class):
    """
    Tests the DELETE /documents/{document_id} endpoint when the document is not found.
    """
    with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index:
        mock_read_index.return_value = MagicMock()
        mock_response = MagicMock()
        mock_response.raise_for_status.return_value = None
        mock_response.json.return_value = {
            "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()}
        }
        mock_post.return_value = mock_response
        
        mock_rag_service_instance = mock_rag_service_class.return_value
        mock_rag_service_instance.delete_document.return_value = None
        
        app = create_app()
        app.dependency_overrides[get_db] = override_get_db
        client = TestClient(app)

        response = client.delete("/documents/999")
        assert response.status_code == 404
        assert response.json()["detail"] == "Document with ID 999 not found."
        mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999)