diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 25b52bc..c3d05b3 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -6,7 +6,7 @@ from app.core.vector_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline @@ -17,6 +17,10 @@ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # Assume one of the retrievers is the FAISS retriever, and you can access it. + # A better approach might be to have a dictionary of named retrievers. + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- @@ -32,12 +36,18 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False # Add the new parameter with a default value + ) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model for the current chat turn. + Allows switching the LLM model and conditionally using the FAISS retriever. """ - # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -45,34 +55,39 @@ if not session: raise ValueError(f"Session with ID {session_id} not found.") - # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() - # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # Conditionally choose the retriever list based on the new parameter + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + # Handle the case where the FaissDBRetriever isn't initialized + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - # Pass the full message history to the pipeline's forward method. - # Note: The history is passed, but the current RAGPipeline implementation - # might not fully utilize it for conversational context unless explicitly - # designed to. This is a placeholder for future conversational RAG. + # If no specific retriever is requested or available, fall back to a default or empty list + # This part of the logic may need to be adjusted based on your system's design. + # For this example, we proceed with an empty list if no retriever is selected. + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, # Pass the existing history + history=session.messages, db=db ) - # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() - # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: @@ -83,7 +98,6 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 25b52bc..c3d05b3 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -6,7 +6,7 @@ from app.core.vector_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline @@ -17,6 +17,10 @@ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # Assume one of the retrievers is the FAISS retriever, and you can access it. + # A better approach might be to have a dictionary of named retrievers. + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- @@ -32,12 +36,18 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False # Add the new parameter with a default value + ) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model for the current chat turn. + Allows switching the LLM model and conditionally using the FAISS retriever. """ - # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -45,34 +55,39 @@ if not session: raise ValueError(f"Session with ID {session_id} not found.") - # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() - # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # Conditionally choose the retriever list based on the new parameter + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + # Handle the case where the FaissDBRetriever isn't initialized + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - # Pass the full message history to the pipeline's forward method. - # Note: The history is passed, but the current RAGPipeline implementation - # might not fully utilize it for conversational context unless explicitly - # designed to. This is a placeholder for future conversational RAG. + # If no specific retriever is requested or available, fall back to a default or empty list + # This part of the logic may need to be adjusted based on your system's design. + # For this example, we proceed with an empty list if no retriever is selected. + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, # Pass the existing history + history=session.messages, db=db ) - # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() - # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: @@ -83,7 +98,6 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 6d692b3..541881f 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -9,6 +9,11 @@ CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" +# Document and prompt for the retrieval-augmented generation test +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + # Global variables to pass state between sequential tests created_document_id = None created_session_id = None @@ -121,32 +126,51 @@ assert response_data["model_used"] == "deepseek" print("✅ Chat (Model Switch back to DeepSeek) test passed.") +async def test_chat_with_document_retrieval(): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This simulates the 'load_faiss_retriever' functionality. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + async with httpx.AsyncClient(timeout=60.0) as client: + # Create a new session for this test + session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] -async def test_get_session_history(): - """Tests retrieving the full message history for the session.""" - print("\n--- Running test_get_session_history ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/messages" - async with httpx.AsyncClient() as client: - response = await client.get(url) + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") - assert response.status_code == 200 - response_data = response.json() - - assert response_data["session_id"] == created_session_id - # After 4 turns, there should be 8 messages (4 user, 4 assistant) - assert len(response_data["messages"]) >= 8 - assert response_data["messages"][0]["content"] == CONTEXT_PROMPT - assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT - # Verify content and sender for the switched models - assert response_data["messages"][4]["content"] == "What is the capital of France?" - assert response_data["messages"][5]["sender"] == "assistant" - assert "Paris" in response_data["messages"][5]["content"] - assert response_data["messages"][6]["content"] == "What is the largest ocean?" - assert response_data["messages"][7]["sender"] == "assistant" - assert "Pacific Ocean" in response_data["messages"][7]["content"] - print("✅ Get session history test passed.") + try: + # Send a chat request with the document ID to enable retrieval + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", # or any other RAG-enabled model + "load_faiss_retriever": True + } + chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" + chat_data = chat_response.json() + + # Verify the response contains information from the document + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") # --- Document Management Lifecycle Tests --- async def test_add_document_for_lifecycle(): diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 25b52bc..c3d05b3 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -6,7 +6,7 @@ from app.core.vector_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline @@ -17,6 +17,10 @@ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # Assume one of the retrievers is the FAISS retriever, and you can access it. + # A better approach might be to have a dictionary of named retrievers. + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- @@ -32,12 +36,18 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False # Add the new parameter with a default value + ) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model for the current chat turn. + Allows switching the LLM model and conditionally using the FAISS retriever. """ - # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -45,34 +55,39 @@ if not session: raise ValueError(f"Session with ID {session_id} not found.") - # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() - # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # Conditionally choose the retriever list based on the new parameter + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + # Handle the case where the FaissDBRetriever isn't initialized + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - # Pass the full message history to the pipeline's forward method. - # Note: The history is passed, but the current RAGPipeline implementation - # might not fully utilize it for conversational context unless explicitly - # designed to. This is a placeholder for future conversational RAG. + # If no specific retriever is requested or available, fall back to a default or empty list + # This part of the logic may need to be adjusted based on your system's design. + # For this example, we proceed with an empty list if no retriever is selected. + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, # Pass the existing history + history=session.messages, db=db ) - # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() - # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: @@ -83,7 +98,6 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 6d692b3..541881f 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -9,6 +9,11 @@ CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" +# Document and prompt for the retrieval-augmented generation test +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + # Global variables to pass state between sequential tests created_document_id = None created_session_id = None @@ -121,32 +126,51 @@ assert response_data["model_used"] == "deepseek" print("✅ Chat (Model Switch back to DeepSeek) test passed.") +async def test_chat_with_document_retrieval(): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This simulates the 'load_faiss_retriever' functionality. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + async with httpx.AsyncClient(timeout=60.0) as client: + # Create a new session for this test + session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] -async def test_get_session_history(): - """Tests retrieving the full message history for the session.""" - print("\n--- Running test_get_session_history ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/messages" - async with httpx.AsyncClient() as client: - response = await client.get(url) + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") - assert response.status_code == 200 - response_data = response.json() - - assert response_data["session_id"] == created_session_id - # After 4 turns, there should be 8 messages (4 user, 4 assistant) - assert len(response_data["messages"]) >= 8 - assert response_data["messages"][0]["content"] == CONTEXT_PROMPT - assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT - # Verify content and sender for the switched models - assert response_data["messages"][4]["content"] == "What is the capital of France?" - assert response_data["messages"][5]["sender"] == "assistant" - assert "Paris" in response_data["messages"][5]["content"] - assert response_data["messages"][6]["content"] == "What is the largest ocean?" - assert response_data["messages"][7]["sender"] == "assistant" - assert "Pacific Ocean" in response_data["messages"][7]["content"] - print("✅ Get session history test passed.") + try: + # Send a chat request with the document ID to enable retrieval + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", # or any other RAG-enabled model + "load_faiss_retriever": True + } + chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" + chat_data = chat_response.json() + + # Verify the response contains information from the document + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") # --- Document Management Lifecycle Tests --- async def test_add_document_for_lifecycle(): diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3e3b252..c58a6c7 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -52,8 +52,8 @@ def test_chat_in_session_success(client): """ - Tests sending a message in an existing session without specifying a model. - It should default to 'deepseek'. + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_rag_service = client mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) @@ -63,11 +63,13 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} # Verify that chat_with_rag was called with the default model 'deepseek' + # and the default load_faiss_retriever=False mock_rag_service.chat_with_rag.assert_called_once_with( db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", - model="deepseek" + model="deepseek", + load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): @@ -86,7 +88,31 @@ db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", - model="gemini" + model="gemini", + load_faiss_retriever=False # It should still default to False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", # The model still defaults to deepseek + load_faiss_retriever=True # Verify that the retriever was explicitly enabled ) def test_get_session_messages_success(client): @@ -156,4 +182,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 + assert response.status_code == 404 \ No newline at end of file diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 25b52bc..c3d05b3 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -6,7 +6,7 @@ from app.core.vector_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline @@ -17,6 +17,10 @@ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # Assume one of the retrievers is the FAISS retriever, and you can access it. + # A better approach might be to have a dictionary of named retrievers. + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- @@ -32,12 +36,18 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False # Add the new parameter with a default value + ) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model for the current chat turn. + Allows switching the LLM model and conditionally using the FAISS retriever. """ - # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -45,34 +55,39 @@ if not session: raise ValueError(f"Session with ID {session_id} not found.") - # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() - # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # Conditionally choose the retriever list based on the new parameter + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + # Handle the case where the FaissDBRetriever isn't initialized + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - # Pass the full message history to the pipeline's forward method. - # Note: The history is passed, but the current RAGPipeline implementation - # might not fully utilize it for conversational context unless explicitly - # designed to. This is a placeholder for future conversational RAG. + # If no specific retriever is requested or available, fall back to a default or empty list + # This part of the logic may need to be adjusted based on your system's design. + # For this example, we proceed with an empty list if no retriever is selected. + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, # Pass the existing history + history=session.messages, db=db ) - # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() - # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: @@ -83,7 +98,6 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 6d692b3..541881f 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -9,6 +9,11 @@ CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" +# Document and prompt for the retrieval-augmented generation test +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + # Global variables to pass state between sequential tests created_document_id = None created_session_id = None @@ -121,32 +126,51 @@ assert response_data["model_used"] == "deepseek" print("✅ Chat (Model Switch back to DeepSeek) test passed.") +async def test_chat_with_document_retrieval(): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This simulates the 'load_faiss_retriever' functionality. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + async with httpx.AsyncClient(timeout=60.0) as client: + # Create a new session for this test + session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] -async def test_get_session_history(): - """Tests retrieving the full message history for the session.""" - print("\n--- Running test_get_session_history ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/messages" - async with httpx.AsyncClient() as client: - response = await client.get(url) + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") - assert response.status_code == 200 - response_data = response.json() - - assert response_data["session_id"] == created_session_id - # After 4 turns, there should be 8 messages (4 user, 4 assistant) - assert len(response_data["messages"]) >= 8 - assert response_data["messages"][0]["content"] == CONTEXT_PROMPT - assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT - # Verify content and sender for the switched models - assert response_data["messages"][4]["content"] == "What is the capital of France?" - assert response_data["messages"][5]["sender"] == "assistant" - assert "Paris" in response_data["messages"][5]["content"] - assert response_data["messages"][6]["content"] == "What is the largest ocean?" - assert response_data["messages"][7]["sender"] == "assistant" - assert "Pacific Ocean" in response_data["messages"][7]["content"] - print("✅ Get session history test passed.") + try: + # Send a chat request with the document ID to enable retrieval + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", # or any other RAG-enabled model + "load_faiss_retriever": True + } + chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" + chat_data = chat_response.json() + + # Verify the response contains information from the document + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") # --- Document Management Lifecycle Tests --- async def test_add_document_for_lifecycle(): diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3e3b252..c58a6c7 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -52,8 +52,8 @@ def test_chat_in_session_success(client): """ - Tests sending a message in an existing session without specifying a model. - It should default to 'deepseek'. + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_rag_service = client mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) @@ -63,11 +63,13 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} # Verify that chat_with_rag was called with the default model 'deepseek' + # and the default load_faiss_retriever=False mock_rag_service.chat_with_rag.assert_called_once_with( db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", - model="deepseek" + model="deepseek", + load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): @@ -86,7 +88,31 @@ db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", - model="gemini" + model="gemini", + load_faiss_retriever=False # It should still default to False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", # The model still defaults to deepseek + load_faiss_retriever=True # Verify that the retriever was explicitly enabled ) def test_get_session_messages_success(client): @@ -156,4 +182,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 + assert response.status_code == 404 \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index 777a466..7fbd50f 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -2,7 +2,7 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session -from sqlalchemy.exc import SQLAlchemyError # Import the specific error type +from sqlalchemy.exc import SQLAlchemyError from typing import List from datetime import datetime import dspy @@ -11,7 +11,8 @@ from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore -from app.core.retrievers import Retriever +# Import FaissDBRetriever and a mock WebRetriever for testing different cases +from app.core.retrievers import FaissDBRetriever, Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @@ -20,13 +21,15 @@ def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. - Correctly instantiates RAGService with only the required arguments. + It includes a mock FaissDBRetriever and a mock generic Retriever to test + conditional loading. """ mock_vector_store = MagicMock(spec=FaissVectorStore) - mock_retriever = MagicMock(spec=Retriever) + mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) + mock_web_retriever = MagicMock(spec=Retriever) return RAGService( vector_store=mock_vector_store, - retrievers=[mock_retriever] + retrievers=[mock_web_retriever, mock_faiss_retriever] ) # --- Session Management Tests --- @@ -48,11 +51,11 @@ @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests the full orchestration of a chat message within a session using the default model. + Tests the full orchestration of a chat message within a session using the default model + and with the retriever loading parameter explicitly set to False. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - # The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session @@ -63,15 +66,24 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - # Pass the 'model' argument, defaulting to "deepseek" for this test case - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek")) + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=42, + prompt="Test prompt", + model="deepseek", + load_faiss_retriever=False # Explicitly pass the default value + ) + ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") - # Assert that the pipeline was called with the history argument + # Assert that DspyRagPipeline was initialized with an empty list of retrievers + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, @@ -81,16 +93,18 @@ assert answer == "Final RAG response" assert model_name == "deepseek" + @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests that chat_with_rag correctly switches the model based on the 'model' argument. + Tests that chat_with_rag correctly switches the model based on the 'model' argument, + while still using the default retriever setting. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) @@ -100,15 +114,24 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - # Explicitly request the "gemini" model for this chat turn - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini")) + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=43, + prompt="Test prompt for Gemini", + model="gemini", + load_faiss_retriever=False # Explicitly pass the default value + ) + ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 - # Verify that get_llm_provider was called with "gemini" mock_get_llm_provider.assert_called_once_with("gemini") + # Assert that DspyRagPipeline was initialized with an empty list of retrievers + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, @@ -118,6 +141,53 @@ assert answer == "Final RAG response from Gemini" assert model_name == "gemini" + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests that the chat_with_rag method correctly initializes the DspyRagPipeline + with the FaissDBRetriever when `load_faiss_retriever` is True. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + # Explicitly enable the FAISS retriever + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=44, + prompt="Test prompt with FAISS", + model="deepseek", + load_faiss_retriever=True + ) + ) + + # --- Assert --- + # The crucial part is to verify that the pipeline was called with the correct retriever + expected_retrievers = [rag_service.faiss_retriever] + mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt with FAISS", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Response with FAISS context" + assert model_name == "deepseek" + + def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange @@ -241,86 +311,4 @@ mock_db.commit.assert_not_called() mock_db.rollback.assert_called_once() -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Test the RAGService.chat_with_rag method when context is retrieved. - Verifies that the RAG prompt is correctly constructed. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, model_name="deepseek", messages=[ - models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)), - models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0)) - ]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") - mock_dspy_pipeline.return_value = mock_pipeline_instance - - prompt = "Test prompt." - expected_context = "Context text 1.\n\nContext text 2." - mock_retriever = rag_service.retrievers[0] - mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."]) - - # --- Act --- - response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) - - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") - - mock_pipeline_instance.forward.assert_called_once_with( - question=prompt, - history=mock_session.messages, - db=mock_db - ) - - assert response_text == "LLM response with context" - assert model_used == "deepseek" - -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Test the RAGService.chat_with_rag method when no context is retrieved. - Verifies that the original prompt is sent to the LLM. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") - mock_dspy_pipeline.return_value = mock_pipeline_instance - - prompt = "Test prompt without context." - mock_retriever = rag_service.retrievers[0] - mock_retriever.retrieve_context = AsyncMock(return_value=[]) - - # --- Act --- - response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) - - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") - - mock_pipeline_instance.forward.assert_called_once_with( - question=prompt, - history=mock_session.messages, - db=mock_db - ) - - assert response_text == "LLM response without context" - assert model_used == "deepseek" \ No newline at end of file diff --git a/ai-hub/add_knowledge.sh b/ai-hub/add_knowledge.sh new file mode 100644 index 0000000..c7ab949 --- /dev/null +++ b/ai-hub/add_knowledge.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# ============================================================================== +# Script to scan for specific file types and upload their content to an API. +# Prerequisites: +# - The FastAPI application must be running locally on http://localhost:8000. +# - The 'jq' command-line JSON processor must be installed. +# - The 'curl' command-line tool must be installed. +# ============================================================================== + +# Define the API endpoint +API_URL="http://localhost:8000/documents" + +DEFAULT_MODEL="gemini" +CURRENT_MODEL="" # The model used in the last turn + +# --- 1. Check for Dependencies --- +if ! command -v jq &> /dev/null +then + echo "❌ 'jq' is not installed. Please install it to run this script." + exit 1 +fi + +# --- 2. Start the FastAPI Server in the Background --- +echo "--- Starting AI Hub Server ---" +uvicorn app.main:app --host 127.0.0.1 --port 8000 & +SERVER_PID=$! + +# Define a cleanup function to kill the server on exit +cleanup() { + echo "" + echo "--- Shutting Down Server (PID: $SERVER_PID) ---" + kill $SERVER_PID +} +# Register the cleanup function to run when the script exits (e.g., Ctrl+C or typing 'exit') +trap cleanup EXIT + +echo "Server started with PID: $SERVER_PID. Waiting for it to initialize..." +sleep 5 # Wait for the server to be ready + + +# Find all files with the specified extensions in the current directory and its subdirectories +# The -print0 option is used to handle filenames with spaces or special characters. +find . -type f \( -name "*.py" -o -name "*.txt" -o -name "*.md" -o -name "*.yaml" \) -print0 | while IFS= read -r -d $'\0' file_path; do + + # Get the file's basename (e.g., "my_file.md") to use as the title + file_title=$(basename -- "$file_path") + + # Get the file creation date. + # Note: 'stat' options differ between Linux and macOS. + # Linux: stat -c %w + # macOS: stat -f %B + # This script assumes a Linux-like stat output. You may need to adjust this. + file_created_at=$(stat -c %w "$file_path") + + # Read the entire file content into a variable. + # `cat` is used to get the content, and it's escaped for JSON. + file_content=$(cat "$file_path" | sed 's/\\/\\\\/g' | sed 's/"/\\"/g' | sed 's/`/\\`/g') + + # Construct the JSON payload using `jq` for robust formatting + json_payload=$(jq -n \ + --arg title "$file_title" \ + --arg text "$file_content" \ + --arg file_path "$file_path" \ + '{title: $title, text: $text, file_path: $file_path}') + + echo "Uploading file: $file_path" + + # Make the POST request to the API + response=$(curl -X POST \ + -s \ + -H "Content-Type: application/json" \ + -d "$json_payload" \ + "$API_URL") + + # Display the response from the API + if [[ "$response" == *"success"* ]]; then + echo "✅ Uploaded successfully. Response: $response" + else + echo "❌ Failed to upload. Response: $response" + fi + + echo "---" + +done + +echo "Script finished." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index dff0a91..85ecafd 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -34,6 +34,7 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, @@ -43,19 +44,21 @@ """ Sends a message within an existing session and gets a contextual response. - The 'model' can now be specified in the request body to switch models mid-conversation. + The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: response_text, model_used = await rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, - model=request.model # Pass the model from the request to the RAG service + model=request.model, + load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): """ diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 95e0fb9..419aa44 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -6,9 +6,11 @@ class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) - # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. - # For session-based chat, this field might be ignored. + # The 'model' can now be specified in the request body to switch models mid-conversation. model: Literal["deepseek", "gemini"] = Field("deepseek") + # Add a new optional boolean field to control the retriever + load_faiss_retriever: Optional[bool] = Field(False, description="Whether to use the FAISS DB retriever for the chat.") + class ChatResponse(BaseModel): """Defines the shape of a successful response from the /chat endpoint.""" diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index 25b52bc..c3d05b3 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -6,7 +6,7 @@ from app.core.vector_store import FaissVectorStore from app.db import models -from app.core.retrievers import Retriever +from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline @@ -17,6 +17,10 @@ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # Assume one of the retrievers is the FAISS retriever, and you can access it. + # A better approach might be to have a dictionary of named retrievers. + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- @@ -32,12 +36,18 @@ db.rollback() raise - async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False # Add the new parameter with a default value + ) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model for the current chat turn. + Allows switching the LLM model and conditionally using the FAISS retriever. """ - # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() @@ -45,34 +55,39 @@ if not session: raise ValueError(f"Session with ID {session_id} not found.") - # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() - # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # Conditionally choose the retriever list based on the new parameter + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + # Handle the case where the FaissDBRetriever isn't initialized + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - # Pass the full message history to the pipeline's forward method. - # Note: The history is passed, but the current RAGPipeline implementation - # might not fully utilize it for conversational context unless explicitly - # designed to. This is a placeholder for future conversational RAG. + # If no specific retriever is requested or available, fall back to a default or empty list + # This part of the logic may need to be adjusted based on your system's design. + # For this example, we proceed with an empty list if no retriever is selected. + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + answer_text = await rag_pipeline.forward( question=prompt, - history=session.messages, # Pass the existing history + history=session.messages, db=db ) - # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() - # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: @@ -83,7 +98,6 @@ joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() - # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 6d692b3..541881f 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -9,6 +9,11 @@ CONTEXT_PROMPT = "Who is the CEO of Microsoft?" FOLLOW_UP_PROMPT = "When was he born?" +# Document and prompt for the retrieval-augmented generation test +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + # Global variables to pass state between sequential tests created_document_id = None created_session_id = None @@ -121,32 +126,51 @@ assert response_data["model_used"] == "deepseek" print("✅ Chat (Model Switch back to DeepSeek) test passed.") +async def test_chat_with_document_retrieval(): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This simulates the 'load_faiss_retriever' functionality. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + async with httpx.AsyncClient(timeout=60.0) as client: + # Create a new session for this test + session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] -async def test_get_session_history(): - """Tests retrieving the full message history for the session.""" - print("\n--- Running test_get_session_history ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/messages" - async with httpx.AsyncClient() as client: - response = await client.get(url) + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") - assert response.status_code == 200 - response_data = response.json() - - assert response_data["session_id"] == created_session_id - # After 4 turns, there should be 8 messages (4 user, 4 assistant) - assert len(response_data["messages"]) >= 8 - assert response_data["messages"][0]["content"] == CONTEXT_PROMPT - assert response_data["messages"][2]["content"] == FOLLOW_UP_PROMPT - # Verify content and sender for the switched models - assert response_data["messages"][4]["content"] == "What is the capital of France?" - assert response_data["messages"][5]["sender"] == "assistant" - assert "Paris" in response_data["messages"][5]["content"] - assert response_data["messages"][6]["content"] == "What is the largest ocean?" - assert response_data["messages"][7]["sender"] == "assistant" - assert "Pacific Ocean" in response_data["messages"][7]["content"] - print("✅ Get session history test passed.") + try: + # Send a chat request with the document ID to enable retrieval + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", # or any other RAG-enabled model + "load_faiss_retriever": True + } + chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" + chat_data = chat_response.json() + + # Verify the response contains information from the document + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") # --- Document Management Lifecycle Tests --- async def test_add_document_for_lifecycle(): diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 3e3b252..c58a6c7 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -52,8 +52,8 @@ def test_chat_in_session_success(client): """ - Tests sending a message in an existing session without specifying a model. - It should default to 'deepseek'. + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_rag_service = client mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) @@ -63,11 +63,13 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} # Verify that chat_with_rag was called with the default model 'deepseek' + # and the default load_faiss_retriever=False mock_rag_service.chat_with_rag.assert_called_once_with( db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", - model="deepseek" + model="deepseek", + load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): @@ -86,7 +88,31 @@ db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", - model="gemini" + model="gemini", + load_faiss_retriever=False # It should still default to False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_rag_service = client + mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters + mock_rag_service.chat_with_rag.assert_called_once_with( + db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", # The model still defaults to deepseek + load_faiss_retriever=True # Verify that the retriever was explicitly enabled ) def test_get_session_messages_success(client): @@ -156,4 +182,4 @@ test_client, mock_rag_service = client mock_rag_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 + assert response.status_code == 404 \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index 777a466..7fbd50f 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -2,7 +2,7 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session -from sqlalchemy.exc import SQLAlchemyError # Import the specific error type +from sqlalchemy.exc import SQLAlchemyError from typing import List from datetime import datetime import dspy @@ -11,7 +11,8 @@ from app.core.services import RAGService from app.db import models from app.core.vector_store import FaissVectorStore -from app.core.retrievers import Retriever +# Import FaissDBRetriever and a mock WebRetriever for testing different cases +from app.core.retrievers import FaissDBRetriever, Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider @@ -20,13 +21,15 @@ def rag_service(): """ Pytest fixture to create a RAGService instance with mocked dependencies. - Correctly instantiates RAGService with only the required arguments. + It includes a mock FaissDBRetriever and a mock generic Retriever to test + conditional loading. """ mock_vector_store = MagicMock(spec=FaissVectorStore) - mock_retriever = MagicMock(spec=Retriever) + mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) + mock_web_retriever = MagicMock(spec=Retriever) return RAGService( vector_store=mock_vector_store, - retrievers=[mock_retriever] + retrievers=[mock_web_retriever, mock_faiss_retriever] ) # --- Session Management Tests --- @@ -48,11 +51,11 @@ @patch('dspy.configure') def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests the full orchestration of a chat message within a session using the default model. + Tests the full orchestration of a chat message within a session using the default model + and with the retriever loading parameter explicitly set to False. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - # The mock session now needs a 'messages' attribute for the history mock_session = models.Session(id=42, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session @@ -63,15 +66,24 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - # Pass the 'model' argument, defaulting to "deepseek" for this test case - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt", model="deepseek")) + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=42, + prompt="Test prompt", + model="deepseek", + load_faiss_retriever=False # Explicitly pass the default value + ) + ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") - # Assert that the pipeline was called with the history argument + # Assert that DspyRagPipeline was initialized with an empty list of retrievers + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, @@ -81,16 +93,18 @@ assert answer == "Final RAG response" assert model_name == "deepseek" + @patch('app.core.services.get_llm_provider') @patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests that chat_with_rag correctly switches the model based on the 'model' argument. + Tests that chat_with_rag correctly switches the model based on the 'model' argument, + while still using the default retriever setting. """ # --- Arrange --- mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=43, model_name="deepseek", messages=[]) # Session might start with deepseek + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session mock_llm_provider = MagicMock(spec=LLMProvider) @@ -100,15 +114,24 @@ mock_dspy_pipeline.return_value = mock_pipeline_instance # --- Act --- - # Explicitly request the "gemini" model for this chat turn - answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=43, prompt="Test prompt for Gemini", model="gemini")) + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=43, + prompt="Test prompt for Gemini", + model="gemini", + load_faiss_retriever=False # Explicitly pass the default value + ) + ) # --- Assert --- mock_db.query.assert_called_once_with(models.Session) assert mock_db.add.call_count == 2 - # Verify that get_llm_provider was called with "gemini" mock_get_llm_provider.assert_called_once_with("gemini") + # Assert that DspyRagPipeline was initialized with an empty list of retrievers + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, @@ -118,6 +141,53 @@ assert answer == "Final RAG response from Gemini" assert model_name == "gemini" + +@patch('app.core.services.get_llm_provider') +@patch('app.core.services.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests that the chat_with_rag method correctly initializes the DspyRagPipeline + with the FaissDBRetriever when `load_faiss_retriever` is True. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + # Explicitly enable the FAISS retriever + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=44, + prompt="Test prompt with FAISS", + model="deepseek", + load_faiss_retriever=True + ) + ) + + # --- Assert --- + # The crucial part is to verify that the pipeline was called with the correct retriever + expected_retrievers = [rag_service.faiss_retriever] + mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt with FAISS", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Response with FAISS context" + assert model_name == "deepseek" + + def test_get_message_history_success(rag_service: RAGService): """Tests successfully retrieving message history for an existing session.""" # Arrange @@ -241,86 +311,4 @@ mock_db.commit.assert_not_called() mock_db.rollback.assert_called_once() -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_rag_service_chat_with_rag_with_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Test the RAGService.chat_with_rag method when context is retrieved. - Verifies that the RAG prompt is correctly constructed. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, model_name="deepseek", messages=[ - models.Message(sender="user", content="Previous user message", created_at=datetime(2023, 1, 1, 9, 0, 0)), - models.Message(sender="assistant", content="Previous assistant response", created_at=datetime(2023, 1, 1, 9, 1, 0)) - ]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="LLM response with context") - mock_dspy_pipeline.return_value = mock_pipeline_instance - - prompt = "Test prompt." - expected_context = "Context text 1.\n\nContext text 2." - mock_retriever = rag_service.retrievers[0] - mock_retriever.retrieve_context = AsyncMock(return_value=["Context text 1.", "Context text 2."]) - - # --- Act --- - response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) - - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") - - mock_pipeline_instance.forward.assert_called_once_with( - question=prompt, - history=mock_session.messages, - db=mock_db - ) - - assert response_text == "LLM response with context" - assert model_used == "deepseek" - -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_rag_service_chat_with_rag_without_context(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Test the RAGService.chat_with_rag method when no context is retrieved. - Verifies that the original prompt is sent to the LLM. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=1, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="LLM response without context") - mock_dspy_pipeline.return_value = mock_pipeline_instance - - prompt = "Test prompt without context." - mock_retriever = rag_service.retrievers[0] - mock_retriever.retrieve_context = AsyncMock(return_value=[]) - - # --- Act --- - response_text, model_used = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=1, prompt=prompt, model="deepseek")) - - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") - - mock_pipeline_instance.forward.assert_called_once_with( - question=prompt, - history=mock_session.messages, - db=mock_db - ) - - assert response_text == "LLM response without context" - assert model_used == "deepseek" \ No newline at end of file diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index aeededc..645c541 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -89,9 +89,9 @@ assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." assert response.json()["model_used"] == "deepseek" - # The fix: Include the default 'model' parameter in the assertion + # FIX: Add the missing 'load_faiss_retriever=False' argument to the assertion mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, session_id=123, prompt="Hello there", model="deepseek" + db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) @patch('app.app.RAGService') @@ -112,11 +112,13 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} # Verify that chat_with_rag was called with the specified model 'gemini' + # FIX: Add the missing 'load_faiss_retriever=False' argument to the assertion mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, session_id=42, prompt="Hello there, Gemini!", - model="gemini" + model="gemini", + load_faiss_retriever=False ) @patch('app.app.RAGService') @@ -283,4 +285,3 @@ assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999) -