Newer
Older
cortex-hub / ai-hub / tests / api / test_routes.py
# import pytest
# from unittest.mock import MagicMock, AsyncMock
# from fastapi import FastAPI, Response
# from fastapi.testclient import TestClient
# from sqlalchemy.orm import Session
# from datetime import datetime
# from httpx import AsyncClient, ASGITransport
# import asyncio

# # Import the dependencies and router factory
# from app.api.dependencies import get_db, ServiceContainer
# from app.core.services.rag import RAGService
# from app.core.services.document import DocumentService
# from app.core.services.tts import TTSService
# from app.api.routes import create_api_router
# from app.db import models

# @pytest.fixture
# def client():
#     """
#     Pytest fixture to create a TestClient with a fully mocked environment
#     for synchronous endpoints.
#     """
#     test_app = FastAPI()

#     mock_rag_service = MagicMock(spec=RAGService)
#     mock_document_service = MagicMock(spec=DocumentService)
#     mock_tts_service = MagicMock(spec=TTSService)

#     mock_services = MagicMock(spec=ServiceContainer)
#     mock_services.rag_service = mock_rag_service
#     mock_services.document_service = mock_document_service
#     mock_services.tts_service = mock_tts_service

#     mock_db_session = MagicMock(spec=Session)

#     def override_get_db():
#         yield mock_db_session

#     api_router = create_api_router(services=mock_services)
#     test_app.dependency_overrides[get_db] = override_get_db
#     test_app.include_router(api_router)

#     test_client = TestClient(test_app)

#     yield test_client, mock_services

# @pytest.fixture
# async def async_client():
#     """
#     Pytest fixture to create an AsyncClient for testing async endpoints.
#     """
#     test_app = FastAPI()

#     mock_rag_service = MagicMock(spec=RAGService)
#     mock_document_service = MagicMock(spec=DocumentService)
#     mock_tts_service = MagicMock(spec=TTSService)

#     mock_services = MagicMock(spec=ServiceContainer)
#     mock_services.rag_service = mock_rag_service
#     mock_services.document_service = mock_document_service
#     mock_services.tts_service = mock_tts_service

#     mock_db_session = MagicMock(spec=Session)

#     def override_get_db():
#         yield mock_db_session

#     api_router = create_api_router(services=mock_services)
#     test_app.dependency_overrides[get_db] = override_get_db
#     test_app.include_router(api_router)

#     async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
#         yield client, mock_services

# # --- General Endpoint ---

# def test_read_root(client):
#     """Tests the root endpoint."""
#     test_client, _ = client
#     response = test_client.get("/")
#     assert response.status_code == 200
#     assert response.json() == {"status": "AI Model Hub is running!"}

# # --- Session and Chat Endpoints ---

# def test_create_session_success(client):
#     """Tests successfully creating a new chat session."""
#     test_client, mock_services = client
#     mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now())
#     mock_services.rag_service.create_session.return_value = mock_session

#     response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

#     assert response.status_code == 200
#     assert response.json()["id"] == 1
#     mock_services.rag_service.create_session.assert_called_once()

# def test_chat_in_session_success(client):
#     """
#     Tests sending a message in an existing session without specifying a model
#     or retriever. It should default to 'deepseek' and 'False'.
#     """
#     test_client, mock_services = client
#     mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek"))

#     response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})

#     assert response.status_code == 200
#     assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"}

#     mock_services.rag_service.chat_with_rag.assert_called_once_with(
#         db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
#         session_id=42,
#         prompt="Hello there",
#         model="deepseek",
#         load_faiss_retriever=False
#     )

# def test_chat_in_session_with_model_switch(client):
#     """
#     Tests sending a message in an existing session and explicitly switching the model.
#     """
#     test_client, mock_services = client
#     mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))

#     response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})

#     assert response.status_code == 200
#     assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}

#     mock_services.rag_service.chat_with_rag.assert_called_once_with(
#         db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
#         session_id=42,
#         prompt="Hello there, Gemini!",
#         model="gemini",
#         load_faiss_retriever=False
#     )

# def test_chat_in_session_with_faiss_retriever(client):
#     """
#     Tests sending a message and explicitly enabling the FAISS retriever.
#     """
#     test_client, mock_services = client
#     mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek"))

#     response = test_client.post(
#         "/sessions/42/chat",
#         json={"prompt": "What is RAG?", "load_faiss_retriever": True}
#     )

#     assert response.status_code == 200
#     assert response.json() == {"answer": "Response with context", "model_used": "deepseek"}

#     mock_services.rag_service.chat_with_rag.assert_called_once_with(
#         db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
#         session_id=42,
#         prompt="What is RAG?",
#         model="deepseek",
#         load_faiss_retriever=True
#     )

# def test_get_session_messages_success(client):
#     """Tests retrieving the message history for a session."""
#     test_client, mock_services = client
#     mock_history = [
#         models.Message(sender="user", content="Hello", created_at=datetime.now()),
#         models.Message(sender="assistant", content="Hi there!", created_at=datetime.now())
#     ]
#     mock_services.rag_service.get_message_history.return_value = mock_history

#     response = test_client.get("/sessions/123/messages")

#     assert response.status_code == 200
#     response_data = response.json()
#     assert response_data["session_id"] == 123
#     assert len(response_data["messages"]) == 2
#     assert response_data["messages"][0]["sender"] == "user"
#     assert response_data["messages"][1]["content"] == "Hi there!"
#     mock_services.rag_service.get_message_history.assert_called_once_with(
#         db=mock_services.rag_service.get_message_history.call_args.kwargs['db'],
#         session_id=123
#     )

# def test_get_session_messages_not_found(client):
#     """Tests retrieving messages for a session that does not exist."""
#     test_client, mock_services = client
#     mock_services.rag_service.get_message_history.return_value = None

#     response = test_client.get("/sessions/999/messages")

#     assert response.status_code == 404
#     assert response.json()["detail"] == "Session with ID 999 not found."

# # --- Document Endpoints ---

# def test_add_document_success(client):
#     """Tests the /documents endpoint for adding a new document."""
#     test_client, mock_services = client
#     mock_services.document_service.add_document.return_value = 123
#     doc_payload = {"title": "Test Doc", "text": "Content here"}
#     response = test_client.post("/documents", json=doc_payload)
#     assert response.status_code == 200
#     assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123"

# def test_get_documents_success(client):
#     """Tests the /documents endpoint for retrieving all documents."""
#     test_client, mock_services = client
#     mock_docs = [
#         models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
#         models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
#     ]
#     mock_services.document_service.get_all_documents.return_value = mock_docs
#     response = test_client.get("/documents")
#     assert response.status_code == 200
#     assert len(response.json()["documents"]) == 2

# def test_delete_document_success(client):
#     """Tests the DELETE /documents/{document_id} endpoint for successful deletion."""
#     test_client, mock_services = client
#     mock_services.document_service.delete_document.return_value = 42
#     response = test_client.delete("/documents/42")
#     assert response.status_code == 200
#     assert response.json()["document_id"] == 42

# def test_delete_document_not_found(client):
#     """Tests the DELETE /documents/{document_id} endpoint when the document is not found."""
#     test_client, mock_services = client
#     mock_services.document_service.delete_document.return_value = None
#     response = test_client.delete("/documents/999")
#     assert response.status_code == 404

# @pytest.mark.asyncio
# async def test_create_speech_response(async_client):
#     """Test the /speech endpoint returns audio bytes."""
#     test_client, mock_services = await anext(async_client)
#     mock_audio_bytes = b"fake wav audio bytes"

#     # The route handler calls `create_speech_non_stream`, not `create_speech_stream`
#     # It's an async function, so we must use AsyncMock
#     mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes)

#     response = await test_client.post("/speech", json={"text": "Hello, this is a test"})

#     assert response.status_code == 200
#     assert response.headers["content-type"] == "audio/wav"
#     assert response.content == mock_audio_bytes

#     mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test")

# @pytest.mark.asyncio
# async def test_create_speech_stream_response(async_client):
#     """Test the consolidated /speech endpoint with stream=true returns a streaming response."""
#     test_client, mock_services = await anext(async_client)
#     mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"]

#     # This async generator mock correctly simulates the streaming service
#     async def mock_async_generator():
#         for chunk in mock_audio_bytes_chunks:
#             yield chunk

#     # We mock `create_speech_stream` with a MagicMock returning the async generator
#     mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator())

#     # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter
#     response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"})

#     assert response.status_code == 200
#     assert response.headers["content-type"] == "audio/wav"

#     # Read the streamed content and verify it matches the mocked chunks
#     streamed_content = b""
#     async for chunk in response.aiter_bytes():
#         streamed_content += chunk

#     assert streamed_content == b"".join(mock_audio_bytes_chunks)
#     mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test")