from unittest.mock import MagicMock, AsyncMock from datetime import datetime from app.db import models def test_create_session_success(client): """Tests successfully creating a new chat session.""" test_client, mock_services = client mock_session = MagicMock(spec=models.Session) mock_session.id = 1 mock_session.user_id = "test_user" mock_session.model_name = "gemini" mock_session.title = "New Chat" mock_session.created_at = datetime.now() mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): """ Tests sending a message in an existing session and explicitly switching the model. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) def test_chat_in_session_with_faiss_retriever(client): """ Tests sending a message and explicitly enabling the FAISS retriever. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", load_faiss_retriever=True ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" test_client, mock_services = client mock_history = [ MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()), MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history response = test_client.get("/sessions/123/messages") assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None response = test_client.get("/sessions/999/messages") assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found."