from unittest.mock import MagicMock, AsyncMock
from datetime import datetime
from app.db import models
def test_create_session_success(client):
"""Tests successfully creating a new chat session."""
test_client, mock_services = client
mock_session = MagicMock(spec=models.Session)
mock_session.id = 1
mock_session.user_id = "test_user"
mock_session.provider_name = "gemini"
mock_session.title = "New Chat"
mock_session.created_at = datetime.now()
mock_services.session_service.create_session.return_value = mock_session
response = test_client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"})
assert response.status_code == 200
assert response.json()["id"] == 1
mock_services.session_service.create_session.assert_called_once()
def test_chat_in_session_success(client):
"""
Tests sending a message in an existing session without specifying a model
or retriever. It should default to 'deepseek' and 'False'.
"""
test_client, mock_services = client
mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek"))
response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})
assert response.status_code == 200
assert response.json() == {"answer": "Mocked response", "provider_used": "deepseek"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="Hello there",
provider_name="deepseek",
load_faiss_retriever=False
)
def test_chat_in_session_with_model_switch(client):
"""
Tests sending a message in an existing session and explicitly switching the model.
"""
test_client, mock_services = client
mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"})
assert response.status_code == 200
assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="Hello there, Gemini!",
provider_name="gemini",
load_faiss_retriever=False
)
def test_chat_in_session_with_faiss_retriever(client):
"""
Tests sending a message and explicitly enabling the FAISS retriever.
"""
test_client, mock_services = client
mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek"))
response = test_client.post(
"/sessions/42/chat",
json={"prompt": "What is RAG?", "load_faiss_retriever": True}
)
assert response.status_code == 200
assert response.json() == {"answer": "Response with context", "provider_used": "deepseek"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="What is RAG?",
provider_name="deepseek",
load_faiss_retriever=True
)
def test_get_session_messages_success(client):
"""Tests retrieving the message history for a session."""
test_client, mock_services = client
mock_history = [
MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()),
MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now())
]
mock_services.rag_service.get_message_history.return_value = mock_history
response = test_client.get("/sessions/123/messages")
assert response.status_code == 200
response_data = response.json()
assert response_data["session_id"] == 123
assert len(response_data["messages"]) == 2
assert response_data["messages"][0]["sender"] == "user"
assert response_data["messages"][1]["content"] == "Hi there!"
mock_services.rag_service.get_message_history.assert_called_once_with(
db=mock_services.rag_service.get_message_history.call_args.kwargs['db'],
session_id=123
)
def test_get_session_messages_not_found(client):
"""Tests retrieving messages for a session that does not exist."""
test_client, mock_services = client
mock_services.rag_service.get_message_history.return_value = None
response = test_client.get("/sessions/999/messages")
assert response.status_code == 404
assert response.json()["detail"] == "Session with ID 999 not found."