Newer
Older
cortex-hub / ai-hub / integration_tests / test_integration.py
import pytest
import httpx

# The base URL for the local server started by the run_tests.sh script
BASE_URL = "http://127.0.0.1:8000"

# A common prompt to be used for the tests
TEST_PROMPT = "Explain the theory of relativity in one sentence."
CONTEXT_PROMPT = "Who is the CEO of Microsoft?"
FOLLOW_UP_PROMPT = "When was he born?"

# Document and prompt for the retrieval-augmented generation test
RAG_DOC_TITLE = "Fictional Company History"
RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'."
RAG_PROMPT = "Who founded AlphaCorp and what is their main product?"

# Global variables to pass state between sequential tests
created_document_id = None
created_session_id = None

async def test_root_endpoint():
    """
    Tests if the root endpoint is alive and returns the correct status message.
    """
    print("\n--- Running test_root_endpoint ---")
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{BASE_URL}/")
        
        assert response.status_code == 200
        assert response.json() == {"status": "AI Model Hub is running!"}
    print("✅ Root endpoint test passed.")

# --- Session and Chat Lifecycle Tests ---

async def test_create_session():
    """Tests creating a new chat session and saves the ID for the next test."""
    global created_session_id
    print("\n--- Running test_create_session ---")
    url = f"{BASE_URL}/sessions"
    payload = {"user_id": "integration_tester", "model": "deepseek"}
    
    async with httpx.AsyncClient() as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"Failed to create session. Response: {response.text}"
    response_data = response.json()
    assert "id" in response_data
    created_session_id = response_data["id"]
    print(f"✅ Session created successfully with ID: {created_session_id}")

async def test_chat_in_session_turn_1():
    """Tests sending the first message to establish context using the default model."""
    print("\n--- Running test_chat_in_session (Turn 1) ---")
    assert created_session_id is not None, "Session ID was not set."
    
    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema
    
    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)
        
    assert response.status_code == 200, f"Chat request failed. Response: {response.text}"
    response_data = response.json()
    # Check that the answer mentions the CEO's name (assuming DeepSeek provides this)
    assert "Satya Nadella" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat Turn 1 (context) test passed.")

async def test_chat_in_session_turn_2_follow_up():
    """
    Tests sending a follow-up question to verify conversational memory using the default model.
    """
    print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---")
    assert created_session_id is not None, "Session ID was not set."
    
    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek"
    
    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)
        
    assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}"
    response_data = response.json()
    # Check that the answer contains the birth year, proving it understood "he"
    assert "1967" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat Turn 2 (follow-up) test passed.")

async def test_chat_in_session_with_model_switch():
    """
    Tests sending a message in the same session, explicitly switching to 'gemini'.
    """
    print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---")
    assert created_session_id is not None, "Session ID was not set."

    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    # Explicitly request 'gemini' model for this turn
    payload = {"prompt": "What is the capital of France?", "model": "gemini"}

    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}"
    response_data = response.json()
    assert "Paris" in response_data["answer"]
    assert response_data["model_used"] == "gemini"
    print("✅ Chat (Model Switch to Gemini) test passed.")

async def test_chat_in_session_switch_back_to_deepseek():
    """
    Tests sending another message in the same session, explicitly switching back to 'deepseek'.
    """
    print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---")
    assert created_session_id is not None, "Session ID was not set."

    url = f"{BASE_URL}/sessions/{created_session_id}/chat"
    # Explicitly request 'deepseek' model for this turn
    payload = {"prompt": "What is the largest ocean?", "model": "deepseek"}

    async with httpx.AsyncClient(timeout=60.0) as client:
        response = await client.post(url, json=payload)

    assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}"
    response_data = response.json()
    assert "Pacific Ocean" in response_data["answer"]
    assert response_data["model_used"] == "deepseek"
    print("✅ Chat (Model Switch back to DeepSeek) test passed.")

async def test_chat_with_document_retrieval():
    """
    Tests injecting a document and using it for retrieval-augmented generation.
    This simulates the 'load_faiss_retriever' functionality.
    """
    print("\n--- Running test_chat_with_document_retrieval ---")
    async with httpx.AsyncClient(timeout=60.0) as client:
        # Create a new session for this test
        session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"})
        assert session_response.status_code == 200
        rag_session_id = session_response.json()["id"]

        # Add a new document with specific content for retrieval
        doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT}
        add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data)
        assert add_doc_response.status_code == 200
        try:
            message = add_doc_response.json().get("message", "")
            rag_document_id = int(message.split(" with ID ")[-1])
            print(f"Document for RAG created with ID: {rag_document_id}")
        except (ValueError, IndexError):
            pytest.fail("Could not parse document ID from response message.")
        
        try:
            # Send a chat request with the document ID to enable retrieval
            chat_payload = {
                "prompt": RAG_PROMPT,
                "document_id": rag_document_id,
                "model": "deepseek", # or any other RAG-enabled model
                "load_faiss_retriever": True
            }
            chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload)
            
            assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}"
            chat_data = chat_response.json()
            
            # Verify the response contains information from the document
            assert "Jane Doe" in chat_data["answer"]
            assert "Nexus" in chat_data["answer"]
            print("✅ Chat with document retrieval test passed.")
        finally:
            # Clean up the document after the test
            delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}")
            assert delete_response.status_code == 200
            print(f"Document {rag_document_id} deleted successfully.")

# --- Document Management Lifecycle Tests ---
async def test_add_document_for_lifecycle():
    global created_document_id
    print("\n--- Running test_add_document (for lifecycle) ---")
    url = f"{BASE_URL}/documents"
    doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."}
    
    async with httpx.AsyncClient(timeout=30.0) as client:
        response = await client.post(url, json=doc_data)

    assert response.status_code == 200
    try:
        message = response.json().get("message", "")
        created_document_id = int(message.split(" with ID ")[-1])
    except (ValueError, IndexError):
        pytest.fail("Could not parse document ID from response message.")
    print(f"✅ Document for lifecycle test created with ID: {created_document_id}")

async def test_list_documents():
    print("\n--- Running test_list_documents ---")
    assert created_document_id is not None, "Document ID was not set."
    
    url = f"{BASE_URL}/documents"
    async with httpx.AsyncClient() as client:
        response = await client.get(url)

    assert response.status_code == 200
    ids_in_response = {doc["id"] for doc in response.json()["documents"]}
    assert created_document_id in ids_in_response
    print("✅ Document list test passed.")

async def test_delete_document():
    print("\n--- Running test_delete_document ---")
    assert created_document_id is not None, "Document ID was not set."

    url = f"{BASE_URL}/documents/{created_document_id}"
    async with httpx.AsyncClient() as client:
        response = await client.delete(url)
    
    assert response.status_code == 200
    assert response.json()["document_id"] == created_document_id
    print("✅ Document delete test passed.")