import os from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock from sqlalchemy.orm import Session from datetime import datetime import numpy as np # Import the factory function directly to get a fresh app instance for testing from app.app import create_app from app.api.dependencies import get_db from app.db import models # Define a constant for the dimension to ensure consistency TEST_DIMENSION = 768 # --- Dependency Override for Testing --- # This is a mock database session that will be used in our tests. mock_db = MagicMock(spec=Session) def override_get_db(): """Returns the mock database session for tests.""" try: yield mock_db finally: pass # --- API Endpoint Tests --- # We patch the RAGService class itself, as the instance is created inside create_app(). def test_read_root(): """Test the root endpoint to ensure it's running.""" # Patch the requests.post call for the GenAIEmbedder to avoid network calls during app creation. # Also patch faiss.read_index to prevent file system errors. with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response app = create_app() client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.RAGService') def test_create_session_success(mock_rag_service_class): """ Tests successfully creating a new chat session via the POST /sessions endpoint. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response # Arrange mock_rag_service_instance = mock_rag_service_class.return_value mock_session_obj = models.Session( id=1, user_id="test_user", model_name="gemini", title="New Chat Session", created_at=datetime.now() ) mock_rag_service_instance.create_session.return_value = mock_session_obj app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) # Assert assert response.status_code == 200 response_data = response.json() assert response_data["id"] == 1 assert response_data["user_id"] == "test_user" mock_rag_service_instance.create_session.assert_called_once_with( db=mock_db, user_id="test_user", model="gemini" ) @patch('app.app.RAGService') def test_chat_in_session_success(mock_rag_service_class): """ Test the session-based chat endpoint with a successful, mocked response. It should default to 'deepseek' if no model is specified. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # Mock the async method correctly using a mock async function async def mock_chat_with_rag(*args, **kwargs): return "This is a mock response.", "deepseek" mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag) app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." assert response.json()["model_used"] == "deepseek" mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) @patch('app.app.RAGService') def test_chat_in_session_with_model_switch(mock_rag_service_class): """ Tests sending a message in an existing session and explicitly switching the model. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value # Mock the async method correctly using a mock async function async def mock_chat_with_rag(*args, **kwargs): return "Mocked response from Gemini", "gemini" mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag) app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) @patch('app.app.RAGService') def test_get_session_messages_success(mock_rag_service_class): """Tests retrieving the message history for a session.""" with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_rag_service_instance.get_message_history.return_value = mock_history app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.get("/sessions/123/messages") # Assert assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123) @patch('app.app.RAGService') def test_get_session_messages_not_found(mock_rag_service_class): """Tests retrieving messages for a session that does not exist.""" with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.get_message_history.return_value = None app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.get("/sessions/999/messages") # Assert assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=999) @patch('app.app.RAGService') def test_add_document_success(mock_rag_service_class): """ Test the /document endpoint with a successful, mocked RAG service response. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.return_value = 1 app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/documents", json=doc_data) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.RAGService') def test_add_document_api_failure(mock_rag_service_class): """ Test the /document endpoint when the RAG service encounters an error. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.side_effect = Exception("Service failed") app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/documents", json=doc_data) assert response.status_code == 500 assert "An error occurred: Service failed" in response.json()["detail"] expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.RAGService') def test_get_documents_success(mock_rag_service_class): """ Tests the /documents endpoint for successful retrieval of documents. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_rag_service_instance.get_all_documents.return_value = mock_docs app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 assert response.json()["documents"][0]["title"] == "Doc One" mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db) @patch('app.app.RAGService') def test_delete_document_success(mock_rag_service_class): """ Tests the DELETE /documents/{document_id} endpoint for successful deletion. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.delete_document.return_value = 42 app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/42") assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42) @patch('app.app.RAGService') def test_delete_document_not_found(mock_rag_service_class): """ Tests the DELETE /documents/{document_id} endpoint when the document is not found. """ with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: mock_read_index.return_value = MagicMock() mock_response = MagicMock() mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} } mock_post.return_value = mock_response mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.delete_document.return_value = None app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/999") assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999)