diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 43a2e10..6d692b3 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -1,9 +1,11 @@ import pytest import httpx -# The base URL for the local server +# The base URL for the local server started by the run_tests.sh script BASE_URL = "http://127.0.0.1:8000" -# Use a specific, context-setting prompt for the conversational test + +# A common prompt to be used for the tests +TEST_PROMPT = "Explain the theory of relativity in one sentence." CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" @@ -12,10 +14,13 @@ created_session_id = None async def test_root_endpoint(): - """Tests if the root endpoint is alive.""" + """ + Tests if the root endpoint is alive and returns the correct status message. + """ print("\n--- Running test_root_endpoint ---") async with httpx.AsyncClient() as client: response = await client.get(f"{BASE_URL}/") + assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") @@ -39,31 +44,32 @@ print(f"✅ Session created successfully with ID: {created_session_id}") async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context.""" + """Tests sending the first message to establish context using the default model.""" print("\n--- Running test_chat_in_session (Turn 1) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} + payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) assert response.status_code == 200, f"Chat request failed. Response: {response.text}" response_data = response.json() - # Check that the answer mentions the CEO's name + # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) assert "Satya Nadella" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") async def test_chat_in_session_turn_2_follow_up(): """ - Tests sending a follow-up question to verify conversational memory. + Tests sending a follow-up question to verify conversational memory using the default model. """ print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} + payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) @@ -72,8 +78,50 @@ response_data = response.json() # Check that the answer contains the birth year, proving it understood "he" assert "1967" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") +async def test_chat_in_session_with_model_switch(): + """ + Tests sending a message in the same session, explicitly switching to 'gemini'. + """ + print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'gemini' model for this turn + payload = {"prompt": "What is the capital of France?", "model": "gemini"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" + response_data = response.json() + assert "Paris" in response_data["answer"] + assert response_data["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + +async def test_chat_in_session_switch_back_to_deepseek(): + """ + Tests sending another message in the same session, explicitly switching back to 'deepseek'. + """ + print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'deepseek' model for this turn + payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" + response_data = response.json() + assert "Pacific Ocean" in response_data["answer"] + assert response_data["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + + async def test_get_session_history(): """Tests retrieving the full message history for the session.""" print("\n--- Running test_get_session_history ---") @@ -87,14 +135,20 @@ response_data = response.json() assert response_data["session_id"] == created_session_id - # After two turns, there should be 4 messages (2 user, 2 assistant) - assert len(response_data["messages"]) >= 4 + # After 4 turns, there should be 8 messages (4 user, 4 assistant) + assert len(response_data["messages"]) >= 8 assert response_data["messages"][0]["content"] == CONTEXT_PROMPT assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT + # Verify content and sender for the switched models + assert response_data["messages"][4]["content"] == "What is the capital of France?" + assert response_data["messages"][5]["sender"] == "assistant" + assert "Paris" in response_data["messages"][5]["content"] + assert response_data["messages"][6]["content"] == "What is the largest ocean?" + assert response_data["messages"][7]["sender"] == "assistant" + assert "Pacific Ocean" in response_data["messages"][7]["content"] print("✅ Get session history test passed.") # --- Document Management Lifecycle Tests --- -# (These tests remain unchanged) async def test_add_document_for_lifecycle(): global created_document_id print("\n--- Running test_add_document (for lifecycle) ---") @@ -135,4 +189,4 @@ assert response.status_code == 200 assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 43a2e10..6d692b3 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -1,9 +1,11 @@ import pytest import httpx -# The base URL for the local server +# The base URL for the local server started by the run_tests.sh script BASE_URL = "http://127.0.0.1:8000" -# Use a specific, context-setting prompt for the conversational test + +# A common prompt to be used for the tests +TEST_PROMPT = "Explain the theory of relativity in one sentence." CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" @@ -12,10 +14,13 @@ created_session_id = None async def test_root_endpoint(): - """Tests if the root endpoint is alive.""" + """ + Tests if the root endpoint is alive and returns the correct status message. + """ print("\n--- Running test_root_endpoint ---") async with httpx.AsyncClient() as client: response = await client.get(f"{BASE_URL}/") + assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") @@ -39,31 +44,32 @@ print(f"✅ Session created successfully with ID: {created_session_id}") async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context.""" + """Tests sending the first message to establish context using the default model.""" print("\n--- Running test_chat_in_session (Turn 1) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} + payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) assert response.status_code == 200, f"Chat request failed. Response: {response.text}" response_data = response.json() - # Check that the answer mentions the CEO's name + # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) assert "Satya Nadella" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") async def test_chat_in_session_turn_2_follow_up(): """ - Tests sending a follow-up question to verify conversational memory. + Tests sending a follow-up question to verify conversational memory using the default model. """ print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} + payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) @@ -72,8 +78,50 @@ response_data = response.json() # Check that the answer contains the birth year, proving it understood "he" assert "1967" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") +async def test_chat_in_session_with_model_switch(): + """ + Tests sending a message in the same session, explicitly switching to 'gemini'. + """ + print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'gemini' model for this turn + payload = {"prompt": "What is the capital of France?", "model": "gemini"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" + response_data = response.json() + assert "Paris" in response_data["answer"] + assert response_data["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + +async def test_chat_in_session_switch_back_to_deepseek(): + """ + Tests sending another message in the same session, explicitly switching back to 'deepseek'. + """ + print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'deepseek' model for this turn + payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" + response_data = response.json() + assert "Pacific Ocean" in response_data["answer"] + assert response_data["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + + async def test_get_session_history(): """Tests retrieving the full message history for the session.""" print("\n--- Running test_get_session_history ---") @@ -87,14 +135,20 @@ response_data = response.json() assert response_data["session_id"] == created_session_id - # After two turns, there should be 4 messages (2 user, 2 assistant) - assert len(response_data["messages"]) >= 4 + # After 4 turns, there should be 8 messages (4 user, 4 assistant) + assert len(response_data["messages"]) >= 8 assert response_data["messages"][0]["content"] == CONTEXT_PROMPT assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT + # Verify content and sender for the switched models + assert response_data["messages"][4]["content"] == "What is the capital of France?" + assert response_data["messages"][5]["sender"] == "assistant" + assert "Paris" in response_data["messages"][5]["content"] + assert response_data["messages"][6]["content"] == "What is the largest ocean?" + assert response_data["messages"][7]["sender"] == "assistant" + assert "Pacific Ocean" in response_data["messages"][7]["content"] print("✅ Get session history test passed.") # --- Document Management Lifecycle Tests --- -# (These tests remain unchanged) async def test_add_document_for_lifecycle(): global created_document_id print("\n--- Running test_add_document (for lifecycle) ---") @@ -135,4 +189,4 @@ assert response.status_code == 200 assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5a5254b..b13e8d1 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -1,11 +1,14 @@ #!/bin/bash # A script to automatically start the server and run an interactive chat session. +# It now allows the user to specify a model for each turn or use the previous one. # # REQUIREMENTS: # - 'jq' must be installed (e.g., sudo apt-get install jq). BASE_URL="http://127.0.0.1:8000" +DEFAULT_MODEL="deepseek" +CURRENT_MODEL="" # The model used in the last turn # --- 1. Check for Dependencies --- if ! command -v jq &> /dev/null @@ -35,33 +38,75 @@ echo "--- Starting a new conversation session... ---" SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ -H "Content-Type: application/json" \ - -d '{"user_id": "local_user", "model": "deepseek"}') + -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ + -w '\n%{http_code}') # Add a new line and the status code -SESSION_ID=$(echo "$SESSION_DATA" | jq '.id') +# Extract body and status code +HTTP_CODE=$(echo "$SESSION_DATA" | tail -n1) +SESSION_DATA_BODY=$(echo "$SESSION_DATA" | head -n-1) -if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then - echo "❌ Failed to create a session. Server might not have started correctly." +if [ "$HTTP_CODE" -ne 200 ]; then + echo "❌ Failed to create a session (HTTP $HTTP_CODE). Server might not have started correctly." + echo "Response body: $SESSION_DATA_BODY" exit 1 fi -echo "✅ Session created with ID: $SESSION_ID. Type 'exit' or 'quit' to end." +SESSION_ID=$(echo "$SESSION_DATA_BODY" | jq -r '.id') +if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then + echo "❌ Failed to create a session. Server response did not contain an ID." + echo "Response body: $SESSION_DATA_BODY" + exit 1 +fi + +# Set the initial model +CURRENT_MODEL="$DEFAULT_MODEL" + +echo "✅ Session created with ID: $SESSION_ID. The initial model is '$CURRENT_MODEL'." +echo "--------------------------------------------------" +echo "To switch models, type your message like this: [gemini] " +echo "To use the previous model, just type your message directly." +echo "Type 'exit' or 'quit' to end." echo "--------------------------------------------------" # --- 4. Start the Interactive Chat Loop --- while true; do - read -p "You: " user_input + read -p "You [$CURRENT_MODEL]: " user_input if [[ "$user_input" == "exit" || "$user_input" == "quit" ]]; then break fi - json_payload=$(jq -n --arg prompt "$user_input" '{"prompt": $prompt}') + # Check for model switch input pattern, e.g., "[model_name] " + if [[ "$user_input" =~ ^\[([a-zA-Z0-9]+)\]\ (.*)$ ]]; then + MODEL_TO_USE="${BASH_REMATCH[1]}" + PROMPT_TEXT="${BASH_REMATCH[2]}" + # Update the current model for the next prompt + CURRENT_MODEL="$MODEL_TO_USE" + else + MODEL_TO_USE="$CURRENT_MODEL" + PROMPT_TEXT="$user_input" + fi + + # Construct the JSON payload with the model and prompt + json_payload=$(jq -n \ + --arg prompt "$PROMPT_TEXT" \ + --arg model "$MODEL_TO_USE" \ + '{"prompt": $prompt, "model": $model}') - ai_response=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ + ai_response_json=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ -H "Content-Type: application/json" \ - -d "$json_payload" | jq -r '.answer') + -d "$json_payload") + + # Check if the response is valid JSON + if ! echo "$ai_response_json" | jq -e . >/dev/null; then + echo "❌ AI: An error occurred or the server returned an invalid response." + echo "Server response: $ai_response_json" + else + ai_answer=$(echo "$ai_response_json" | jq -r '.answer') + model_used=$(echo "$ai_response_json" | jq -r '.model_used') + echo "AI [$model_used]: $ai_answer" + fi - echo "AI: $ai_response" done # The 'trap' will automatically call the cleanup function when the loop breaks. \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 43a2e10..6d692b3 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -1,9 +1,11 @@ import pytest import httpx -# The base URL for the local server +# The base URL for the local server started by the run_tests.sh script BASE_URL = "http://127.0.0.1:8000" -# Use a specific, context-setting prompt for the conversational test + +# A common prompt to be used for the tests +TEST_PROMPT = "Explain the theory of relativity in one sentence." CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" @@ -12,10 +14,13 @@ created_session_id = None async def test_root_endpoint(): - """Tests if the root endpoint is alive.""" + """ + Tests if the root endpoint is alive and returns the correct status message. + """ print("\n--- Running test_root_endpoint ---") async with httpx.AsyncClient() as client: response = await client.get(f"{BASE_URL}/") + assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") @@ -39,31 +44,32 @@ print(f"✅ Session created successfully with ID: {created_session_id}") async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context.""" + """Tests sending the first message to establish context using the default model.""" print("\n--- Running test_chat_in_session (Turn 1) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} + payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) assert response.status_code == 200, f"Chat request failed. Response: {response.text}" response_data = response.json() - # Check that the answer mentions the CEO's name + # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) assert "Satya Nadella" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") async def test_chat_in_session_turn_2_follow_up(): """ - Tests sending a follow-up question to verify conversational memory. + Tests sending a follow-up question to verify conversational memory using the default model. """ print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} + payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) @@ -72,8 +78,50 @@ response_data = response.json() # Check that the answer contains the birth year, proving it understood "he" assert "1967" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") +async def test_chat_in_session_with_model_switch(): + """ + Tests sending a message in the same session, explicitly switching to 'gemini'. + """ + print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'gemini' model for this turn + payload = {"prompt": "What is the capital of France?", "model": "gemini"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" + response_data = response.json() + assert "Paris" in response_data["answer"] + assert response_data["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + +async def test_chat_in_session_switch_back_to_deepseek(): + """ + Tests sending another message in the same session, explicitly switching back to 'deepseek'. + """ + print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'deepseek' model for this turn + payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" + response_data = response.json() + assert "Pacific Ocean" in response_data["answer"] + assert response_data["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + + async def test_get_session_history(): """Tests retrieving the full message history for the session.""" print("\n--- Running test_get_session_history ---") @@ -87,14 +135,20 @@ response_data = response.json() assert response_data["session_id"] == created_session_id - # After two turns, there should be 4 messages (2 user, 2 assistant) - assert len(response_data["messages"]) >= 4 + # After 4 turns, there should be 8 messages (4 user, 4 assistant) + assert len(response_data["messages"]) >= 8 assert response_data["messages"][0]["content"] == CONTEXT_PROMPT assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT + # Verify content and sender for the switched models + assert response_data["messages"][4]["content"] == "What is the capital of France?" + assert response_data["messages"][5]["sender"] == "assistant" + assert "Paris" in response_data["messages"][5]["content"] + assert response_data["messages"][6]["content"] == "What is the largest ocean?" + assert response_data["messages"][7]["sender"] == "assistant" + assert "Pacific Ocean" in response_data["messages"][7]["content"] print("✅ Get session history test passed.") # --- Document Management Lifecycle Tests --- -# (These tests remain unchanged) async def test_add_document_for_lifecycle(): global created_document_id print("\n--- Running test_add_document (for lifecycle) ---") @@ -135,4 +189,4 @@ assert response.status_code == 200 assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5a5254b..b13e8d1 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -1,11 +1,14 @@ #!/bin/bash # A script to automatically start the server and run an interactive chat session. +# It now allows the user to specify a model for each turn or use the previous one. # # REQUIREMENTS: # - 'jq' must be installed (e.g., sudo apt-get install jq). BASE_URL="http://127.0.0.1:8000" +DEFAULT_MODEL="deepseek" +CURRENT_MODEL="" # The model used in the last turn # --- 1. Check for Dependencies --- if ! command -v jq &> /dev/null @@ -35,33 +38,75 @@ echo "--- Starting a new conversation session... ---" SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ -H "Content-Type: application/json" \ - -d '{"user_id": "local_user", "model": "deepseek"}') + -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ + -w '\n%{http_code}') # Add a new line and the status code -SESSION_ID=$(echo "$SESSION_DATA" | jq '.id') +# Extract body and status code +HTTP_CODE=$(echo "$SESSION_DATA" | tail -n1) +SESSION_DATA_BODY=$(echo "$SESSION_DATA" | head -n-1) -if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then - echo "❌ Failed to create a session. Server might not have started correctly." +if [ "$HTTP_CODE" -ne 200 ]; then + echo "❌ Failed to create a session (HTTP $HTTP_CODE). Server might not have started correctly." + echo "Response body: $SESSION_DATA_BODY" exit 1 fi -echo "✅ Session created with ID: $SESSION_ID. Type 'exit' or 'quit' to end." +SESSION_ID=$(echo "$SESSION_DATA_BODY" | jq -r '.id') +if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then + echo "❌ Failed to create a session. Server response did not contain an ID." + echo "Response body: $SESSION_DATA_BODY" + exit 1 +fi + +# Set the initial model +CURRENT_MODEL="$DEFAULT_MODEL" + +echo "✅ Session created with ID: $SESSION_ID. The initial model is '$CURRENT_MODEL'." +echo "--------------------------------------------------" +echo "To switch models, type your message like this: [gemini] " +echo "To use the previous model, just type your message directly." +echo "Type 'exit' or 'quit' to end." echo "--------------------------------------------------" # --- 4. Start the Interactive Chat Loop --- while true; do - read -p "You: " user_input + read -p "You [$CURRENT_MODEL]: " user_input if [[ "$user_input" == "exit" || "$user_input" == "quit" ]]; then break fi - json_payload=$(jq -n --arg prompt "$user_input" '{"prompt": $prompt}') + # Check for model switch input pattern, e.g., "[model_name] " + if [[ "$user_input" =~ ^\[([a-zA-Z0-9]+)\]\ (.*)$ ]]; then + MODEL_TO_USE="${BASH_REMATCH[1]}" + PROMPT_TEXT="${BASH_REMATCH[2]}" + # Update the current model for the next prompt + CURRENT_MODEL="$MODEL_TO_USE" + else + MODEL_TO_USE="$CURRENT_MODEL" + PROMPT_TEXT="$user_input" + fi + + # Construct the JSON payload with the model and prompt + json_payload=$(jq -n \ + --arg prompt "$PROMPT_TEXT" \ + --arg model "$MODEL_TO_USE" \ + '{"prompt": $prompt, "model": $model}') - ai_response=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ + ai_response_json=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ -H "Content-Type: application/json" \ - -d "$json_payload" | jq -r '.answer') + -d "$json_payload") + + # Check if the response is valid JSON + if ! echo "$ai_response_json" | jq -e . >/dev/null; then + echo "❌ AI: An error occurred or the server returned an invalid response." + echo "Server response: $ai_response_json" + else + ai_answer=$(echo "$ai_response_json" | jq -r '.answer') + model_used=$(echo "$ai_response_json" | jq -r '.model_used') + echo "AI [$model_used]: $ai_answer" + fi - echo "AI: $ai_response" done # The 'trap' will automatically call the cleanup function when the loop breaks. \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3053a4d..3e3b252 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -51,15 +51,43 @@ mock_rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): - """Tests sending a message in an existing session.""" + """ + Tests sending a message in an existing session without specifying a model. + It should default to 'deepseek'. + """ test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - mock_rag_service.chat_with_rag.assert_called_once() + # Verify that chat_with_rag was called with the default model 'deepseek' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek" + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini" + ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" @@ -128,4 +156,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 43a2e10..6d692b3 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -1,9 +1,11 @@ import pytest import httpx -# The base URL for the local server +# The base URL for the local server started by the run_tests.sh script BASE_URL = "http://127.0.0.1:8000" -# Use a specific, context-setting prompt for the conversational test + +# A common prompt to be used for the tests +TEST_PROMPT = "Explain the theory of relativity in one sentence." CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" @@ -12,10 +14,13 @@ created_session_id = None async def test_root_endpoint(): - """Tests if the root endpoint is alive.""" + """ + Tests if the root endpoint is alive and returns the correct status message. + """ print("\n--- Running test_root_endpoint ---") async with httpx.AsyncClient() as client: response = await client.get(f"{BASE_URL}/") + assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") @@ -39,31 +44,32 @@ print(f"✅ Session created successfully with ID: {created_session_id}") async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context.""" + """Tests sending the first message to establish context using the default model.""" print("\n--- Running test_chat_in_session (Turn 1) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} + payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) assert response.status_code == 200, f"Chat request failed. Response: {response.text}" response_data = response.json() - # Check that the answer mentions the CEO's name + # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) assert "Satya Nadella" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") async def test_chat_in_session_turn_2_follow_up(): """ - Tests sending a follow-up question to verify conversational memory. + Tests sending a follow-up question to verify conversational memory using the default model. """ print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} + payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) @@ -72,8 +78,50 @@ response_data = response.json() # Check that the answer contains the birth year, proving it understood "he" assert "1967" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") +async def test_chat_in_session_with_model_switch(): + """ + Tests sending a message in the same session, explicitly switching to 'gemini'. + """ + print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'gemini' model for this turn + payload = {"prompt": "What is the capital of France?", "model": "gemini"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" + response_data = response.json() + assert "Paris" in response_data["answer"] + assert response_data["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + +async def test_chat_in_session_switch_back_to_deepseek(): + """ + Tests sending another message in the same session, explicitly switching back to 'deepseek'. + """ + print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'deepseek' model for this turn + payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" + response_data = response.json() + assert "Pacific Ocean" in response_data["answer"] + assert response_data["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + + async def test_get_session_history(): """Tests retrieving the full message history for the session.""" print("\n--- Running test_get_session_history ---") @@ -87,14 +135,20 @@ response_data = response.json() assert response_data["session_id"] == created_session_id - # After two turns, there should be 4 messages (2 user, 2 assistant) - assert len(response_data["messages"]) >= 4 + # After 4 turns, there should be 8 messages (4 user, 4 assistant) + assert len(response_data["messages"]) >= 8 assert response_data["messages"][0]["content"] == CONTEXT_PROMPT assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT + # Verify content and sender for the switched models + assert response_data["messages"][4]["content"] == "What is the capital of France?" + assert response_data["messages"][5]["sender"] == "assistant" + assert "Paris" in response_data["messages"][5]["content"] + assert response_data["messages"][6]["content"] == "What is the largest ocean?" + assert response_data["messages"][7]["sender"] == "assistant" + assert "Pacific Ocean" in response_data["messages"][7]["content"] print("✅ Get session history test passed.") # --- Document Management Lifecycle Tests --- -# (These tests remain unchanged) async def test_add_document_for_lifecycle(): global created_document_id print("\n--- Running test_add_document (for lifecycle) ---") @@ -135,4 +189,4 @@ assert response.status_code == 200 assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5a5254b..b13e8d1 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -1,11 +1,14 @@ #!/bin/bash # A script to automatically start the server and run an interactive chat session. +# It now allows the user to specify a model for each turn or use the previous one. # # REQUIREMENTS: # - 'jq' must be installed (e.g., sudo apt-get install jq). BASE_URL="http://127.0.0.1:8000" +DEFAULT_MODEL="deepseek" +CURRENT_MODEL="" # The model used in the last turn # --- 1. Check for Dependencies --- if ! command -v jq &> /dev/null @@ -35,33 +38,75 @@ echo "--- Starting a new conversation session... ---" SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ -H "Content-Type: application/json" \ - -d '{"user_id": "local_user", "model": "deepseek"}') + -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ + -w '\n%{http_code}') # Add a new line and the status code -SESSION_ID=$(echo "$SESSION_DATA" | jq '.id') +# Extract body and status code +HTTP_CODE=$(echo "$SESSION_DATA" | tail -n1) +SESSION_DATA_BODY=$(echo "$SESSION_DATA" | head -n-1) -if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then - echo "❌ Failed to create a session. Server might not have started correctly." +if [ "$HTTP_CODE" -ne 200 ]; then + echo "❌ Failed to create a session (HTTP $HTTP_CODE). Server might not have started correctly." + echo "Response body: $SESSION_DATA_BODY" exit 1 fi -echo "✅ Session created with ID: $SESSION_ID. Type 'exit' or 'quit' to end." +SESSION_ID=$(echo "$SESSION_DATA_BODY" | jq -r '.id') +if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then + echo "❌ Failed to create a session. Server response did not contain an ID." + echo "Response body: $SESSION_DATA_BODY" + exit 1 +fi + +# Set the initial model +CURRENT_MODEL="$DEFAULT_MODEL" + +echo "✅ Session created with ID: $SESSION_ID. The initial model is '$CURRENT_MODEL'." +echo "--------------------------------------------------" +echo "To switch models, type your message like this: [gemini] " +echo "To use the previous model, just type your message directly." +echo "Type 'exit' or 'quit' to end." echo "--------------------------------------------------" # --- 4. Start the Interactive Chat Loop --- while true; do - read -p "You: " user_input + read -p "You [$CURRENT_MODEL]: " user_input if [[ "$user_input" == "exit" || "$user_input" == "quit" ]]; then break fi - json_payload=$(jq -n --arg prompt "$user_input" '{"prompt": $prompt}') + # Check for model switch input pattern, e.g., "[model_name] " + if [[ "$user_input" =~ ^\[([a-zA-Z0-9]+)\]\ (.*)$ ]]; then + MODEL_TO_USE="${BASH_REMATCH[1]}" + PROMPT_TEXT="${BASH_REMATCH[2]}" + # Update the current model for the next prompt + CURRENT_MODEL="$MODEL_TO_USE" + else + MODEL_TO_USE="$CURRENT_MODEL" + PROMPT_TEXT="$user_input" + fi + + # Construct the JSON payload with the model and prompt + json_payload=$(jq -n \ + --arg prompt "$PROMPT_TEXT" \ + --arg model "$MODEL_TO_USE" \ + '{"prompt": $prompt, "model": $model}') - ai_response=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ + ai_response_json=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ -H "Content-Type: application/json" \ - -d "$json_payload" | jq -r '.answer') + -d "$json_payload") + + # Check if the response is valid JSON + if ! echo "$ai_response_json" | jq -e . >/dev/null; then + echo "❌ AI: An error occurred or the server returned an invalid response." + echo "Server response: $ai_response_json" + else + ai_answer=$(echo "$ai_response_json" | jq -r '.answer') + model_used=$(echo "$ai_response_json" | jq -r '.model_used') + echo "AI [$model_used]: $ai_answer" + fi - echo "AI: $ai_response" done # The 'trap' will automatically call the cleanup function when the loop breaks. \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3053a4d..3e3b252 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -51,15 +51,43 @@ mock_rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): - """Tests sending a message in an existing session.""" + """ + Tests sending a message in an existing session without specifying a model. + It should default to 'deepseek'. + """ test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - mock_rag_service.chat_with_rag.assert_called_once() + # Verify that chat_with_rag was called with the default model 'deepseek' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek" + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini" + ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" @@ -128,4 +156,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index f6dab25..777a466 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -2,21 +2,32 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError # Import the specific error type +from typing import List +from datetime import datetime +import dspy # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider +from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider + @pytest.fixture def rag_service(): - """Pytest fixture to create a RAGService instance with mocked dependencies.""" + """ + Pytest fixture to create a RAGService instance with mocked dependencies. + Correctly instantiates RAGService with only the required arguments. + """ mock_vector_store = MagicMock(spec=FaissVectorStore) mock_retriever = MagicMock(spec=Retriever) - return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) + return RAGService( + vector_store=mock_vector_store, + retrievers=[mock_retriever] + ) # --- Session Management Tests --- @@ -37,11 +48,11 @@ @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests the full orchestration of a chat message within a session. + Tests the full orchestration of a chat message within a session using the default model. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - # **FIX**: The mock session now needs a 'messages' attribute for the history + # The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session @@ -52,14 +63,15 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) + # Pass the 'model' argument, defaulting to "deepseek" for this test case + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") - # **FIX**: Assert that the pipeline was called with the history argument + # Assert that the pipeline was called with the history argument mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, @@ -69,11 +81,52 @@ assert answer == "Final RAG response" assert model_name == "deepseek" +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests that chat_with_rag correctly switches the model based on the 'model' argument. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + # Explicitly request the "gemini" model for this chat turn + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + # Verify that get_llm_provider was called with "gemini" + mock_get_llm_provider.assert_called_once_with("gemini") + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt for Gemini", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response from Gemini" + assert model_name == "gemini" + def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, messages=[models.Message(), models.Message()]) + # Ensure mocked messages have created_at for sorting + mock_session = models.Session(id=1, messages=[ + models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), + models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) + ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act @@ -81,6 +134,7 @@ # Assert assert len(messages) == 2 + assert messages[0].created_at < messages[1].created_at # Verify sorting mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): @@ -93,4 +147,180 @@ messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert - assert messages is None \ No newline at end of file + assert messages is None + +# --- Document Management Tests --- + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service correctly + rag_service = RAGService( + vector_store=mock_vector_store_instance, + retrievers=[] + ) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + from unittest.mock import call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise the specific SQLAlchemyError. + mock_db.add.side_effect = SQLAlchemyError("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service correctly + rag_service = RAGService( + vector_store=mock_vector_store_instance, + retrievers=[] + ) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + with pytest.raises(SQLAlchemyError, match="Database error"): + rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + mock_db.add.assert_called_once() + mock_db.commit.assert_not_called() + mock_db.rollback.assert_called_once() + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, model_name="deepseek", messages=[ + models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)), + models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0)) + ]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + prompt = "Test prompt." + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever = rag_service.retrievers[0] + mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."]) + + # --- Act --- + response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_pipeline_instance.forward.assert_called_once_with( + question=prompt, + history=mock_session.messages, + db=mock_db + ) + + assert response_text == "LLM response with context" + assert model_used == "deepseek" + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + prompt = "Test prompt without context." + mock_retriever = rag_service.retrievers[0] + mock_retriever.retrieve_context = AsyncMock(return_value=[]) + + # --- Act --- + response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_pipeline_instance.forward.assert_called_once_with( + question=prompt, + history=mock_session.messages, + db=mock_db + ) + + assert response_text == "LLM response without context" + assert model_used == "deepseek" \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index b499ef6..dff0a91 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -42,12 +42,15 @@ ): """ Sends a message within an existing session and gets a contextual response. + + The 'model' can now be specified in the request body to switch models mid-conversation. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, - prompt=request.prompt + prompt=request.prompt, + model=request.model # Pass the model from the request to the RAG service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: @@ -108,4 +111,4 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router \ No newline at end of file + return router diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 9984f70..25b52bc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError @@ -7,8 +8,7 @@ from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider - +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ @@ -32,9 +32,12 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: - """Handles a message within a session, including saving history and getting a response.""" - # **FIX 1**: Eagerly load the message history in a single query for efficiency. + async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + Allows switching the LLM model for the current chat turn. + """ + # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -47,16 +50,20 @@ db.add(user_message) db.commit() - llm_provider = get_llm_provider(session.model_name) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + # Use the 'model' parameter passed to this method for the current chat turn + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - # **FIX 2**: Pass the full message history to the pipeline's forward method. + # Pass the full message history to the pipeline's forward method. + # Note: The history is passed, but the current RAGPipeline implementation + # might not fully utilize it for conversational context unless explicitly + # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, + history=session.messages, # Pass the existing history db=db ) @@ -65,7 +72,8 @@ db.add(assistant_message) db.commit() - return answer_text, session.model_name + # Return the answer text and the model that was actually used for this turn + return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ @@ -75,7 +83,8 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - return session.messages if session else None + # Return messages sorted by created_at to ensure chronological order + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: diff --git a/ai-hub/dspy_rag.py b/ai-hub/dspy_rag.py deleted file mode 100644 index f246fd3..0000000 --- a/ai-hub/dspy_rag.py +++ /dev/null @@ -1,83 +0,0 @@ -import dspy -import logging -from typing import List -from types import SimpleNamespace -from sqlalchemy.orm import Session - -from app.db import models # Import your SQLAlchemy models -from app.core.retrievers import Retriever -from app.core.llm_providers import LLMProvider - -# (The DSPyLLMProvider class is unchanged) -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - -# --- 1. Update the Signature to include Chat History --- -class AnswerWithHistory(dspy.Signature): - """Given the context and chat history, answer the user's question.""" - - context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") - chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") - question = dspy.InputField() - answer = dspy.OutputField() - -class DspyRagPipeline(dspy.Module): - """ - A conversational RAG pipeline that uses document context and chat history. - """ - def __init__(self, retrievers: List[Retriever]): - super().__init__() - self.retrievers = retrievers - # Use the new signature that includes history - self.generate_answer = dspy.Predict(AnswerWithHistory) - - # --- 2. Update the `forward` method to accept history --- - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: - """ - Executes the RAG pipeline using the question and the conversation history. - """ - logging.debug(f"[app.api.dependencies] Received question: '{question}'") - - # Retrieve document context based on the current question - retrieved_contexts = [] - for retriever in self.retrievers: - context = retriever.retrieve_context(question, db) - retrieved_contexts.extend(context) - - context_text = "\n\n".join(retrieved_contexts) or "No context provided." - - # --- 3. Format the chat history into a string --- - history_str = "\n".join( - f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" - for msg in history - ) - - # --- 4. Build the final prompt including history --- - instruction = self.generate_answer.signature.__doc__ - full_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context: {context_text}\n\n" - f"---\n\n" - f"Chat History:\n{history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" - ) - - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - return response_obj.choices[0].message.content \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 43a2e10..6d692b3 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -1,9 +1,11 @@ import pytest import httpx -# The base URL for the local server +# The base URL for the local server started by the run_tests.sh script BASE_URL = "http://127.0.0.1:8000" -# Use a specific, context-setting prompt for the conversational test + +# A common prompt to be used for the tests +TEST_PROMPT = "Explain the theory of relativity in one sentence." CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" @@ -12,10 +14,13 @@ created_session_id = None async def test_root_endpoint(): - """Tests if the root endpoint is alive.""" + """ + Tests if the root endpoint is alive and returns the correct status message. + """ print("\n--- Running test_root_endpoint ---") async with httpx.AsyncClient() as client: response = await client.get(f"{BASE_URL}/") + assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") @@ -39,31 +44,32 @@ print(f"✅ Session created successfully with ID: {created_session_id}") async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context.""" + """Tests sending the first message to establish context using the default model.""" print("\n--- Running test_chat_in_session (Turn 1) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} + payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) assert response.status_code == 200, f"Chat request failed. Response: {response.text}" response_data = response.json() - # Check that the answer mentions the CEO's name + # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) assert "Satya Nadella" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") async def test_chat_in_session_turn_2_follow_up(): """ - Tests sending a follow-up question to verify conversational memory. + Tests sending a follow-up question to verify conversational memory using the default model. """ print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") assert created_session_id is not None, "Session ID was not set." url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} + payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) @@ -72,8 +78,50 @@ response_data = response.json() # Check that the answer contains the birth year, proving it understood "he" assert "1967" in response_data["answer"] + assert response_data["model_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") +async def test_chat_in_session_with_model_switch(): + """ + Tests sending a message in the same session, explicitly switching to 'gemini'. + """ + print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'gemini' model for this turn + payload = {"prompt": "What is the capital of France?", "model": "gemini"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" + response_data = response.json() + assert "Paris" in response_data["answer"] + assert response_data["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + +async def test_chat_in_session_switch_back_to_deepseek(): + """ + Tests sending another message in the same session, explicitly switching back to 'deepseek'. + """ + print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") + assert created_session_id is not None, "Session ID was not set." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + # Explicitly request 'deepseek' model for this turn + payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" + response_data = response.json() + assert "Pacific Ocean" in response_data["answer"] + assert response_data["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + + async def test_get_session_history(): """Tests retrieving the full message history for the session.""" print("\n--- Running test_get_session_history ---") @@ -87,14 +135,20 @@ response_data = response.json() assert response_data["session_id"] == created_session_id - # After two turns, there should be 4 messages (2 user, 2 assistant) - assert len(response_data["messages"]) >= 4 + # After 4 turns, there should be 8 messages (4 user, 4 assistant) + assert len(response_data["messages"]) >= 8 assert response_data["messages"][0]["content"] == CONTEXT_PROMPT assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT + # Verify content and sender for the switched models + assert response_data["messages"][4]["content"] == "What is the capital of France?" + assert response_data["messages"][5]["sender"] == "assistant" + assert "Paris" in response_data["messages"][5]["content"] + assert response_data["messages"][6]["content"] == "What is the largest ocean?" + assert response_data["messages"][7]["sender"] == "assistant" + assert "Pacific Ocean" in response_data["messages"][7]["content"] print("✅ Get session history test passed.") # --- Document Management Lifecycle Tests --- -# (These tests remain unchanged) async def test_add_document_for_lifecycle(): global created_document_id print("\n--- Running test_add_document (for lifecycle) ---") @@ -135,4 +189,4 @@ assert response.status_code == 200 assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5a5254b..b13e8d1 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -1,11 +1,14 @@ #!/bin/bash # A script to automatically start the server and run an interactive chat session. +# It now allows the user to specify a model for each turn or use the previous one. # # REQUIREMENTS: # - 'jq' must be installed (e.g., sudo apt-get install jq). BASE_URL="http://127.0.0.1:8000" +DEFAULT_MODEL="deepseek" +CURRENT_MODEL="" # The model used in the last turn # --- 1. Check for Dependencies --- if ! command -v jq &> /dev/null @@ -35,33 +38,75 @@ echo "--- Starting a new conversation session... ---" SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ -H "Content-Type: application/json" \ - -d '{"user_id": "local_user", "model": "deepseek"}') + -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ + -w '\n%{http_code}') # Add a new line and the status code -SESSION_ID=$(echo "$SESSION_DATA" | jq '.id') +# Extract body and status code +HTTP_CODE=$(echo "$SESSION_DATA" | tail -n1) +SESSION_DATA_BODY=$(echo "$SESSION_DATA" | head -n-1) -if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then - echo "❌ Failed to create a session. Server might not have started correctly." +if [ "$HTTP_CODE" -ne 200 ]; then + echo "❌ Failed to create a session (HTTP $HTTP_CODE). Server might not have started correctly." + echo "Response body: $SESSION_DATA_BODY" exit 1 fi -echo "✅ Session created with ID: $SESSION_ID. Type 'exit' or 'quit' to end." +SESSION_ID=$(echo "$SESSION_DATA_BODY" | jq -r '.id') +if [ -z "$SESSION_ID" ] || [ "$SESSION_ID" == "null" ]; then + echo "❌ Failed to create a session. Server response did not contain an ID." + echo "Response body: $SESSION_DATA_BODY" + exit 1 +fi + +# Set the initial model +CURRENT_MODEL="$DEFAULT_MODEL" + +echo "✅ Session created with ID: $SESSION_ID. The initial model is '$CURRENT_MODEL'." +echo "--------------------------------------------------" +echo "To switch models, type your message like this: [gemini] " +echo "To use the previous model, just type your message directly." +echo "Type 'exit' or 'quit' to end." echo "--------------------------------------------------" # --- 4. Start the Interactive Chat Loop --- while true; do - read -p "You: " user_input + read -p "You [$CURRENT_MODEL]: " user_input if [[ "$user_input" == "exit" || "$user_input" == "quit" ]]; then break fi - json_payload=$(jq -n --arg prompt "$user_input" '{"prompt": $prompt}') + # Check for model switch input pattern, e.g., "[model_name] " + if [[ "$user_input" =~ ^\[([a-zA-Z0-9]+)\]\ (.*)$ ]]; then + MODEL_TO_USE="${BASH_REMATCH[1]}" + PROMPT_TEXT="${BASH_REMATCH[2]}" + # Update the current model for the next prompt + CURRENT_MODEL="$MODEL_TO_USE" + else + MODEL_TO_USE="$CURRENT_MODEL" + PROMPT_TEXT="$user_input" + fi + + # Construct the JSON payload with the model and prompt + json_payload=$(jq -n \ + --arg prompt "$PROMPT_TEXT" \ + --arg model "$MODEL_TO_USE" \ + '{"prompt": $prompt, "model": $model}') - ai_response=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ + ai_response_json=$(curl -s -X POST "$BASE_URL/sessions/$SESSION_ID/chat" \ -H "Content-Type: application/json" \ - -d "$json_payload" | jq -r '.answer') + -d "$json_payload") + + # Check if the response is valid JSON + if ! echo "$ai_response_json" | jq -e . >/dev/null; then + echo "❌ AI: An error occurred or the server returned an invalid response." + echo "Server response: $ai_response_json" + else + ai_answer=$(echo "$ai_response_json" | jq -r '.answer') + model_used=$(echo "$ai_response_json" | jq -r '.model_used') + echo "AI [$model_used]: $ai_answer" + fi - echo "AI: $ai_response" done # The 'trap' will automatically call the cleanup function when the loop breaks. \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3053a4d..3e3b252 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -51,15 +51,43 @@ mock_rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): - """Tests sending a message in an existing session.""" + """ + Tests sending a message in an existing session without specifying a model. + It should default to 'deepseek'. + """ test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - mock_rag_service.chat_with_rag.assert_called_once() + # Verify that chat_with_rag was called with the default model 'deepseek' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek" + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini" + ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" @@ -128,4 +156,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index f6dab25..777a466 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -2,21 +2,32 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError # Import the specific error type +from typing import List +from datetime import datetime +import dspy # Import the service and its dependencies from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider +from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider + @pytest.fixture def rag_service(): - """Pytest fixture to create a RAGService instance with mocked dependencies.""" + """ + Pytest fixture to create a RAGService instance with mocked dependencies. + Correctly instantiates RAGService with only the required arguments. + """ mock_vector_store = MagicMock(spec=FaissVectorStore) mock_retriever = MagicMock(spec=Retriever) - return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) + return RAGService( + vector_store=mock_vector_store, + retrievers=[mock_retriever] + ) # --- Session Management Tests --- @@ -37,11 +48,11 @@ @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests the full orchestration of a chat message within a session. + Tests the full orchestration of a chat message within a session using the default model. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - # **FIX**: The mock session now needs a 'messages' attribute for the history + # The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session @@ -52,14 +63,15 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) + # Pass the 'model' argument, defaulting to "deepseek" for this test case + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek")) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") - # **FIX**: Assert that the pipeline was called with the history argument + # Assert that the pipeline was called with the history argument mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, @@ -69,11 +81,52 @@ assert answer == "Final RAG response" assert model_name == "deepseek" +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests that chat_with_rag correctly switches the model based on the 'model' argument. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + # Explicitly request the "gemini" model for this chat turn + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + # Verify that get_llm_provider was called with "gemini" + mock_get_llm_provider.assert_called_once_with("gemini") + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt for Gemini", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response from Gemini" + assert model_name == "gemini" + def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, messages=[models.Message(), models.Message()]) + # Ensure mocked messages have created_at for sorting + mock_session = models.Session(id=1, messages=[ + models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), + models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) + ]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session # Act @@ -81,6 +134,7 @@ # Assert assert len(messages) == 2 + assert messages[0].created_at < messages[1].created_at # Verify sorting mock_db.query.assert_called_once_with(models.Session) def test_get_message_history_not_found(rag_service: RAGService): @@ -93,4 +147,180 @@ messages = rag_service.get_message_history(db=mock_db, session_id=999) # Assert - assert messages is None \ No newline at end of file + assert messages is None + +# --- Document Management Tests --- + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service correctly + rag_service = RAGService( + vector_store=mock_vector_store_instance, + retrievers=[] + ) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + from unittest.mock import call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise the specific SQLAlchemyError. + mock_db.add.side_effect = SQLAlchemyError("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service correctly + rag_service = RAGService( + vector_store=mock_vector_store_instance, + retrievers=[] + ) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + with pytest.raises(SQLAlchemyError, match="Database error"): + rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + mock_db.add.assert_called_once() + mock_db.commit.assert_not_called() + mock_db.rollback.assert_called_once() + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, model_name="deepseek", messages=[ + models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)), + models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0)) + ]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + prompt = "Test prompt." + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever = rag_service.retrievers[0] + mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."]) + + # --- Act --- + response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_pipeline_instance.forward.assert_called_once_with( + question=prompt, + history=mock_session.messages, + db=mock_db + ) + + assert response_text == "LLM response with context" + assert model_used == "deepseek" + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + prompt = "Test prompt without context." + mock_retriever = rag_service.retrievers[0] + mock_retriever.retrieve_context = AsyncMock(return_value=[]) + + # --- Act --- + response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_pipeline_instance.forward.assert_called_once_with( + question=prompt, + history=mock_session.messages, + db=mock_db + ) + + assert response_text == "LLM response without context" + assert model_used == "deepseek" \ No newline at end of file diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index ae6226b..aeededc 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -1,27 +1,33 @@ +import os from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session -from datetime import datetime +from datetime import datetime # Import datetime for models.Session +# Import the factory function directly to get a fresh app instance for testing from app.app import create_app +# The get_db function is now in app.api.dependencies.py, so we must update the import path. from app.api.dependencies import get_db from app.db import models # Import your SQLAlchemy models -# --- Test Setup --- - -# A mock DB session that can be used across tests +# --- Dependency Override for Testing --- +# This is a mock database session that will be used in our tests. mock_db = MagicMock(spec=Session) + def override_get_db(): - """Dependency override to replace the real database with a mock.""" + """Returns the mock database session for tests.""" try: yield mock_db finally: pass + # --- API Endpoint Tests --- +# We patch the RAGService class itself, as the instance is created inside create_app(). def test_read_root(): """Test the root endpoint to ensure it's running.""" + # Create app and client here to be sure no mocking interferes app = create_app() client = TestClient(app) response = client.get("/") @@ -65,11 +71,12 @@ def test_chat_in_session_success(mock_rag_service_class): """ Test the session-based chat endpoint with a successful, mocked response. + It should default to 'deepseek' if no model is specified. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value # The service now returns a tuple: (answer_text, model_used) - mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "gemini")) + mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek")) app = create_app() app.dependency_overrides[get_db] = override_get_db @@ -81,7 +88,199 @@ # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." - assert response.json()["model_used"] == "gemini" + assert response.json()["model_used"] == "deepseek" + # The fix: Include the default 'model' parameter in the assertion mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, session_id=123, prompt="Hello there" - ) \ No newline at end of file + db=mock_db, session_id=123, prompt="Hello there", model="deepseek" + ) + +@patch('app.app.RAGService') +def test_chat_in_session_with_model_switch(mock_rag_service_class): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client = TestClient(create_app()) # Create client within test to ensure fresh mock + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' + mock_rag_service_instance.chat_with_rag.assert_called_once_with( + db=mock_db, + session_id=42, + prompt="Hello there, Gemini!", + model="gemini" + ) + +@patch('app.app.RAGService') +def test_get_session_messages_success(mock_rag_service_class): + """Tests retrieving the message history for a session.""" + mock_rag_service_instance = mock_rag_service_class.return_value + # Arrange: Mock the service to return a list of message objects + mock_history = [ + models.Message(sender="user", content="Hello", created_at=datetime.now()), + models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) + ] + mock_rag_service_instance.get_message_history.return_value = mock_history + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Act + response = client.get("/sessions/123/messages") + + # Assert + assert response.status_code == 200 + response_data = response.json() + assert response_data["session_id"] == 123 + assert len(response_data["messages"]) == 2 + assert response_data["messages"][0]["sender"] == "user" + assert response_data["messages"][1]["content"] == "Hi there!" + mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123) + +@patch('app.app.RAGService') +def test_get_session_messages_not_found(mock_rag_service_class): + """Tests retrieving messages for a session that does not exist.""" + mock_rag_service_instance = mock_rag_service_class.return_value + # Arrange: Mock the service to return None, indicating the session wasn't found + mock_rag_service_instance.get_message_history.return_value = None + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Act + response = client.get("/sessions/999/messages") + + # Assert + assert response.status_code == 404 + assert response.json()["detail"] == "Session with ID 999 not found." + +@patch('app.app.RAGService') +def test_add_document_success(mock_rag_service_class): + """ + Test the /document endpoint with a successful, mocked RAG service response. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.return_value = 1 + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/documents", json=doc_data) # Changed to /documents as per routes.py + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + + +@patch('app.app.RAGService') +def test_add_document_api_failure(mock_rag_service_class): + """ + Test the /document endpoint when the RAG service encounters an error. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.side_effect = Exception("Service failed") + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/documents", json=doc_data) # Changed to /documents + + assert response.status_code == 500 + assert "An error occurred: Service failed" in response.json()["detail"] + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + +@patch('app.app.RAGService') +def test_get_documents_success(mock_rag_service_class): + """ + Tests the /documents endpoint for successful retrieval of documents. + """ + mock_rag_service_instance = mock_rag_service_class.return_value + mock_docs = [ + models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), + models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) + ] + mock_rag_service_instance.get_all_documents.return_value = mock_docs + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.get("/documents") + assert response.status_code == 200 + assert len(response.json()["documents"]) == 2 + assert response.json()["documents"][0]["title"] == "Doc One" + mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db) + +@patch('app.app.RAGService') +def test_delete_document_success(mock_rag_service_class): + """ + Tests the DELETE /documents/{document_id} endpoint for successful deletion. + """ + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.delete_document.return_value = 42 + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["message"] == "Document deleted successfully" + assert response.json()["document_id"] == 42 + mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42) + +@patch('app.app.RAGService') +def test_delete_document_not_found(mock_rag_service_class): + """ + Tests the DELETE /documents/{document_id} endpoint when the document is not found. + """ + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.delete_document.return_value = None + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.delete("/documents/999") + assert response.status_code == 404 + assert response.json()["detail"] == "Document with ID 999 not found." + mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999) +