import os from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from datetime import datetime # Import datetime for models.Session # Import the factory function directly to get a fresh app instance for testing from app.app import create_app # The get_db function is now in app.api.dependencies.py, so we must update the import path. from app.api.dependencies import get_db from app.db import models # Import your SQLAlchemy models # --- Dependency Override for Testing --- # This is a mock database session that will be used in our tests. mock_db = MagicMock(spec=Session) def override_get_db(): """Returns the mock database session for tests.""" try: yield mock_db finally: pass # --- API Endpoint Tests --- # We patch the RAGService class itself, as the instance is created inside create_app(). def test_read_root(): """Test the root endpoint to ensure it's running.""" # Create app and client here to be sure no mocking interferes app = create_app() client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.RAGService') def test_create_session_success(mock_rag_service_class): """ Tests successfully creating a new chat session via the POST /sessions endpoint. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # The service should return a SQLAlchemy Session object mock_session_obj = models.Session( id=1, user_id="test_user", model_name="gemini", title="New Chat Session", created_at=datetime.now() ) mock_rag_service_instance.create_session.return_value = mock_session_obj app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) # Assert assert response.status_code == 200 response_data = response.json() assert response_data["id"] == 1 assert response_data["user_id"] == "test_user" mock_rag_service_instance.create_session.assert_called_once_with( db=mock_db, user_id="test_user", model="gemini" ) @patch('app.app.RAGService') def test_chat_in_session_success(mock_rag_service_class): """ Test the session-based chat endpoint with a successful, mocked response. It should default to 'deepseek' if no model is specified. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # The service now returns a tuple: (answer_text, model_used) mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek")) app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." assert response.json()["model_used"] == "deepseek" # FIX: Add the missing 'load_faiss_retriever=False' argument to the assertion mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) @patch('app.app.RAGService') def test_chat_in_session_with_model_switch(mock_rag_service_class): """ Tests sending a message in an existing session and explicitly switching the model. """ test_client = TestClient(create_app()) # Create client within test to ensure fresh mock mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} # Verify that chat_with_rag was called with the specified model 'gemini' # FIX: Add the missing 'load_faiss_retriever=False' argument to the assertion mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) @patch('app.app.RAGService') def test_get_session_messages_success(mock_rag_service_class): """Tests retrieving the message history for a session.""" mock_rag_service_instance = mock_rag_service_class.return_value # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_rag_service_instance.get_message_history.return_value = mock_history app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.get("/sessions/123/messages") # Assert assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123) @patch('app.app.RAGService') def test_get_session_messages_not_found(mock_rag_service_class): """Tests retrieving messages for a session that does not exist.""" mock_rag_service_instance = mock_rag_service_class.return_value # Arrange: Mock the service to return None, indicating the session wasn't found mock_rag_service_instance.get_message_history.return_value = None app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.get("/sessions/999/messages") # Assert assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @patch('app.app.RAGService') def test_add_document_success(mock_rag_service_class): """ Test the /document endpoint with a successful, mocked RAG service response. """ # Create a mock instance of RAGService mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.return_value = 1 # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/documents", json=doc_data) # Changed to /documents as per routes.py assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" # Verify that the mocked method was called with the correct arguments, # including the default values added by Pydantic. expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.RAGService') def test_add_document_api_failure(mock_rag_service_class): """ Test the /document endpoint when the RAG service encounters an error. """ # Create a mock instance of RAGService mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.side_effect = Exception("Service failed") # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/documents", json=doc_data) # Changed to /documents assert response.status_code == 500 assert "An error occurred: Service failed" in response.json()["detail"] # Verify that the mocked method was called with the correct arguments, # including the default values added by Pydantic. expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.RAGService') def test_get_documents_success(mock_rag_service_class): """ Tests the /documents endpoint for successful retrieval of documents. """ mock_rag_service_instance = mock_rag_service_class.return_value mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_rag_service_instance.get_all_documents.return_value = mock_docs app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 assert response.json()["documents"][0]["title"] == "Doc One" mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db) @patch('app.app.RAGService') def test_delete_document_success(mock_rag_service_class): """ Tests the DELETE /documents/{document_id} endpoint for successful deletion. """ mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.delete_document.return_value = 42 app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/42") assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42) @patch('app.app.RAGService') def test_delete_document_not_found(mock_rag_service_class): """ Tests the DELETE /documents/{document_id} endpoint when the document is not found. """ mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.delete_document.return_value = None app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/999") assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999)