Newer
Older
cortex-hub / ai-hub / integration_tests / test_integration.py
import pytest
import httpx

# The base URL for the local server started by the run_tests.sh script
BASE_URL = "http://127.0.0.1:8000"

# A common prompt to be used for the tests
TEST_PROMPT = "Explain the theory of relativity in one sentence."
CONTEXT_PROMPT = "Who is the CEO of Microsoft?"
FOLLOW_UP_PROMPT = "When was he born?"

# Global variables to pass state between sequential tests
created_document_id = None
created_session_id = None

async def test_root_endpoint():
    """
    Tests if the root endpoint is alive and returns the correct status message.
    """
    print("\n--- Running test_root_endpoint ---")
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{BASE_URL}/")
        
        assert response.status_code == 200
        assert response.json() == {"status": "AI Model Hub is running!"}
    print("✅ Root endpoint test passed.")

# --- Session and Chat Lifecycle Tests ---

async def test_create_session():
    """Tests creating a new chat session and saves the ID for the next test."""
    global created_session_id
    print("\n--- Running test_create_session ---")
    url = f"{BASE_URL}/sessions"
    payload = {"user_id": "integration_tester", "model": "deepseek"}
    
    async with httpx.AsyncClient() as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"Failed to create session. Response: {response.text}"
    response_data = response.json()
    assert "id" in response_data
    created_session_id = response_data["id"]
    print(f"✅ Session created successfully with ID: {created_session_id}")

async def test_chat_in_session_turn_1():
    """Tests sending the first message to establish context using the default model."""
    print("\n--- Running test_chat_in_session (Turn 1) ---")
    assert created_session_id is not None, "Session ID was not set."
    
    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema
    
    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)
        
    assert response.status_code == 200, f"Chat request failed. Response: {response.text}"
    response_data = response.json()
    # Check that the answer mentions the CEO's name (assuming DeepSeek provides this)
    assert "Satya Nadella" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat Turn 1 (context) test passed.")

async def test_chat_in_session_turn_2_follow_up():
    """
    Tests sending a follow-up question to verify conversational memory using the default model.
    """
    print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---")
    assert created_session_id is not None, "Session ID was not set."
    
    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek"
    
    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)
        
    assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}"
    response_data = response.json()
    # Check that the answer contains the birth year, proving it understood "he"
    assert "1967" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat Turn 2 (follow-up) test passed.")

async def test_chat_in_session_with_model_switch():
    """
    Tests sending a message in the same session, explicitly switching to 'gemini'.
    """
    print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---")
    assert created_session_id is not None, "Session ID was not set."

    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    # Explicitly request 'gemini' model for this turn
    payload = {"prompt": "What is the capital of France?", "model": "gemini"}

    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}"
    response_data = response.json()
    assert "Paris" in response_data["answer"]
    assert response_data["model_used"] == "gemini"
    print("✅ Chat (Model Switch to Gemini) test passed.")

async def test_chat_in_session_switch_back_to_deepseek():
    """
    Tests sending another message in the same session, explicitly switching back to 'deepseek'.
    """
    print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---")
    assert created_session_id is not None, "Session ID was not set."

    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    # Explicitly request 'deepseek' model for this turn
    payload = {"prompt": "What is the largest ocean?", "model": "deepseek"}

    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}"
    response_data = response.json()
    assert "Pacific Ocean" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat (Model Switch back to DeepSeek) test passed.")


async def test_get_session_history():
    """Tests retrieving the full message history for the session."""
    print("\n--- Running test_get_session_history ---")
    assert created_session_id is not None, "Session ID was not set."
    
    url = f"{BASE_URL}/sessions/{created_session_id}/messages"
    async with httpx.AsyncClient() as client:
        response = await client.get(url)
        
    assert response.status_code == 200
    response_data = response.json()
    
    assert response_data["session_id"] == created_session_id
    # After 4 turns, there should be 8 messages (4 user, 4 assistant)
    assert len(response_data["messages"]) >= 8
    assert response_data["messages"][0]["content"] == CONTEXT_PROMPT
    assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT
    # Verify content and sender for the switched models
    assert response_data["messages"][4]["content"] == "What is the capital of France?"
    assert response_data["messages"][5]["sender"] == "assistant"
    assert "Paris" in response_data["messages"][5]["content"]
    assert response_data["messages"][6]["content"] == "What is the largest ocean?"
    assert response_data["messages"][7]["sender"] == "assistant"
    assert "Pacific Ocean" in response_data["messages"][7]["content"]
    print("✅ Get session history test passed.")

# --- Document Management Lifecycle Tests ---
async def test_add_document_for_lifecycle():
    global created_document_id
    print("\n--- Running test_add_document (for lifecycle) ---")
    url = f"{BASE_URL}/documents"
    doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."}
    
    async with httpx.AsyncClient(timeout=30.0) as client:
        response = await client.post(url, json=doc_data)

    assert response.status_code == 200
    try:
        message = response.json().get("message", "")
        created_document_id = int(message.split(" with ID ")[-1])
    except (ValueError, IndexError):
        pytest.fail("Could not parse document ID from response message.")
    print(f"✅ Document for lifecycle test created with ID: {created_document_id}")

async def test_list_documents():
    print("\n--- Running test_list_documents ---")
    assert created_document_id is not None, "Document ID was not set."
    
    url = f"{BASE_URL}/documents"
    async with httpx.AsyncClient() as client:
        response = await client.get(url)

    assert response.status_code == 200
    ids_in_response = {doc["id"] for doc in response.json()["documents"]}
    assert created_document_id in ids_in_response
    print("✅ Document list test passed.")

async def test_delete_document():
    print("\n--- Running test_delete_document ---")
    assert created_document_id is not None, "Document ID was not set."

    url = f"{BASE_URL}/documents/{created_document_id}"
    async with httpx.AsyncClient() as client:
        response = await client.delete(url)
    
    assert response.status_code == 200
    assert response.json()["document_id"] == created_document_id
    print("✅ Document delete test passed.")