Newer
Older
cortex-hub / ai-hub / integration_tests / test_sessions.py
import os
import httpx
import pytest
from conftest import BASE_URL

def _headers():
    uid = os.getenv("SYNC_TEST_USER_ID", "")
    return {"X-User-ID": uid}

def test_sessions_crud():
    """Test listing and deleting sessions."""
    user_id = os.getenv("SYNC_TEST_USER_ID", "")
    assert user_id, "User ID not found in environment."
    
    with httpx.Client(timeout=10.0) as client:
        # 1. Create a session to ensure at least one exists
        session_payload = {
            "user_id": user_id,
            "provider_name": "gemini",
            "feature_name": "default"
        }
        r_create = client.post(f"{BASE_URL}/sessions/", headers=_headers(), json=session_payload)
        assert r_create.status_code == 200, f"Failed to create session: {r_create.text}"
        session_id = r_create.json()["id"]
        
        # 2. List Sessions
        r_list = client.get(f"{BASE_URL}/sessions/", headers=_headers())
        assert r_list.status_code == 200, f"Failed to list sessions: {r_list.text}"
        sessions = r_list.json()
        assert isinstance(sessions, list)
        assert len(sessions) > 0
        assert any(s["id"] == session_id for s in sessions)
        
        # 3. Delete Session
        r_delete = client.delete(f"{BASE_URL}/sessions/{session_id}", headers=_headers())
        assert r_delete.status_code == 200, f"Failed to delete session: {r_delete.text}"
        
        # 4. Verify Deleted
        r_get = client.get(f"{BASE_URL}/sessions/{session_id}", headers=_headers())
        assert r_get.status_code == 404, f"Expected 404, got {r_get.status_code}"

def test_sessions_extended_apis():
    user_id = os.getenv("SYNC_TEST_USER_ID", "")
    assert user_id, "User ID not found in environment."
    
    session_payload = {
        "provider_name": "gemini",
        "feature_name": "default"
    }
    
    with httpx.Client(timeout=10.0) as client:
        # 1. Create Session
        r = client.post(f"{BASE_URL}/sessions/", headers=_headers(), json=session_payload)
        assert r.status_code == 200
        session_id = r.json()["id"]
        
        # 2. Patch Session
        patch_payload = {"title": "Updated Session Title"}
        r = client.patch(f"{BASE_URL}/sessions/{session_id}", headers=_headers(), json=patch_payload)
        assert r.status_code == 200
        assert r.json()["title"] == "Updated Session Title"
        
        # 3. Get Session Tokens
        r = client.get(f"{BASE_URL}/sessions/{session_id}/tokens", headers=_headers())
        assert r.status_code == 200
        
        # 4. Clear History
        r = client.post(f"{BASE_URL}/sessions/{session_id}/clear-history", headers=_headers())
        assert r.status_code == 200
        
        # 5. Cancel Task
        r = client.post(f"{BASE_URL}/sessions/{session_id}/cancel", headers=_headers())
        assert r.status_code == 200
        
        # 6. Nodes
        r = client.get(f"{BASE_URL}/sessions/{session_id}/nodes", headers=_headers())
        assert r.status_code == 200
        
        node_id = os.getenv("SYNC_TEST_NODE1", "test-node-1")
        r = client.post(f"{BASE_URL}/sessions/{session_id}/nodes", headers=_headers(), json={"node_ids": [node_id]})
        assert r.status_code == 200
        
        r = client.delete(f"{BASE_URL}/sessions/{session_id}/nodes/{node_id}", headers=_headers())
        assert r.status_code == 200
        
        # 7. Audio on messages
        chat_payload = {
            "prompt": "Hello",
            "provider_name": "gemini",
            "load_faiss_retriever": False
        }
        try:
            with client.stream("POST", f"{BASE_URL}/sessions/{session_id}/chat", headers=_headers(), json=chat_payload) as r_chat:
                 for line in r_chat.iter_lines():
                     break
        except Exception:
            pass
            
        r = client.get(f"{BASE_URL}/sessions/{session_id}/messages", headers=_headers())
        assert r.status_code == 200
        messages = r.json().get("messages", [])
        assert len(messages) > 0
        message_id = messages[0]["id"]
        
        dummy_audio = b"RIFF....WAVEfmt ....data...."
        files = {'file': ('test.wav', dummy_audio, 'audio/wav')}
        r = client.post(f"{BASE_URL}/sessions/messages/{message_id}/audio", files=files, headers=_headers())
        assert r.status_code == 200
        
        r = client.get(f"{BASE_URL}/sessions/messages/{message_id}/audio", headers=_headers())
        assert r.status_code == 200
        assert r.content == dummy_audio
        
        # 8. Delete All Sessions
        r = client.delete(f"{BASE_URL}/sessions/", params={"feature_name": "default"}, headers=_headers())
        assert r.status_code == 200
        
        # Verify deleted
        r = client.get(f"{BASE_URL}/sessions/", headers=_headers())
        sessions = r.json()
        assert not any(s["id"] == session_id for s in sessions)