# tests/app/test_app.py
import os
import asyncio
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.orm import Session
from datetime import datetime
import numpy as np
# Import the factory function directly and dependencies
from app.app import create_app
from app.api.dependencies import get_db, ServiceContainer
from app.db import models
from app.core.retrievers.base_retriever import Retriever
from app.core.vector_store.faiss_store import FaissVectorStore
# Define a constant for the dimension to ensure consistency
TEST_DIMENSION = 768
# --- Dependency Override for Testing ---
# This is a mock database session that will be used in our tests.
mock_db = MagicMock(spec=Session)
def override_get_db():
"""Returns the mock database session for tests."""
try:
yield mock_db
finally:
pass
# We patch ServiceContainer directly to control its instantiation in create_app
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.app.print_config')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index') # This patch is for the FaissVectorStore initialization
def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container):
"""Test the root endpoint to ensure it's running."""
# Arrange: We patch the embedder and faiss calls to prevent real logic
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# The mock_service_container is a mock of the ServiceContainer class.
# We create an instance of it (mock_services) and configure it.
mock_services = MagicMock()
mock_service_container.return_value = mock_services
app = create_app()
client = TestClient(app)
response = client.get("/")
# Assert
assert response.status_code == 200
assert response.json() == {"status": "AI Model Hub is running!"}
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_create_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Tests successfully creating a new chat session via the POST /sessions endpoint.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
# Configure the mock to return a mocked session object
mock_session_obj = models.Session(
id=1,
user_id="test_user",
provider_name="gemini",
title="New Chat Session",
created_at=datetime.now()
)
mock_services.session_service.create_session.return_value = mock_session_obj
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
# Act
response = client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"})
# Assert
assert response.status_code == 200
response_data = response.json()
assert response_data["id"] == 1
assert response_data["user_id"] == "test_user"
mock_services.session_service.create_session.assert_called_once_with(
db=mock_db, user_id="test_user", provider_name="gemini"
)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_chat_in_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Test the session-based chat endpoint with a successful, mocked response.
It should default to 'deepseek' if no model is specified.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
# Correctly mock the async method using AsyncMock
mock_chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek"))
mock_services.rag_service.chat_with_rag = mock_chat_with_rag
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
# Act
response = client.post("/sessions/123/chat", json={"prompt": "Hello there"})
# Assert
assert response.status_code == 200
assert response.json()["answer"] == "This is a mock response."
assert response.json()["provider_used"] == "deepseek"
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_db, session_id=123, prompt="Hello there", provider_name="deepseek", load_faiss_retriever=False
)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_chat_in_session_with_model_switch(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Tests sending a message in an existing session and explicitly switching the model.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
# Correctly mock the async method using AsyncMock
mock_chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
mock_services.rag_service.chat_with_rag = mock_chat_with_rag
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"})
# Assert
assert response.status_code == 200
assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_db,
session_id=42,
prompt="Hello there, Gemini!",
provider_name="gemini",
load_faiss_retriever=False
)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_add_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Test the /document endpoint with a successful, mocked RAG service response.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
mock_services.document_service.add_document.return_value = 1
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
doc_data = {
"title": "Test Document",
"text": "This is a test document.",
"source_url": "http://example.com/test"
}
# Act
response = client.post("/documents", json=doc_data)
# Assert
assert response.status_code == 200
assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1"
expected_doc_data = doc_data.copy()
expected_doc_data.update({"author": None, "user_id": "default_user"})
mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_add_document_api_failure(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Test the /document endpoint when the RAG service encounters an error.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
mock_services.document_service.add_document.side_effect = Exception("Service failed")
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
doc_data = {
"title": "Test Document",
"text": "This is a test document.",
"source_url": "http://example.com/test"
}
# Act
response = client.post("/documents", json=doc_data)
# Assert
assert response.status_code == 500
assert "An error occurred: Service failed" in response.json()["detail"]
expected_doc_data = doc_data.copy()
expected_doc_data.update({"author": None, "user_id": "default_user"})
mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_get_documents_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Tests the /documents endpoint for successful retrieval of documents.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
mock_docs = [
models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
]
mock_services.document_service.get_all_documents.return_value = mock_docs
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
# Act
response = client.get("/documents")
assert response.status_code == 200
assert len(response.json()["documents"]) == 2
assert response.json()["documents"][0]["title"] == "Doc One"
mock_services.document_service.get_all_documents.assert_called_once_with(db=mock_db)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_delete_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Tests the DELETE /documents/{document_id} endpoint for successful deletion.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
mock_services.document_service.delete_document.return_value = 42
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.delete("/documents/42")
assert response.status_code == 200
assert response.json()["message"] == "Document deleted successfully"
assert response.json()["document_id"] == 42
mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=42)
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('faiss.read_index')
def test_delete_document_not_found(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container):
"""
Tests the DELETE /documents/{document_id} endpoint when the document is not found.
"""
# Arrange
mock_read_index.return_value = MagicMock()
mock_get_embedder.return_value = MagicMock()
# Create a mock instance of ServiceContainer and its services
mock_services = MagicMock()
mock_service_container.return_value = mock_services
mock_services.document_service.delete_document.return_value = None
app = create_app()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.delete("/documents/999")
assert response.status_code == 404
assert response.json()["detail"] == "Document with ID 999 not found."
mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999)
@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index_and_metadata')
@patch('app.app.ServiceContainer')
@patch('app.app.create_db_and_tables')
@patch('app.app.print_config')
@patch('app.core.vector_store.embedder.factory.get_embedder_from_config')
@patch('os.path.exists', return_value=True)
@patch('faiss.read_index')
def test_shutdown_saves_index_and_metadata(mock_read_index, mock_os_exists, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index_and_metadata):
"""
Tests that the FAISS index and its associated metadata are saved on application shutdown.
"""
# Arrange
# Mock FAISS components
mock_index = MagicMock()
mock_index.ntotal = 10
mock_read_index.return_value = mock_index
# Mock the embedder
mock_get_embedder.return_value = MagicMock()
# We need to simulate the FaissVectorStore instance being created and having data
# to be saved. We can't just mock the method, we must also ensure the instance is
# properly created and accessible by the lifespan event.
mock_faiss_vector_store = MagicMock(spec=FaissVectorStore)
mock_faiss_vector_store.doc_tags = {i: {"tag": "value"} for i in range(10)}
mock_faiss_vector_store.index = mock_index
mock_faiss_vector_store.dimension = 768
# We need to mock the ServiceContainer to return our mocked FaissVectorStore instance.
mock_service_container.return_value.vector_store = mock_faiss_vector_store
# Create the app and let the lifespan events run.
app = create_app()
with TestClient(app) as client:
# Act: The lifespan shutdown event will run when the 'with' block is exited.
pass
# Assert
# Check that the new save_index_and_metadata method was called exactly once.
mock_save_index_and_metadata.assert_called_once()