diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 9426b1e..b5be92c 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -10,4 +10,6 @@ psycopg2 pytest-asyncio pytest-tornasync -pytest-trio \ No newline at end of file +pytest-trio +numpy +faiss-cpu \ No newline at end of file diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 9426b1e..b5be92c 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -10,4 +10,6 @@ psycopg2 pytest-asyncio pytest-tornasync -pytest-trio \ No newline at end of file +pytest-trio +numpy +faiss-cpu \ No newline at end of file diff --git a/ai-hub/tests/core/test_rag_service.py b/ai-hub/tests/core/test_rag_service.py new file mode 100644 index 0000000..b86d893 --- /dev/null +++ b/ai-hub/tests/core/test_rag_service.py @@ -0,0 +1,170 @@ +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock, call +from sqlalchemy.orm import Session +from typing import List + +# Import the RAGService class and its dependencies +from app.core.rag_service import RAGService +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import Retriever +from app.db import models + +# --- RAGService Unit Tests --- +# These tests directly target the methods of the RAGService class +# to verify their internal logic in isolation. + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service with the mock dependencies + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + # Use mock.call to check for both calls to db.add in the correct order. + # We must mock the VectorMetadata model to check its constructor call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise an exception + mock_db.add.side_effect = Exception("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + try: + rag_service.add_document(db=mock_db, doc_data=doc_data) + assert False, "Expected an exception to be raised" + except Exception as e: + assert str(e) == "Database error" + + # Assertions + # The first db.add was called + mock_db.add.assert_called_once() + # No commit should have occurred + mock_db.commit.assert_not_called() + # The transaction should have been rolled back + mock_db.rollback.assert_called_once() + + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once() + actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0] + + # Check if the generated prompt contains the expected context and question + assert expected_context in actual_llm_prompt + assert prompt in actual_llm_prompt + assert response_text == "LLM response with context" + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = [] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt without context." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once_with(prompt) + assert response_text == "LLM response without context" diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 9426b1e..b5be92c 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -10,4 +10,6 @@ psycopg2 pytest-asyncio pytest-tornasync -pytest-trio \ No newline at end of file +pytest-trio +numpy +faiss-cpu \ No newline at end of file diff --git a/ai-hub/tests/core/test_rag_service.py b/ai-hub/tests/core/test_rag_service.py new file mode 100644 index 0000000..b86d893 --- /dev/null +++ b/ai-hub/tests/core/test_rag_service.py @@ -0,0 +1,170 @@ +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock, call +from sqlalchemy.orm import Session +from typing import List + +# Import the RAGService class and its dependencies +from app.core.rag_service import RAGService +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import Retriever +from app.db import models + +# --- RAGService Unit Tests --- +# These tests directly target the methods of the RAGService class +# to verify their internal logic in isolation. + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service with the mock dependencies + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + # Use mock.call to check for both calls to db.add in the correct order. + # We must mock the VectorMetadata model to check its constructor call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise an exception + mock_db.add.side_effect = Exception("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + try: + rag_service.add_document(db=mock_db, doc_data=doc_data) + assert False, "Expected an exception to be raised" + except Exception as e: + assert str(e) == "Database error" + + # Assertions + # The first db.add was called + mock_db.add.assert_called_once() + # No commit should have occurred + mock_db.commit.assert_not_called() + # The transaction should have been rolled back + mock_db.rollback.assert_called_once() + + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once() + actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0] + + # Check if the generated prompt contains the expected context and question + assert expected_context in actual_llm_prompt + assert prompt in actual_llm_prompt + assert response_text == "LLM response with context" + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = [] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt without context." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once_with(prompt) + assert response_text == "LLM response without context" diff --git a/ai-hub/tests/core/test_vector_store.py b/ai-hub/tests/core/test_vector_store.py new file mode 100644 index 0000000..1969dbb --- /dev/null +++ b/ai-hub/tests/core/test_vector_store.py @@ -0,0 +1,150 @@ +import pytest +import numpy as np +import faiss +import os +import shutil +from typing import List, Tuple + +# We need to configure the python path so that pytest can find our application code +# Since this is a test file, we assume the app/ directory is available from the +# pytest root. +from app.core.vector_store import FaissVectorStore, MockEmbedder + +# Define constants for our tests to ensure consistency +TEST_DIMENSION = 128 +TEST_INDEX_FILE = "test_faiss_index.bin" + + +# --- Fixtures --- +# Pytest fixtures are used to set up a clean environment for each test. + +@pytest.fixture(scope="function") +def temp_faiss_dir(tmp_path): + """ + Fixture to create a temporary directory for each test function. + This ensures that each test runs in a clean environment without + interfering with other tests or the main application. + """ + # Create a sub-directory within the pytest temporary path + test_dir = tmp_path / "faiss_test" + test_dir.mkdir() + yield test_dir + # The cleanup is automatically handled by the tmp_path fixture, + # but we'll add a manual check just in case. + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + + +@pytest.fixture(scope="function") +def faiss_store(temp_faiss_dir): + """ + Fixture that provides a fresh FaissVectorStore instance for each test. + The index file path points to the temporary directory. + """ + index_file_path = os.path.join(temp_faiss_dir, TEST_INDEX_FILE) + store = FaissVectorStore(index_file_path=index_file_path, dimension=TEST_DIMENSION) + return store + + +# --- Unit Tests --- + +def test_init_creates_new_index(faiss_store): + """ + Test that the constructor correctly creates a new FAISS index + if the index file does not exist. + """ + # We verify that the index is a faiss.IndexFlatL2 instance + assert isinstance(faiss_store.index, faiss.IndexFlatL2) + # The index should be empty initially + assert faiss_store.index.ntotal == 0 + # The file should NOT exist yet as it's only saved on add_document + assert not os.path.exists(faiss_store.index_file_path) + + +def test_add_document(faiss_store): + """ + Test the add_document method to ensure it adds a vector and saves the index. + """ + test_text = "This is a test document." + + # The index should be empty before adding + assert faiss_store.index.ntotal == 0 + + # Add the document and get the returned index ID + faiss_id = faiss_store.add_document(test_text) + + # The index should now have one item + assert faiss_store.index.ntotal == 1 + # The returned ID should be the first index, which is 0 + assert faiss_id == 0 + # The index file should now exist on disk + assert os.path.exists(faiss_store.index_file_path) + + +def test_add_multiple_documents(faiss_store): + """ + Test that multiple documents can be added and the index size grows correctly. + """ + docs = ["Doc 1", "Doc 2", "Doc 3"] + + # Add each document and check the total number of items + for i, doc in enumerate(docs): + faiss_id = faiss_store.add_document(doc) + assert faiss_store.index.ntotal == i + 1 + assert faiss_id == i + + # The final index file should exist and the count should be correct + assert os.path.exists(faiss_store.index_file_path) + assert faiss_store.index.ntotal == 3 + + +def test_load_existing_index(temp_faiss_dir): + """ + Test that the store can load an existing index file from disk. + """ + # Step 1: Create an index and add an item to it, then save it. + first_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + first_store.add_document("Document for persistence test.") + + # Ensure the file was saved + assert os.path.exists(first_store.index_file_path) + assert first_store.index.ntotal == 1 + + # Step 2: Create a new store instance pointing to the same file. + second_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + + # The new store should have loaded the index and should have 1 item. + assert second_store.index.ntotal == 1 + assert isinstance(second_store.index, faiss.IndexFlatL2) + + +def test_search_similar_documents(faiss_store): + """ + Test the search functionality. Since we're using a mock embedder with + random vectors, we can't predict the exact result, but we can + verify the format and number of results. + """ + # Add some documents to the store + faiss_store.add_document("Document 1") + faiss_store.add_document("Document 2") + faiss_store.add_document("Document 3") + faiss_store.add_document("Document 4") + faiss_store.add_document("Document 5") + + # Search for a query and ask for 3 results + results = faiss_store.search_similar_documents("A query string", k=3) + + # The results should be a list of 3 items + assert isinstance(results, list) + assert len(results) == 3 + + # The results should be integers, and valid FAISS IDs + for result_id in results: + assert isinstance(result_id, int) + assert 0 <= result_id < 5 # IDs should be between 0 and 4 diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 9426b1e..b5be92c 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -10,4 +10,6 @@ psycopg2 pytest-asyncio pytest-tornasync -pytest-trio \ No newline at end of file +pytest-trio +numpy +faiss-cpu \ No newline at end of file diff --git a/ai-hub/tests/core/test_rag_service.py b/ai-hub/tests/core/test_rag_service.py new file mode 100644 index 0000000..b86d893 --- /dev/null +++ b/ai-hub/tests/core/test_rag_service.py @@ -0,0 +1,170 @@ +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock, call +from sqlalchemy.orm import Session +from typing import List + +# Import the RAGService class and its dependencies +from app.core.rag_service import RAGService +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import Retriever +from app.db import models + +# --- RAGService Unit Tests --- +# These tests directly target the methods of the RAGService class +# to verify their internal logic in isolation. + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service with the mock dependencies + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + # Use mock.call to check for both calls to db.add in the correct order. + # We must mock the VectorMetadata model to check its constructor call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise an exception + mock_db.add.side_effect = Exception("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + try: + rag_service.add_document(db=mock_db, doc_data=doc_data) + assert False, "Expected an exception to be raised" + except Exception as e: + assert str(e) == "Database error" + + # Assertions + # The first db.add was called + mock_db.add.assert_called_once() + # No commit should have occurred + mock_db.commit.assert_not_called() + # The transaction should have been rolled back + mock_db.rollback.assert_called_once() + + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once() + actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0] + + # Check if the generated prompt contains the expected context and question + assert expected_context in actual_llm_prompt + assert prompt in actual_llm_prompt + assert response_text == "LLM response with context" + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = [] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt without context." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once_with(prompt) + assert response_text == "LLM response without context" diff --git a/ai-hub/tests/core/test_vector_store.py b/ai-hub/tests/core/test_vector_store.py new file mode 100644 index 0000000..1969dbb --- /dev/null +++ b/ai-hub/tests/core/test_vector_store.py @@ -0,0 +1,150 @@ +import pytest +import numpy as np +import faiss +import os +import shutil +from typing import List, Tuple + +# We need to configure the python path so that pytest can find our application code +# Since this is a test file, we assume the app/ directory is available from the +# pytest root. +from app.core.vector_store import FaissVectorStore, MockEmbedder + +# Define constants for our tests to ensure consistency +TEST_DIMENSION = 128 +TEST_INDEX_FILE = "test_faiss_index.bin" + + +# --- Fixtures --- +# Pytest fixtures are used to set up a clean environment for each test. + +@pytest.fixture(scope="function") +def temp_faiss_dir(tmp_path): + """ + Fixture to create a temporary directory for each test function. + This ensures that each test runs in a clean environment without + interfering with other tests or the main application. + """ + # Create a sub-directory within the pytest temporary path + test_dir = tmp_path / "faiss_test" + test_dir.mkdir() + yield test_dir + # The cleanup is automatically handled by the tmp_path fixture, + # but we'll add a manual check just in case. + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + + +@pytest.fixture(scope="function") +def faiss_store(temp_faiss_dir): + """ + Fixture that provides a fresh FaissVectorStore instance for each test. + The index file path points to the temporary directory. + """ + index_file_path = os.path.join(temp_faiss_dir, TEST_INDEX_FILE) + store = FaissVectorStore(index_file_path=index_file_path, dimension=TEST_DIMENSION) + return store + + +# --- Unit Tests --- + +def test_init_creates_new_index(faiss_store): + """ + Test that the constructor correctly creates a new FAISS index + if the index file does not exist. + """ + # We verify that the index is a faiss.IndexFlatL2 instance + assert isinstance(faiss_store.index, faiss.IndexFlatL2) + # The index should be empty initially + assert faiss_store.index.ntotal == 0 + # The file should NOT exist yet as it's only saved on add_document + assert not os.path.exists(faiss_store.index_file_path) + + +def test_add_document(faiss_store): + """ + Test the add_document method to ensure it adds a vector and saves the index. + """ + test_text = "This is a test document." + + # The index should be empty before adding + assert faiss_store.index.ntotal == 0 + + # Add the document and get the returned index ID + faiss_id = faiss_store.add_document(test_text) + + # The index should now have one item + assert faiss_store.index.ntotal == 1 + # The returned ID should be the first index, which is 0 + assert faiss_id == 0 + # The index file should now exist on disk + assert os.path.exists(faiss_store.index_file_path) + + +def test_add_multiple_documents(faiss_store): + """ + Test that multiple documents can be added and the index size grows correctly. + """ + docs = ["Doc 1", "Doc 2", "Doc 3"] + + # Add each document and check the total number of items + for i, doc in enumerate(docs): + faiss_id = faiss_store.add_document(doc) + assert faiss_store.index.ntotal == i + 1 + assert faiss_id == i + + # The final index file should exist and the count should be correct + assert os.path.exists(faiss_store.index_file_path) + assert faiss_store.index.ntotal == 3 + + +def test_load_existing_index(temp_faiss_dir): + """ + Test that the store can load an existing index file from disk. + """ + # Step 1: Create an index and add an item to it, then save it. + first_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + first_store.add_document("Document for persistence test.") + + # Ensure the file was saved + assert os.path.exists(first_store.index_file_path) + assert first_store.index.ntotal == 1 + + # Step 2: Create a new store instance pointing to the same file. + second_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + + # The new store should have loaded the index and should have 1 item. + assert second_store.index.ntotal == 1 + assert isinstance(second_store.index, faiss.IndexFlatL2) + + +def test_search_similar_documents(faiss_store): + """ + Test the search functionality. Since we're using a mock embedder with + random vectors, we can't predict the exact result, but we can + verify the format and number of results. + """ + # Add some documents to the store + faiss_store.add_document("Document 1") + faiss_store.add_document("Document 2") + faiss_store.add_document("Document 3") + faiss_store.add_document("Document 4") + faiss_store.add_document("Document 5") + + # Search for a query and ask for 3 results + results = faiss_store.search_similar_documents("A query string", k=3) + + # The results should be a list of 3 items + assert isinstance(results, list) + assert len(results) == 3 + + # The results should be integers, and valid FAISS IDs + for result_id in results: + assert isinstance(result_id, int) + assert 0 <= result_id < 5 # IDs should be between 0 and 4 diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py new file mode 100644 index 0000000..9aa1a0c --- /dev/null +++ b/ai-hub/tests/test_app.py @@ -0,0 +1,156 @@ +import os +from fastapi.testclient import TestClient +from unittest.mock import patch, MagicMock, AsyncMock +from sqlalchemy.orm import Session + +# Import the factory function directly to get a fresh app instance for testing +from app.app import create_app +# The get_db function is now in app/db_setup.py, so we must update the import path. +from app.db_setup import get_db + +# --- Dependency Override for Testing --- +# This is a mock database session that will be used in our tests. +mock_db = MagicMock(spec=Session) + +def override_get_db(): + """Returns the mock database session for tests.""" + try: + yield mock_db + finally: + pass + + +# --- API Endpoint Tests --- +# We patch the RAGService class itself, as the instance is created inside create_app(). + +# This test does not require mocking, so the app can be created at the module level. +# For consistency, we can still move it inside a function if preferred. +def test_read_root(): + """Test the root endpoint to ensure it's running.""" + # Create app and client here to be sure no mocking interferes + app = create_app() + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + +@patch('app.app.RAGService') +def test_chat_handler_success(mock_rag_service_class): + """ + Test the /chat endpoint with a successful, mocked RAG service response. + + We patch the RAGService class and configure a mock instance + with a controlled return value. + """ + # Create a mock instance of RAGService that will be returned by the factory + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.chat_with_rag = AsyncMock(return_value="This is a mock response from the RAG service.") + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Make the request to our app + response = client.post("/chat", json={"prompt": "Hello there"}) + + # Assert our app behaved as expected + assert response.status_code == 200 + assert response.json()["response"] == "This is a mock response from the RAG service." + + # Verify that the mocked method was called with the correct arguments + mock_rag_service_instance.chat_with_rag.assert_called_once_with( + db=mock_db, prompt="Hello there", model="deepseek" + ) + +@patch('app.app.RAGService') +def test_chat_handler_api_failure(mock_rag_service_class): + """ + Test the /chat endpoint when the RAG service encounters an error. + + We configure the mock RAGService instance's chat_with_rag method + to raise an exception. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.chat_with_rag = AsyncMock(side_effect=Exception("API connection error")) + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Make the request to our app + response = client.post("/chat", json={"prompt": "This request will fail"}) + + # Assert our app handles the error gracefully + assert response.status_code == 500 + assert "An error occurred with the deepseek API" in response.json()["detail"] + + # Verify that the mocked method was called with the correct arguments + mock_rag_service_instance.chat_with_rag.assert_called_once_with( + db=mock_db, prompt="This request will fail", model="deepseek" + ) + +@patch('app.app.RAGService') +def test_add_document_success(mock_rag_service_class): + """ + Test the /document endpoint with a successful, mocked RAG service response. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.return_value = 1 + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/document", json=doc_data) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + + +@patch('app.app.RAGService') +def test_add_document_api_failure(mock_rag_service_class): + """ + Test the /document endpoint when the RAG service encounters an error. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.side_effect = Exception("Service failed") + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/document", json=doc_data) + + assert response.status_code == 500 + assert "An error occurred: Service failed" in response.json()["detail"] + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) diff --git a/.gitignore b/.gitignore index c02da13..576208c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ .env **/.env -**/*.egg-info \ No newline at end of file +**/*.egg-info +faiss_index.bin \ No newline at end of file diff --git a/ai-hub/app/api_endpoints.py b/ai-hub/app/api_endpoints.py new file mode 100644 index 0000000..fc4f1db --- /dev/null +++ b/ai-hub/app/api_endpoints.py @@ -0,0 +1,66 @@ +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel +from typing import Literal +from sqlalchemy.orm import Session +from app.core.rag_service import RAGService +from app.db_setup import get_db + +# Pydantic Models for API requests +class ChatRequest(BaseModel): + prompt: str + +class DocumentCreate(BaseModel): + title: str + text: str + source_url: str = None + author: str = None + user_id: str = "default_user" + +def create_api_router(rag_service: RAGService) -> APIRouter: + """ + Creates and returns an APIRouter with all the application's endpoints. + + This function takes the RAGService instance as an argument, so it can be + injected from the main application factory. + """ + router = APIRouter() + + @router.get("/") + def read_root(): + return {"status": "AI Model Hub is running!"} + + @router.post("/chat") + async def chat_handler( + request: ChatRequest, + model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use."), + db: Session = Depends(get_db) + ): + try: + response_text = await rag_service.chat_with_rag( + db=db, + prompt=request.prompt, + model=model + ) + return {"response": response_text, "model_used": model} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") + + @router.post("/document") + async def add_document( + doc: DocumentCreate, + db: Session = Depends(get_db) + ): + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + doc_data = doc.model_dump() + document_id = rag_service.add_document(db=db, doc_data=doc_data) + + return {"message": f"Document '{doc.title}' added successfully with ID {document_id}"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py new file mode 100644 index 0000000..bdb93e0 --- /dev/null +++ b/ai-hub/app/app.py @@ -0,0 +1,55 @@ +import os +from contextlib import asynccontextmanager +from fastapi import FastAPI +from dotenv import load_dotenv +from typing import List + +# Import core application logic +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.rag_service import RAGService + +# Import the new files for database and API routes +from app.db_setup import create_db_tables +from app.api_endpoints import create_api_router + +# Load environment variables from a .env file +load_dotenv() + +# --- Application Factory Function --- +def create_app() -> FastAPI: + """ + Factory function to create and configure the FastAPI application. + This encapsulates all setup logic, making the main entry point clean. + """ + # Initialize core services for RAG + vector_store = FaissVectorStore() + retrievers: List[Retriever] = [ + FaissDBRetriever(vector_store=vector_store), + ] + rag_service = RAGService(vector_store=vector_store, retrievers=retrievers) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + Initializes the database and vector store on startup and handles + cleanup on shutdown. + """ + print("Initializing application services...") + create_db_tables() + yield + print("Shutting down application services...") + vector_store.save_index() + + app = FastAPI( + title="AI Model Hub Service", + description="A extensible hub to route requests to various LLMs with RAG capabilities.", + version="0.0.0", + lifespan=lifespan + ) + + # Create and include the API router + api_router = create_api_router(rag_service=rag_service) + app.include_router(api_router) + + return app diff --git a/ai-hub/app/core/guide.md b/ai-hub/app/core/guide.md new file mode 100644 index 0000000..880aab4 --- /dev/null +++ b/ai-hub/app/core/guide.md @@ -0,0 +1,104 @@ +# 🤖 LLM Provider Module Documentation + +> **Location**: `app/core/providers.py` +> This module provides a **standardized and extensible interface** for interacting with multiple **Large Language Model (LLM) APIs**. +> It enables pluggable support for models like DeepSeek and Gemini using a unified abstraction layer. + +--- + +## 1. ⚙️ Configuration + +The module behavior is controlled by **environment variables**, primarily for authentication and model selection. + +### 📌 Environment Variables + +| Variable | Description | Default Value | +| --------------------- | ----------------------------------------- | --------------------------- | +| `DEEPSEEK_API_KEY` | Your DeepSeek API key (**required**) | `None` | +| `GEMINI_API_KEY` | Your Google Gemini API key (**required**) | `None` | +| `DEEPSEEK_MODEL_NAME` | Name of the DeepSeek model to use | `"deepseek-chat"` | +| `GEMINI_MODEL_NAME` | Name of the Gemini model to use | `"gemini-1.5-flash-latest"` | + +> 💡 **Tip**: Set these variables in your `.env` file or environment before running the application. + +--- + +## 2. 🧱 Core Components + +The system uses a **provider interface pattern** to support multiple LLMs with a consistent API. + +### 🧩 `LLMProvider` (Abstract Base Class) + +* Base class for all LLM providers. +* Defines the required method: + + ```python + def generate_response(self, prompt: str) -> str + ``` + +* All concrete providers must implement this method. + +--- + +### 🧠 `DeepSeekProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `DEEPSEEK_API_KEY` + * `DEEPSEEK_MODEL_NAME` +* Handles API calls to **DeepSeek's** LLM service. + +--- + +### 🌟 `GeminiProvider` + +* Implements the `LLMProvider` interface. +* Uses: + + * `GEMINI_API_KEY` + * `GEMINI_MODEL_NAME` +* Handles API calls to **Google Gemini's** LLM service. + +--- + +### 🏭 `get_llm_provider(model_name: str) -> LLMProvider` + +* Factory function for obtaining a provider instance. + +#### **Usage:** + +```python +provider = get_llm_provider("deepseek") +response = provider.generate_response("Tell me a joke.") +``` + +#### **Supported values for `model_name`:** + +* `"deepseek"` +* `"gemini"` + +#### **Raises:** + +* `ValueError` if an unsupported `model_name` is passed. + +--- + +## 3. 🚀 Usage Example + +Use the `get_llm_provider()` factory function to get a provider instance. Then call `generate_response()` on it: + +```python +from app.core.providers import get_llm_provider + +# Get provider +llm = get_llm_provider("gemini") + +# Generate response +response = llm.generate_response("What is the capital of France?") +print(response) +``` + +This design allows you to easily **swap or extend LLMs** (e.g., add OpenAI, Anthropic) by simply implementing a new provider class. + +--- diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py new file mode 100644 index 0000000..75629c7 --- /dev/null +++ b/ai-hub/app/core/llm_providers.py @@ -0,0 +1,91 @@ +import os +import httpx +from abc import ABC, abstractmethod +from openai import OpenAI +from typing import final + +# --- 1. Load Configuration from Environment --- +# Best practice is to centralize configuration loading at the top. + +# API Keys (required) +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") + +# Model Names (optional, with defaults) +# Allows changing the model version without code changes. +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") +GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") + +# --- 2. Initialize API Clients and URLs --- +# Initialize any clients or constants that will be used by the providers. +deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") +GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" + + +# --- 3. Provider Interface and Implementations --- + +class LLMProvider(ABC): + """Abstract base class ('Interface') for all LLM providers.""" + @abstractmethod + async def generate_response(self, prompt: str) -> str: + """Generates a response from the LLM.""" + pass + +@final +class DeepSeekProvider(LLMProvider): + """Provider for the DeepSeek API.""" + def __init__(self, model_name: str): + self.model = model_name + print(f"DeepSeekProvider initialized with model: {self.model}") + + async def generate_response(self, prompt: str) -> str: + try: + chat_completion = deepseek_client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + stream=False + ) + return chat_completion.choices[0].message.content + except Exception as e: + print(f"DeepSeek Error: {e}") + raise # Re-raise the exception to be handled by the main app + +@final +class GeminiProvider(LLMProvider): + """Provider for the Google Gemini API.""" + def __init__(self, api_url: str): + self.url = api_url + print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") + + async def generate_response(self, prompt: str) -> str: + payload = {"contents": [{"parts": [{"text": prompt}]}]} + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + return data['candidates'][0]['content']['parts'][0]['text'] + except (httpx.HTTPStatusError, KeyError, IndexError) as e: + print(f"Gemini Error: {e}") + raise # Re-raise for the main app to handle + +# --- 4. The Factory Function --- +# This is where we instantiate our concrete providers with their configuration. + +_providers = { + "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), + "gemini": GeminiProvider(api_url=GEMINI_URL) +} + +def get_llm_provider(model_name: str) -> LLMProvider: + """Factory function to get the appropriate, pre-configured LLM provider.""" + provider = _providers.get(model_name) + if not provider: + raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") + return provider + diff --git a/ai-hub/app/core/rag_service.py b/ai-hub/app/core/rag_service.py new file mode 100644 index 0000000..8a51f48 --- /dev/null +++ b/ai-hub/app/core/rag_service.py @@ -0,0 +1,95 @@ +import logging +from typing import Literal, List +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.core.llm_providers import get_llm_provider +from app.core.retrievers import Retriever +from app.db import models + +# Configure logging for the service +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RAGService: + """ + A service class to handle all RAG-related business logic. + This includes adding documents and processing chat requests with context retrieval. + The retrieval logic is now handled by pluggable Retriever components. + """ + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + """ + Initializes the service. + + Args: + vector_store (FaissVectorStore): The FAISS vector store for document vectors. + retrievers (List[Retriever]): A list of retriever components to use for + context retrieval. + """ + self.vector_store = vector_store + self.retrievers = retrievers + + def add_document(self, db: Session, doc_data: dict) -> int: + """ + Adds a new document to the database and its vector to the FAISS index. + """ + try: + new_document = models.Document(**doc_data) + db.add(new_document) + db.commit() + db.refresh(new_document) + + faiss_id = self.vector_store.add_document(new_document.text) + + vector_meta = models.VectorMetadata( + document_id=new_document.id, + faiss_index=faiss_id, + embedding_model="mock_embedder" + ) + db.add(vector_meta) + db.commit() + + logger.info(f"Document '{new_document.title}' added successfully with ID {new_document.id}") + return new_document.id + except Exception as e: + db.rollback() + logger.error(f"Failed to add document: {e}") + raise e + + async def chat_with_rag( + self, + db: Session, + prompt: str, + model: Literal["deepseek", "gemini"] + ) -> str: + """ + Handles a chat request by retrieving context from all configured + retrievers and passing it to the LLM. + """ + context_docs_text = [] + # The service now iterates through all configured retrievers to gather context + for retriever in self.retrievers: + context_docs_text.extend(retriever.retrieve_context(prompt, db)) + + combined_context = "\n\n".join(context_docs_text) + + if combined_context: + logger.info(f"Retrieved context for prompt: '{prompt}'") + rag_prompt = f""" + You are an AI assistant that answers questions based on the provided context. + + Context: + {combined_context} + + Question: + {prompt} + + If the answer is not in the context, say that you cannot answer the question based on the information provided. + """ + else: + rag_prompt = prompt + logger.warning("No context found for the query.") + + provider = get_llm_provider(model) + response_text = await provider.generate_response(rag_prompt) + + return response_text diff --git a/ai-hub/app/core/retrievers.py b/ai-hub/app/core/retrievers.py new file mode 100644 index 0000000..8f20713 --- /dev/null +++ b/ai-hub/app/core/retrievers.py @@ -0,0 +1,65 @@ +import abc +from typing import List, Dict +from sqlalchemy.orm import Session +from app.core.vector_store import FaissVectorStore +from app.db import models + +class Retriever(abc.ABC): + """ + Abstract base class for a Retriever. + + A retriever is a pluggable component that is responsible for fetching + relevant context for a given query from a specific data source. + """ + @abc.abstractmethod + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Fetches context for a given query. + + Args: + query (str): The user's query string. + db (Session): The database session. + + Returns: + List[str]: A list of text strings representing the retrieved context. + """ + raise NotImplementedError + +class FaissDBRetriever(Retriever): + """ + A concrete retriever that uses a FAISS index and a local database + to find and return relevant document text. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + + def retrieve_context(self, query: str, db: Session) -> List[str]: + """ + Retrieves document text by first searching the FAISS index + and then fetching the corresponding documents from the database. + """ + faiss_ids = self.vector_store.search_similar_documents(query, k=3) + context_docs_text = [] + + if faiss_ids: + # Use FAISS IDs to find the corresponding document_id from the database + document_ids_from_vectors = db.query(models.VectorMetadata.document_id).filter( + models.VectorMetadata.faiss_index.in_(faiss_ids) + ).all() + + document_ids = [doc_id for (doc_id,) in document_ids_from_vectors] + + # Retrieve the full documents from the Document table + context_docs = db.query(models.Document).filter( + models.Document.id.in_(document_ids) + ).all() + + context_docs_text = [doc.text for doc in context_docs] + + return context_docs_text + +# You could add other retriever implementations here, like: +# class RemoteServiceRetriever(Retriever): +# def retrieve_context(self, query: str, db: Session) -> List[str]: +# # Logic to call a remote API and return context +# ... diff --git a/ai-hub/app/core/vector_store.py b/ai-hub/app/core/vector_store.py new file mode 100644 index 0000000..b5ffe70 --- /dev/null +++ b/ai-hub/app/core/vector_store.py @@ -0,0 +1,89 @@ +import numpy as np +import faiss +import os +from typing import List, Tuple + +# Mock embedding function for demonstration. In a real app, you'd use a +# real embedding model like sentence-transformers, OpenAI's API, or a local model. +class MockEmbedder: + """A simple class to simulate an embedding model.""" + def __init__(self, dimension: int = 768): + self.dimension = dimension + + def embed(self, text: str) -> np.ndarray: + """Generates a random vector to simulate an embedding.""" + # This is a mock. A real embedder would take the text and return a + # meaningful vector. + return np.random.rand(self.dimension).astype('float32') + + +class FaissVectorStore: + """ + Manages a FAISS index for efficient vector storage and search. + This class handles the creation, persistence, and querying of the index. + """ + def __init__(self, index_file_path: str = "faiss_index.bin", dimension: int = 768): + """ + Initializes the vector store. + + Args: + index_file_path (str): The file path to save/load the FAISS index. + dimension (int): The dimension of the vectors in the index. + """ + self.index_file_path = index_file_path + self.dimension = dimension + self.index = self._load_or_create_index() + self.embedder = MockEmbedder(dimension) + + def _load_or_create_index(self): + """Loads an existing index from disk or creates a new one.""" + if os.path.exists(self.index_file_path): + print(f"Loading FAISS index from {self.index_file_path}") + return faiss.read_index(self.index_file_path) + else: + print("Creating a new FAISS index.") + # We'll use IndexFlatL2 for a simple Euclidean distance search. + return faiss.IndexFlatL2(self.dimension) + + def save_index(self): + """Saves the current index to disk.""" + faiss.write_index(self.index, self.index_file_path) + print(f"FAISS index saved to {self.index_file_path}") + + def add_document(self, text: str) -> int: + """ + Embeds a document and adds its vector to the index. + + Args: + text (str): The text content of the document. + + Returns: + int: The index ID of the added vector. + """ + vector = self.embedder.embed(text).reshape(1, -1) + # Add the vector to the index. FAISS assigns a new internal ID. + self.index.add(vector) + # Get the new total number of vectors in the index. The ID of the + # newly added vector is one less than this count. + index_id = self.index.ntotal - 1 + print(f"Document added to FAISS with index ID: {index_id}") + self.save_index() # Save after every addition for persistence + return index_id + + def search_similar_documents(self, query: str, k: int = 5) -> List[int]: + """ + Performs a similarity search on the index for a given query. + + Args: + query (str): The search query text. + k (int): The number of nearest neighbors to retrieve. + + Returns: + List[int]: A list of FAISS index IDs for the top k similar documents. + """ + query_vector = self.embedder.embed(query).reshape(1, -1) + # Faiss search returns distances and the corresponding index IDs. + distances, indices = self.index.search(query_vector, k) + + print(f"Found {len(indices[0])} similar documents for the query.") + return [int(i) for i in indices[0] if i >= 0] diff --git a/ai-hub/app/db/guide.md b/ai-hub/app/db/guide.md new file mode 100644 index 0000000..a3185c7 --- /dev/null +++ b/ai-hub/app/db/guide.md @@ -0,0 +1,193 @@ + +--- + +# 📚 Database Module Documentation + +> **File:** `app/db/database.py` +> This module provides a streamlined way to connect to and manage database sessions for your application. +> It supports both **PostgreSQL** and **SQLite**, with environment-based configuration for flexible deployment. + +--- + +## 1. ⚙️ Configuration + +The database connection is controlled using environment variables. If these are not set, the module uses sensible defaults. + +### 📌 Environment Variables + +| Variable | Description | Default Value | Supported Values | +| -------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------- | -------------------- | +| `DB_MODE` | Specifies the type of database to use | `"postgres"` | `postgres`, `sqlite` | +| `DATABASE_URL` | Connection string for the database | PostgreSQL: `postgresql://user:password@localhost/ai_hub_db`
SQLite: `sqlite:///./ai_hub.db` | Any SQLAlchemy URI | + +### 💡 Example: Switch to SQLite + +```bash +export DB_MODE="sqlite" +``` + +--- + +## 2. 🧱 Core Components + +This module exposes several key components used to interface with the database. + +### `engine` + +* A SQLAlchemy **Engine** instance. +* Manages connections, pooling, and execution context. + +### `SessionLocal` + +* A factory for creating new database **session objects**. +* Typically accessed indirectly through the `get_db` dependency. + +### `Base` + +* The base class all SQLAlchemy models should inherit from. +* Used to declare table mappings. + +### `get_db()` + +* A **FastAPI dependency function**. +* Yields a new session per request and ensures it's closed afterward, even in case of errors. + +--- + +## 3. 🚀 Usage with FastAPI + +To safely interact with the database in your FastAPI routes, use the `get_db()` dependency. + +### ✅ Steps to Use + +1. **Import `get_db`:** + + ```python + from app.db.database import get_db + ``` + +2. **Add it as a dependency:** + + ```python + from fastapi import Depends + from sqlalchemy.orm import Session + ``` + +3. **Inject into your route handler:** + + ```python + @app.get("/items/") + def read_items(db: Session = Depends(get_db)): + return db.query(Item).all() + ``` + +This pattern ensures that every request has a dedicated, isolated database session that is properly cleaned up afterward. + +--- + +Here is your fully **formatted and organized** documentation for the SQLAlchemy models in `app/db/models.py`, suitable for technical documentation or a README: + +--- + +# 🗂️ Database Models Documentation + +This document describes the SQLAlchemy models defined in `app/db/models.py`. These classes represent the tables used to store application data, including **chat sessions**, **messages**, and **document metadata** for **Retrieval-Augmented Generation (RAG)**. + +All models inherit from the `Base` class, imported from the `database.py` module. + +--- + +## 1. 🗨️ Session Model + +Represents a single **conversation** between a user and an AI. It serves as the container for all messages in that session. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------- | --------- | ----------------------------------------- | +| `id` | Integer | Primary key for the session | +| `user_id` | String | Unique identifier of the user | +| `title` | String | Optional AI-generated title | +| `model_name` | String | Name of the LLM used in the session | +| `created_at` | DateTime | Timestamp of when the session was created | +| `is_archived` | Boolean | Soft-delete/archive flag for the session | + +### 🔗 Relationships + +* **`messages`**: + One-to-many relationship with the **Message** model. + Deleting a session will also delete all related messages (via `cascade="all, delete-orphan"`). + +--- + +## 2. 💬 Message Model + +Stores individual **messages** within a session, including user inputs and AI responses. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| --------------------- | --------- | ------------------------------------------------------------------- | +| `id` | Integer | Primary key for the message | +| `session_id` | Integer | Foreign key linking to the parent session | +| `sender` | String | Role of the sender (`"user"` or `"assistant"`) | +| `content` | Text | Full text of the message | +| `created_at` | DateTime | Timestamp of message creation | +| `model_response_time` | Integer | Time (in seconds) taken by the model to generate the response | +| `token_count` | Integer | Number of tokens in the message | +| `message_metadata` | JSON | Flexible field for storing unstructured metadata (e.g., tool calls) | + +### 🔗 Relationships + +* **`session`**: + Many-to-one relationship with the **Session** model. + +--- + +## 3. 📄 Document Model + +Stores **metadata** and **content** of documents ingested into the system for RAG purposes. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ------------ | --------- | -------------------------------------------------- | +| `id` | Integer | Primary key for the document | +| `title` | String | Human-readable title | +| `text` | Text | Full content of the document | +| `source_url` | String | URL or path where the document was retrieved from | +| `author` | String | Author of the document | +| `status` | String | Processing status (`"ready"`, `"processing"` etc.) | +| `created_at` | DateTime | Timestamp of document creation | +| `user_id` | String | ID of the user who added the document | + +### 🔗 Relationships + +* **`vector_metadata`**: + One-to-one relationship with the **VectorMetadata** model. + +--- + +## 4. 🧠 VectorMetadata Model + +Connects documents to their **vector representations** (e.g., FAISS indices) for efficient retrieval in RAG workflows. + +### 🔑 Key Fields + +| Field | Data Type | Description | +| ----------------- | --------- | ---------------------------------------------------------- | +| `id` | Integer | Primary key for the metadata entry | +| `document_id` | Integer | Foreign key to the parent **Document** (unique constraint) | +| `faiss_index` | Integer | Vector's index in the FAISS store | +| `session_id` | Integer | Foreign key to the **Session** used in the RAG context | +| `embedding_model` | String | Embedding model used to generate the vector | + +### 🔗 Relationships + +* **`document`**: + Many-to-one relationship with the **Document** model + +* **`session`**: + Many-to-one relationship with the **Session** model + +--- diff --git a/ai-hub/app/db_setup.py b/ai-hub/app/db_setup.py new file mode 100644 index 0000000..739ebd2 --- /dev/null +++ b/ai-hub/app/db_setup.py @@ -0,0 +1,36 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import Base # Assuming `Base` is in this file + +# Load environment variables from a .env file +load_dotenv() + +# --- Database Connection Setup --- +# This configuration allows for easy switching between SQLite and PostgreSQL. +DB_MODE = os.getenv("DB_MODE", "sqlite") +if DB_MODE == "sqlite": + DATABASE_URL = "sqlite:///./ai_hub.db" + # The connect_args are needed for SQLite to work with FastAPI's multiple threads + engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +else: + DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost/ai_hub_db") + engine = create_engine(DATABASE_URL) + +# Create a database session class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def create_db_tables(): + """Create all database tables based on the models.""" + print("Creating database tables...") + Base.metadata.create_all(bind=engine) + +# The dependency to get a database session +def get_db(): + """Dependency that provides a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/ai-hub/app/llm_providers.py b/ai-hub/app/llm_providers.py deleted file mode 100644 index 75629c7..0000000 --- a/ai-hub/app/llm_providers.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import httpx -from abc import ABC, abstractmethod -from openai import OpenAI -from typing import final - -# --- 1. Load Configuration from Environment --- -# Best practice is to centralize configuration loading at the top. - -# API Keys (required) -DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") - -# Model Names (optional, with defaults) -# Allows changing the model version without code changes. -DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat") -GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") - -# --- 2. Initialize API Clients and URLs --- -# Initialize any clients or constants that will be used by the providers. -deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}" - - -# --- 3. Provider Interface and Implementations --- - -class LLMProvider(ABC): - """Abstract base class ('Interface') for all LLM providers.""" - @abstractmethod - async def generate_response(self, prompt: str) -> str: - """Generates a response from the LLM.""" - pass - -@final -class DeepSeekProvider(LLMProvider): - """Provider for the DeepSeek API.""" - def __init__(self, model_name: str): - self.model = model_name - print(f"DeepSeekProvider initialized with model: {self.model}") - - async def generate_response(self, prompt: str) -> str: - try: - chat_completion = deepseek_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - stream=False - ) - return chat_completion.choices[0].message.content - except Exception as e: - print(f"DeepSeek Error: {e}") - raise # Re-raise the exception to be handled by the main app - -@final -class GeminiProvider(LLMProvider): - """Provider for the Google Gemini API.""" - def __init__(self, api_url: str): - self.url = api_url - print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}") - - async def generate_response(self, prompt: str) -> str: - payload = {"contents": [{"parts": [{"text": prompt}]}]} - headers = {"Content-Type": "application/json"} - - try: - async with httpx.AsyncClient() as client: - response = await client.post(self.url, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - return data['candidates'][0]['content']['parts'][0]['text'] - except (httpx.HTTPStatusError, KeyError, IndexError) as e: - print(f"Gemini Error: {e}") - raise # Re-raise for the main app to handle - -# --- 4. The Factory Function --- -# This is where we instantiate our concrete providers with their configuration. - -_providers = { - "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL), - "gemini": GeminiProvider(api_url=GEMINI_URL) -} - -def get_llm_provider(model_name: str) -> LLMProvider: - """Factory function to get the appropriate, pre-configured LLM provider.""" - provider = _providers.get(model_name) - if not provider: - raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") - return provider - diff --git a/ai-hub/app/main.py b/ai-hub/app/main.py index 05d3952..cfc25bf 100644 --- a/ai-hub/app/main.py +++ b/ai-hub/app/main.py @@ -1,47 +1,10 @@ -# main.py +import uvicorn +from app.app import create_app -from fastapi import FastAPI, HTTPException, Query -from pydantic import BaseModel -from typing import Literal -from dotenv import load_dotenv +# Use the application factory to create the FastAPI app instance +app = create_app() -# Import our new factory function -from app.llm_providers import get_llm_provider - -# --- 1. Application Setup --- -load_dotenv() -app = FastAPI( - title="AI Model Hub Service", - description="A extensible hub to route requests to various LLMs using a Factory Pattern.", - version="0.0.0", -) - -# --- 2. Pydantic Models --- -class ChatRequest(BaseModel): - prompt: str - -# --- 3. API Endpoints --- -@app.get("/") -def read_root(): - return {"status": "AI Model Hub is running!"} - -@app.post("/chat") -async def chat_handler( - request: ChatRequest, - model: Literal["deepseek", "gemini"] = Query("deepseek", description="The AI model to use.") -): - try: - # Use the factory to get the correct provider instance - provider = get_llm_provider(model) - - # Call the method on the instance. We don't need to know if it's - # Gemini or DeepSeek, only that it fulfills the "contract". - response_text = await provider.generate_response(request.prompt) - - return {"response": response_text, "model_used": model} - except ValueError as e: - # This catches errors from the factory (e.g., unsupported model) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - # This catches errors from the provider's API call - raise HTTPException(status_code=500, detail=f"An error occurred with the {model} API: {e}") \ No newline at end of file +if __name__ == "__main__": + # This block allows you to run the application directly with 'python main.py' + # For production, it's recommended to use the `uvicorn` command directly. + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/ai-hub/pytest.ini b/ai-hub/pytest.ini new file mode 100644 index 0000000..cfd6bf4 --- /dev/null +++ b/ai-hub/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning +testpaths = + tests \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 9426b1e..b5be92c 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -10,4 +10,6 @@ psycopg2 pytest-asyncio pytest-tornasync -pytest-trio \ No newline at end of file +pytest-trio +numpy +faiss-cpu \ No newline at end of file diff --git a/ai-hub/tests/core/test_rag_service.py b/ai-hub/tests/core/test_rag_service.py new file mode 100644 index 0000000..b86d893 --- /dev/null +++ b/ai-hub/tests/core/test_rag_service.py @@ -0,0 +1,170 @@ +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock, call +from sqlalchemy.orm import Session +from typing import List + +# Import the RAGService class and its dependencies +from app.core.rag_service import RAGService +from app.core.vector_store import FaissVectorStore +from app.core.retrievers import Retriever +from app.db import models + +# --- RAGService Unit Tests --- +# These tests directly target the methods of the RAGService class +# to verify their internal logic in isolation. + +@patch('app.db.models.VectorMetadata') +@patch('app.db.models.Document') +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): + """ + Test the RAGService.add_document method for a successful run. + Verifies that the method correctly calls db.add(), db.commit(), and the vector store. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_new_document_instance = MagicMock() + mock_document_model.return_value = mock_new_document_instance + mock_new_document_instance.id = 1 + mock_new_document_instance.text = "Test text." + mock_new_document_instance.title = "Test Title" + + mock_vector_store_instance = mock_vector_store.return_value + mock_vector_store_instance.add_document.return_value = 123 + + # Instantiate the service with the mock dependencies + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test + document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) + + # Assertions + assert document_id == 1 + + # Use mock.call to check for both calls to db.add in the correct order. + # We must mock the VectorMetadata model to check its constructor call + expected_calls = [ + call(mock_new_document_instance), + call(mock_vector_metadata_model.return_value) + ] + mock_db.add.assert_has_calls(expected_calls) + + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_new_document_instance) + mock_vector_store_instance.add_document.assert_called_once_with("Test text.") + + # Assert that VectorMetadata was instantiated with the correct arguments + mock_vector_metadata_model.assert_called_once_with( + document_id=mock_new_document_instance.id, + faiss_index=mock_vector_store_instance.add_document.return_value, + embedding_model="mock_embedder" + ) + +@patch('app.core.vector_store.FaissVectorStore') +def test_rag_service_add_document_error_handling(mock_vector_store): + """ + Test the RAGService.add_document method's error handling. + Verifies that the transaction is rolled back on an exception. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + + # Configure the mock db.add to raise an exception + mock_db.add.side_effect = Exception("Database error") + + mock_vector_store_instance = mock_vector_store.return_value + + # Instantiate the service + rag_service = RAGService(vector_store=mock_vector_store_instance, retrievers=[]) + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Call the method under test and expect an exception + try: + rag_service.add_document(db=mock_db, doc_data=doc_data) + assert False, "Expected an exception to be raised" + except Exception as e: + assert str(e) == "Database error" + + # Assertions + # The first db.add was called + mock_db.add.assert_called_once() + # No commit should have occurred + mock_db.commit.assert_not_called() + # The transaction should have been rolled back + mock_db.rollback.assert_called_once() + + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_with_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when context is retrieved. + Verifies that the RAG prompt is correctly constructed. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response with context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = ["Context text 1.", "Context text 2."] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + expected_context = "Context text 1.\n\nContext text 2." + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once() + actual_llm_prompt = mock_llm_provider.generate_response.call_args[0][0] + + # Check if the generated prompt contains the expected context and question + assert expected_context in actual_llm_prompt + assert prompt in actual_llm_prompt + assert response_text == "LLM response with context" + +@patch('app.core.rag_service.get_llm_provider') +def test_rag_service_chat_with_rag_without_context(mock_get_llm_provider): + """ + Test the RAGService.chat_with_rag method when no context is retrieved. + Verifies that the original prompt is sent to the LLM. + """ + # Setup mocks + mock_db = MagicMock(spec=Session) + mock_llm_provider = MagicMock() + mock_llm_provider.generate_response = AsyncMock(return_value="LLM response without context") + mock_get_llm_provider.return_value = mock_llm_provider + + mock_retriever = MagicMock(spec=Retriever) + mock_retriever.retrieve_context.return_value = [] + + # Instantiate the service with the mock retriever + rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) + + prompt = "Test prompt without context." + + # Call the method under test and run the async function + response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model="deepseek")) + + # Assertions + mock_retriever.retrieve_context.assert_called_once_with(prompt, mock_db) + + mock_llm_provider.generate_response.assert_called_once_with(prompt) + assert response_text == "LLM response without context" diff --git a/ai-hub/tests/core/test_vector_store.py b/ai-hub/tests/core/test_vector_store.py new file mode 100644 index 0000000..1969dbb --- /dev/null +++ b/ai-hub/tests/core/test_vector_store.py @@ -0,0 +1,150 @@ +import pytest +import numpy as np +import faiss +import os +import shutil +from typing import List, Tuple + +# We need to configure the python path so that pytest can find our application code +# Since this is a test file, we assume the app/ directory is available from the +# pytest root. +from app.core.vector_store import FaissVectorStore, MockEmbedder + +# Define constants for our tests to ensure consistency +TEST_DIMENSION = 128 +TEST_INDEX_FILE = "test_faiss_index.bin" + + +# --- Fixtures --- +# Pytest fixtures are used to set up a clean environment for each test. + +@pytest.fixture(scope="function") +def temp_faiss_dir(tmp_path): + """ + Fixture to create a temporary directory for each test function. + This ensures that each test runs in a clean environment without + interfering with other tests or the main application. + """ + # Create a sub-directory within the pytest temporary path + test_dir = tmp_path / "faiss_test" + test_dir.mkdir() + yield test_dir + # The cleanup is automatically handled by the tmp_path fixture, + # but we'll add a manual check just in case. + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + + +@pytest.fixture(scope="function") +def faiss_store(temp_faiss_dir): + """ + Fixture that provides a fresh FaissVectorStore instance for each test. + The index file path points to the temporary directory. + """ + index_file_path = os.path.join(temp_faiss_dir, TEST_INDEX_FILE) + store = FaissVectorStore(index_file_path=index_file_path, dimension=TEST_DIMENSION) + return store + + +# --- Unit Tests --- + +def test_init_creates_new_index(faiss_store): + """ + Test that the constructor correctly creates a new FAISS index + if the index file does not exist. + """ + # We verify that the index is a faiss.IndexFlatL2 instance + assert isinstance(faiss_store.index, faiss.IndexFlatL2) + # The index should be empty initially + assert faiss_store.index.ntotal == 0 + # The file should NOT exist yet as it's only saved on add_document + assert not os.path.exists(faiss_store.index_file_path) + + +def test_add_document(faiss_store): + """ + Test the add_document method to ensure it adds a vector and saves the index. + """ + test_text = "This is a test document." + + # The index should be empty before adding + assert faiss_store.index.ntotal == 0 + + # Add the document and get the returned index ID + faiss_id = faiss_store.add_document(test_text) + + # The index should now have one item + assert faiss_store.index.ntotal == 1 + # The returned ID should be the first index, which is 0 + assert faiss_id == 0 + # The index file should now exist on disk + assert os.path.exists(faiss_store.index_file_path) + + +def test_add_multiple_documents(faiss_store): + """ + Test that multiple documents can be added and the index size grows correctly. + """ + docs = ["Doc 1", "Doc 2", "Doc 3"] + + # Add each document and check the total number of items + for i, doc in enumerate(docs): + faiss_id = faiss_store.add_document(doc) + assert faiss_store.index.ntotal == i + 1 + assert faiss_id == i + + # The final index file should exist and the count should be correct + assert os.path.exists(faiss_store.index_file_path) + assert faiss_store.index.ntotal == 3 + + +def test_load_existing_index(temp_faiss_dir): + """ + Test that the store can load an existing index file from disk. + """ + # Step 1: Create an index and add an item to it, then save it. + first_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + first_store.add_document("Document for persistence test.") + + # Ensure the file was saved + assert os.path.exists(first_store.index_file_path) + assert first_store.index.ntotal == 1 + + # Step 2: Create a new store instance pointing to the same file. + second_store = FaissVectorStore( + index_file_path=os.path.join(temp_faiss_dir, TEST_INDEX_FILE), + dimension=TEST_DIMENSION + ) + + # The new store should have loaded the index and should have 1 item. + assert second_store.index.ntotal == 1 + assert isinstance(second_store.index, faiss.IndexFlatL2) + + +def test_search_similar_documents(faiss_store): + """ + Test the search functionality. Since we're using a mock embedder with + random vectors, we can't predict the exact result, but we can + verify the format and number of results. + """ + # Add some documents to the store + faiss_store.add_document("Document 1") + faiss_store.add_document("Document 2") + faiss_store.add_document("Document 3") + faiss_store.add_document("Document 4") + faiss_store.add_document("Document 5") + + # Search for a query and ask for 3 results + results = faiss_store.search_similar_documents("A query string", k=3) + + # The results should be a list of 3 items + assert isinstance(results, list) + assert len(results) == 3 + + # The results should be integers, and valid FAISS IDs + for result_id in results: + assert isinstance(result_id, int) + assert 0 <= result_id < 5 # IDs should be between 0 and 4 diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py new file mode 100644 index 0000000..9aa1a0c --- /dev/null +++ b/ai-hub/tests/test_app.py @@ -0,0 +1,156 @@ +import os +from fastapi.testclient import TestClient +from unittest.mock import patch, MagicMock, AsyncMock +from sqlalchemy.orm import Session + +# Import the factory function directly to get a fresh app instance for testing +from app.app import create_app +# The get_db function is now in app/db_setup.py, so we must update the import path. +from app.db_setup import get_db + +# --- Dependency Override for Testing --- +# This is a mock database session that will be used in our tests. +mock_db = MagicMock(spec=Session) + +def override_get_db(): + """Returns the mock database session for tests.""" + try: + yield mock_db + finally: + pass + + +# --- API Endpoint Tests --- +# We patch the RAGService class itself, as the instance is created inside create_app(). + +# This test does not require mocking, so the app can be created at the module level. +# For consistency, we can still move it inside a function if preferred. +def test_read_root(): + """Test the root endpoint to ensure it's running.""" + # Create app and client here to be sure no mocking interferes + app = create_app() + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + +@patch('app.app.RAGService') +def test_chat_handler_success(mock_rag_service_class): + """ + Test the /chat endpoint with a successful, mocked RAG service response. + + We patch the RAGService class and configure a mock instance + with a controlled return value. + """ + # Create a mock instance of RAGService that will be returned by the factory + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.chat_with_rag = AsyncMock(return_value="This is a mock response from the RAG service.") + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Make the request to our app + response = client.post("/chat", json={"prompt": "Hello there"}) + + # Assert our app behaved as expected + assert response.status_code == 200 + assert response.json()["response"] == "This is a mock response from the RAG service." + + # Verify that the mocked method was called with the correct arguments + mock_rag_service_instance.chat_with_rag.assert_called_once_with( + db=mock_db, prompt="Hello there", model="deepseek" + ) + +@patch('app.app.RAGService') +def test_chat_handler_api_failure(mock_rag_service_class): + """ + Test the /chat endpoint when the RAG service encounters an error. + + We configure the mock RAGService instance's chat_with_rag method + to raise an exception. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.chat_with_rag = AsyncMock(side_effect=Exception("API connection error")) + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Make the request to our app + response = client.post("/chat", json={"prompt": "This request will fail"}) + + # Assert our app handles the error gracefully + assert response.status_code == 500 + assert "An error occurred with the deepseek API" in response.json()["detail"] + + # Verify that the mocked method was called with the correct arguments + mock_rag_service_instance.chat_with_rag.assert_called_once_with( + db=mock_db, prompt="This request will fail", model="deepseek" + ) + +@patch('app.app.RAGService') +def test_add_document_success(mock_rag_service_class): + """ + Test the /document endpoint with a successful, mocked RAG service response. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.return_value = 1 + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/document", json=doc_data) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + + +@patch('app.app.RAGService') +def test_add_document_api_failure(mock_rag_service_class): + """ + Test the /document endpoint when the RAG service encounters an error. + """ + # Create a mock instance of RAGService + mock_rag_service_instance = mock_rag_service_class.return_value + mock_rag_service_instance.add_document.side_effect = Exception("Service failed") + + # Now create the app and client, so the patch takes effect. + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + response = client.post("/document", json=doc_data) + + assert response.status_code == 500 + assert "An error occurred: Service failed" in response.json()["detail"] + + # Verify that the mocked method was called with the correct arguments, + # including the default values added by Pydantic. + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) diff --git a/ai-hub/tests/test_main.py b/ai-hub/tests/test_main.py deleted file mode 100644 index ee713b3..0000000 --- a/ai-hub/tests/test_main.py +++ /dev/null @@ -1,65 +0,0 @@ -from fastapi.testclient import TestClient -from unittest.mock import patch, MagicMock, AsyncMock - -# Import the FastAPI app instance to create a test client -from app.main import app - -# Create a TestClient instance based on our FastAPI app -client = TestClient(app) - -def test_read_root(): - """Test the root endpoint to ensure it's running.""" - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - -@patch('app.main.get_llm_provider') -def test_chat_handler_success(mock_get_llm_provider): - """ - Test the /chat endpoint with a successful, mocked LLM response. - - We patch the get_llm_provider factory function to control the - behavior of the LLM provider instance it returns. - """ - # Configure a mock LLM provider instance with an async method - mock_provider = MagicMock() - mock_provider.generate_response = AsyncMock(return_value="This is a mock response from a provider.") - - # Configure our mocked factory function to return our mock provider - mock_get_llm_provider.return_value = mock_provider - - # Make the request to our app - response = client.post("/chat", json={"prompt": "Hello there"}) - - # Assert our app behaved as expected - assert response.status_code == 200 - assert response.json()["response"] == "This is a mock response from a provider." - - # Verify that the mocked factory and its method were called - mock_get_llm_provider.assert_called_once_with("deepseek") - mock_provider.generate_response.assert_called_once_with("Hello there") - -@patch('app.main.get_llm_provider') -def test_chat_handler_api_failure(mock_get_llm_provider): - """ - Test the /chat endpoint when the external LLM API fails. - - We configure the mocked provider's generate_response method to raise an exception. - """ - # Configure a mock LLM provider instance with an async method that raises an exception - mock_provider = MagicMock() - mock_provider.generate_response = AsyncMock(side_effect=Exception("API connection error")) - - # Configure our mocked factory function to return our mock provider - mock_get_llm_provider.return_value = mock_provider - - # Make the request to our app - response = client.post("/chat", json={"prompt": "This request will fail"}) - - # Assert our app handles the error gracefully - assert response.status_code == 500 - assert "An error occurred with the deepseek API" in response.json()["detail"] - - # Verify that the mocked factory and its method were called - mock_get_llm_provider.assert_called_once_with("deepseek") - mock_provider.generate_response.assert_called_once_with("This request will fail")