diff --git a/.gitignore b/.gitignore
index 96a9867..176d2c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,3 +6,4 @@
**.bin
**.db
ai-hub/data/*
+.vscode/
diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py
index 2515147..5ff6fb6 100644
--- a/ai-hub/app/api/routes/sessions.py
+++ b/ai-hub/app/api/routes/sessions.py
@@ -16,7 +16,7 @@
new_session = services.rag_service.create_session(
db=db,
user_id=request.user_id,
- model=request.model
+ provider_name=request.provider_name
)
return new_session
except Exception as e:
@@ -29,14 +29,14 @@
db: Session = Depends(get_db)
):
try:
- response_text, model_used = await services.rag_service.chat_with_rag(
+ response_text, provider_used = await services.rag_service.chat_with_rag(
db=db,
session_id=session_id,
prompt=request.prompt,
- model=request.model,
+ provider_name=request.provider_name,
load_faiss_retriever=request.load_faiss_retriever
)
- return schemas.ChatResponse(answer=response_text, model_used=model_used)
+ return schemas.ChatResponse(answer=response_text, provider_used=provider_used)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")
diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py
index a1eff32..9014400 100644
--- a/ai-hub/app/api/schemas.py
+++ b/ai-hub/app/api/schemas.py
@@ -7,7 +7,7 @@
"""Defines the shape of a request to the /chat endpoint."""
prompt: str = Field(..., min_length=1)
# The 'model' can now be specified in the request body to switch models mid-conversation.
- model: Literal["deepseek", "gemini"] = Field("deepseek")
+ provider_name: Literal["deepseek", "gemini"] = Field("deepseek")
# Add a new optional boolean field to control the retriever
load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.")
@@ -15,7 +15,7 @@
class ChatResponse(BaseModel):
"""Defines the shape of a successful response from the /chat endpoint."""
answer: str
- model_used: str
+ provider_used: str
# --- Document Schemas ---
class DocumentCreate(BaseModel):
@@ -47,14 +47,14 @@
class SessionCreate(BaseModel):
"""Defines the shape for starting a new conversation session."""
user_id: str
- model: Literal["deepseek", "gemini"] = "deepseek"
+ provider_name: Literal["deepseek", "gemini"] = "deepseek"
class Session(BaseModel):
"""Defines the shape of a session object returned by the API."""
id: int
user_id: str
title: str
- model_name: str
+ provider_name: str
created_at: datetime
model_config = ConfigDict(from_attributes=True)
diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py
index 7e2f243..22ef056 100644
--- a/ai-hub/app/core/pipelines/dspy_rag.py
+++ b/ai-hub/app/core/pipelines/dspy_rag.py
@@ -1,28 +1,13 @@
import dspy
import logging
from typing import List, Callable, Optional
-from types import SimpleNamespace
from sqlalchemy.orm import Session
from app.db import models
from app.core.retrievers.base_retriever import Retriever
-from app.core.providers.base import LLMProvider
-class DSPyLLMProvider(dspy.BaseLM):
- def __init__(self, provider: LLMProvider, model_name: str, **kwargs):
- super().__init__(model=model_name)
- self.provider = provider
- self.kwargs.update(kwargs)
-
- async def aforward(self, prompt: str, **kwargs):
- if not prompt or not prompt.strip():
- return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))])
- response_text = await self.provider.generate_response(prompt)
- choice = SimpleNamespace(message=SimpleNamespace(content=response_text))
- return SimpleNamespace(choices=[choice])
-
-
+# --- DSPy Signature Class (No Change) ---
class AnswerWithHistory(dspy.Signature):
"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history."""
@@ -32,6 +17,7 @@
answer = dspy.OutputField(desc="A well-formed answer suitable for delivery in an audio play format.")
+# --- DSPy RAG Pipeline Class (Updated) ---
class DspyRagPipeline(dspy.Module):
"""
A flexible and extensible DSPy-based RAG pipeline with modular stages.
@@ -66,21 +52,18 @@
# Step 2: Format history
history_text = self.history_formatter(history)
- # Step 3: Build final prompt
- instruction = self.generate_answer.signature.__doc__
- full_prompt = self._build_prompt(instruction, context_text, history_text, question)
+ # Step 3: Generate response using LLM
+ # With DSPy and LiteLLM, the signature-based generation handles the prompt building.
+ # You no longer need to manually build the prompt string.
+ prediction = await self.generate_answer.aforward(
+ context=context_text,
+ chat_history=history_text,
+ question=question
+ )
- logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}")
+ raw_response = prediction.answer
- # Step 4: Generate response using LLM
- lm = dspy.settings.lm
- if lm is None:
- raise RuntimeError("DSPy LM not configured.")
-
- response_obj = await lm.aforward(prompt=full_prompt)
- raw_response = response_obj.choices[0].message.content
-
- # Step 5: Optional response postprocessing
+ # Step 4: Optional response postprocessing
if self.response_postprocessor:
return self.response_postprocessor(raw_response)
@@ -97,15 +80,4 @@
for msg in history
)
- # Prompt builder
- def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str:
- return (
- f"{instruction.strip()}\n\n"
- f"---\n\n"
- f"Context:\n{context.strip()}\n\n"
- f"---\n\n"
- f"Chat History:\n{history.strip()}\n\n"
- f"---\n\n"
- f"Human: {question.strip()}\n"
- f"Assistant:"
- )
+# Note: The _build_prompt method is removed as DSPy handles this automatically.
\ No newline at end of file
diff --git a/ai-hub/app/core/pipelines/file_selector_rag.py b/ai-hub/app/core/pipelines/file_selector_rag.py
new file mode 100644
index 0000000..fb28ff6
--- /dev/null
+++ b/ai-hub/app/core/pipelines/file_selector_rag.py
@@ -0,0 +1,42 @@
+import dspy
+from typing import List
+from app.db import models
+
+# Assuming SelectFiles and other necessary imports are defined as in the previous example
+
+class SelectFiles(dspy.Signature):
+ """
+ Based on the user's question, communication history, and the code folder's file list, identify the files that are most relevant to answer the question.
+ """
+ question = dspy.InputField(desc="The user's current question.")
+ chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI.")
+ code_folder_filename_list = dspy.InputField(desc="A list of file names as strings, representing the file structure of the code base.")
+ answer = dspy.OutputField(format=list, desc="A list of strings containing the names of the most relevant files to examine further.")
+
+class CodeRagFileSelector(dspy.Module):
+ """
+ A single-step module to select relevant files from a list based on a user question.
+ """
+ def __init__(self):
+ super().__init__()
+ self.select_files = dspy.Predict(SelectFiles)
+
+ async def forward(self, question: str, history: List[models.Message], file_list: List[str]) -> List[str]:
+ # Format history for the signature
+ history_text = self._default_history_formatter(history)
+
+ # Call the predictor with the necessary inputs
+ prediction = await self.select_files(
+ question=question,
+ chat_history=history_text,
+ code_folder_filename_list="\n".join(file_list)
+ )
+
+ # The output is expected to be a list of strings
+ return prediction.answer
+
+ def _default_history_formatter(self, history: List[models.Message]) -> str:
+ return "\n".join(
+ f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
+ for msg in history
+ )
\ No newline at end of file
diff --git a/ai-hub/app/core/pipelines/file_selector_rag_test.py b/ai-hub/app/core/pipelines/file_selector_rag_test.py
new file mode 100644
index 0000000..c8f4c5e
--- /dev/null
+++ b/ai-hub/app/core/pipelines/file_selector_rag_test.py
@@ -0,0 +1,85 @@
+import dspy
+import asyncio
+import json
+import os
+from typing import List
+from app.core.providers.factory import get_llm_provider
+
+# Assume these are defined elsewhere
+class MockMessage:
+ def __init__(self, sender: str, content: str):
+ self.sender = sender
+ self.content = content
+
+# --- Step 2: Paste your existing DSPy components here ---
+class SelectFiles(dspy.Signature):
+ """
+ Based on the user's question, communication history, and the code folder's file list, identify the files that are most relevant to answer the question.
+ """
+ question = dspy.InputField(desc="The user's current question.")
+ chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI.")
+ code_folder_filename_list = dspy.InputField(desc="A list of file names as strings, representing the file structure of the code base.")
+ answer = dspy.OutputField(format=list, desc="A list of strings containing the names of the most relevant files to examine further.")
+
+class CodeRagFileSelector(dspy.Module):
+ """
+ A single-step module to select relevant files from a list based on a user question.
+ """
+ def __init__(self):
+ super().__init__()
+ self.select_files = dspy.Predict(SelectFiles)
+
+ async def forward(self, question: str, history: List[MockMessage], file_list: List[str]) -> List[str]:
+ # Format history for the signature
+ history_text = self._default_history_formatter(history)
+
+ # Call the predictor with the necessary inputs
+ prediction = await self.select_files.acall(
+ question=question,
+ chat_history=history_text,
+ code_folder_filename_list="\n".join(file_list)
+ )
+
+ # The output is expected to be a list of strings
+ # The DSPy Predict method automatically handles the `format=list` and parses the JSON.
+ return prediction.answer
+
+ def _default_history_formatter(self, history: List[MockMessage]) -> str:
+ return "\n".join(
+ f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
+ for msg in history
+ )
+
+# --- Step 3: Write the test code ---
+async def main():
+ dspy.settings.configure(lm=get_llm_provider("gemini","gemini-1.5-flash-latest"))
+
+ # Instantiate the module
+ file_selector = CodeRagFileSelector()
+
+ # Define sample data
+ question = "How is the `data` variable initialized in the main application? Also, where are the tests?"
+ history = [
+ MockMessage(sender="user", content="What does the main script do?"),
+ MockMessage(sender="assistant", content="The main script handles data processing and initializes the primary variables.")
+ ]
+ file_list = [
+ "main.py",
+ "utils.py",
+ "README.md",
+ "config.yaml",
+ "tests/test_main.py",
+ "src/data_handler.py"
+ ]
+
+ print("Running the file selector with Gemini...")
+ selected_files = await file_selector(question, history, file_list)
+
+ print("\n--- Test Results ---")
+ print(f"Question: {question}")
+ print(f"All files: {file_list}")
+ print(f"Selected files: {selected_files}")
+ print("--------------------")
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/ai-hub/app/core/pipelines/test_gemini.py b/ai-hub/app/core/pipelines/test_gemini.py
new file mode 100644
index 0000000..0060b9b
--- /dev/null
+++ b/ai-hub/app/core/pipelines/test_gemini.py
@@ -0,0 +1,10 @@
+import dspy
+
+lm = dspy.LM('gemini/gemini-2.0-flash')
+dspy.configure(lm=lm)
+
+# Send a simple prompt
+response = lm(prompt="hello world")
+
+# Print the response text
+print(response)
\ No newline at end of file
diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py
index 725f871..4c2664e 100644
--- a/ai-hub/app/core/providers/factory.py
+++ b/ai-hub/app/core/providers/factory.py
@@ -1,29 +1,43 @@
from app.config import settings
-from .base import LLMProvider, TTSProvider, STTProvider
-from .llm.deepseek import DeepSeekProvider
-from .llm.gemini import GeminiProvider
+from .base import TTSProvider, STTProvider
+from .llm.general import GeneralProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
+from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI
-# --- 1. Initialize API Clients from Central Config ---
-deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
-GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"
+import litellm
-# --- 2. The Factory Dictionaries ---
+
+# --- 1. Initialize API Clients from Central Config ---
+# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
+# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"
+
+# # --- 2. The Factory Dictionaries ---
_llm_providers = {
- "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME, client=deepseek_client),
- "gemini": GeminiProvider(api_url=GEMINI_URL)
+ "deepseek": settings.DEEPSEEK_API_KEY,
+ "gemini": settings.GEMINI_API_KEY
+}
+
+_llm_models = {
+ "deepseek": settings.DEEPSEEK_MODEL_NAME,
+ "gemini": settings.GEMINI_MODEL_NAME
}
# --- 3. The Factory Functions ---
-def get_llm_provider(model_name: str) -> LLMProvider:
+def get_llm_provider(provider_name: str, model_name: str = "") -> BaseLM:
"""Factory function to get the appropriate, pre-configured LLM provider."""
- provider = _llm_providers.get(model_name)
- if not provider:
- raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}")
- return provider
+ providerKey = _llm_providers.get(provider_name)
+ if not providerKey:
+ raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
+ modelName = model_name
+ if modelName == "":
+ modelName = _llm_models.get(provider_name)
+ if not modelName:
+ raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
+
+ return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key= providerKey)
def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider:
if provider_name == "google_gemini":
@@ -35,4 +49,6 @@
def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider:
if provider_name == "google_gemini":
return GoogleSTTProvider(api_key=api_key, model_name=model_name)
- raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")
\ No newline at end of file
+ raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")
+
+# async def lite_llm_call(model_name: str, prompt: str) -> str:
\ No newline at end of file
diff --git a/ai-hub/app/core/providers/llm/README.md b/ai-hub/app/core/providers/llm/README.md
new file mode 100644
index 0000000..34f12a8
--- /dev/null
+++ b/ai-hub/app/core/providers/llm/README.md
@@ -0,0 +1,8 @@
+The new LLM provider configuration is not necessary because we have migrated the setup to LiteLLM. This allows the program to handle requests for all different providers using a single, unified syntax, simplifying the codebase and making it easier to add new models.
+
+
+
+
+
+
+
diff --git a/ai-hub/app/core/providers/llm/deepseek.py b/ai-hub/app/core/providers/llm/deepseek.py
deleted file mode 100644
index 2fde32c..0000000
--- a/ai-hub/app/core/providers/llm/deepseek.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import logging
-from openai import AsyncOpenAI # Use AsyncOpenAI
-from typing import final
-from app.core.providers.base import LLMProvider
-from app.config import settings
-
-@final
-class DeepSeekProvider(LLMProvider):
- """Provider for the DeepSeek API."""
- def __init__(self, model_name: str, client: AsyncOpenAI): # Type hint with AsyncOpenAI
- self.model = model_name
- self._client = client
-
- async def generate_response(self, prompt: str) -> str:
- messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
- try:
- # This await is now correct for AsyncOpenAI
- chat_completion = await self._client.chat.completions.create(model=self.model, messages=messages)
- return chat_completion.choices[0].message.content
- except Exception as e:
- logging.error("DeepSeek Provider Error", exc_info=True)
- raise
\ No newline at end of file
diff --git a/ai-hub/app/core/providers/llm/gemini.py b/ai-hub/app/core/providers/llm/gemini.py
deleted file mode 100644
index ecbba9c..0000000
--- a/ai-hub/app/core/providers/llm/gemini.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import httpx
-import logging
-import json
-from typing import final
-from app.core.providers.base import LLMProvider
-from app.config import settings
-
-@final
-class GeminiProvider(LLMProvider):
- """Provider for the Google Gemini API."""
- def __init__(self, api_url: str):
- self.url = api_url
-
- async def generate_response(self, prompt: str) -> str:
- payload = {"contents": [{"parts": [{"text": prompt}]}]}
- headers = {"Content-Type": "application/json"}
- try:
- async with httpx.AsyncClient() as client:
- response = await client.post(self.url, json=payload, headers=headers)
- response.raise_for_status()
- # Await the async `json` method
- data = response.json()
- return data['candidates'][0]['content']['parts'][0]['text']
- except Exception as e:
- logging.error("Gemini Provider Error", exc_info=True)
- raise
\ No newline at end of file
diff --git a/ai-hub/app/core/providers/llm/general.py b/ai-hub/app/core/providers/llm/general.py
new file mode 100644
index 0000000..13b961c
--- /dev/null
+++ b/ai-hub/app/core/providers/llm/general.py
@@ -0,0 +1,44 @@
+
+import litellm
+from dspy.clients.base_lm import BaseLM
+
+class GeneralProvider(BaseLM):
+ def __init__(self, model_name: str, api_key: str):
+ self.model_name = model_name
+ self.api_key = api_key
+ # Call the parent constructor
+ super().__init__(model=model_name)
+
+ def forward(self, prompt=None, messages=None, **kwargs):
+ """
+ Synchronous forward pass using LiteLLM.
+ """
+ messages = messages or [{"role": "user", "content": prompt}]
+ request = {
+ "model": self.model_name,
+ "messages": messages,
+ "api_key": self.api_key,
+ **self.kwargs,
+ **kwargs,
+ }
+ try:
+ return litellm.completion(**request)
+ except Exception as e:
+ raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")
+
+ async def aforward(self, prompt=None, messages=None, **kwargs):
+ """
+ Asynchronous forward pass using LiteLLM.
+ """
+ messages = messages or [{"role": "user", "content": prompt}]
+ request = {
+ "model": self.model_name,
+ "messages": messages,
+ "api_key": self.api_key,
+ **self.kwargs,
+ **kwargs,
+ }
+ try:
+ return await litellm.acompletion(**request)
+ except Exception as e:
+ raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")
\ No newline at end of file
diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py
index e3973da..4965630 100644
--- a/ai-hub/app/core/services/rag.py
+++ b/ai-hub/app/core/services/rag.py
@@ -9,7 +9,7 @@
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
-from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline
+from app.core.pipelines.dspy_rag import DspyRagPipeline
class RAGService:
"""
@@ -21,10 +21,10 @@
self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)
# --- Session Management ---
- def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
+ def create_session(self, db: Session, user_id: str, provider_name: str) -> models.Session:
"""Creates a new chat session in the database."""
try:
- new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session")
+ new_session = models.Session(user_id=user_id, provider_name=provider_name, title=f"New Chat Session")
db.add(new_session)
db.commit()
db.refresh(new_session)
@@ -38,7 +38,7 @@
db: Session,
session_id: int,
prompt: str,
- model: str,
+ provider_name: str,
load_faiss_retriever: bool = False
) -> Tuple[str, str]:
"""
@@ -54,10 +54,9 @@
user_message = models.Message(session_id=session_id, sender="user", content=prompt)
db.add(user_message)
db.commit()
+ db.refresh(user_message)
- llm_provider = get_llm_provider(model)
- dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model)
- dspy.configure(lm=dspy_llm)
+ llm_provider = get_llm_provider(provider_name)
current_retrievers = []
if load_faiss_retriever:
@@ -68,17 +67,20 @@
rag_pipeline = DspyRagPipeline(retrievers=current_retrievers)
- answer_text = await rag_pipeline.forward(
- question=prompt,
- history=session.messages,
- db=db
- )
+ # Use dspy.context to configure the language model for this specific async task
+ with dspy.context(lm=llm_provider):
+ answer_text = await rag_pipeline.forward(
+ question=prompt,
+ history=session.messages,
+ db=db
+ )
assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
db.add(assistant_message)
db.commit()
+ db.refresh(assistant_message)
- return answer_text, model
+ return answer_text, provider_name
def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
"""
diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py
index ec4f703..2d3476c 100644
--- a/ai-hub/app/db/models.py
+++ b/ai-hub/app/db/models.py
@@ -25,7 +25,7 @@
# A title for the conversation, which can be generated by the AI.
title = Column(String, index=True, nullable=True)
# The name of the LLM model used for this session (e.g., "Gemini", "DeepSeek").
- model_name = Column(String, nullable=True)
+ provider_name = Column(String, nullable=True)
# Timestamp for when the session was created.
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
# Flag to indicate if the session has been archived or soft-deleted.
diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py
index 435ce3d..2286daf 100644
--- a/ai-hub/integration_tests/test_sessions_api.py
+++ b/ai-hub/integration_tests/test_sessions_api.py
@@ -16,7 +16,7 @@
print("\n--- Running test_chat_in_session_lifecycle ---")
# 1. Create a new session with a trailing slash
- payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"}
+ payload = {"user_id": "integration_tester_lifecycle", "provider_name": "deepseek"}
response = await http_client.post("/sessions/", json=payload)
assert response.status_code == 200
session_id = response.json()["id"]
@@ -27,7 +27,7 @@
response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1)
assert response_1.status_code == 200
assert "Satya Nadella" in response_1.json()["answer"]
- assert response_1.json()["model_used"] == "deepseek"
+ assert response_1.json()["provider_used"] == "deepseek"
print("✅ Chat Turn 1 (context) test passed.")
# 3. Second chat turn (follow-up) to test conversational memory
@@ -35,7 +35,7 @@
response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2)
assert response_2.status_code == 200
assert "1967" in response_2.json()["answer"]
- assert response_2.json()["model_used"] == "deepseek"
+ assert response_2.json()["provider_used"] == "deepseek"
print("✅ Chat Turn 2 (follow-up) test passed.")
# 4. Cleanup (optional, but good practice if not using a test database that resets)
@@ -47,19 +47,19 @@
print("\n--- Running test_chat_with_model_switch ---")
# Send a message to the new session with a different model
- payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"}
+ payload_gemini = {"prompt": "What is the capital of France?", "provider_name": "gemini"}
response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini)
assert response_gemini.status_code == 200
assert "Paris" in response_gemini.json()["answer"]
- assert response_gemini.json()["model_used"] == "gemini"
+ assert response_gemini.json()["provider_used"] == "gemini"
print("✅ Chat (Model Switch to Gemini) test passed.")
# Switch back to the original model
- payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"}
+ payload_deepseek = {"prompt": "What is the largest ocean?", "provider_name": "deepseek"}
response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek)
assert response_deepseek.status_code == 200
assert "Pacific Ocean" in response_deepseek.json()["answer"]
- assert response_deepseek.json()["model_used"] == "deepseek"
+ assert response_deepseek.json()["provider_used"] == "deepseek"
print("✅ Chat (Model Switch back to DeepSeek) test passed.")
@pytest.mark.asyncio
@@ -71,14 +71,12 @@
print("\n--- Running test_chat_with_document_retrieval ---")
# Create a new session for this RAG test
- # Corrected URL with a trailing slash
- session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"})
+ session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "provider_name": "deepseek"})
assert session_response.status_code == 200
rag_session_id = session_response.json()["id"]
# Add a new document with specific content for retrieval
doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT}
- # Corrected URL with a trailing slash
add_doc_response = await http_client.post("/documents/", json=doc_data)
assert add_doc_response.status_code == 200
try:
@@ -92,12 +90,22 @@
chat_payload = {
"prompt": RAG_PROMPT,
"document_id": rag_document_id,
- "model": "deepseek",
+ "provider_name": "deepseek",
"load_faiss_retriever": True
}
chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload)
+ # --- MODIFICATION START ---
+ # If a 500 error occurs, print the detailed response text.
+ if chat_response.status_code != 200:
+ print(f"❌ Test Failed! Received status code: {chat_response.status_code}")
+ print("--- Response Body (for debugging) ---")
+ print(chat_response.text)
+ print("---------------------------------------")
+
assert chat_response.status_code == 200
+ # --- MODIFICATION END ---
+
chat_data = chat_response.json()
assert "Jane Doe" in chat_data["answer"]
assert "Nexus" in chat_data["answer"]
@@ -106,4 +114,4 @@
# Clean up the document after the test
delete_response = await http_client.delete(f"/documents/{rag_document_id}")
assert delete_response.status_code == 200
- print(f"Document {rag_document_id} deleted successfully.")
+ print(f"Document {rag_document_id} deleted successfully.")
\ No newline at end of file
diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh
index 5646e5a..258c374 100644
--- a/ai-hub/run_chat.sh
+++ b/ai-hub/run_chat.sh
@@ -39,7 +39,7 @@
# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect
SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \
-H "Content-Type: application/json" \
- -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \
+ -d '{"user_id": "local_user", "provider_name": "'"$DEFAULT_MODEL"'"}' \
-w '\n%{http_code}') # Add a new line and the status code
# Extract body and status code
@@ -110,9 +110,9 @@
# Note the use of --argjson to pass the boolean value correctly
json_payload=$(jq -n \
--arg prompt "$PROMPT_TEXT" \
- --arg model "$MODEL_TO_USE" \
+ --arg provider_name "$MODEL_TO_USE" \
--argjson faiss_flag "$LOAD_FAISS_RETRIEVER" \
- '{"prompt": $prompt, "model": $model, "load_faiss_retriever": $faiss_flag}')
+ '{"prompt": $prompt, "provider_name": $provider_name, "load_faiss_retriever": $faiss_flag}')
echo "Payload: $json_payload" # Optional: for debugging
@@ -126,8 +126,8 @@
echo "Server response: $ai_response_json"
else
ai_answer=$(echo "$ai_response_json" | jq -r '.answer')
- model_used=$(echo "$ai_response_json" | jq -r '.model_used')
- echo "AI [$model_used] [faiss_enabled: $LOAD_FAISS_RETRIEVER]: $ai_answer"
+ provider_used=$(echo "$ai_response_json" | jq -r '.provider_used')
+ echo "AI [$provider_used] [faiss_enabled: $LOAD_FAISS_RETRIEVER]: $ai_answer"
fi
done
diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh
index e0f94d5..5a2ecb8 100644
--- a/ai-hub/run_integration_tests.sh
+++ b/ai-hub/run_integration_tests.sh
@@ -21,6 +21,7 @@
export DB_MODE=sqlite
export LOCAL_DB_PATH="data/integration_test_ai_hub.db"
export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin"
+export LOG_LEVEL="DEBUG"
# --- User Interaction ---
echo "--- AI Hub Test Runner ---"
@@ -60,7 +61,7 @@
# Start the uvicorn server in the background
# We bind it to 127.0.0.1 to ensure it's not accessible from outside the local machine.
-uvicorn app.main:app --host 127.0.0.1 --port 8000 &
+uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload &
# Get the Process ID (PID) of the background server
SERVER_PID=$!
diff --git a/ai-hub/tests/api/routes/test_sessions.py b/ai-hub/tests/api/routes/test_sessions.py
index 10334d2..ad23a3a 100644
--- a/ai-hub/tests/api/routes/test_sessions.py
+++ b/ai-hub/tests/api/routes/test_sessions.py
@@ -8,12 +8,12 @@
mock_session = MagicMock(spec=models.Session)
mock_session.id = 1
mock_session.user_id = "test_user"
- mock_session.model_name = "gemini"
+ mock_session.provider_name = "gemini"
mock_session.title = "New Chat"
mock_session.created_at = datetime.now()
mock_services.rag_service.create_session.return_value = mock_session
- response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})
+ response = test_client.post("/sessions", json={"user_id": "test_user", "provider_name": "gemini"})
assert response.status_code == 200
assert response.json()["id"] == 1
@@ -30,13 +30,13 @@
response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})
assert response.status_code == 200
- assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"}
+ assert response.json() == {"answer": "Mocked response", "provider_used": "deepseek"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="Hello there",
- model="deepseek",
+ provider_name="deepseek",
load_faiss_retriever=False
)
@@ -47,16 +47,16 @@
test_client, mock_services = client
mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))
- response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})
+ response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "provider_name": "gemini"})
assert response.status_code == 200
- assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}
+ assert response.json() == {"answer": "Mocked response from Gemini", "provider_used": "gemini"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="Hello there, Gemini!",
- model="gemini",
+ provider_name="gemini",
load_faiss_retriever=False
)
@@ -73,13 +73,13 @@
)
assert response.status_code == 200
- assert response.json() == {"answer": "Response with context", "model_used": "deepseek"}
+ assert response.json() == {"answer": "Response with context", "provider_used": "deepseek"}
mock_services.rag_service.chat_with_rag.assert_called_once_with(
db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
session_id=42,
prompt="What is RAG?",
- model="deepseek",
+ provider_name="deepseek",
load_faiss_retriever=True
)
diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py
index 4e44c4b..be4a13d 100644
--- a/ai-hub/tests/core/pipelines/test_dspy_rag.py
+++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py
@@ -1,97 +1,132 @@
import pytest
+from typing import List
import asyncio
from unittest.mock import MagicMock, AsyncMock
from sqlalchemy.orm import Session
-import dspy
+import dspy # <-- Import dspy
# Import the pipeline and its new signature
from app.core.pipelines.dspy_rag import DspyRagPipeline, AnswerWithHistory
from app.db import models # Import your SQLAlchemy models for mocking history
from app.core.retrievers.base_retriever import Retriever
+
+# --- Mock Classes ---
+
+class MockRetriever(Retriever):
+ """A mock retriever that returns a predefined list of strings."""
+ def __init__(self, name: str, mock_data: List[str]):
+ self.name = name
+ self.mock_data = mock_data
+
+ def retrieve_context(self, question: str, db: Session) -> List[str]:
+ return self.mock_data
+
+# --- Fixtures ---
+
@pytest.fixture
-def mock_lm_configured():
- """Pytest fixture to mock and configure the dspy language model for a test."""
- mock_lm_instance = MagicMock()
- mock_lm_instance.aforward = AsyncMock(
- return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="Mocked LLM answer"))])
+def mock_db():
+ """A mock SQLAlchemy session."""
+ return MagicMock()
+
+@pytest.fixture(autouse=True)
+def mock_dspy_predict_instance(mocker):
+ """
+ Mocks the dspy.Predict class itself to return a mock instance
+ with a controllable aforward method.
+ """
+ # Create a mock instance of dspy.Predict
+ mock_predict_instance = MagicMock()
+ mock_predict_instance.aforward = AsyncMock(return_value=MagicMock(answer="Mocked LLM answer."))
+
+ # Patch the dspy.Predict class to return our mock instance
+ mocker.patch('dspy.Predict', return_value=mock_predict_instance)
+
+ return mock_predict_instance
+
+
+# --- Test Cases ---
+
+# def test_pipeline_initializes_with_defaults(mock_dspy_predict_instance):
+# """Test that the pipeline initializes correctly with default processors."""
+# pipeline = DspyRagPipeline(retrievers=[MockRetriever("test", ["test context"])])
+# assert pipeline.retrievers is not None
+# assert pipeline.context_postprocessor is pipeline._default_context_postprocessor
+# assert pipeline.history_formatter is pipeline._default_history_formatter
+# assert pipeline.response_postprocessor is None
+# # Verify that dspy.Predict was instantiated once
+# dspy.Predict.assert_called_once()
+# # Verify the aforward method was not called yet
+# mock_dspy_predict_instance.aforward.assert_not_called()
+
+@pytest.mark.asyncio
+async def test_forward_pass_with_defaults(mock_db, mock_dspy_predict_instance):
+ """Test a successful forward pass using default processors."""
+ retriever = MockRetriever("test_retriever", ["Context 1.", "Context 2."])
+ pipeline = DspyRagPipeline(retrievers=[retriever])
+
+ question = "What is the capital of France?"
+ history = [models.Message(sender="user", content="Hello there."), models.Message(sender="assistant", content="Hi.")]
+
+ response = await pipeline.forward(question, history, mock_db)
+
+ expected_context = "Context 1.\n\nContext 2."
+ expected_history = "Human: Hello there.\nAssistant: Hi."
+
+ mock_dspy_predict_instance.aforward.assert_called_once_with(
+ context=expected_context,
+ chat_history=expected_history,
+ question=question
)
- original_lm = dspy.settings.lm
- dspy.configure(lm=mock_lm_instance)
- yield mock_lm_instance
- dspy.configure(lm=original_lm)
+ assert response == "Mocked LLM answer."
-def test_pipeline_with_context_and_history(mock_lm_configured):
- """
- Tests the pipeline's prompt construction when it has both retrieved context
- and a conversation history.
- """
- # --- Arrange ---
- mock_retriever = MagicMock(spec=Retriever)
- mock_retriever.retrieve_context.return_value = ["Context chunk 1."]
- mock_db = MagicMock(spec=Session)
+@pytest.mark.asyncio
+async def test_forward_with_custom_processors(mock_db, mock_dspy_predict_instance):
+ """Test that custom processors are used correctly."""
- # Create a mock conversation history
- mock_history = [
- models.Message(sender="user", content="What is the capital of France?"),
- models.Message(sender="assistant", content="The capital of France is Paris.")
- ]
-
- pipeline = DspyRagPipeline(retrievers=[mock_retriever])
- question = "What is its population?"
+ def custom_context_processor(contexts: List[str]) -> str:
+ return "CUSTOM_CONTEXT: " + " | ".join(contexts)
- # --- Act ---
- response = asyncio.run(pipeline.forward(question=question, history=mock_history, db=mock_db))
+ def custom_history_formatter(history: List[models.Message]) -> str:
+ return "