diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c58a6c7..8d841a2 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,3 +1,4 @@ +# tests/app/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI @@ -6,26 +7,42 @@ from datetime import datetime # Import the dependencies and router factory -from app.core.services import RAGService -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """Pytest fixture to create a TestClient with a fully mocked environment.""" + """ + Pytest fixture to create a TestClient with a fully mocked environment, + including a mock ServiceContainer. + """ test_app = FastAPI() + + # Mock individual services mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + + # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + + # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - api_router = create_api_router(rag_service=mock_rag_service) + # Pass the mock ServiceContainer to the router factory + api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - yield TestClient(test_app), mock_rag_service + # Return the test client and the mock services for assertion + yield TestClient(test_app), mock_services # --- General Endpoint --- @@ -40,32 +57,33 @@ def test_create_session_success(client): """Tests successfully creating a new chat session.""" - test_client, mock_rag_service = client + test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_rag_service.create_session.return_value = mock_session + mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 - mock_rag_service.create_session.assert_called_once() + mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the default model 'deepseek' # and the default load_faiss_retriever=False - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", @@ -76,16 +94,17 @@ """ Tests sending a message in an existing session and explicitly switching the model. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", @@ -96,8 +115,8 @@ """ Tests sending a message and explicitly enabling the FAISS retriever. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", @@ -106,9 +125,10 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", # The model still defaults to deepseek @@ -117,13 +137,13 @@ def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] - mock_rag_service.get_message_history.return_value = mock_history + mock_services.rag_service.get_message_history.return_value = mock_history # Act response = test_client.get("/sessions/123/messages") @@ -135,13 +155,16 @@ assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service.get_message_history.assert_called_once_with(db=mock_rag_service.get_message_history.call_args.kwargs['db'], session_id=123) + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return None, indicating the session wasn't found - mock_rag_service.get_message_history.return_value = None + mock_services.rag_service.get_message_history.return_value = None # Act response = test_client.get("/sessions/999/messages") @@ -151,35 +174,39 @@ assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- -# (These tests are unchanged) + def test_add_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.add_document.return_value = 123 + """Tests the /documents endpoint for adding a new document.""" + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): - test_client, mock_rag_service = client + """Tests the /documents endpoint for retrieving all documents.""" + test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] - mock_rag_service.get_all_documents.return_value = mock_docs + mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = 42 + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = None + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c58a6c7..8d841a2 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,3 +1,4 @@ +# tests/app/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI @@ -6,26 +7,42 @@ from datetime import datetime # Import the dependencies and router factory -from app.core.services import RAGService -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """Pytest fixture to create a TestClient with a fully mocked environment.""" + """ + Pytest fixture to create a TestClient with a fully mocked environment, + including a mock ServiceContainer. + """ test_app = FastAPI() + + # Mock individual services mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + + # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + + # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - api_router = create_api_router(rag_service=mock_rag_service) + # Pass the mock ServiceContainer to the router factory + api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - yield TestClient(test_app), mock_rag_service + # Return the test client and the mock services for assertion + yield TestClient(test_app), mock_services # --- General Endpoint --- @@ -40,32 +57,33 @@ def test_create_session_success(client): """Tests successfully creating a new chat session.""" - test_client, mock_rag_service = client + test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_rag_service.create_session.return_value = mock_session + mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 - mock_rag_service.create_session.assert_called_once() + mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the default model 'deepseek' # and the default load_faiss_retriever=False - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", @@ -76,16 +94,17 @@ """ Tests sending a message in an existing session and explicitly switching the model. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", @@ -96,8 +115,8 @@ """ Tests sending a message and explicitly enabling the FAISS retriever. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", @@ -106,9 +125,10 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", # The model still defaults to deepseek @@ -117,13 +137,13 @@ def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] - mock_rag_service.get_message_history.return_value = mock_history + mock_services.rag_service.get_message_history.return_value = mock_history # Act response = test_client.get("/sessions/123/messages") @@ -135,13 +155,16 @@ assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service.get_message_history.assert_called_once_with(db=mock_rag_service.get_message_history.call_args.kwargs['db'], session_id=123) + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return None, indicating the session wasn't found - mock_rag_service.get_message_history.return_value = None + mock_services.rag_service.get_message_history.return_value = None # Act response = test_client.get("/sessions/999/messages") @@ -151,35 +174,39 @@ assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- -# (These tests are unchanged) + def test_add_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.add_document.return_value = 123 + """Tests the /documents endpoint for adding a new document.""" + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): - test_client, mock_rag_service = client + """Tests the /documents endpoint for retrieving all documents.""" + test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] - mock_rag_service.get_all_documents.return_value = mock_docs + mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = 42 + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = None + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py new file mode 100644 index 0000000..cdf8b44 --- /dev/null +++ b/ai-hub/tests/core/services/test_document.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from datetime import datetime + +from app.core.services.document import DocumentService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder + +@pytest.fixture +def document_service(): + """ + Pytest fixture to create a DocumentService instance with mocked dependencies. + """ + mock_embedder = MagicMock(spec=MockEmbedder) + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_vector_store.embedder = mock_embedder + return DocumentService(vector_store=mock_vector_store) + +# --- add_document Tests --- + +def test_add_document_success(document_service: DocumentService): + """ + Tests that add_document successfully adds a document to the database + and its vector embedding to the FAISS index. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_document = MagicMock(id=1, text="Test text.") + mock_db.add.side_effect = [None, None] # Allow multiple calls + + # Configure the mock db.query to return a document object + mock_document_model_instance = models.Document( + id=1, + title="Test Title", + text="Test text.", + source_url="http://test.com" + ) + + with patch('app.core.services.document.models.Document', return_value=mock_document_model_instance) as mock_document_model, \ + patch('app.core.services.document.models.VectorMetadata') as mock_vector_metadata_model: + + document_service.vector_store.add_document.return_value = 123 + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Act + document_id = document_service.add_document(db=mock_db, doc_data=doc_data) + + # Assert + assert document_id == 1 + mock_db.add.assert_any_call(mock_document_model_instance) + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_document_model_instance) + document_service.vector_store.add_document.assert_called_once_with("Test text.") + mock_vector_metadata_model.assert_called_once_with( + document_id=1, + faiss_index=123, + embedding_model="mock_embedder" + ) + mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) + + +def test_add_document_sql_error(document_service: DocumentService): + """ + Tests that add_document correctly handles a SQLAlchemyError by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.add.side_effect = SQLAlchemyError("Database error") + doc_data = {"title": "Test", "text": "...", "source_url": "http://test.com"} + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Database error"): + document_service.add_document(db=mock_db, doc_data=doc_data) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() + +# --- get_all_documents Tests --- + +def test_get_all_documents_success(document_service: DocumentService): + """ + Tests that get_all_documents returns a list of documents. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_documents = [models.Document(id=1), models.Document(id=2)] + mock_db.query.return_value.order_by.return_value.all.return_value = mock_documents + + # Act + documents = document_service.get_all_documents(db=mock_db) + + # Assert + assert documents == mock_documents + mock_db.query.assert_called_once_with(models.Document) + mock_db.query.return_value.order_by.assert_called_once() + +# --- delete_document Tests --- + +def test_delete_document_success(document_service: DocumentService): + """ + Tests that delete_document correctly deletes a document. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id_to_delete = 1 + doc_to_delete = models.Document(id=doc_id_to_delete) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=doc_id_to_delete) + + # Assert + assert deleted_id == doc_id_to_delete + mock_db.query.assert_called_once_with(models.Document) + mock_db.delete.assert_called_once_with(doc_to_delete) + mock_db.commit.assert_called_once() + +def test_delete_document_not_found(document_service: DocumentService): + """ + Tests that delete_document returns None if the document is not found. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=999) + + # Assert + assert deleted_id is None + mock_db.delete.assert_not_called() + mock_db.commit.assert_not_called() + +def test_delete_document_sql_error(document_service: DocumentService): + """ + Tests that delete_document handles a SQLAlchemyError correctly by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id = 1 + doc_to_delete = models.Document(id=doc_id) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + mock_db.delete.side_effect = SQLAlchemyError("Delete error") + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Delete error"): + document_service.delete_document(db=mock_db, document_id=doc_id) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c58a6c7..8d841a2 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,3 +1,4 @@ +# tests/app/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI @@ -6,26 +7,42 @@ from datetime import datetime # Import the dependencies and router factory -from app.core.services import RAGService -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """Pytest fixture to create a TestClient with a fully mocked environment.""" + """ + Pytest fixture to create a TestClient with a fully mocked environment, + including a mock ServiceContainer. + """ test_app = FastAPI() + + # Mock individual services mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + + # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + + # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - api_router = create_api_router(rag_service=mock_rag_service) + # Pass the mock ServiceContainer to the router factory + api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - yield TestClient(test_app), mock_rag_service + # Return the test client and the mock services for assertion + yield TestClient(test_app), mock_services # --- General Endpoint --- @@ -40,32 +57,33 @@ def test_create_session_success(client): """Tests successfully creating a new chat session.""" - test_client, mock_rag_service = client + test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_rag_service.create_session.return_value = mock_session + mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 - mock_rag_service.create_session.assert_called_once() + mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the default model 'deepseek' # and the default load_faiss_retriever=False - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", @@ -76,16 +94,17 @@ """ Tests sending a message in an existing session and explicitly switching the model. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", @@ -96,8 +115,8 @@ """ Tests sending a message and explicitly enabling the FAISS retriever. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", @@ -106,9 +125,10 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", # The model still defaults to deepseek @@ -117,13 +137,13 @@ def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] - mock_rag_service.get_message_history.return_value = mock_history + mock_services.rag_service.get_message_history.return_value = mock_history # Act response = test_client.get("/sessions/123/messages") @@ -135,13 +155,16 @@ assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service.get_message_history.assert_called_once_with(db=mock_rag_service.get_message_history.call_args.kwargs['db'], session_id=123) + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return None, indicating the session wasn't found - mock_rag_service.get_message_history.return_value = None + mock_services.rag_service.get_message_history.return_value = None # Act response = test_client.get("/sessions/999/messages") @@ -151,35 +174,39 @@ assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- -# (These tests are unchanged) + def test_add_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.add_document.return_value = 123 + """Tests the /documents endpoint for adding a new document.""" + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): - test_client, mock_rag_service = client + """Tests the /documents endpoint for retrieving all documents.""" + test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] - mock_rag_service.get_all_documents.return_value = mock_docs + mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = 42 + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = None + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py new file mode 100644 index 0000000..cdf8b44 --- /dev/null +++ b/ai-hub/tests/core/services/test_document.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from datetime import datetime + +from app.core.services.document import DocumentService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder + +@pytest.fixture +def document_service(): + """ + Pytest fixture to create a DocumentService instance with mocked dependencies. + """ + mock_embedder = MagicMock(spec=MockEmbedder) + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_vector_store.embedder = mock_embedder + return DocumentService(vector_store=mock_vector_store) + +# --- add_document Tests --- + +def test_add_document_success(document_service: DocumentService): + """ + Tests that add_document successfully adds a document to the database + and its vector embedding to the FAISS index. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_document = MagicMock(id=1, text="Test text.") + mock_db.add.side_effect = [None, None] # Allow multiple calls + + # Configure the mock db.query to return a document object + mock_document_model_instance = models.Document( + id=1, + title="Test Title", + text="Test text.", + source_url="http://test.com" + ) + + with patch('app.core.services.document.models.Document', return_value=mock_document_model_instance) as mock_document_model, \ + patch('app.core.services.document.models.VectorMetadata') as mock_vector_metadata_model: + + document_service.vector_store.add_document.return_value = 123 + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Act + document_id = document_service.add_document(db=mock_db, doc_data=doc_data) + + # Assert + assert document_id == 1 + mock_db.add.assert_any_call(mock_document_model_instance) + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_document_model_instance) + document_service.vector_store.add_document.assert_called_once_with("Test text.") + mock_vector_metadata_model.assert_called_once_with( + document_id=1, + faiss_index=123, + embedding_model="mock_embedder" + ) + mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) + + +def test_add_document_sql_error(document_service: DocumentService): + """ + Tests that add_document correctly handles a SQLAlchemyError by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.add.side_effect = SQLAlchemyError("Database error") + doc_data = {"title": "Test", "text": "...", "source_url": "http://test.com"} + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Database error"): + document_service.add_document(db=mock_db, doc_data=doc_data) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() + +# --- get_all_documents Tests --- + +def test_get_all_documents_success(document_service: DocumentService): + """ + Tests that get_all_documents returns a list of documents. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_documents = [models.Document(id=1), models.Document(id=2)] + mock_db.query.return_value.order_by.return_value.all.return_value = mock_documents + + # Act + documents = document_service.get_all_documents(db=mock_db) + + # Assert + assert documents == mock_documents + mock_db.query.assert_called_once_with(models.Document) + mock_db.query.return_value.order_by.assert_called_once() + +# --- delete_document Tests --- + +def test_delete_document_success(document_service: DocumentService): + """ + Tests that delete_document correctly deletes a document. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id_to_delete = 1 + doc_to_delete = models.Document(id=doc_id_to_delete) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=doc_id_to_delete) + + # Assert + assert deleted_id == doc_id_to_delete + mock_db.query.assert_called_once_with(models.Document) + mock_db.delete.assert_called_once_with(doc_to_delete) + mock_db.commit.assert_called_once() + +def test_delete_document_not_found(document_service: DocumentService): + """ + Tests that delete_document returns None if the document is not found. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=999) + + # Assert + assert deleted_id is None + mock_db.delete.assert_not_called() + mock_db.commit.assert_not_called() + +def test_delete_document_sql_error(document_service: DocumentService): + """ + Tests that delete_document handles a SQLAlchemyError correctly by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id = 1 + doc_to_delete = models.Document(id=doc_id) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + mock_db.delete.side_effect = SQLAlchemyError("Delete error") + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Delete error"): + document_service.delete_document(db=mock_db, document_id=doc_id) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() \ No newline at end of file diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py new file mode 100644 index 0000000..2fd4ab3 --- /dev/null +++ b/ai-hub/tests/core/services/test_rag.py @@ -0,0 +1,216 @@ +import pytest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from typing import List +from datetime import datetime +import dspy + +from app.core.services.rag import RAGService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.pipelines.dspy_rag import DspyRagPipeline +from app.core.llm_providers import LLMProvider + +@pytest.fixture +def rag_service(): + """ + Pytest fixture to create a RAGService instance with mocked dependencies. + It includes a mock FaissDBRetriever and a mock generic Retriever to test + conditional loading. + """ + # Create a mock vector store to provide a mock retriever + mock_vector_store = MagicMock(spec=FaissVectorStore) + + mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) + mock_web_retriever = MagicMock(spec=Retriever) + + return RAGService( + retrievers=[mock_web_retriever, mock_faiss_retriever] + ) + +# --- Session Management Tests --- + +def test_create_session(rag_service: RAGService): + """Tests that the create_session method correctly creates a new session.""" + mock_db = MagicMock(spec=Session) + + rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + + mock_db.add.assert_called_once() + added_object = mock_db.add.call_args[0][0] + assert isinstance(added_object, models.Session) + assert added_object.user_id == "test_user" + assert added_object.model_name == "gemini" + +@patch('app.core.services.rag.get_llm_provider') +@patch('app.core.services.rag.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests the full orchestration of a chat message within a session using the default model + and with the retriever loading parameter explicitly set to False. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=42, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=42, + prompt="Test prompt", + model="deepseek", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response" + assert model_name == "deepseek" + +def test_chat_with_rag_model_switch(rag_service: RAGService): + """ + Tests that chat_with_rag correctly switches the model based on the 'model' argument, + while still using the default retriever setting. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=43, + prompt="Test prompt for Gemini", + model="gemini", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("gemini") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt for Gemini", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response from Gemini" + assert model_name == "gemini" + + +def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): + """ + Tests that the chat_with_rag method correctly initializes the DspyRagPipeline + with the FaissDBRetriever when `load_faiss_retriever` is True. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=44, + prompt="Test prompt with FAISS", + model="deepseek", + load_faiss_retriever=True + ) + ) + + # --- Assert --- + expected_retrievers = [rag_service.faiss_retriever] + mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt with FAISS", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Response with FAISS context" + assert model_name == "deepseek" + + +def test_get_message_history_success(rag_service: RAGService): + """Tests successfully retrieving message history for an existing session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, messages=[ + models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), + models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) + ]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=1) + + # Assert + assert len(messages) == 2 + assert messages[0].created_at < messages[1].created_at + mock_db.query.assert_called_once_with(models.Session) + +def test_get_message_history_not_found(rag_service: RAGService): + """Tests retrieving history for a non-existent session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=999) + + # Assert + assert messages is None \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c58a6c7..8d841a2 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,3 +1,4 @@ +# tests/app/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI @@ -6,26 +7,42 @@ from datetime import datetime # Import the dependencies and router factory -from app.core.services import RAGService -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """Pytest fixture to create a TestClient with a fully mocked environment.""" + """ + Pytest fixture to create a TestClient with a fully mocked environment, + including a mock ServiceContainer. + """ test_app = FastAPI() + + # Mock individual services mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + + # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + + # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - api_router = create_api_router(rag_service=mock_rag_service) + # Pass the mock ServiceContainer to the router factory + api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - yield TestClient(test_app), mock_rag_service + # Return the test client and the mock services for assertion + yield TestClient(test_app), mock_services # --- General Endpoint --- @@ -40,32 +57,33 @@ def test_create_session_success(client): """Tests successfully creating a new chat session.""" - test_client, mock_rag_service = client + test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_rag_service.create_session.return_value = mock_session + mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 - mock_rag_service.create_session.assert_called_once() + mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the default model 'deepseek' # and the default load_faiss_retriever=False - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", @@ -76,16 +94,17 @@ """ Tests sending a message in an existing session and explicitly switching the model. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", @@ -96,8 +115,8 @@ """ Tests sending a message and explicitly enabling the FAISS retriever. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", @@ -106,9 +125,10 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", # The model still defaults to deepseek @@ -117,13 +137,13 @@ def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] - mock_rag_service.get_message_history.return_value = mock_history + mock_services.rag_service.get_message_history.return_value = mock_history # Act response = test_client.get("/sessions/123/messages") @@ -135,13 +155,16 @@ assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service.get_message_history.assert_called_once_with(db=mock_rag_service.get_message_history.call_args.kwargs['db'], session_id=123) + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return None, indicating the session wasn't found - mock_rag_service.get_message_history.return_value = None + mock_services.rag_service.get_message_history.return_value = None # Act response = test_client.get("/sessions/999/messages") @@ -151,35 +174,39 @@ assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- -# (These tests are unchanged) + def test_add_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.add_document.return_value = 123 + """Tests the /documents endpoint for adding a new document.""" + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): - test_client, mock_rag_service = client + """Tests the /documents endpoint for retrieving all documents.""" + test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] - mock_rag_service.get_all_documents.return_value = mock_docs + mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = 42 + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = None + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py new file mode 100644 index 0000000..cdf8b44 --- /dev/null +++ b/ai-hub/tests/core/services/test_document.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from datetime import datetime + +from app.core.services.document import DocumentService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder + +@pytest.fixture +def document_service(): + """ + Pytest fixture to create a DocumentService instance with mocked dependencies. + """ + mock_embedder = MagicMock(spec=MockEmbedder) + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_vector_store.embedder = mock_embedder + return DocumentService(vector_store=mock_vector_store) + +# --- add_document Tests --- + +def test_add_document_success(document_service: DocumentService): + """ + Tests that add_document successfully adds a document to the database + and its vector embedding to the FAISS index. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_document = MagicMock(id=1, text="Test text.") + mock_db.add.side_effect = [None, None] # Allow multiple calls + + # Configure the mock db.query to return a document object + mock_document_model_instance = models.Document( + id=1, + title="Test Title", + text="Test text.", + source_url="http://test.com" + ) + + with patch('app.core.services.document.models.Document', return_value=mock_document_model_instance) as mock_document_model, \ + patch('app.core.services.document.models.VectorMetadata') as mock_vector_metadata_model: + + document_service.vector_store.add_document.return_value = 123 + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Act + document_id = document_service.add_document(db=mock_db, doc_data=doc_data) + + # Assert + assert document_id == 1 + mock_db.add.assert_any_call(mock_document_model_instance) + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_document_model_instance) + document_service.vector_store.add_document.assert_called_once_with("Test text.") + mock_vector_metadata_model.assert_called_once_with( + document_id=1, + faiss_index=123, + embedding_model="mock_embedder" + ) + mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) + + +def test_add_document_sql_error(document_service: DocumentService): + """ + Tests that add_document correctly handles a SQLAlchemyError by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.add.side_effect = SQLAlchemyError("Database error") + doc_data = {"title": "Test", "text": "...", "source_url": "http://test.com"} + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Database error"): + document_service.add_document(db=mock_db, doc_data=doc_data) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() + +# --- get_all_documents Tests --- + +def test_get_all_documents_success(document_service: DocumentService): + """ + Tests that get_all_documents returns a list of documents. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_documents = [models.Document(id=1), models.Document(id=2)] + mock_db.query.return_value.order_by.return_value.all.return_value = mock_documents + + # Act + documents = document_service.get_all_documents(db=mock_db) + + # Assert + assert documents == mock_documents + mock_db.query.assert_called_once_with(models.Document) + mock_db.query.return_value.order_by.assert_called_once() + +# --- delete_document Tests --- + +def test_delete_document_success(document_service: DocumentService): + """ + Tests that delete_document correctly deletes a document. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id_to_delete = 1 + doc_to_delete = models.Document(id=doc_id_to_delete) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=doc_id_to_delete) + + # Assert + assert deleted_id == doc_id_to_delete + mock_db.query.assert_called_once_with(models.Document) + mock_db.delete.assert_called_once_with(doc_to_delete) + mock_db.commit.assert_called_once() + +def test_delete_document_not_found(document_service: DocumentService): + """ + Tests that delete_document returns None if the document is not found. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=999) + + # Assert + assert deleted_id is None + mock_db.delete.assert_not_called() + mock_db.commit.assert_not_called() + +def test_delete_document_sql_error(document_service: DocumentService): + """ + Tests that delete_document handles a SQLAlchemyError correctly by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id = 1 + doc_to_delete = models.Document(id=doc_id) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + mock_db.delete.side_effect = SQLAlchemyError("Delete error") + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Delete error"): + document_service.delete_document(db=mock_db, document_id=doc_id) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() \ No newline at end of file diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py new file mode 100644 index 0000000..2fd4ab3 --- /dev/null +++ b/ai-hub/tests/core/services/test_rag.py @@ -0,0 +1,216 @@ +import pytest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from typing import List +from datetime import datetime +import dspy + +from app.core.services.rag import RAGService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.pipelines.dspy_rag import DspyRagPipeline +from app.core.llm_providers import LLMProvider + +@pytest.fixture +def rag_service(): + """ + Pytest fixture to create a RAGService instance with mocked dependencies. + It includes a mock FaissDBRetriever and a mock generic Retriever to test + conditional loading. + """ + # Create a mock vector store to provide a mock retriever + mock_vector_store = MagicMock(spec=FaissVectorStore) + + mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) + mock_web_retriever = MagicMock(spec=Retriever) + + return RAGService( + retrievers=[mock_web_retriever, mock_faiss_retriever] + ) + +# --- Session Management Tests --- + +def test_create_session(rag_service: RAGService): + """Tests that the create_session method correctly creates a new session.""" + mock_db = MagicMock(spec=Session) + + rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + + mock_db.add.assert_called_once() + added_object = mock_db.add.call_args[0][0] + assert isinstance(added_object, models.Session) + assert added_object.user_id == "test_user" + assert added_object.model_name == "gemini" + +@patch('app.core.services.rag.get_llm_provider') +@patch('app.core.services.rag.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests the full orchestration of a chat message within a session using the default model + and with the retriever loading parameter explicitly set to False. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=42, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=42, + prompt="Test prompt", + model="deepseek", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response" + assert model_name == "deepseek" + +def test_chat_with_rag_model_switch(rag_service: RAGService): + """ + Tests that chat_with_rag correctly switches the model based on the 'model' argument, + while still using the default retriever setting. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=43, + prompt="Test prompt for Gemini", + model="gemini", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("gemini") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt for Gemini", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response from Gemini" + assert model_name == "gemini" + + +def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): + """ + Tests that the chat_with_rag method correctly initializes the DspyRagPipeline + with the FaissDBRetriever when `load_faiss_retriever` is True. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=44, + prompt="Test prompt with FAISS", + model="deepseek", + load_faiss_retriever=True + ) + ) + + # --- Assert --- + expected_retrievers = [rag_service.faiss_retriever] + mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt with FAISS", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Response with FAISS context" + assert model_name == "deepseek" + + +def test_get_message_history_success(rag_service: RAGService): + """Tests successfully retrieving message history for an existing session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, messages=[ + models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), + models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) + ]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=1) + + # Assert + assert len(messages) == 2 + assert messages[0].created_at < messages[1].created_at + mock_db.query.assert_called_once_with(models.Session) + +def test_get_message_history_not_found(rag_service: RAGService): + """Tests retrieving history for a non-existent session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=999) + + # Assert + assert messages is None \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index d36c203..a59a6d4 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -1,320 +1,320 @@ -import pytest -import asyncio -from unittest.mock import patch, MagicMock, AsyncMock -from sqlalchemy.orm import Session -from sqlalchemy.exc import SQLAlchemyError -from typing import List -from datetime import datetime -import dspy +# import pytest +# import asyncio +# from unittest.mock import patch, MagicMock, AsyncMock +# from sqlalchemy.orm import Session +# from sqlalchemy.exc import SQLAlchemyError +# from typing import List +# from datetime import datetime +# import dspy -# Import the service and its dependencies -from app.core.services import RAGService -from app.db import models -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -# Import FaissDBRetriever and a mock WebRetriever for testing different cases -from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider -from app.core.llm_providers import LLMProvider +# # Import the service and its dependencies +# from app.core.services import RAGService +# from app.db import models +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# # Import FaissDBRetriever and a mock WebRetriever for testing different cases +# from app.core.retrievers import FaissDBRetriever, Retriever +# from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider +# from app.core.llm_providers import LLMProvider -@pytest.fixture -def rag_service(): - """ - Pytest fixture to create a RAGService instance with mocked dependencies. - It includes a mock FaissDBRetriever and a mock generic Retriever to test - conditional loading. - """ - # Create a mock embedder to be attached to the vector store mock - mock_embedder = MagicMock(spec=MockEmbedder) - mock_vector_store = MagicMock(spec=FaissVectorStore) - mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute +# @pytest.fixture +# def rag_service(): +# """ +# Pytest fixture to create a RAGService instance with mocked dependencies. +# It includes a mock FaissDBRetriever and a mock generic Retriever to test +# conditional loading. +# """ +# # Create a mock embedder to be attached to the vector store mock +# mock_embedder = MagicMock(spec=MockEmbedder) +# mock_vector_store = MagicMock(spec=FaissVectorStore) +# mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute - mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) - mock_web_retriever = MagicMock(spec=Retriever) - return RAGService( - vector_store=mock_vector_store, - retrievers=[mock_web_retriever, mock_faiss_retriever] - ) +# mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) +# mock_web_retriever = MagicMock(spec=Retriever) +# return RAGService( +# vector_store=mock_vector_store, +# retrievers=[mock_web_retriever, mock_faiss_retriever] +# ) -# --- Session Management Tests --- +# # --- Session Management Tests --- -def test_create_session(rag_service: RAGService): - """Tests that the create_session method correctly creates a new session.""" - mock_db = MagicMock(spec=Session) +# def test_create_session(rag_service: RAGService): +# """Tests that the create_session method correctly creates a new session.""" +# mock_db = MagicMock(spec=Session) - rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") +# rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") - mock_db.add.assert_called_once() - added_object = mock_db.add.call_args[0][0] - assert isinstance(added_object, models.Session) - assert added_object.user_id == "test_user" - assert added_object.model_name == "gemini" +# mock_db.add.assert_called_once() +# added_object = mock_db.add.call_args[0][0] +# assert isinstance(added_object, models.Session) +# assert added_object.user_id == "test_user" +# assert added_object.model_name == "gemini" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests the full orchestration of a chat message within a session using the default model - and with the retriever loading parameter explicitly set to False. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=42, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests the full orchestration of a chat message within a session using the default model +# and with the retriever loading parameter explicitly set to False. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=42, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=42, - prompt="Test prompt", - model="deepseek", - load_faiss_retriever=False # Explicitly pass the default value - ) - ) +# # --- Act --- +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=42, +# prompt="Test prompt", +# model="deepseek", +# load_faiss_retriever=False # Explicitly pass the default value +# ) +# ) - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") +# # --- Assert --- +# mock_db.query.assert_called_once_with(models.Session) +# assert mock_db.add.call_count == 2 +# mock_get_llm_provider.assert_called_once_with("deepseek") - # Assert that DspyRagPipeline was initialized with an empty list of retrievers - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) +# # Assert that DspyRagPipeline was initialized with an empty list of retrievers +# mock_dspy_pipeline.assert_called_once_with(retrievers=[]) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Final RAG response" - assert model_name == "deepseek" +# assert answer == "Final RAG response" +# assert model_name == "deepseek" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests that chat_with_rag correctly switches the model based on the 'model' argument, - while still using the default retriever setting. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=43, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests that chat_with_rag correctly switches the model based on the 'model' argument, +# while still using the default retriever setting. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=43, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=43, - prompt="Test prompt for Gemini", - model="gemini", - load_faiss_retriever=False # Explicitly pass the default value - ) - ) +# # --- Act --- +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=43, +# prompt="Test prompt for Gemini", +# model="gemini", +# load_faiss_retriever=False # Explicitly pass the default value +# ) +# ) - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("gemini") +# # --- Assert --- +# mock_db.query.assert_called_once_with(models.Session) +# assert mock_db.add.call_count == 2 +# mock_get_llm_provider.assert_called_once_with("gemini") - # Assert that DspyRagPipeline was initialized with an empty list of retrievers - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) +# # Assert that DspyRagPipeline was initialized with an empty list of retrievers +# mock_dspy_pipeline.assert_called_once_with(retrievers=[]) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt for Gemini", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt for Gemini", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Final RAG response from Gemini" - assert model_name == "gemini" +# assert answer == "Final RAG response from Gemini" +# assert model_name == "gemini" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests that the chat_with_rag method correctly initializes the DspyRagPipeline - with the FaissDBRetriever when `load_faiss_retriever` is True. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=44, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests that the chat_with_rag method correctly initializes the DspyRagPipeline +# with the FaissDBRetriever when `load_faiss_retriever` is True. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=44, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - # Explicitly enable the FAISS retriever - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=44, - prompt="Test prompt with FAISS", - model="deepseek", - load_faiss_retriever=True - ) - ) +# # --- Act --- +# # Explicitly enable the FAISS retriever +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=44, +# prompt="Test prompt with FAISS", +# model="deepseek", +# load_faiss_retriever=True +# ) +# ) - # --- Assert --- - # The crucial part is to verify that the pipeline was called with the correct retriever - expected_retrievers = [rag_service.faiss_retriever] - mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) +# # --- Assert --- +# # The crucial part is to verify that the pipeline was called with the correct retriever +# expected_retrievers = [rag_service.faiss_retriever] +# mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt with FAISS", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt with FAISS", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Response with FAISS context" - assert model_name == "deepseek" +# assert answer == "Response with FAISS context" +# assert model_name == "deepseek" -def test_get_message_history_success(rag_service: RAGService): - """Tests successfully retrieving message history for an existing session.""" - # Arrange - mock_db = MagicMock(spec=Session) - # Ensure mocked messages have created_at for sorting - mock_session = models.Session(id=1, messages=[ - models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), - models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) - ]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# def test_get_message_history_success(rag_service: RAGService): +# """Tests successfully retrieving message history for an existing session.""" +# # Arrange +# mock_db = MagicMock(spec=Session) +# # Ensure mocked messages have created_at for sorting +# mock_session = models.Session(id=1, messages=[ +# models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), +# models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) +# ]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - # Act - messages = rag_service.get_message_history(db=mock_db, session_id=1) +# # Act +# messages = rag_service.get_message_history(db=mock_db, session_id=1) - # Assert - assert len(messages) == 2 - assert messages[0].created_at < messages[1].created_at # Verify sorting - mock_db.query.assert_called_once_with(models.Session) +# # Assert +# assert len(messages) == 2 +# assert messages[0].created_at < messages[1].created_at # Verify sorting +# mock_db.query.assert_called_once_with(models.Session) -def test_get_message_history_not_found(rag_service: RAGService): - """Tests retrieving history for a non-existent session.""" - # Arrange - mock_db = MagicMock(spec=Session) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None +# def test_get_message_history_not_found(rag_service: RAGService): +# """Tests retrieving history for a non-existent session.""" +# # Arrange +# mock_db = MagicMock(spec=Session) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None - # Act - messages = rag_service.get_message_history(db=mock_db, session_id=999) +# # Act +# messages = rag_service.get_message_history(db=mock_db, session_id=999) - # Assert - assert messages is None +# # Assert +# assert messages is None -# --- Document Management Tests --- +# # --- Document Management Tests --- -@patch('app.db.models.VectorMetadata') -@patch('app.db.models.Document') -@patch('app.core.vector_store.faiss_store.FaissVectorStore') -def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): - """ - Test the RAGService.add_document method for a successful run. - Verifies that the method correctly calls db.add(), db.commit(), and the vector store. - """ - # Setup mocks - mock_db = MagicMock(spec=Session) - mock_new_document_instance = MagicMock() - mock_document_model.return_value = mock_new_document_instance - mock_new_document_instance.id = 1 - mock_new_document_instance.text = "Test text." - mock_new_document_instance.title = "Test Title" +# @patch('app.db.models.VectorMetadata') +# @patch('app.db.models.Document') +# @patch('app.core.vector_store.faiss_store.FaissVectorStore') +# def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): +# """ +# Test the RAGService.add_document method for a successful run. +# Verifies that the method correctly calls db.add(), db.commit(), and the vector store. +# """ +# # Setup mocks +# mock_db = MagicMock(spec=Session) +# mock_new_document_instance = MagicMock() +# mock_document_model.return_value = mock_new_document_instance +# mock_new_document_instance.id = 1 +# mock_new_document_instance.text = "Test text." +# mock_new_document_instance.title = "Test Title" - mock_vector_store_instance = mock_vector_store.return_value - # Fix: Manually set the embedder on the mock vector store instance - mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) - mock_vector_store_instance.add_document.return_value = 123 +# mock_vector_store_instance = mock_vector_store.return_value +# # Fix: Manually set the embedder on the mock vector store instance +# mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) +# mock_vector_store_instance.add_document.return_value = 123 - # Instantiate the service correctly - rag_service = RAGService( - vector_store=mock_vector_store_instance, - retrievers=[] - ) +# # Instantiate the service correctly +# rag_service = RAGService( +# vector_store=mock_vector_store_instance, +# retrievers=[] +# ) - doc_data = { - "title": "Test Title", - "text": "Test text.", - "source_url": "http://test.com" - } +# doc_data = { +# "title": "Test Title", +# "text": "Test text.", +# "source_url": "http://test.com" +# } - # Call the method under test - document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) +# # Call the method under test +# document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) - # Assertions - assert document_id == 1 +# # Assertions +# assert document_id == 1 - from unittest.mock import call - expected_calls = [ - call(mock_new_document_instance), - call(mock_vector_metadata_model.return_value) - ] - mock_db.add.assert_has_calls(expected_calls) +# from unittest.mock import call +# expected_calls = [ +# call(mock_new_document_instance), +# call(mock_vector_metadata_model.return_value) +# ] +# mock_db.add.assert_has_calls(expected_calls) - mock_db.commit.assert_called() - mock_db.refresh.assert_called_with(mock_new_document_instance) - mock_vector_store_instance.add_document.assert_called_once_with("Test text.") +# mock_db.commit.assert_called() +# mock_db.refresh.assert_called_with(mock_new_document_instance) +# mock_vector_store_instance.add_document.assert_called_once_with("Test text.") - # Assert that VectorMetadata was instantiated with the correct arguments - mock_vector_metadata_model.assert_called_once_with( - document_id=mock_new_document_instance.id, - faiss_index=mock_vector_store_instance.add_document.return_value, - embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder - ) +# # Assert that VectorMetadata was instantiated with the correct arguments +# mock_vector_metadata_model.assert_called_once_with( +# document_id=mock_new_document_instance.id, +# faiss_index=mock_vector_store_instance.add_document.return_value, +# embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder +# ) -@patch('app.core.vector_store.faiss_store.FaissVectorStore') -def test_rag_service_add_document_error_handling(mock_vector_store): - """ - Test the RAGService.add_document method's error handling. - Verifies that the transaction is rolled back on an exception. - """ - # Setup mocks - mock_db = MagicMock(spec=Session) +# @patch('app.core.vector_store.faiss_store.FaissVectorStore') +# def test_rag_service_add_document_error_handling(mock_vector_store): +# """ +# Test the RAGService.add_document method's error handling. +# Verifies that the transaction is rolled back on an exception. +# """ +# # Setup mocks +# mock_db = MagicMock(spec=Session) - # Configure the mock db.add to raise the specific SQLAlchemyError. - mock_db.add.side_effect = SQLAlchemyError("Database error") +# # Configure the mock db.add to raise the specific SQLAlchemyError. +# mock_db.add.side_effect = SQLAlchemyError("Database error") - mock_vector_store_instance = mock_vector_store.return_value +# mock_vector_store_instance = mock_vector_store.return_value - # Instantiate the service correctly - rag_service = RAGService( - vector_store=mock_vector_store_instance, - retrievers=[] - ) +# # Instantiate the service correctly +# rag_service = RAGService( +# vector_store=mock_vector_store_instance, +# retrievers=[] +# ) - doc_data = { - "title": "Test Title", - "text": "Test text.", - "source_url": "http://test.com" - } +# doc_data = { +# "title": "Test Title", +# "text": "Test text.", +# "source_url": "http://test.com" +# } - # Call the method under test and expect an exception - with pytest.raises(SQLAlchemyError, match="Database error"): - rag_service.add_document(db=mock_db, doc_data=doc_data) +# # Call the method under test and expect an exception +# with pytest.raises(SQLAlchemyError, match="Database error"): +# rag_service.add_document(db=mock_db, doc_data=doc_data) - # Assertions - mock_db.add.assert_called_once() - mock_db.commit.assert_not_called() - mock_db.rollback.assert_called_once() +# # Assertions +# mock_db.add.assert_called_once() +# mock_db.commit.assert_not_called() +# mock_db.rollback.assert_called_once() diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index b8185af..18afe8f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -1,7 +1,12 @@ # app/api/dependencies.py from fastapi import Depends, HTTPException, status +from typing import List from sqlalchemy.orm import Session from app.db.session import SessionLocal +from app.core.retrievers import Retriever +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore # This is a dependency def get_db(): @@ -16,4 +21,13 @@ # In a real app, you would decode the token and fetch the user if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return {"email": "user@example.com", "id": 1} # Dummy user \ No newline at end of file + return {"email": "user@example.com", "id": 1} # Dummy user + + +class ServiceContainer: + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + # Initialize all services within the container + self.document_service = DocumentService(vector_store=vector_store) + self.rag_service = RAGService( + retrievers=retrievers + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 85ecafd..6d216ec 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session -from app.core.services import RAGService +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -def create_api_router(rag_service: RAGService) -> APIRouter: +def create_api_router(services: ServiceContainer) -> APIRouter: """ Creates and returns an APIRouter with all the application's endpoints. """ @@ -25,7 +25,7 @@ Starts a new conversation session and returns its details. """ try: - new_session = rag_service.create_session( + new_session = services.rag_service.create_session( db=db, user_id=request.user_id, model=request.model @@ -47,7 +47,7 @@ The 'model' and 'load_faiss_retriever' can now be specified in the request body. """ try: - response_text, model_used = await rag_service.chat_with_rag( + response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, @@ -66,7 +66,7 @@ """ try: # Note: You'll need to add a `get_message_history` method to your RAGService. - messages = rag_service.get_message_history(db=db, session_id=session_id) + messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: # Service can return None if session not found raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") @@ -83,7 +83,7 @@ def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() - document_id = rag_service.add_document(db=db, doc_data=doc_data) + document_id = services.document_service.add_document(db=db, doc_data=doc_data) return schemas.DocumentResponse( message=f"Document '{doc.title}' added successfully with ID {document_id}" ) @@ -93,7 +93,7 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): try: - documents_from_db = rag_service.get_all_documents(db=db) + documents_from_db = services.document_service.get_all_documents(db=db) return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @@ -101,7 +101,7 @@ @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): try: - deleted_id = rag_service.delete_document(db=db, document_id=document_id) + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index df21ba3..94413c8 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -7,10 +7,10 @@ from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.services import RAGService from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config +from app.api.dependencies import ServiceContainer # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -56,7 +56,7 @@ ) # 2. Initialize the FaissVectorStore with the chosen embedder - app.state.vector_store = FaissVectorStore( + vector_store = FaissVectorStore( index_file_path=settings.FAISS_INDEX_PATH, dimension=settings.EMBEDDING_DIMENSION, embedder=embedder # Pass the instantiated embedder object, @@ -65,18 +65,15 @@ # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=app.state.vector_store), + FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the RAGService with the created retriever list - # The llm_clients are no longer passed here, as per your services.py - rag_service = RAGService( - vector_store=app.state.vector_store, - retrievers=retrievers - ) + # 4. Initialize the Service Container + + services = ServiceContainer(vector_store, retrievers) # Create and include the API router, injecting the service - api_router = create_api_router(rag_service=rag_service) + api_router = create_api_router(services=services) app.include_router(api_router) return app diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index afa0e1f..a5694dc 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,141 +1,141 @@ -import asyncio -from typing import List, Dict, Any, Tuple -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.exc import SQLAlchemyError -import dspy +# import asyncio +# from typing import List, Dict, Any, Tuple +# from sqlalchemy.orm import Session, joinedload +# from sqlalchemy.exc import SQLAlchemyError +# import dspy -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -from app.db import models -from app.core.retrievers import Retriever, FaissDBRetriever -from app.core.llm_providers import get_llm_provider -from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# from app.db import models +# from app.core.retrievers import Retriever, FaissDBRetriever +# from app.core.llm_providers import get_llm_provider +# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline -class RAGService: - """ - Service class for managing documents and conversational RAG sessions. - This class is now more robust and can handle both real and mock embedders - by inspecting its dependencies. - """ - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): - self.vector_store = vector_store - self.retrievers = retrievers +# class RAGService: +# """ +# Service class for managing documents and conversational RAG sessions. +# This class is now more robust and can handle both real and mock embedders +# by inspecting its dependencies. +# """ +# def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): +# self.vector_store = vector_store +# self.retrievers = retrievers - # Assume one of the retrievers is the FAISS retriever, and you can access it. - # A better approach might be to have a dictionary of named retrievers. - self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) +# # Assume one of the retrievers is the FAISS retriever, and you can access it. +# # A better approach might be to have a dictionary of named retrievers. +# self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) - # Store the embedder from the vector store for dynamic naming - self.embedder = self.vector_store.embedder +# # Store the embedder from the vector store for dynamic naming +# self.embedder = self.vector_store.embedder - # --- Session Management --- +# # --- Session Management --- - def create_session(self, db: Session, user_id: str, model: str) -> models.Session: - """Creates a new chat session in the database.""" - try: - new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") - db.add(new_session) - db.commit() - db.refresh(new_session) - return new_session - except SQLAlchemyError as e: - db.rollback() - raise +# def create_session(self, db: Session, user_id: str, model: str) -> models.Session: +# """Creates a new chat session in the database.""" +# try: +# new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") +# db.add(new_session) +# db.commit() +# db.refresh(new_session) +# return new_session +# except SQLAlchemyError as e: +# db.rollback() +# raise - async def chat_with_rag( - self, - db: Session, - session_id: int, - prompt: str, - model: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: - """ - Handles a message within a session, including saving history and getting a response. - Allows switching the LLM model and conditionally using the FAISS retriever. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# async def chat_with_rag( +# self, +# db: Session, +# session_id: int, +# prompt: str, +# model: str, +# load_faiss_retriever: bool = False +# ) -> Tuple[str, str]: +# """ +# Handles a message within a session, including saving history and getting a response. +# Allows switching the LLM model and conditionally using the FAISS retriever. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - if not session: - raise ValueError(f"Session with ID {session_id} not found.") +# if not session: +# raise ValueError(f"Session with ID {session_id} not found.") - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() +# user_message = models.Message(session_id=session_id, sender="user", content=prompt) +# db.add(user_message) +# db.commit() - llm_provider = get_llm_provider(model) - dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) - dspy.configure(lm=dspy_llm) +# llm_provider = get_llm_provider(model) +# dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) +# dspy.configure(lm=dspy_llm) - current_retrievers = [] - if load_faiss_retriever: - if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) - else: - print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") +# current_retrievers = [] +# if load_faiss_retriever: +# if self.faiss_retriever: +# current_retrievers.append(self.faiss_retriever) +# else: +# print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) +# rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) - answer_text = await rag_pipeline.forward( - question=prompt, - history=session.messages, - db=db - ) +# answer_text = await rag_pipeline.forward( +# question=prompt, +# history=session.messages, +# db=db +# ) - assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) - db.add(assistant_message) - db.commit() +# assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) +# db.add(assistant_message) +# db.commit() - return answer_text, model +# return answer_text, model - def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, or None if the session doesn't exist. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() +# def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: +# """ +# Retrieves all messages for a given session, or None if the session doesn't exist. +# """ +# session = db.query(models.Session).options( +# joinedload(models.Session.messages) +# ).filter(models.Session.id == session_id).first() - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None +# return sorted(session.messages, key=lambda msg: msg.created_at) if session else None - # --- Document Management (Updated) --- - def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - try: - document_db = models.Document(**doc_data) - db.add(document_db) - db.commit() - db.refresh(document_db) +# # --- Document Management (Updated) --- +# def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: +# try: +# document_db = models.Document(**doc_data) +# db.add(document_db) +# db.commit() +# db.refresh(document_db) - # Use the embedder provided to the vector store to get the correct model name - embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" +# # Use the embedder provided to the vector store to get the correct model name +# embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( - document_id=document_db.id, - faiss_index=faiss_index, - embedding_model=embedding_model_name - ) - db.add(vector_metadata) - db.commit() - return document_db.id - except SQLAlchemyError as e: - db.rollback() - raise +# faiss_index = self.vector_store.add_document(document_db.text) +# vector_metadata = models.VectorMetadata( +# document_id=document_db.id, +# faiss_index=faiss_index, +# embedding_model=embedding_model_name +# ) +# db.add(vector_metadata) +# db.commit() +# return document_db.id +# except SQLAlchemyError as e: +# db.rollback() +# raise - def get_all_documents(self, db: Session) -> List[models.Document]: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() +# def get_all_documents(self, db: Session) -> List[models.Document]: +# return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - def delete_document(self, db: Session, document_id: int) -> int: - try: - doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: - return None - db.delete(doc_to_delete) - db.commit() - return document_id - except SQLAlchemyError as e: - db.rollback() - raise +# def delete_document(self, db: Session, document_id: int) -> int: +# try: +# doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() +# if not doc_to_delete: +# return None +# db.delete(doc_to_delete) +# db.commit() +# return document_id +# except SQLAlchemyError as e: +# db.rollback() +# raise diff --git a/ai-hub/app/core/services/__init__.py b/ai-hub/app/core/services/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/core/services/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py new file mode 100644 index 0000000..3634f3e --- /dev/null +++ b/ai-hub/app/core/services/document.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.db import models + +class DocumentService: + """ + Service class for managing document lifecycle, including + adding, retrieving, and deleting documents and their vector metadata. + """ + def __init__(self, vector_store: FaissVectorStore): + self.vector_store = vector_store + self.embedder = self.vector_store.embedder + + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: + """ + Adds a new document to the database and its vector embedding to the FAISS index. + """ + try: + document_db = models.Document(**doc_data) + db.add(document_db) + db.commit() + db.refresh(document_db) + + embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder" + + faiss_index = self.vector_store.add_document(document_db.text) + vector_metadata = models.VectorMetadata( + document_id=document_db.id, + faiss_index=faiss_index, + embedding_model=embedding_model_name + ) + db.add(vector_metadata) + db.commit() + return document_db.id + except SQLAlchemyError as e: + db.rollback() + raise + + def get_all_documents(self, db: Session) -> List[models.Document]: + """ + Retrieves all documents from the database. + """ + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + + def delete_document(self, db: Session, document_id: int) -> int: + """ + Deletes a document and its associated vector metadata from the database. + """ + try: + doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() + if not doc_to_delete: + return None + + # Assuming you also need to delete the vector metadata associated with the document + # for a full cleanup. + # db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() + + db.delete(doc_to_delete) + db.commit() + return document_id + except SQLAlchemyError as e: + db.rollback() + raise \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py new file mode 100644 index 0000000..344e779 --- /dev/null +++ b/ai-hub/app/core/services/rag.py @@ -0,0 +1,90 @@ +import asyncio +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError +import dspy + +from app.core.vector_store.faiss_store import FaissVectorStore +from app.db import models +from app.core.retrievers import Retriever, FaissDBRetriever +from app.core.llm_providers import get_llm_provider +from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline + +class RAGService: + """ + Service class for managing conversational RAG sessions. + This class orchestrates the RAG pipeline and manages chat sessions. + """ + def __init__(self, retrievers: List[Retriever]): + self.retrievers = retrievers + self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """Creates a new chat session in the database.""" + try: + new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag( + self, + db: Session, + session_id: int, + prompt: str, + model: str, + load_faiss_retriever: bool = False + ) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + llm_provider = get_llm_provider(model) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) + dspy.configure(lm=dspy_llm) + + current_retrievers = [] + if load_faiss_retriever: + if self.faiss_retriever: + current_retrievers.append(self.faiss_retriever) + else: + print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") + + rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + + answer_text = await rag_pipeline.forward( + question=prompt, + history=session.messages, + db=db + ) + + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, model + + def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: + """ + Retrieves all messages for a given session. + """ + session = db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py new file mode 100644 index 0000000..470780d --- /dev/null +++ b/ai-hub/tests/api/test_dependencies.py @@ -0,0 +1,112 @@ +# tests/api/test_dependencies.py +import pytest +import asyncio +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from fastapi import HTTPException + +# Import the dependencies and services to be tested +from app.api.dependencies import get_db, get_current_user, ServiceContainer +from app.core.services.document import DocumentService +from app.core.services.rag import RAGService +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.retrievers import Retriever + +@pytest.fixture +def mock_session(): + """ + Fixture that provides a mock SQLAlchemy session. + """ + mock = MagicMock(spec=Session) + yield mock + +# --- Tests for get_db dependency --- + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_yields_session_and_closes(mock_session_local, mock_session): + """ + Tests that get_db yields a database session and ensures it's closed correctly. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act: Use the generator in a context manager + db_generator = get_db() + db = next(db_generator) + + # Assert 1: The correct session object was yielded + assert db == mock_session + + # Act 2: Manually close the generator + with pytest.raises(StopIteration): + next(db_generator) + + # Assert 2: The session's close method was called + mock_session.close.assert_called_once() + +@patch('app.api.dependencies.SessionLocal') +def test_get_db_closes_on_exception(mock_session_local, mock_session): + """ + Tests that get_db still closes the session even if an exception occurs. + """ + # Arrange: Configure the mock SessionLocal to return our mock_session + mock_session_local.return_value = mock_session + + # Act & Assert: Call the generator and raise an exception + db_generator = get_db() + db = next(db_generator) + with pytest.raises(Exception): + db_generator.throw(Exception("Test exception")) + + # Assert: The session's close method was still called after the exception was handled + mock_session.close.assert_called_once() + + +# --- Tests for get_current_user dependency --- + +def test_get_current_user_with_valid_token(): + """ + Tests that get_current_user returns the expected user dictionary for a valid token. + """ + # Act + user = asyncio.run(get_current_user(token="valid_token")) + + # Assert + assert user == {"email": "user@example.com", "id": 1} + +def test_get_current_user_with_no_token(): + """ + Tests that get_current_user raises an HTTPException for a missing token. + """ + # Assert + with pytest.raises(HTTPException) as excinfo: + asyncio.run(get_current_user(token=None)) + + assert excinfo.value.status_code == 401 + assert "Unauthorized" in excinfo.value.detail + +# --- Tests for ServiceContainer class --- + +def test_service_container_initialization(): + """ + Tests that ServiceContainer initializes DocumentService and RAGService + with the correct dependencies. + """ + # Arrange: Create mock dependencies + mock_vector_store = MagicMock(spec=FaissVectorStore) + # The DocumentService constructor needs a .embedder attribute on the vector_store + mock_vector_store.embedder = MagicMock() + mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + + # Act: Instantiate the ServiceContainer + container = ServiceContainer( + vector_store=mock_vector_store, + retrievers=mock_retrievers + ) + + # Assert: Check if the services were created and configured correctly + assert isinstance(container.document_service, DocumentService) + assert container.document_service.vector_store == mock_vector_store + + assert isinstance(container.rag_service, RAGService) + assert container.rag_service.retrievers == mock_retrievers diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c58a6c7..8d841a2 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,3 +1,4 @@ +# tests/app/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI @@ -6,26 +7,42 @@ from datetime import datetime # Import the dependencies and router factory -from app.core.services import RAGService -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """Pytest fixture to create a TestClient with a fully mocked environment.""" + """ + Pytest fixture to create a TestClient with a fully mocked environment, + including a mock ServiceContainer. + """ test_app = FastAPI() + + # Mock individual services mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + + # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + + # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - api_router = create_api_router(rag_service=mock_rag_service) + # Pass the mock ServiceContainer to the router factory + api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - yield TestClient(test_app), mock_rag_service + # Return the test client and the mock services for assertion + yield TestClient(test_app), mock_services # --- General Endpoint --- @@ -40,32 +57,33 @@ def test_create_session_success(client): """Tests successfully creating a new chat session.""" - test_client, mock_rag_service = client + test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_rag_service.create_session.return_value = mock_session + mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 - mock_rag_service.create_session.assert_called_once() + mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the default model 'deepseek' # and the default load_faiss_retriever=False - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", @@ -76,16 +94,17 @@ """ Tests sending a message in an existing session and explicitly switching the model. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + # Verify that chat_with_rag was called with the specified model 'gemini' - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", @@ -96,8 +115,8 @@ """ Tests sending a message and explicitly enabling the FAISS retriever. """ - test_client, mock_rag_service = client - mock_rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", @@ -106,9 +125,10 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + # Verify that chat_with_rag was called with the correct parameters - mock_rag_service.chat_with_rag.assert_called_once_with( - db=mock_rag_service.chat_with_rag.call_args.kwargs['db'], + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", # The model still defaults to deepseek @@ -117,13 +137,13 @@ def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] - mock_rag_service.get_message_history.return_value = mock_history + mock_services.rag_service.get_message_history.return_value = mock_history # Act response = test_client.get("/sessions/123/messages") @@ -135,13 +155,16 @@ assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service.get_message_history.assert_called_once_with(db=mock_rag_service.get_message_history.call_args.kwargs['db'], session_id=123) + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" - test_client, mock_rag_service = client + test_client, mock_services = client # Arrange: Mock the service to return None, indicating the session wasn't found - mock_rag_service.get_message_history.return_value = None + mock_services.rag_service.get_message_history.return_value = None # Act response = test_client.get("/sessions/999/messages") @@ -151,35 +174,39 @@ assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- -# (These tests are unchanged) + def test_add_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.add_document.return_value = 123 + """Tests the /documents endpoint for adding a new document.""" + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): - test_client, mock_rag_service = client + """Tests the /documents endpoint for retrieving all documents.""" + test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] - mock_rag_service.get_all_documents.return_value = mock_docs + mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = 42 + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): - test_client, mock_rag_service = client - mock_rag_service.delete_document.return_value = None + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") - assert response.status_code == 404 \ No newline at end of file + assert response.status_code == 404 diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py new file mode 100644 index 0000000..cdf8b44 --- /dev/null +++ b/ai-hub/tests/core/services/test_document.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import MagicMock, patch +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from datetime import datetime + +from app.core.services.document import DocumentService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder + +@pytest.fixture +def document_service(): + """ + Pytest fixture to create a DocumentService instance with mocked dependencies. + """ + mock_embedder = MagicMock(spec=MockEmbedder) + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_vector_store.embedder = mock_embedder + return DocumentService(vector_store=mock_vector_store) + +# --- add_document Tests --- + +def test_add_document_success(document_service: DocumentService): + """ + Tests that add_document successfully adds a document to the database + and its vector embedding to the FAISS index. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_document = MagicMock(id=1, text="Test text.") + mock_db.add.side_effect = [None, None] # Allow multiple calls + + # Configure the mock db.query to return a document object + mock_document_model_instance = models.Document( + id=1, + title="Test Title", + text="Test text.", + source_url="http://test.com" + ) + + with patch('app.core.services.document.models.Document', return_value=mock_document_model_instance) as mock_document_model, \ + patch('app.core.services.document.models.VectorMetadata') as mock_vector_metadata_model: + + document_service.vector_store.add_document.return_value = 123 + + doc_data = { + "title": "Test Title", + "text": "Test text.", + "source_url": "http://test.com" + } + + # Act + document_id = document_service.add_document(db=mock_db, doc_data=doc_data) + + # Assert + assert document_id == 1 + mock_db.add.assert_any_call(mock_document_model_instance) + mock_db.commit.assert_called() + mock_db.refresh.assert_called_with(mock_document_model_instance) + document_service.vector_store.add_document.assert_called_once_with("Test text.") + mock_vector_metadata_model.assert_called_once_with( + document_id=1, + faiss_index=123, + embedding_model="mock_embedder" + ) + mock_db.add.assert_any_call(mock_vector_metadata_model.return_value) + + +def test_add_document_sql_error(document_service: DocumentService): + """ + Tests that add_document correctly handles a SQLAlchemyError by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.add.side_effect = SQLAlchemyError("Database error") + doc_data = {"title": "Test", "text": "...", "source_url": "http://test.com"} + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Database error"): + document_service.add_document(db=mock_db, doc_data=doc_data) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() + +# --- get_all_documents Tests --- + +def test_get_all_documents_success(document_service: DocumentService): + """ + Tests that get_all_documents returns a list of documents. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_documents = [models.Document(id=1), models.Document(id=2)] + mock_db.query.return_value.order_by.return_value.all.return_value = mock_documents + + # Act + documents = document_service.get_all_documents(db=mock_db) + + # Assert + assert documents == mock_documents + mock_db.query.assert_called_once_with(models.Document) + mock_db.query.return_value.order_by.assert_called_once() + +# --- delete_document Tests --- + +def test_delete_document_success(document_service: DocumentService): + """ + Tests that delete_document correctly deletes a document. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id_to_delete = 1 + doc_to_delete = models.Document(id=doc_id_to_delete) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=doc_id_to_delete) + + # Assert + assert deleted_id == doc_id_to_delete + mock_db.query.assert_called_once_with(models.Document) + mock_db.delete.assert_called_once_with(doc_to_delete) + mock_db.commit.assert_called_once() + +def test_delete_document_not_found(document_service: DocumentService): + """ + Tests that delete_document returns None if the document is not found. + """ + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act + deleted_id = document_service.delete_document(db=mock_db, document_id=999) + + # Assert + assert deleted_id is None + mock_db.delete.assert_not_called() + mock_db.commit.assert_not_called() + +def test_delete_document_sql_error(document_service: DocumentService): + """ + Tests that delete_document handles a SQLAlchemyError correctly by rolling back. + """ + # Arrange + mock_db = MagicMock(spec=Session) + doc_id = 1 + doc_to_delete = models.Document(id=doc_id) + mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + mock_db.delete.side_effect = SQLAlchemyError("Delete error") + + # Act & Assert + with pytest.raises(SQLAlchemyError, match="Delete error"): + document_service.delete_document(db=mock_db, document_id=doc_id) + + mock_db.rollback.assert_called_once() + mock_db.commit.assert_not_called() \ No newline at end of file diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py new file mode 100644 index 0000000..2fd4ab3 --- /dev/null +++ b/ai-hub/tests/core/services/test_rag.py @@ -0,0 +1,216 @@ +import pytest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError +from typing import List +from datetime import datetime +import dspy + +from app.core.services.rag import RAGService +from app.db import models +from app.core.vector_store.faiss_store import FaissVectorStore +from app.core.vector_store.embedder.mock import MockEmbedder +from app.core.retrievers import FaissDBRetriever, Retriever +from app.core.pipelines.dspy_rag import DspyRagPipeline +from app.core.llm_providers import LLMProvider + +@pytest.fixture +def rag_service(): + """ + Pytest fixture to create a RAGService instance with mocked dependencies. + It includes a mock FaissDBRetriever and a mock generic Retriever to test + conditional loading. + """ + # Create a mock vector store to provide a mock retriever + mock_vector_store = MagicMock(spec=FaissVectorStore) + + mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) + mock_web_retriever = MagicMock(spec=Retriever) + + return RAGService( + retrievers=[mock_web_retriever, mock_faiss_retriever] + ) + +# --- Session Management Tests --- + +def test_create_session(rag_service: RAGService): + """Tests that the create_session method correctly creates a new session.""" + mock_db = MagicMock(spec=Session) + + rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + + mock_db.add.assert_called_once() + added_object = mock_db.add.call_args[0][0] + assert isinstance(added_object, models.Session) + assert added_object.user_id == "test_user" + assert added_object.model_name == "gemini" + +@patch('app.core.services.rag.get_llm_provider') +@patch('app.core.services.rag.DspyRagPipeline') +@patch('dspy.configure') +def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): + """ + Tests the full orchestration of a chat message within a session using the default model + and with the retriever loading parameter explicitly set to False. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=42, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=42, + prompt="Test prompt", + model="deepseek", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("deepseek") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response" + assert model_name == "deepseek" + +def test_chat_with_rag_model_switch(rag_service: RAGService): + """ + Tests that chat_with_rag correctly switches the model based on the 'model' argument, + while still using the default retriever setting. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=43, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=43, + prompt="Test prompt for Gemini", + model="gemini", + load_faiss_retriever=False + ) + ) + + # --- Assert --- + mock_db.query.assert_called_once_with(models.Session) + assert mock_db.add.call_count == 2 + mock_get_llm_provider.assert_called_once_with("gemini") + + mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt for Gemini", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Final RAG response from Gemini" + assert model_name == "gemini" + + +def test_chat_with_rag_with_faiss_retriever(rag_service: RAGService): + """ + Tests that the chat_with_rag method correctly initializes the DspyRagPipeline + with the FaissDBRetriever when `load_faiss_retriever` is True. + """ + # --- Arrange --- + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=44, model_name="deepseek", messages=[]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ + patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ + patch('dspy.configure'): + + mock_llm_provider = MagicMock(spec=LLMProvider) + mock_get_llm_provider.return_value = mock_llm_provider + mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) + mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") + mock_dspy_pipeline.return_value = mock_pipeline_instance + + # --- Act --- + answer, model_name = asyncio.run( + rag_service.chat_with_rag( + db=mock_db, + session_id=44, + prompt="Test prompt with FAISS", + model="deepseek", + load_faiss_retriever=True + ) + ) + + # --- Assert --- + expected_retrievers = [rag_service.faiss_retriever] + mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + + mock_pipeline_instance.forward.assert_called_once_with( + question="Test prompt with FAISS", + history=mock_session.messages, + db=mock_db + ) + + assert answer == "Response with FAISS context" + assert model_name == "deepseek" + + +def test_get_message_history_success(rag_service: RAGService): + """Tests successfully retrieving message history for an existing session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=1, messages=[ + models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), + models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) + ]) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=1) + + # Assert + assert len(messages) == 2 + assert messages[0].created_at < messages[1].created_at + mock_db.query.assert_called_once_with(models.Session) + +def test_get_message_history_not_found(rag_service: RAGService): + """Tests retrieving history for a non-existent session.""" + # Arrange + mock_db = MagicMock(spec=Session) + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None + + # Act + messages = rag_service.get_message_history(db=mock_db, session_id=999) + + # Assert + assert messages is None \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index d36c203..a59a6d4 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -1,320 +1,320 @@ -import pytest -import asyncio -from unittest.mock import patch, MagicMock, AsyncMock -from sqlalchemy.orm import Session -from sqlalchemy.exc import SQLAlchemyError -from typing import List -from datetime import datetime -import dspy +# import pytest +# import asyncio +# from unittest.mock import patch, MagicMock, AsyncMock +# from sqlalchemy.orm import Session +# from sqlalchemy.exc import SQLAlchemyError +# from typing import List +# from datetime import datetime +# import dspy -# Import the service and its dependencies -from app.core.services import RAGService -from app.db import models -from app.core.vector_store.faiss_store import FaissVectorStore -from app.core.vector_store.embedder.mock import MockEmbedder -# Import FaissDBRetriever and a mock WebRetriever for testing different cases -from app.core.retrievers import FaissDBRetriever, Retriever -from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider -from app.core.llm_providers import LLMProvider +# # Import the service and its dependencies +# from app.core.services import RAGService +# from app.db import models +# from app.core.vector_store.faiss_store import FaissVectorStore +# from app.core.vector_store.embedder.mock import MockEmbedder +# # Import FaissDBRetriever and a mock WebRetriever for testing different cases +# from app.core.retrievers import FaissDBRetriever, Retriever +# from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider +# from app.core.llm_providers import LLMProvider -@pytest.fixture -def rag_service(): - """ - Pytest fixture to create a RAGService instance with mocked dependencies. - It includes a mock FaissDBRetriever and a mock generic Retriever to test - conditional loading. - """ - # Create a mock embedder to be attached to the vector store mock - mock_embedder = MagicMock(spec=MockEmbedder) - mock_vector_store = MagicMock(spec=FaissVectorStore) - mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute +# @pytest.fixture +# def rag_service(): +# """ +# Pytest fixture to create a RAGService instance with mocked dependencies. +# It includes a mock FaissDBRetriever and a mock generic Retriever to test +# conditional loading. +# """ +# # Create a mock embedder to be attached to the vector store mock +# mock_embedder = MagicMock(spec=MockEmbedder) +# mock_vector_store = MagicMock(spec=FaissVectorStore) +# mock_vector_store.embedder = mock_embedder # Explicitly set the embedder attribute - mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) - mock_web_retriever = MagicMock(spec=Retriever) - return RAGService( - vector_store=mock_vector_store, - retrievers=[mock_web_retriever, mock_faiss_retriever] - ) +# mock_faiss_retriever = MagicMock(spec=FaissDBRetriever) +# mock_web_retriever = MagicMock(spec=Retriever) +# return RAGService( +# vector_store=mock_vector_store, +# retrievers=[mock_web_retriever, mock_faiss_retriever] +# ) -# --- Session Management Tests --- +# # --- Session Management Tests --- -def test_create_session(rag_service: RAGService): - """Tests that the create_session method correctly creates a new session.""" - mock_db = MagicMock(spec=Session) +# def test_create_session(rag_service: RAGService): +# """Tests that the create_session method correctly creates a new session.""" +# mock_db = MagicMock(spec=Session) - rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") +# rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") - mock_db.add.assert_called_once() - added_object = mock_db.add.call_args[0][0] - assert isinstance(added_object, models.Session) - assert added_object.user_id == "test_user" - assert added_object.model_name == "gemini" +# mock_db.add.assert_called_once() +# added_object = mock_db.add.call_args[0][0] +# assert isinstance(added_object, models.Session) +# assert added_object.user_id == "test_user" +# assert added_object.model_name == "gemini" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests the full orchestration of a chat message within a session using the default model - and with the retriever loading parameter explicitly set to False. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=42, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests the full orchestration of a chat message within a session using the default model +# and with the retriever loading parameter explicitly set to False. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=42, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=42, - prompt="Test prompt", - model="deepseek", - load_faiss_retriever=False # Explicitly pass the default value - ) - ) +# # --- Act --- +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=42, +# prompt="Test prompt", +# model="deepseek", +# load_faiss_retriever=False # Explicitly pass the default value +# ) +# ) - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("deepseek") +# # --- Assert --- +# mock_db.query.assert_called_once_with(models.Session) +# assert mock_db.add.call_count == 2 +# mock_get_llm_provider.assert_called_once_with("deepseek") - # Assert that DspyRagPipeline was initialized with an empty list of retrievers - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) +# # Assert that DspyRagPipeline was initialized with an empty list of retrievers +# mock_dspy_pipeline.assert_called_once_with(retrievers=[]) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Final RAG response" - assert model_name == "deepseek" +# assert answer == "Final RAG response" +# assert model_name == "deepseek" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests that chat_with_rag correctly switches the model based on the 'model' argument, - while still using the default retriever setting. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=43, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_model_switch(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests that chat_with_rag correctly switches the model based on the 'model' argument, +# while still using the default retriever setting. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=43, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response from Gemini") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=43, - prompt="Test prompt for Gemini", - model="gemini", - load_faiss_retriever=False # Explicitly pass the default value - ) - ) +# # --- Act --- +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=43, +# prompt="Test prompt for Gemini", +# model="gemini", +# load_faiss_retriever=False # Explicitly pass the default value +# ) +# ) - # --- Assert --- - mock_db.query.assert_called_once_with(models.Session) - assert mock_db.add.call_count == 2 - mock_get_llm_provider.assert_called_once_with("gemini") +# # --- Assert --- +# mock_db.query.assert_called_once_with(models.Session) +# assert mock_db.add.call_count == 2 +# mock_get_llm_provider.assert_called_once_with("gemini") - # Assert that DspyRagPipeline was initialized with an empty list of retrievers - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) +# # Assert that DspyRagPipeline was initialized with an empty list of retrievers +# mock_dspy_pipeline.assert_called_once_with(retrievers=[]) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt for Gemini", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt for Gemini", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Final RAG response from Gemini" - assert model_name == "gemini" +# assert answer == "Final RAG response from Gemini" +# assert model_name == "gemini" -@patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') -@patch('dspy.configure') -def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): - """ - Tests that the chat_with_rag method correctly initializes the DspyRagPipeline - with the FaissDBRetriever when `load_faiss_retriever` is True. - """ - # --- Arrange --- - mock_db = MagicMock(spec=Session) - mock_session = models.Session(id=44, model_name="deepseek", messages=[]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# @patch('app.core.services.get_llm_provider') +# @patch('app.core.services.DspyRagPipeline') +# @patch('dspy.configure') +# def test_chat_with_rag_with_faiss_retriever(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): +# """ +# Tests that the chat_with_rag method correctly initializes the DspyRagPipeline +# with the FaissDBRetriever when `load_faiss_retriever` is True. +# """ +# # --- Arrange --- +# mock_db = MagicMock(spec=Session) +# mock_session = models.Session(id=44, model_name="deepseek", messages=[]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - mock_llm_provider = MagicMock(spec=LLMProvider) - mock_get_llm_provider.return_value = mock_llm_provider - mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) - mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") - mock_dspy_pipeline.return_value = mock_pipeline_instance +# mock_llm_provider = MagicMock(spec=LLMProvider) +# mock_get_llm_provider.return_value = mock_llm_provider +# mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) +# mock_pipeline_instance.forward = AsyncMock(return_value="Response with FAISS context") +# mock_dspy_pipeline.return_value = mock_pipeline_instance - # --- Act --- - # Explicitly enable the FAISS retriever - answer, model_name = asyncio.run( - rag_service.chat_with_rag( - db=mock_db, - session_id=44, - prompt="Test prompt with FAISS", - model="deepseek", - load_faiss_retriever=True - ) - ) +# # --- Act --- +# # Explicitly enable the FAISS retriever +# answer, model_name = asyncio.run( +# rag_service.chat_with_rag( +# db=mock_db, +# session_id=44, +# prompt="Test prompt with FAISS", +# model="deepseek", +# load_faiss_retriever=True +# ) +# ) - # --- Assert --- - # The crucial part is to verify that the pipeline was called with the correct retriever - expected_retrievers = [rag_service.faiss_retriever] - mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) +# # --- Assert --- +# # The crucial part is to verify that the pipeline was called with the correct retriever +# expected_retrievers = [rag_service.faiss_retriever] +# mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) - mock_pipeline_instance.forward.assert_called_once_with( - question="Test prompt with FAISS", - history=mock_session.messages, - db=mock_db - ) +# mock_pipeline_instance.forward.assert_called_once_with( +# question="Test prompt with FAISS", +# history=mock_session.messages, +# db=mock_db +# ) - assert answer == "Response with FAISS context" - assert model_name == "deepseek" +# assert answer == "Response with FAISS context" +# assert model_name == "deepseek" -def test_get_message_history_success(rag_service: RAGService): - """Tests successfully retrieving message history for an existing session.""" - # Arrange - mock_db = MagicMock(spec=Session) - # Ensure mocked messages have created_at for sorting - mock_session = models.Session(id=1, messages=[ - models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), - models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) - ]) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session +# def test_get_message_history_success(rag_service: RAGService): +# """Tests successfully retrieving message history for an existing session.""" +# # Arrange +# mock_db = MagicMock(spec=Session) +# # Ensure mocked messages have created_at for sorting +# mock_session = models.Session(id=1, messages=[ +# models.Message(created_at=datetime(2023, 1, 1, 10, 0, 0)), +# models.Message(created_at=datetime(2023, 1, 1, 10, 1, 0)) +# ]) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session - # Act - messages = rag_service.get_message_history(db=mock_db, session_id=1) +# # Act +# messages = rag_service.get_message_history(db=mock_db, session_id=1) - # Assert - assert len(messages) == 2 - assert messages[0].created_at < messages[1].created_at # Verify sorting - mock_db.query.assert_called_once_with(models.Session) +# # Assert +# assert len(messages) == 2 +# assert messages[0].created_at < messages[1].created_at # Verify sorting +# mock_db.query.assert_called_once_with(models.Session) -def test_get_message_history_not_found(rag_service: RAGService): - """Tests retrieving history for a non-existent session.""" - # Arrange - mock_db = MagicMock(spec=Session) - mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None +# def test_get_message_history_not_found(rag_service: RAGService): +# """Tests retrieving history for a non-existent session.""" +# # Arrange +# mock_db = MagicMock(spec=Session) +# mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None - # Act - messages = rag_service.get_message_history(db=mock_db, session_id=999) +# # Act +# messages = rag_service.get_message_history(db=mock_db, session_id=999) - # Assert - assert messages is None +# # Assert +# assert messages is None -# --- Document Management Tests --- +# # --- Document Management Tests --- -@patch('app.db.models.VectorMetadata') -@patch('app.db.models.Document') -@patch('app.core.vector_store.faiss_store.FaissVectorStore') -def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): - """ - Test the RAGService.add_document method for a successful run. - Verifies that the method correctly calls db.add(), db.commit(), and the vector store. - """ - # Setup mocks - mock_db = MagicMock(spec=Session) - mock_new_document_instance = MagicMock() - mock_document_model.return_value = mock_new_document_instance - mock_new_document_instance.id = 1 - mock_new_document_instance.text = "Test text." - mock_new_document_instance.title = "Test Title" +# @patch('app.db.models.VectorMetadata') +# @patch('app.db.models.Document') +# @patch('app.core.vector_store.faiss_store.FaissVectorStore') +# def test_rag_service_add_document_success(mock_vector_store, mock_document_model, mock_vector_metadata_model): +# """ +# Test the RAGService.add_document method for a successful run. +# Verifies that the method correctly calls db.add(), db.commit(), and the vector store. +# """ +# # Setup mocks +# mock_db = MagicMock(spec=Session) +# mock_new_document_instance = MagicMock() +# mock_document_model.return_value = mock_new_document_instance +# mock_new_document_instance.id = 1 +# mock_new_document_instance.text = "Test text." +# mock_new_document_instance.title = "Test Title" - mock_vector_store_instance = mock_vector_store.return_value - # Fix: Manually set the embedder on the mock vector store instance - mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) - mock_vector_store_instance.add_document.return_value = 123 +# mock_vector_store_instance = mock_vector_store.return_value +# # Fix: Manually set the embedder on the mock vector store instance +# mock_vector_store_instance.embedder = MagicMock(spec=MockEmbedder) +# mock_vector_store_instance.add_document.return_value = 123 - # Instantiate the service correctly - rag_service = RAGService( - vector_store=mock_vector_store_instance, - retrievers=[] - ) +# # Instantiate the service correctly +# rag_service = RAGService( +# vector_store=mock_vector_store_instance, +# retrievers=[] +# ) - doc_data = { - "title": "Test Title", - "text": "Test text.", - "source_url": "http://test.com" - } +# doc_data = { +# "title": "Test Title", +# "text": "Test text.", +# "source_url": "http://test.com" +# } - # Call the method under test - document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) +# # Call the method under test +# document_id = rag_service.add_document(db=mock_db, doc_data=doc_data) - # Assertions - assert document_id == 1 +# # Assertions +# assert document_id == 1 - from unittest.mock import call - expected_calls = [ - call(mock_new_document_instance), - call(mock_vector_metadata_model.return_value) - ] - mock_db.add.assert_has_calls(expected_calls) +# from unittest.mock import call +# expected_calls = [ +# call(mock_new_document_instance), +# call(mock_vector_metadata_model.return_value) +# ] +# mock_db.add.assert_has_calls(expected_calls) - mock_db.commit.assert_called() - mock_db.refresh.assert_called_with(mock_new_document_instance) - mock_vector_store_instance.add_document.assert_called_once_with("Test text.") +# mock_db.commit.assert_called() +# mock_db.refresh.assert_called_with(mock_new_document_instance) +# mock_vector_store_instance.add_document.assert_called_once_with("Test text.") - # Assert that VectorMetadata was instantiated with the correct arguments - mock_vector_metadata_model.assert_called_once_with( - document_id=mock_new_document_instance.id, - faiss_index=mock_vector_store_instance.add_document.return_value, - embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder - ) +# # Assert that VectorMetadata was instantiated with the correct arguments +# mock_vector_metadata_model.assert_called_once_with( +# document_id=mock_new_document_instance.id, +# faiss_index=mock_vector_store_instance.add_document.return_value, +# embedding_model="mock_embedder" # This now passes because the mock embedder is of type MockEmbedder +# ) -@patch('app.core.vector_store.faiss_store.FaissVectorStore') -def test_rag_service_add_document_error_handling(mock_vector_store): - """ - Test the RAGService.add_document method's error handling. - Verifies that the transaction is rolled back on an exception. - """ - # Setup mocks - mock_db = MagicMock(spec=Session) +# @patch('app.core.vector_store.faiss_store.FaissVectorStore') +# def test_rag_service_add_document_error_handling(mock_vector_store): +# """ +# Test the RAGService.add_document method's error handling. +# Verifies that the transaction is rolled back on an exception. +# """ +# # Setup mocks +# mock_db = MagicMock(spec=Session) - # Configure the mock db.add to raise the specific SQLAlchemyError. - mock_db.add.side_effect = SQLAlchemyError("Database error") +# # Configure the mock db.add to raise the specific SQLAlchemyError. +# mock_db.add.side_effect = SQLAlchemyError("Database error") - mock_vector_store_instance = mock_vector_store.return_value +# mock_vector_store_instance = mock_vector_store.return_value - # Instantiate the service correctly - rag_service = RAGService( - vector_store=mock_vector_store_instance, - retrievers=[] - ) +# # Instantiate the service correctly +# rag_service = RAGService( +# vector_store=mock_vector_store_instance, +# retrievers=[] +# ) - doc_data = { - "title": "Test Title", - "text": "Test text.", - "source_url": "http://test.com" - } +# doc_data = { +# "title": "Test Title", +# "text": "Test text.", +# "source_url": "http://test.com" +# } - # Call the method under test and expect an exception - with pytest.raises(SQLAlchemyError, match="Database error"): - rag_service.add_document(db=mock_db, doc_data=doc_data) +# # Call the method under test and expect an exception +# with pytest.raises(SQLAlchemyError, match="Database error"): +# rag_service.add_document(db=mock_db, doc_data=doc_data) - # Assertions - mock_db.add.assert_called_once() - mock_db.commit.assert_not_called() - mock_db.rollback.assert_called_once() +# # Assertions +# mock_db.add.assert_called_once() +# mock_db.commit.assert_not_called() +# mock_db.rollback.assert_called_once() diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 1574b46..1d7a96d 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -1,14 +1,17 @@ +# tests/app/test_app.py import os +import asyncio from fastapi.testclient import TestClient -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from datetime import datetime import numpy as np -# Import the factory function directly to get a fresh app instance for testing +# Import the factory function directly and dependencies from app.app import create_app -from app.api.dependencies import get_db +from app.api.dependencies import get_db, ServiceContainer from app.db import models +from app.core.retrievers import Retriever # Define a constant for the dimension to ensure consistency TEST_DIMENSION = 768 @@ -24,357 +27,310 @@ finally: pass -# --- API Endpoint Tests --- -# We patch the RAGService class itself, as the instance is created inside create_app(). - -def test_read_root(): +# We patch ServiceContainer directly to control its instantiation in create_app +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.app.FaissVectorStore.save_index') +@patch('app.app.print_config') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') # This patch is for the FaissVectorStore initialization +def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_save_index, mock_create_db, mock_service_container): """Test the root endpoint to ensure it's running.""" - # Patch the requests.post call for the GenAIEmbedder to avoid network calls during app creation. - # Also patch faiss.read_index to prevent file system errors. - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Arrange: We patch the embedder and faiss calls to prevent real logic + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - app = create_app() - client = TestClient(app) - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} + # The mock_service_container is a mock of the ServiceContainer class. + # We create an instance of it (mock_services) and configure it. + mock_services = MagicMock() + mock_service_container.return_value = mock_services -@patch('app.app.RAGService') -def test_create_session_success(mock_rag_service_class): + app = create_app() + client = TestClient(app) + response = client.get("/") + + # Assert + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + + +# We patch ServiceContainer directly to control its instantiation in create_app +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_create_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests successfully creating a new chat session via the POST /sessions endpoint. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response - - # Arrange - mock_rag_service_instance = mock_rag_service_class.return_value - mock_session_obj = models.Session( - id=1, - user_id="test_user", - model_name="gemini", - title="New Chat Session", - created_at=datetime.now() - ) - mock_rag_service_instance.create_session.return_value = mock_session_obj - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) - - # Act - response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - # Assert - assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == 1 - assert response_data["user_id"] == "test_user" - mock_rag_service_instance.create_session.assert_called_once_with( - db=mock_db, user_id="test_user", model="gemini" - ) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services -@patch('app.app.RAGService') -def test_chat_in_session_success(mock_rag_service_class): + # Configure the mock rag_service to return a mocked session object + mock_session_obj = models.Session( + id=1, + user_id="test_user", + model_name="gemini", + title="New Chat Session", + created_at=datetime.now() + ) + mock_services.rag_service.create_session.return_value = mock_session_obj + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Act + response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + + # Assert + assert response.status_code == 200 + response_data = response.json() + assert response_data["id"] == 1 + assert response_data["user_id"] == "test_user" + mock_services.rag_service.create_session.assert_called_once_with( + db=mock_db, user_id="test_user", model="gemini" + ) + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_chat_in_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the session-based chat endpoint with a successful, mocked response. It should default to 'deepseek' if no model is specified. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() + + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services - # Arrange - mock_rag_service_instance = mock_rag_service_class.return_value - # Mock the async method correctly using a mock async function - async def mock_chat_with_rag(*args, **kwargs): - return "This is a mock response.", "deepseek" - mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag) - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) - - # Act - response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) + # Correctly mock the async method using AsyncMock + mock_chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek")) + mock_services.rag_service.chat_with_rag = mock_chat_with_rag - # Assert - assert response.status_code == 200 - assert response.json()["answer"] == "This is a mock response." - assert response.json()["model_used"] == "deepseek" - mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False - ) + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) -@patch('app.app.RAGService') -def test_chat_in_session_with_model_switch(mock_rag_service_class): + # Act + response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) + + # Assert + assert response.status_code == 200 + assert response.json()["answer"] == "This is a mock response." + assert response.json()["model_used"] == "deepseek" + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False + ) + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_chat_in_session_with_model_switch(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests sending a message in an existing session and explicitly switching the model. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response - - mock_rag_service_instance = mock_rag_service_class.return_value - # Mock the async method correctly using a mock async function - async def mock_chat_with_rag(*args, **kwargs): - return "Mocked response from Gemini", "gemini" - mock_rag_service_instance.chat_with_rag = MagicMock(side_effect=mock_chat_with_rag) - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - - assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, - session_id=42, - prompt="Hello there, Gemini!", - model="gemini", - load_faiss_retriever=False - ) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services -@patch('app.app.RAGService') -def test_get_session_messages_success(mock_rag_service_class): - """Tests retrieving the message history for a session.""" - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Correctly mock the async method using AsyncMock + mock_chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + mock_services.rag_service.chat_with_rag = mock_chat_with_rag - mock_rag_service_instance = mock_rag_service_class.return_value - mock_history = [ - models.Message(sender="user", content="Hello", created_at=datetime.now()), - models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) - ] - mock_rag_service_instance.get_message_history.return_value = mock_history - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) - - # Act - response = client.get("/sessions/123/messages") - - # Assert - assert response.status_code == 200 - response_data = response.json() - assert response_data["session_id"] == 123 - assert len(response_data["messages"]) == 2 - assert response_data["messages"][0]["sender"] == "user" - assert response_data["messages"][1]["content"] == "Hi there!" - mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=123) + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) -@patch('app.app.RAGService') -def test_get_session_messages_not_found(mock_rag_service_class): - """Tests retrieving messages for a session that does not exist.""" - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.get_message_history.return_value = None - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Assert + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_db, + session_id=42, + prompt="Hello there, Gemini!", + model="gemini", + load_faiss_retriever=False + ) - # Act - response = client.get("/sessions/999/messages") - - # Assert - assert response.status_code == 404 - assert response.json()["detail"] == "Session with ID 999 not found." - mock_rag_service_instance.get_message_history.assert_called_once_with(db=mock_db, session_id=999) - -@patch('app.app.RAGService') -def test_add_document_success(mock_rag_service_class): +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_add_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the /document endpoint with a successful, mocked RAG service response. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.add_document.return_value = 1 - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services + mock_services.document_service.add_document.return_value = 1 - doc_data = { - "title": "Test Document", - "text": "This is a test document.", - "source_url": "http://example.com/test" - } - - response = client.post("/documents", json=doc_data) - - assert response.status_code == 200 - assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" - - expected_doc_data = doc_data.copy() - expected_doc_data.update({"author": None, "user_id": "default_user"}) - mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } + + # Act + response = client.post("/documents", json=doc_data) + + # Assert + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" + + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) -@patch('app.app.RAGService') -def test_add_document_api_failure(mock_rag_service_class): +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_add_document_api_failure(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the /document endpoint when the RAG service encounters an error. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.add_document.side_effect = Exception("Service failed") - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services + mock_services.document_service.add_document.side_effect = Exception("Service failed") - doc_data = { - "title": "Test Document", - "text": "This is a test document.", - "source_url": "http://example.com/test" - } - - response = client.post("/documents", json=doc_data) - - assert response.status_code == 500 - assert "An error occurred: Service failed" in response.json()["detail"] + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) - expected_doc_data = doc_data.copy() - expected_doc_data.update({"author": None, "user_id": "default_user"}) - mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + doc_data = { + "title": "Test Document", + "text": "This is a test document.", + "source_url": "http://example.com/test" + } -@patch('app.app.RAGService') -def test_get_documents_success(mock_rag_service_class): + # Act + response = client.post("/documents", json=doc_data) + + # Assert + assert response.status_code == 500 + assert "An error occurred: Service failed" in response.json()["detail"] + + expected_doc_data = doc_data.copy() + expected_doc_data.update({"author": None, "user_id": "default_user"}) + mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) + + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_get_documents_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the /documents endpoint for successful retrieval of documents. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - mock_rag_service_instance = mock_rag_service_class.return_value - mock_docs = [ - models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), - models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) - ] - mock_rag_service_instance.get_all_documents.return_value = mock_docs - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services - response = client.get("/documents") - assert response.status_code == 200 - assert len(response.json()["documents"]) == 2 - assert response.json()["documents"][0]["title"] == "Doc One" - mock_rag_service_instance.get_all_documents.assert_called_once_with(db=mock_db) + mock_docs = [ + models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), + models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) + ] + mock_services.document_service.get_all_documents.return_value = mock_docs -@patch('app.app.RAGService') -def test_delete_document_success(mock_rag_service_class): + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Act + response = client.get("/documents") + assert response.status_code == 200 + assert len(response.json()["documents"]) == 2 + assert response.json()["documents"][0]["title"] == "Doc One" + mock_services.document_service.get_all_documents.assert_called_once_with(db=mock_db) + + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_delete_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the DELETE /documents/{document_id} endpoint for successful deletion. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response - - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.delete_document.return_value = 42 - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - response = client.delete("/documents/42") - assert response.status_code == 200 - assert response.json()["message"] == "Document deleted successfully" - assert response.json()["document_id"] == 42 - mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=42) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services + mock_services.document_service.delete_document.return_value = 42 -@patch('app.app.RAGService') -def test_delete_document_not_found(mock_rag_service_class): + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["message"] == "Document deleted successfully" + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=42) + +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_delete_document_not_found(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the DELETE /documents/{document_id} endpoint when the document is not found. """ - with patch('requests.post') as mock_post, patch('faiss.read_index') as mock_read_index: - mock_read_index.return_value = MagicMock() - mock_response = MagicMock() - mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - "embedding": {"values": np.random.rand(TEST_DIMENSION).tolist()} - } - mock_post.return_value = mock_response - - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.delete_document.return_value = None - - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() - response = client.delete("/documents/999") - assert response.status_code == 404 - assert response.json()["detail"] == "Document with ID 999 not found." - mock_rag_service_instance.delete_document.assert_called_once_with(db=mock_db, document_id=999) + # Create a mock instance of ServiceContainer and its services + mock_services = MagicMock() + mock_service_container.return_value = mock_services + mock_services.document_service.delete_document.return_value = None + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + response = client.delete("/documents/999") + assert response.status_code == 404 + assert response.json()["detail"] == "Document with ID 999 not found." + mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999)