from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from datetime import datetime from app.app import create_app from app.api.dependencies import get_db from app.db import models # Import your SQLAlchemy models # --- Test Setup --- # A mock DB session that can be used across tests mock_db = MagicMock(spec=Session) def override_get_db(): """Dependency override to replace the real database with a mock.""" try: yield mock_db finally: pass # --- API Endpoint Tests --- def test_read_root(): """Test the root endpoint to ensure it's running.""" app = create_app() client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.RAGService') def test_create_session_success(mock_rag_service_class): """ Tests successfully creating a new chat session via the POST /sessions endpoint. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # The service should return a SQLAlchemy Session object mock_session_obj = models.Session( id=1, user_id="test_user", model_name="gemini", title="New Chat Session", created_at=datetime.now() ) mock_rag_service_instance.create_session.return_value = mock_session_obj app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) # Assert assert response.status_code == 200 response_data = response.json() assert response_data["id"] == 1 assert response_data["user_id"] == "test_user" mock_rag_service_instance.create_session.assert_called_once_with( db=mock_db, user_id="test_user", model="gemini" ) @patch('app.app.RAGService') def test_chat_in_session_success(mock_rag_service_class): """ Test the session-based chat endpoint with a successful, mocked response. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # The service now returns a tuple: (answer_text, model_used) mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "gemini")) app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." assert response.json()["model_used"] == "gemini" mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=123, prompt="Hello there" )