Newer
Older
cortex-hub / ai-hub / tests / api / routes / test_sessions.py
from unittest.mock import MagicMock, AsyncMock
from datetime import datetime
from app.db import models

def test_create_session_success(client):
    """Tests successfully creating a new chat session."""
    test_client, mock_services = client
    mock_session = MagicMock(spec=models.Session)
    mock_session.id = 1
    mock_session.user_id = "test_user"
    mock_session.provider_name = "gemini"
    mock_session.title = "New Chat"
    mock_session.feature_name = "default"
    mock_session.created_at = datetime.now()
    mock_services.session_service.create_session.return_value = mock_session

    response = test_client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"})

    assert response.status_code == 200
    assert response.json()["id"] == 1
    mock_services.session_service.create_session.assert_called_once()

def test_chat_in_session_success(client):
    """
    Tests sending a message in an existing session without specifying a model
    or retriever. It should default to 'deepseek' and 'False'.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek"))

    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})

    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response", "provider_used": "deepseek"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there",
        provider_name="deepseek",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_model_switch(client):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))

    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"})

    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there, Gemini!",
        provider_name="gemini",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_faiss_retriever(client):
    """
    Tests sending a message and explicitly enabling the FAISS retriever.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek"))

    response = test_client.post(
        "/sessions/42/chat",
        json={"prompt": "What is RAG?", "load_faiss_retriever": True}
    )

    assert response.status_code == 200
    assert response.json() == {"answer": "Response with context", "provider_used": "deepseek"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="What is RAG?",
        provider_name="deepseek",
        load_faiss_retriever=True
    )

def test_get_session_messages_success(client):
    """Tests retrieving the message history for a session."""
    test_client, mock_services = client
    mock_history = [
        MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()),
        MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now())
    ]
    mock_services.rag_service.get_message_history.return_value = mock_history

    response = test_client.get("/sessions/123/messages")

    assert response.status_code == 200
    response_data = response.json()
    assert response_data["session_id"] == 123
    assert len(response_data["messages"]) == 2
    assert response_data["messages"][0]["sender"] == "user"
    assert response_data["messages"][1]["content"] == "Hi there!"
    mock_services.rag_service.get_message_history.assert_called_once_with(
        db=mock_services.rag_service.get_message_history.call_args.kwargs['db'],
        session_id=123
    )

def test_get_session_messages_not_found(client):
    """Tests retrieving messages for a session that does not exist."""
    test_client, mock_services = client
    mock_services.rag_service.get_message_history.return_value = None

    response = test_client.get("/sessions/999/messages")

    assert response.status_code == 404
    assert response.json()["detail"] == "Session with ID 999 not found."


# --- New Session Management Tests ---

def test_create_session_with_feature_name(client):
    """Tests that creating a session with a feature_name stores it correctly."""
    test_client, mock_services = client
    mock_session = MagicMock(spec=models.Session)
    mock_session.id = 2
    mock_session.user_id = "test_user"
    mock_session.provider_name = "gemini"
    mock_session.title = "New Chat"
    mock_session.feature_name = "coding_assistant"
    mock_session.created_at = datetime.now()
    mock_services.session_service.create_session.return_value = mock_session

    response = test_client.post(
        "/sessions",
        json={"user_id": "test_user", "provider_name": "gemini", "feature_name": "coding_assistant"}
    )

    assert response.status_code == 200
    assert response.json()["feature_name"] == "coding_assistant"
    mock_services.session_service.create_session.assert_called_once_with(
        db=mock_services.session_service.create_session.call_args.kwargs["db"],
        user_id="test_user",
        provider_name="gemini",
        feature_name="coding_assistant",
    )


def test_get_sessions_by_feature(client):
    """Tests listing all sessions for a given user and feature namespace."""
    test_client, mock_services = client

    response = test_client.get("/sessions/?user_id=test_user&feature_name=coding_assistant")

    # The route queries the DB directly; we only assert the endpoint is reachable
    # and responds with 200 (the mock DB returns an empty list by default).
    assert response.status_code == 200
    assert isinstance(response.json(), list)


def test_delete_session_success(client):
    """Tests soft-deleting (archiving) a single session by ID."""
    test_client, mock_services = client

    # Build a mock session that will be returned by the DB query
    mock_session = MagicMock(spec=models.Session)
    mock_session.id = 5
    mock_session.is_archived = False

    # Patch the DB query inside the route handler
    with __import__("unittest.mock", fromlist=["patch"]).patch(
        "app.api.routes.sessions.models.Session"
    ) as mock_model_cls:
        # The route does: db.query(models.Session).filter(...).first()
        mock_model_cls.id = models.Session.id
        mock_model_cls.is_archived = models.Session.is_archived

        response = test_client.delete("/sessions/5")

    # The response can be 200 (found & archived) or 404 (not found in mock DB).
    # Since mock DB returns None, we expect 404 here — important to know the mock
    # _doesn't_ set up the query, so we verify the route handles it correctly.
    assert response.status_code in (200, 404)


def test_delete_all_sessions_for_feature(client):
    """Tests bulk archiving all sessions for a feature namespace."""
    test_client, mock_services = client

    response = test_client.delete("/sessions/?user_id=test_user&feature_name=voice_chat")

    # The route queries the DB directly. Mock DB returns empty list, commit is a no-op.
    assert response.status_code == 200
    assert "deleted" in response.json().get("message", "").lower()


def test_get_session_token_usage_success(client):
    """Tests the token usage endpoint for an existing session."""
    test_client, mock_services = client

    mock_history = [
        MagicMock(spec=models.Message, content="Hello, assistant!", created_at=datetime.now()),
        MagicMock(spec=models.Message, content="Hi there, user!", created_at=datetime.now()),
    ]
    mock_services.rag_service.get_message_history.return_value = mock_history

    response = test_client.get("/sessions/1/tokens")

    assert response.status_code == 200
    data = response.json()
    assert "token_count" in data
    assert "token_limit" in data
    assert "percentage" in data
    assert data["token_count"] >= 0