diff --git a/.gitignore b/.gitignore index 96a9867..176d2c0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ **.bin **.db ai-hub/data/* +.vscode/ diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py index 2515147..5ff6fb6 100644 --- a/ai-hub/app/api/routes/sessions.py +++ b/ai-hub/app/api/routes/sessions.py @@ -16,7 +16,7 @@ new_session = services.rag_service.create_session( db=db, user_id=request.user_id, - model=request.model + provider_name=request.provider_name ) return new_session except Exception as e: @@ -29,14 +29,14 @@ db: Session = Depends(get_db) ): try: - response_text, model_used = await services.rag_service.chat_with_rag( + response_text, provider_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model, + provider_name=request.provider_name, load_faiss_retriever=request.load_faiss_retriever ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) + return schemas.ChatResponse(answer=response_text, provider_used=provider_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index a1eff32..9014400 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -7,7 +7,7 @@ """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) # The 'model' can now be specified in the request body to switch models mid-conversation. - model: Literal["deepseek", "gemini"] = Field("deepseek") + provider_name: Literal["deepseek", "gemini"] = Field("deepseek") # Add a new optional boolean field to control the retriever load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") @@ -15,7 +15,7 @@ class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" answer: str - model_used: str + provider_used: str # --- Document Schemas --- class DocumentCreate(BaseModel): @@ -47,14 +47,14 @@ class SessionCreate(BaseModel): """Defines the shape for starting a new conversation session.""" user_id: str - model: Literal["deepseek", "gemini"] = "deepseek" + provider_name: Literal["deepseek", "gemini"] = "deepseek" class Session(BaseModel): """Defines the shape of a session object returned by the API.""" id: int user_id: str title: str - model_name: str + provider_name: str created_at: datetime model_config = ConfigDict(from_attributes=True) diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 7e2f243..22ef056 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -1,28 +1,13 @@ import dspy import logging from typing import List, Callable, Optional -from types import SimpleNamespace from sqlalchemy.orm import Session from app.db import models from app.core.retrievers.base_retriever import Retriever -from app.core.providers.base import LLMProvider -class DSPyLLMProvider(dspy.BaseLM): - def __init__(self, provider: LLMProvider, model_name: str, **kwargs): - super().__init__(model=model_name) - self.provider = provider - self.kwargs.update(kwargs) - - async def aforward(self, prompt: str, **kwargs): - if not prompt or not prompt.strip(): - return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) - response_text = await self.provider.generate_response(prompt) - choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) - return SimpleNamespace(choices=[choice]) - - +# --- DSPy Signature Class (No Change) --- class AnswerWithHistory(dspy.Signature): """Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history.""" @@ -32,6 +17,7 @@ answer = dspy.OutputField(desc="A well-formed answer suitable for delivery in an audio play format.") +# --- DSPy RAG Pipeline Class (Updated) --- class DspyRagPipeline(dspy.Module): """ A flexible and extensible DSPy-based RAG pipeline with modular stages. @@ -66,21 +52,18 @@ # Step 2: Format history history_text = self.history_formatter(history) - # Step 3: Build final prompt - instruction = self.generate_answer.signature.__doc__ - full_prompt = self._build_prompt(instruction, context_text, history_text, question) + # Step 3: Generate response using LLM + # With DSPy and LiteLLM, the signature-based generation handles the prompt building. + # You no longer need to manually build the prompt string. + prediction = await self.generate_answer.aforward( + context=context_text, + chat_history=history_text, + question=question + ) - logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") + raw_response = prediction.answer - # Step 4: Generate response using LLM - lm = dspy.settings.lm - if lm is None: - raise RuntimeError("DSPy LM not configured.") - - response_obj = await lm.aforward(prompt=full_prompt) - raw_response = response_obj.choices[0].message.content - - # Step 5: Optional response postprocessing + # Step 4: Optional response postprocessing if self.response_postprocessor: return self.response_postprocessor(raw_response) @@ -97,15 +80,4 @@ for msg in history ) - # Prompt builder - def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: - return ( - f"{instruction.strip()}\n\n" - f"---\n\n" - f"Context:\n{context.strip()}\n\n" - f"---\n\n" - f"Chat History:\n{history.strip()}\n\n" - f"---\n\n" - f"Human: {question.strip()}\n" - f"Assistant:" - ) +# Note: The _build_prompt method is removed as DSPy handles this automatically. \ No newline at end of file diff --git a/ai-hub/app/core/pipelines/file_selector_rag.py b/ai-hub/app/core/pipelines/file_selector_rag.py new file mode 100644 index 0000000..fb28ff6 --- /dev/null +++ b/ai-hub/app/core/pipelines/file_selector_rag.py @@ -0,0 +1,42 @@ +import dspy +from typing import List +from app.db import models + +# Assuming SelectFiles and other necessary imports are defined as in the previous example + +class SelectFiles(dspy.Signature): + """ + Based on the user's question, communication history, and the code folder's file list, identify the files that are most relevant to answer the question. + """ + question = dspy.InputField(desc="The user's current question.") + chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI.") + code_folder_filename_list = dspy.InputField(desc="A list of file names as strings, representing the file structure of the code base.") + answer = dspy.OutputField(format=list, desc="A list of strings containing the names of the most relevant files to examine further.") + +class CodeRagFileSelector(dspy.Module): + """ + A single-step module to select relevant files from a list based on a user question. + """ + def __init__(self): + super().__init__() + self.select_files = dspy.Predict(SelectFiles) + + async def forward(self, question: str, history: List[models.Message], file_list: List[str]) -> List[str]: + # Format history for the signature + history_text = self._default_history_formatter(history) + + # Call the predictor with the necessary inputs + prediction = await self.select_files( + question=question, + chat_history=history_text, + code_folder_filename_list="\n".join(file_list) + ) + + # The output is expected to be a list of strings + return prediction.answer + + def _default_history_formatter(self, history: List[models.Message]) -> str: + return "\n".join( + f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" + for msg in history + ) \ No newline at end of file diff --git a/ai-hub/app/core/pipelines/file_selector_rag_test.py b/ai-hub/app/core/pipelines/file_selector_rag_test.py new file mode 100644 index 0000000..c8f4c5e --- /dev/null +++ b/ai-hub/app/core/pipelines/file_selector_rag_test.py @@ -0,0 +1,85 @@ +import dspy +import asyncio +import json +import os +from typing import List +from app.core.providers.factory import get_llm_provider + +# Assume these are defined elsewhere +class MockMessage: + def __init__(self, sender: str, content: str): + self.sender = sender + self.content = content + +# --- Step 2: Paste your existing DSPy components here --- +class SelectFiles(dspy.Signature): + """ + Based on the user's question, communication history, and the code folder's file list, identify the files that are most relevant to answer the question. + """ + question = dspy.InputField(desc="The user's current question.") + chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI.") + code_folder_filename_list = dspy.InputField(desc="A list of file names as strings, representing the file structure of the code base.") + answer = dspy.OutputField(format=list, desc="A list of strings containing the names of the most relevant files to examine further.") + +class CodeRagFileSelector(dspy.Module): + """ + A single-step module to select relevant files from a list based on a user question. + """ + def __init__(self): + super().__init__() + self.select_files = dspy.Predict(SelectFiles) + + async def forward(self, question: str, history: List[MockMessage], file_list: List[str]) -> List[str]: + # Format history for the signature + history_text = self._default_history_formatter(history) + + # Call the predictor with the necessary inputs + prediction = await self.select_files.acall( + question=question, + chat_history=history_text, + code_folder_filename_list="\n".join(file_list) + ) + + # The output is expected to be a list of strings + # The DSPy Predict method automatically handles the `format=list` and parses the JSON. + return prediction.answer + + def _default_history_formatter(self, history: List[MockMessage]) -> str: + return "\n".join( + f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" + for msg in history + ) + +# --- Step 3: Write the test code --- +async def main(): + dspy.settings.configure(lm=get_llm_provider("gemini","gemini-1.5-flash-latest")) + + # Instantiate the module + file_selector = CodeRagFileSelector() + + # Define sample data + question = "How is the `data` variable initialized in the main application? Also, where are the tests?" + history = [ + MockMessage(sender="user", content="What does the main script do?"), + MockMessage(sender="assistant", content="The main script handles data processing and initializes the primary variables.") + ] + file_list = [ + "main.py", + "utils.py", + "README.md", + "config.yaml", + "tests/test_main.py", + "src/data_handler.py" + ] + + print("Running the file selector with Gemini...") + selected_files = await file_selector(question, history, file_list) + + print("\n--- Test Results ---") + print(f"Question: {question}") + print(f"All files: {file_list}") + print(f"Selected files: {selected_files}") + print("--------------------") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ai-hub/app/core/pipelines/test_gemini.py b/ai-hub/app/core/pipelines/test_gemini.py new file mode 100644 index 0000000..0060b9b --- /dev/null +++ b/ai-hub/app/core/pipelines/test_gemini.py @@ -0,0 +1,10 @@ +import dspy + +lm = dspy.LM('gemini/gemini-2.0-flash') +dspy.configure(lm=lm) + +# Send a simple prompt +response = lm(prompt="hello world") + +# Print the response text +print(response) \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 725f871..4c2664e 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,29 +1,43 @@ from app.config import settings -from .base import LLMProvider, TTSProvider, STTProvider -from .llm.deepseek import DeepSeekProvider -from .llm.gemini import GeminiProvider +from .base import TTSProvider, STTProvider +from .llm.general import GeneralProvider from .tts.gemini import GeminiTTSProvider from .tts.gcloud_tts import GCloudTTSProvider from .stt.gemini import GoogleSTTProvider +from dspy.clients.base_lm import BaseLM from openai import AsyncOpenAI -# --- 1. Initialize API Clients from Central Config --- -deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" +import litellm -# --- 2. The Factory Dictionaries --- + +# --- 1. Initialize API Clients from Central Config --- +# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" + +# # --- 2. The Factory Dictionaries --- _llm_providers = { - "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME, client=deepseek_client), - "gemini": GeminiProvider(api_url=GEMINI_URL) + "deepseek": settings.DEEPSEEK_API_KEY, + "gemini": settings.GEMINI_API_KEY +} + +_llm_models = { + "deepseek": settings.DEEPSEEK_MODEL_NAME, + "gemini": settings.GEMINI_MODEL_NAME } # --- 3. The Factory Functions --- -def get_llm_provider(model_name: str) -> LLMProvider: +def get_llm_provider(provider_name: str, model_name: str = "") -> BaseLM: """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _llm_providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") - return provider + providerKey = _llm_providers.get(provider_name) + if not providerKey: + raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") + modelName = model_name + if modelName == "": + modelName = _llm_models.get(provider_name) + if not modelName: + raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") + + return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key= providerKey) def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider: if provider_name == "google_gemini": @@ -35,4 +49,6 @@ def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": return GoogleSTTProvider(api_key=api_key, model_name=model_name) - raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") + +# async def lite_llm_call(model_name: str, prompt: str) -> str: \ No newline at end of file diff --git a/ai-hub/app/core/providers/llm/README.md b/ai-hub/app/core/providers/llm/README.md new file mode 100644 index 0000000..34f12a8 --- /dev/null +++ b/ai-hub/app/core/providers/llm/README.md @@ -0,0 +1,8 @@ +The new LLM provider configuration is not necessary because we have migrated the setup to LiteLLM. This allows the program to handle requests for all different providers using a single, unified syntax, simplifying the codebase and making it easier to add new models. + + + + + + + diff --git a/ai-hub/app/core/providers/llm/deepseek.py b/ai-hub/app/core/providers/llm/deepseek.py deleted file mode 100644 index 2fde32c..0000000 --- a/ai-hub/app/core/providers/llm/deepseek.py +++ /dev/null @@ -1,22 +0,0 @@ -import logging -from openai import AsyncOpenAI # Use AsyncOpenAI -from typing import final -from app.core.providers.base import LLMProvider -from app.config import settings - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str, client: AsyncOpenAI): # Type hint with AsyncOpenAI - self.model = model_name - self._client = client - - async def generate_response(self, prompt: str) -> str: - messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] - try: - # This await is now correct for AsyncOpenAI - chat_completion = await self._client.chat.completions.create(model=self.model, messages=messages) - return chat_completion.choices[0].message.content - except Exception as e: - logging.error("DeepSeek Provider Error", exc_info=True) - raise \ No newline at end of file diff --git a/ai-hub/app/core/providers/llm/gemini.py b/ai-hub/app/core/providers/llm/gemini.py deleted file mode 100644 index ecbba9c..0000000 --- a/ai-hub/app/core/providers/llm/gemini.py +++ /dev/null @@ -1,26 +0,0 @@ -import httpx -import logging -import json -from typing import final -from app.core.providers.base import LLMProvider -from app.config import settings - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - # Await the async `json` method - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except Exception as e: - logging.error("Gemini Provider Error", exc_info=True) - raise \ No newline at end of file diff --git a/ai-hub/app/core/providers/llm/general.py b/ai-hub/app/core/providers/llm/general.py new file mode 100644 index 0000000..13b961c --- /dev/null +++ b/ai-hub/app/core/providers/llm/general.py @@ -0,0 +1,44 @@ + +import litellm +from dspy.clients.base_lm import BaseLM + +class GeneralProvider(BaseLM): + def __init__(self, model_name: str, api_key: str): + self.model_name = model_name + self.api_key = api_key + # Call the parent constructor + super().__init__(model=model_name) + + def forward(self, prompt=None, messages=None, **kwargs): + """ + Synchronous forward pass using LiteLLM. + """ + messages = messages or [{"role": "user", "content": prompt}] + request = { + "model": self.model_name, + "messages": messages, + "api_key": self.api_key, + **self.kwargs, + **kwargs, + } + try: + return litellm.completion(**request) + except Exception as e: + raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}") + + async def aforward(self, prompt=None, messages=None, **kwargs): + """ + Asynchronous forward pass using LiteLLM. + """ + messages = messages or [{"role": "user", "content": prompt}] + request = { + "model": self.model_name, + "messages": messages, + "api_key": self.api_key, + **self.kwargs, + **kwargs, + } + try: + return await litellm.acompletion(**request) + except Exception as e: + raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index e3973da..4965630 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -9,7 +9,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.core.providers.factory import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +from app.core.pipelines.dspy_rag import DspyRagPipeline class RAGService: """ @@ -21,10 +21,10 @@ self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + def create_session(self, db: Session, user_id: str, provider_name: str) -> models.Session: """Creates a new chat session in the database.""" try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + new_session = models.Session(user_id=user_id, provider_name=provider_name, title=f"New Chat Session") db.add(new_session) db.commit() db.refresh(new_session) @@ -38,7 +38,7 @@ db: Session, session_id: int, prompt: str, - model: str, + provider_name: str, load_faiss_retriever: bool = False ) -> Tuple[str, str]: """ @@ -54,10 +54,9 @@ user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() + db.refresh(user_message) - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) + llm_provider = get_llm_provider(provider_name) current_retrievers = [] if load_faiss_retriever: @@ -68,17 +67,20 @@ rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) + # Use dspy.context to configure the language model for this specific async task + with dspy.context(lm=llm_provider): + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() + db.refresh(assistant_message) - return answer_text, model + return answer_text, provider_name def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index ec4f703..2d3476c 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -25,7 +25,7 @@ # A title for the conversation, which can be generated by the AI. title = Column(String, index=True, nullable=True) # The name of the LLM model used for this session (e.g., "Gemini", "DeepSeek"). - model_name = Column(String, nullable=True) + provider_name = Column(String, nullable=True) # Timestamp for when the session was created. created_at = Column(DateTime, default=datetime.utcnow, nullable=False) # Flag to indicate if the session has been archived or soft-deleted. diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py index 435ce3d..2286daf 100644 --- a/ai-hub/integration_tests/test_sessions_api.py +++ b/ai-hub/integration_tests/test_sessions_api.py @@ -16,7 +16,7 @@ print("\n--- Running test_chat_in_session_lifecycle ---") # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + payload = {"user_id": "integration_tester_lifecycle", "provider_name": "deepseek"} response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] @@ -27,7 +27,7 @@ response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) assert response_1.status_code == 200 assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" + assert response_1.json()["provider_used"] == "deepseek" print("✅ Chat Turn 1 (context) test passed.") # 3. Second chat turn (follow-up) to test conversational memory @@ -35,7 +35,7 @@ response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) assert response_2.status_code == 200 assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" + assert response_2.json()["provider_used"] == "deepseek" print("✅ Chat Turn 2 (follow-up) test passed.") # 4. Cleanup (optional, but good practice if not using a test database that resets) @@ -47,19 +47,19 @@ print("\n--- Running test_chat_with_model_switch ---") # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + payload_gemini = {"prompt": "What is the capital of France?", "provider_name": "gemini"} response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) assert response_gemini.status_code == 200 assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" + assert response_gemini.json()["provider_used"] == "gemini" print("✅ Chat (Model Switch to Gemini) test passed.") # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + payload_deepseek = {"prompt": "What is the largest ocean?", "provider_name": "deepseek"} response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) assert response_deepseek.status_code == 200 assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" + assert response_deepseek.json()["provider_used"] == "deepseek" print("✅ Chat (Model Switch back to DeepSeek) test passed.") @pytest.mark.asyncio @@ -71,14 +71,12 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "provider_name": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: @@ -92,12 +90,22 @@ chat_payload = { "prompt": RAG_PROMPT, "document_id": rag_document_id, - "model": "deepseek", + "provider_name": "deepseek", "load_faiss_retriever": True } chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + # --- MODIFICATION START --- + # If a 500 error occurs, print the detailed response text. + if chat_response.status_code != 200: + print(f"❌ Test Failed! Received status code: {chat_response.status_code}") + print("--- Response Body (for debugging) ---") + print(chat_response.text) + print("---------------------------------------") + assert chat_response.status_code == 200 + # --- MODIFICATION END --- + chat_data = chat_response.json() assert "Jane Doe" in chat_data["answer"] assert "Nexus" in chat_data["answer"] @@ -106,4 +114,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5646e5a..258c374 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -39,7 +39,7 @@ # FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ - -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ + -d '{"user_id": "local_user", "provider_name": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code # Extract body and status code @@ -110,9 +110,9 @@ # Note the use of --argjson to pass the boolean value correctly json_payload=$(jq -n \ --arg prompt "$PROMPT_TEXT" \ - --arg model "$MODEL_TO_USE" \ + --arg provider_name "$MODEL_TO_USE" \ --argjson faiss_flag "$LOAD_FAISS_RETRIEVER" \ - '{"prompt": $prompt, "model": $model, "load_faiss_retriever": $faiss_flag}') + '{"prompt": $prompt, "provider_name": $provider_name, "load_faiss_retriever": $faiss_flag}') echo "Payload: $json_payload" # Optional: for debugging @@ -126,8 +126,8 @@ echo "Server response: $ai_response_json" else ai_answer=$(echo "$ai_response_json" | jq -r '.answer') - model_used=$(echo "$ai_response_json" | jq -r '.model_used') - echo "AI [$model_used] [faiss_enabled: $LOAD_FAISS_RETRIEVER]: $ai_answer" + provider_used=$(echo "$ai_response_json" | jq -r '.provider_used') + echo "AI [$provider_used] [faiss_enabled: $LOAD_FAISS_RETRIEVER]: $ai_answer" fi done diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index e0f94d5..5a2ecb8 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -21,6 +21,7 @@ export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +export LOG_LEVEL="DEBUG" # --- User Interaction --- echo "--- AI Hub Test Runner ---" @@ -60,7 +61,7 @@ # Start the uvicorn server in the background # We bind it to 127.0.0.1 to ensure it's not accessible from outside the local machine. -uvicorn app.main:app --host 127.0.0.1 --port 8000 & +uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload & # Get the Process ID (PID) of the background server SERVER_PID=$! diff --git a/ai-hub/tests/api/routes/test_sessions.py b/ai-hub/tests/api/routes/test_sessions.py index 10334d2..ad23a3a 100644 --- a/ai-hub/tests/api/routes/test_sessions.py +++ b/ai-hub/tests/api/routes/test_sessions.py @@ -8,12 +8,12 @@ mock_session = MagicMock(spec=models.Session) mock_session.id = 1 mock_session.user_id = "test_user" - mock_session.model_name = "gemini" + mock_session.provider_name = "gemini" mock_session.title = "New Chat" mock_session.created_at = datetime.now() mock_services.rag_service.create_session.return_value = mock_session - response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + response = test_client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 @@ -30,13 +30,13 @@ response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + assert response.json() == {"answer": "Mocked response", "provider_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", - model="deepseek", + provider_name="deepseek", load_faiss_retriever=False ) @@ -47,16 +47,16 @@ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"}) assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", - model="gemini", + provider_name="gemini", load_faiss_retriever=False ) @@ -73,13 +73,13 @@ ) assert response.status_code == 200 - assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + assert response.json() == {"answer": "Response with context", "provider_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", - model="deepseek", + provider_name="deepseek", load_faiss_retriever=True ) diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index 4e44c4b..be4a13d 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -1,97 +1,132 @@ import pytest +from typing import List import asyncio from unittest.mock import MagicMock, AsyncMock from sqlalchemy.orm import Session -import dspy +import dspy # <-- Import dspy # Import the pipeline and its new signature from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory from app.db import models # Import your SQLAlchemy models for mocking history from app.core.retrievers.base_retriever import Retriever + +# --- Mock Classes --- + +class MockRetriever(Retriever): + """A mock retriever that returns a predefined list of strings.""" + def __init__(self, name: str, mock_data: List[str]): + self.name = name + self.mock_data = mock_data + + def retrieve_context(self, question: str, db: Session) -> List[str]: + return self.mock_data + +# --- Fixtures --- + @pytest.fixture -def mock_lm_configured(): - """Pytest fixture to mock and configure the dspy language model for a test.""" - mock_lm_instance = MagicMock() - mock_lm_instance.aforward = AsyncMock( - return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="Mocked LLM answer"))]) +def mock_db(): + """A mock SQLAlchemy session.""" + return MagicMock() + +@pytest.fixture(autouse=True) +def mock_dspy_predict_instance(mocker): + """ + Mocks the dspy.Predict class itself to return a mock instance + with a controllable aforward method. + """ + # Create a mock instance of dspy.Predict + mock_predict_instance = MagicMock() + mock_predict_instance.aforward = AsyncMock(return_value=MagicMock(answer="Mocked LLM answer.")) + + # Patch the dspy.Predict class to return our mock instance + mocker.patch('dspy.Predict', return_value=mock_predict_instance) + + return mock_predict_instance + + +# --- Test Cases --- + +# def test_pipeline_initializes_with_defaults(mock_dspy_predict_instance): +# """Test that the pipeline initializes correctly with default processors.""" +# pipeline = DspyRagPipeline(retrievers=[MockRetriever("test", ["test context"])]) +# assert pipeline.retrievers is not None +# assert pipeline.context_postprocessor is pipeline._default_context_postprocessor +# assert pipeline.history_formatter is pipeline._default_history_formatter +# assert pipeline.response_postprocessor is None +# # Verify that dspy.Predict was instantiated once +# dspy.Predict.assert_called_once() +# # Verify the aforward method was not called yet +# mock_dspy_predict_instance.aforward.assert_not_called() + +@pytest.mark.asyncio +async def test_forward_pass_with_defaults(mock_db, mock_dspy_predict_instance): + """Test a successful forward pass using default processors.""" + retriever = MockRetriever("test_retriever", ["Context 1.", "Context 2."]) + pipeline = DspyRagPipeline(retrievers=[retriever]) + + question = "What is the capital of France?" + history = [models.Message(sender="user", content="Hello there."), models.Message(sender="assistant", content="Hi.")] + + response = await pipeline.forward(question, history, mock_db) + + expected_context = "Context 1.\n\nContext 2." + expected_history = "Human: Hello there.\nAssistant: Hi." + + mock_dspy_predict_instance.aforward.assert_called_once_with( + context=expected_context, + chat_history=expected_history, + question=question ) - original_lm = dspy.settings.lm - dspy.configure(lm=mock_lm_instance) - yield mock_lm_instance - dspy.configure(lm=original_lm) + assert response == "Mocked LLM answer." -def test_pipeline_with_context_and_history(mock_lm_configured): - """ - Tests the pipeline's prompt construction when it has both retrieved context - and a conversation history. - """ - # --- Arrange --- - mock_retriever = MagicMock(spec=Retriever) - mock_retriever.retrieve_context.return_value = ["Context chunk 1."] - mock_db = MagicMock(spec=Session) +@pytest.mark.asyncio +async def test_forward_with_custom_processors(mock_db, mock_dspy_predict_instance): + """Test that custom processors are used correctly.""" - # Create a mock conversation history - mock_history = [ - models.Message(sender="user", content="What is the capital of France?"), - models.Message(sender="assistant", content="The capital of France is Paris.") - ] - - pipeline = DspyRagPipeline(retrievers=[mock_retriever]) - question = "What is its population?" + def custom_context_processor(contexts: List[str]) -> str: + return "CUSTOM_CONTEXT: " + " | ".join(contexts) - # --- Act --- - response = asyncio.run(pipeline.forward(question=question, history=mock_history, db=mock_db)) + def custom_history_formatter(history: List[models.Message]) -> str: + return " " + " ".join([m.content for m in history]) + " " - # --- Assert --- - mock_retriever.retrieve_context.assert_called_once_with(question, mock_db) + def custom_response_processor(response: str) -> str: + return f"FINAL: {response.upper()}" - # Assert that the final prompt includes all parts: context, history, and the new question - instruction = AnswerWithHistory.__doc__ - expected_history_str = "Human: What is the capital of France?\nAssistant: The capital of France is Paris." - - expected_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context:\nContext chunk 1.\n\n" - f"---\n\n" - f"Chat History:\n{expected_history_str}\n\n" - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" + retriever = MockRetriever("test_retriever", ["Context A", "Context B"]) + pipeline = DspyRagPipeline( + retrievers=[retriever], + context_postprocessor=custom_context_processor, + history_formatter=custom_history_formatter, + response_postprocessor=custom_response_processor ) + + question = "Custom question?" + history = [models.Message(sender="user", content="User message.")] - mock_lm_configured.aforward.assert_called_once_with(prompt=expected_prompt) - assert response == "Mocked LLM answer" + response = await pipeline.forward(question, history, mock_db) -def test_pipeline_with_no_context_or_history(mock_lm_configured): - """ - Tests the pipeline's prompt construction for a new conversation where no - relevant documents are found. - """ - # --- Arrange --- - mock_retriever = MagicMock(spec=Retriever) - mock_retriever.retrieve_context.return_value = [] # No context found - mock_db = MagicMock(spec=Session) - - pipeline = DspyRagPipeline(retrievers=[mock_retriever]) - question = "First question" - empty_history = [] - - # --- Act --- - asyncio.run(pipeline.forward(question=question, history=empty_history, db=mock_db)) - - # --- Assert --- - # Check that the prompt was constructed with placeholder context and empty history - instruction = AnswerWithHistory.__doc__ - expected_prompt = ( - f"{instruction}\n\n" - f"---\n\n" - f"Context:\nNo context provided.\n\n" - f"---\n\n" - f"Chat History:\n\n\n" # History string is empty - f"---\n\n" - f"Human: {question}\n" - f"Assistant:" + mock_dspy_predict_instance.aforward.assert_called_once_with( + context="CUSTOM_CONTEXT: Context A | Context B", + chat_history=" User message. ", + question="Custom question?" ) - mock_lm_configured.aforward.assert_called_once_with(prompt=expected_prompt) \ No newline at end of file + assert response == "FINAL: MOCKED LLM ANSWER." + +@pytest.mark.asyncio +async def test_empty_context_and_history_handling(mock_db, mock_dspy_predict_instance): + """Test behavior with empty context and chat history.""" + retriever = MockRetriever("empty_retriever", []) + pipeline = DspyRagPipeline(retrievers=[retriever]) + + question = "No context question." + history = [] + + response = await pipeline.forward(question, history, mock_db) + + mock_dspy_predict_instance.aforward.assert_called_once_with( + context="No context provided.", + chat_history="", + question=question + ) + assert response == "Mocked LLM answer." \ No newline at end of file diff --git a/ai-hub/tests/core/providers/llm/test_llm_general.py b/ai-hub/tests/core/providers/llm/test_llm_general.py new file mode 100644 index 0000000..978a4dd --- /dev/null +++ b/ai-hub/tests/core/providers/llm/test_llm_general.py @@ -0,0 +1,73 @@ +import unittest +import asyncio +from unittest.mock import patch, AsyncMock, MagicMock + +# Import the class to be tested +from app.core.providers.llm.general import GeneralProvider + +# A mock class to simulate the LiteLLM completion response object +# It mimics the nested structure: response.choices[0].message.content +class MockResponse: + def __init__(self, content): + self.choices = [self.MockChoice(content)] + + class MockChoice: + def __init__(self, content): + self.message = self.MockMessage(content) + + class MockMessage: + def __init__(self, content): + self.content = content + +class TestGeneralProvider(unittest.TestCase): + def setUp(self): + """Set up a new GeneralProvider instance before each test.""" + self.model_name = "gemini/gemini-1.5-pro-latest" + self.api_key = "test-api-key" + self.provider = GeneralProvider(self.model_name, self.api_key) + self.prompt = "What is the capital of France?" + + @patch('litellm.acompletion', new_callable=AsyncMock) + def test_generate_response_success(self, mock_litellm_completion: AsyncMock): + """ + Test that aforward returns the correct content on a successful call. + """ + async def run_test(): + # Arrange: Configure the mock to return a successful response + expected_content = "The capital of France is Paris." + mock_litellm_completion.return_value = MockResponse(expected_content) + + # Act: Call the async method + response_obj = await self.provider.aforward(prompt=self.prompt) + response = response_obj.choices[0].message.content + + # Assert: Check the returned value and the mock's call arguments + self.assertEqual(response, expected_content) + + # Assert that the mock was called with the correct parameters, including defaults from BaseLM + mock_litellm_completion.assert_called_once_with( + model=self.model_name, + messages=[{"role": "user", "content": self.prompt}], + api_key=self.api_key, + temperature=0.0, # <-- Add this line + max_tokens=1000 # <-- Add this line + ) + + # Run the async test function + asyncio.run(run_test()) + + @patch('litellm.acompletion', new_callable=AsyncMock) + def test_generate_response_failure(self, mock_litellm_completion: AsyncMock): + """ + Test that aforward raises a RuntimeError on an API failure. + """ + async def run_test(): + # Arrange: Configure the mock to raise an exception + mock_litellm_completion.side_effect = Exception("API connection failed.") + + # Act & Assert: Use assertRaises to check for the RuntimeError + with self.assertRaisesRegex(RuntimeError, f"Failed to get response from LiteLLM for model '{self.model_name}': API connection failed."): + await self.provider.aforward(prompt=self.prompt) + + # Run the async test function + asyncio.run(run_test()) diff --git a/ai-hub/tests/core/providers/llm/test_llm_providers.py b/ai-hub/tests/core/providers/llm/test_llm_providers.py index 38ef108..cef3808 100644 --- a/ai-hub/tests/core/providers/llm/test_llm_providers.py +++ b/ai-hub/tests/core/providers/llm/test_llm_providers.py @@ -1,81 +1,81 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock +# import pytest +# from unittest.mock import AsyncMock, patch, MagicMock -from app.core.providers.base import LLMProvider -from app.core.providers.llm.deepseek import DeepSeekProvider -from app.core.providers.llm.gemini import GeminiProvider -from openai import OpenAI -import httpx +# from app.core.providers.base import LLMProvider +# from app.core.providers.llm.deepseek import DeepSeekProvider +# from app.core.providers.llm.gemini import GeminiProvider +# from openai import OpenAI +# import httpx -# --- Fixtures --- +# # --- Fixtures --- -@pytest.fixture -def mock_deepseek_client(): - """Provides a mock OpenAI client with a mocked chat.completions.create method.""" - mock_client = MagicMock(spec=OpenAI) +# @pytest.fixture +# def mock_deepseek_client(): +# """Provides a mock OpenAI client with a mocked chat.completions.create method.""" +# mock_client = MagicMock(spec=OpenAI) - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "This is a mocked DeepSeek response." +# mock_response = MagicMock() +# mock_response.choices = [MagicMock()] +# mock_response.choices[0].message.content = "This is a mocked DeepSeek response." - mock_client.chat.completions.create = AsyncMock(return_value=mock_response) +# mock_client.chat.completions.create = AsyncMock(return_value=mock_response) - return mock_client +# return mock_client -# --- DeepSeek Provider Tests --- +# # --- DeepSeek Provider Tests --- -@pytest.mark.asyncio -async def test_deepseek_provider_generates_response(mock_deepseek_client): - """Tests that DeepSeekProvider returns the expected response.""" - provider = DeepSeekProvider(model_name="deepseek-chat", client=mock_deepseek_client) +# @pytest.mark.asyncio +# async def test_deepseek_provider_generates_response(mock_deepseek_client): +# """Tests that DeepSeekProvider returns the expected response.""" +# provider = DeepSeekProvider(model_name="deepseek-chat", client=mock_deepseek_client) - assert isinstance(provider, LLMProvider) +# assert isinstance(provider, LLMProvider) - response = await provider.generate_response("Test prompt") +# response = await provider.generate_response("Test prompt") - mock_deepseek_client.chat.completions.create.assert_awaited_with( - model="deepseek-chat", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Test prompt"} - ] - ) - assert response == "This is a mocked DeepSeek response." +# mock_deepseek_client.chat.completions.create.assert_awaited_with( +# model="deepseek-chat", +# messages=[ +# {"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": "Test prompt"} +# ] +# ) +# assert response == "This is a mocked DeepSeek response." -# --- Gemini Provider Tests --- -@pytest.mark.asyncio -async def test_gemini_provider_generates_response(mocker): - """Tests that GeminiProvider returns the expected response.""" - # Mock the response object from the post call - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json = MagicMock(return_value={ - 'candidates': [{'content': {'parts': [{'text': 'This is a mocked Gemini response.'}]}}] - }) +# # --- Gemini Provider Tests --- +# @pytest.mark.asyncio +# async def test_gemini_provider_generates_response(mocker): +# """Tests that GeminiProvider returns the expected response.""" +# # Mock the response object from the post call +# mock_response = MagicMock() +# mock_response.raise_for_status.return_value = None +# mock_response.json = MagicMock(return_value={ +# 'candidates': [{'content': {'parts': [{'text': 'This is a mocked Gemini response.'}]}}] +# }) - # Create a mock for the AsyncClient instance - mock_client_instance = MagicMock(spec=httpx.AsyncClient) - mock_client_instance.post = AsyncMock(return_value=mock_response) +# # Create a mock for the AsyncClient instance +# mock_client_instance = MagicMock(spec=httpx.AsyncClient) +# mock_client_instance.post = AsyncMock(return_value=mock_response) - # Create a mock for the AsyncClient class itself - mock_async_client_class = MagicMock(spec=httpx.AsyncClient) +# # Create a mock for the AsyncClient class itself +# mock_async_client_class = MagicMock(spec=httpx.AsyncClient) - # Mock the behavior of `async with` - mock_async_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) - mock_async_client_class.return_value.__aexit__ = AsyncMock(return_value=None) +# # Mock the behavior of `async with` +# mock_async_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) +# mock_async_client_class.return_value.__aexit__ = AsyncMock(return_value=None) - # Patch the AsyncClient class - mocker.patch('httpx.AsyncClient', new=mock_async_client_class) +# # Patch the AsyncClient class +# mocker.patch('httpx.AsyncClient', new=mock_async_client_class) - provider = GeminiProvider(api_url="http://mock-gemini.com") - assert isinstance(provider, LLMProvider) +# provider = GeminiProvider(api_url="http://mock-gemini.com") +# assert isinstance(provider, LLMProvider) - response = await provider.generate_response("Test prompt") +# response = await provider.generate_response("Test prompt") - mock_client_instance.post.assert_awaited_with( - "http://mock-gemini.com", - json={"contents": [{"parts": [{"text": "Test prompt"}]}]}, - headers={"Content-Type": "application/json"} - ) +# mock_client_instance.post.assert_awaited_with( +# "http://mock-gemini.com", +# json={"contents": [{"parts": [{"text": "Test prompt"}]}]}, +# headers={"Content-Type": "application/json"} +# ) - assert response == "This is a mocked Gemini response." +# assert response == "This is a mocked Gemini response." diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index e00e71b..6caf8d8 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -1,21 +1,20 @@ import pytest from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider -from app.core.providers.llm.deepseek import DeepSeekProvider -from app.core.providers.llm.gemini import GeminiProvider +from app.core.providers.llm.general import GeneralProvider from app.core.providers.tts.gemini import GeminiTTSProvider from app.core.providers.stt.gemini import GoogleSTTProvider # --- Existing Tests for LLM Provider --- def test_get_llm_provider_returns_deepseek_provider(): - """Tests that the factory returns a DeepSeekProvider instance.""" + """Tests that the factory returns a GeneralProvider instance for 'deepseek'.""" provider = get_llm_provider("deepseek") - assert isinstance(provider, DeepSeekProvider) + assert isinstance(provider, GeneralProvider) def test_get_llm_provider_returns_gemini_provider(): - """Tests that the factory returns a GeminiProvider instance.""" + """Tests that the factory returns a GeneralProvider instance for 'gemini'.""" provider = get_llm_provider("gemini") - assert isinstance(provider, GeminiProvider) + assert isinstance(provider, GeneralProvider) def test_get_llm_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported provider name.""" diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py index 06a133f..84ee258 100644 --- a/ai-hub/tests/core/services/test_rag.py +++ b/ai-hub/tests/core/services/test_rag.py @@ -39,13 +39,13 @@ """Tests that the create_session method correctly creates a new session.""" mock_db = MagicMock(spec=Session) - rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + rag_service.create_session(db=mock_db, user_id="test_user", provider_name="gemini") mock_db.add.assert_called_once() added_object = mock_db.add.call_args[0][0] assert isinstance(added_object, models.Session) assert added_object.user_id == "test_user" - assert added_object.model_name == "gemini" + assert added_object.provider_name == "gemini" @patch('app.core.services.rag.get_llm_provider') @patch('app.core.services.rag.DspyRagPipeline') @@ -57,7 +57,7 @@ """ # --- Arrange --- mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=42, model_name="deepseek", messages=[]) + mock_session = models.Session(id=42, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) @@ -67,12 +67,12 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - answer, model_name = asyncio.run( + answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=42, prompt="Test prompt", - model="deepseek", + provider_name="deepseek", load_faiss_retriever=False ) ) @@ -91,7 +91,7 @@ ) assert answer == "Final RAG response" - assert model_name == "deepseek" + assert provider_name == "deepseek" def test_chat_with_rag_model_switch(rag_service: RAGService): """ @@ -100,7 +100,7 @@ """ # --- Arrange --- mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=43, model_name="deepseek", messages=[]) + mock_session = models.Session(id=43, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ @@ -114,12 +114,12 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - answer, model_name = asyncio.run( + answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=43, prompt="Test prompt for Gemini", - model="gemini", + provider_name="gemini", load_faiss_retriever=False ) ) @@ -138,7 +138,7 @@ ) assert answer == "Final RAG response from Gemini" - assert model_name == "gemini" + assert provider_name == "gemini" def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): @@ -148,7 +148,7 @@ """ # --- Arrange --- mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_session = models.Session(id=44, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ @@ -162,12 +162,12 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - answer, model_name = asyncio.run( + answer, provider_name = asyncio.run( rag_service.chat_with_rag( db=mock_db, session_id=44, prompt="Test prompt with FAISS", - model="deepseek", + provider_name="deepseek", load_faiss_retriever=True ) ) @@ -183,7 +183,7 @@ ) assert answer == "Response with FAISS context" - assert model_name == "deepseek" + assert provider_name == "deepseek" def test_get_message_history_success(rag_service: RAGService): diff --git a/ai-hub/tests/db/test_models.py b/ai-hub/tests/db/test_models.py index c2e928b..20173e9 100644 --- a/ai-hub/tests/db/test_models.py +++ b/ai-hub/tests/db/test_models.py @@ -66,7 +66,7 @@ Tests the creation and retrieval of a Session object. """ # Create a new session object - new_session = Session(user_id="test-user-123", title="Test Session", model_name="gemini") + new_session = Session(user_id="test-user-123", title="Test Session", provider_name="gemini") # Add to session and commit to the database db_session.add(new_session) @@ -80,7 +80,7 @@ assert retrieved_session is not None assert retrieved_session.user_id == "test-user-123" assert retrieved_session.title == "Test Session" - assert retrieved_session.model_name == "gemini" + assert retrieved_session.provider_name == "gemini" def test_create_message_with_session_relationship(db_session): diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index c0bb4b4..1870907 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -73,7 +73,7 @@ mock_session_obj = models.Session( id=1, user_id="test_user", - model_name="gemini", + provider_name="gemini", title="New Chat Session", created_at=datetime.now() ) @@ -84,7 +84,7 @@ client = TestClient(app) # Act - response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + response = client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"}) # Assert assert response.status_code == 200 @@ -92,7 +92,7 @@ assert response_data["id"] == 1 assert response_data["user_id"] == "test_user" mock_services.rag_service.create_session.assert_called_once_with( - db=mock_db, user_id="test_user", model="gemini" + db=mock_db, user_id="test_user", provider_name="gemini" ) @patch('app.app.ServiceContainer') @@ -126,9 +126,9 @@ # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." - assert response.json()["model_used"] == "deepseek" + assert response.json()["provider_used"] == "deepseek" mock_services.rag_service.chat_with_rag.assert_called_once_with( - db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False + db=mock_db, session_id=123, prompt="Hello there", provider_name="deepseek", load_faiss_retriever=False ) @patch('app.app.ServiceContainer') @@ -155,16 +155,16 @@ app.dependency_overrides[get_db] = override_get_db client = TestClient(app) - response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"}) # Assert assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_db, session_id=42, prompt="Hello there, Gemini!", - model="gemini", + provider_name="gemini", load_faiss_retriever=False ) diff --git a/ui/client-app/src/components/Controls.js b/ui/client-app/src/components/Controls.js index 7400067..23d960b 100644 --- a/ui/client-app/src/components/Controls.js +++ b/ui/client-app/src/components/Controls.js @@ -1,6 +1,6 @@ // src/components/Controls.js import React from "react"; -import { FaMicrophone, FaRegStopCircle, FaCog } from "react-icons/fa"; +import { FaMicrophone, FaRegStopCircle } from "react-icons/fa"; const Controls = ({ status, @@ -37,7 +37,7 @@
-
diff --git a/ui/client-app/src/services/apiService.js b/ui/client-app/src/services/apiService.js index ed4c4ec..315acbf 100644 --- a/ui/client-app/src/services/apiService.js +++ b/ui/client-app/src/services/apiService.js @@ -57,7 +57,7 @@ const response = await fetch(SESSIONS_CHAT_ENDPOINT(sessionId), { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ prompt: prompt, model: "gemini" }), + body: JSON.stringify({ prompt: prompt, provider_name: "gemini" }), }); if (!response.ok) { throw new Error("LLM API failed");