diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py new file mode 100644 index 0000000..44d63d7 --- /dev/null +++ b/ai-hub/integration_tests/test_sessions.py @@ -0,0 +1,107 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await http_client.post("/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py new file mode 100644 index 0000000..44d63d7 --- /dev/null +++ b/ai-hub/integration_tests/test_sessions.py @@ -0,0 +1,107 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await http_client.post("/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 65062f6..ef45dc0 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -4,15 +4,45 @@ # It starts the FastAPI server, runs the specified tests, and then shuts down the server. # --- Configuration --- -# Set the default path for tests to run. This will be used if no argument is provided. -DEFAULT_TEST_PATH="integration_tests/" -# You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' -TEST_PATH=${1:-$DEFAULT_TEST_PATH} +# You can define aliases for your test file paths here. +TEST_SUITES=( + "All tests" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) +TEST_PATHS=( + "integration_tests/" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- User Interaction --- +echo "--- AI Hub Test Runner ---" +echo "Select a test suite to run:" + +# Display the menu options +for i in "${!TEST_SUITES[@]}"; do + echo "$((i+1)). ${TEST_SUITES[$i]}" +done + +# Prompt for user input +read -p "Enter the number of your choice (1-${#TEST_SUITES[@]}): " choice + +# Validate input and set the TEST_PATH +if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#TEST_SUITES[@]}" ]; then + TEST_PATH=${TEST_PATHS[$((choice-1))]} + echo "You have selected: ${TEST_SUITES[$((choice-1))]}" +else + echo "Invalid choice. Running all tests by default." + TEST_PATH=${TEST_PATHS[0]} +fi + # --- Pre-test Cleanup --- # Check for and remove old test files to ensure a clean test environment. echo "--- Checking for and removing old test files ---" @@ -62,4 +92,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE +exit $TEST_EXIT_CODE \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py new file mode 100644 index 0000000..44d63d7 --- /dev/null +++ b/ai-hub/integration_tests/test_sessions.py @@ -0,0 +1,107 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await http_client.post("/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 65062f6..ef45dc0 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -4,15 +4,45 @@ # It starts the FastAPI server, runs the specified tests, and then shuts down the server. # --- Configuration --- -# Set the default path for tests to run. This will be used if no argument is provided. -DEFAULT_TEST_PATH="integration_tests/" -# You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' -TEST_PATH=${1:-$DEFAULT_TEST_PATH} +# You can define aliases for your test file paths here. +TEST_SUITES=( + "All tests" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) +TEST_PATHS=( + "integration_tests/" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- User Interaction --- +echo "--- AI Hub Test Runner ---" +echo "Select a test suite to run:" + +# Display the menu options +for i in "${!TEST_SUITES[@]}"; do + echo "$((i+1)). ${TEST_SUITES[$i]}" +done + +# Prompt for user input +read -p "Enter the number of your choice (1-${#TEST_SUITES[@]}): " choice + +# Validate input and set the TEST_PATH +if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#TEST_SUITES[@]}" ]; then + TEST_PATH=${TEST_PATHS[$((choice-1))]} + echo "You have selected: ${TEST_SUITES[$((choice-1))]}" +else + echo "Invalid choice. Running all tests by default." + TEST_PATH=${TEST_PATHS[0]} +fi + # --- Pre-test Cleanup --- # Check for and remove old test files to ensure a clean test environment. echo "--- Checking for and removing old test files ---" @@ -62,4 +92,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE +exit $TEST_EXIT_CODE \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 65490e8..d2e0472 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,6 +9,7 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -97,11 +98,13 @@ # The DocumentService constructor needs a .embedder attribute on the vector_store mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + mock_tts_service = MagicMock(spec=TTSService) # Act: Instantiate the ServiceContainer container = ServiceContainer( vector_store=mock_vector_store, - retrievers=mock_retrievers + retrievers=mock_retrievers, + tts_service=mock_tts_service # Passing the mock TTS service ) # Assert: Check if the services were created and configured correctly @@ -110,3 +113,7 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers + + # Assert for the tts_service as well + assert isinstance(container.tts_service, TTSService) + assert container.tts_service == mock_tts_service diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py new file mode 100644 index 0000000..44d63d7 --- /dev/null +++ b/ai-hub/integration_tests/test_sessions.py @@ -0,0 +1,107 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await http_client.post("/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 65062f6..ef45dc0 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -4,15 +4,45 @@ # It starts the FastAPI server, runs the specified tests, and then shuts down the server. # --- Configuration --- -# Set the default path for tests to run. This will be used if no argument is provided. -DEFAULT_TEST_PATH="integration_tests/" -# You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' -TEST_PATH=${1:-$DEFAULT_TEST_PATH} +# You can define aliases for your test file paths here. +TEST_SUITES=( + "All tests" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) +TEST_PATHS=( + "integration_tests/" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- User Interaction --- +echo "--- AI Hub Test Runner ---" +echo "Select a test suite to run:" + +# Display the menu options +for i in "${!TEST_SUITES[@]}"; do + echo "$((i+1)). ${TEST_SUITES[$i]}" +done + +# Prompt for user input +read -p "Enter the number of your choice (1-${#TEST_SUITES[@]}): " choice + +# Validate input and set the TEST_PATH +if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#TEST_SUITES[@]}" ]; then + TEST_PATH=${TEST_PATHS[$((choice-1))]} + echo "You have selected: ${TEST_SUITES[$((choice-1))]}" +else + echo "Invalid choice. Running all tests by default." + TEST_PATH=${TEST_PATHS[0]} +fi + # --- Pre-test Cleanup --- # Check for and remove old test files to ensure a clean test environment. echo "--- Checking for and removing old test files ---" @@ -62,4 +92,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE +exit $TEST_EXIT_CODE \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 65490e8..d2e0472 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,6 +9,7 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -97,11 +98,13 @@ # The DocumentService constructor needs a .embedder attribute on the vector_store mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + mock_tts_service = MagicMock(spec=TTSService) # Act: Instantiate the ServiceContainer container = ServiceContainer( vector_store=mock_vector_store, - retrievers=mock_retrievers + retrievers=mock_retrievers, + tts_service=mock_tts_service # Passing the mock TTS service ) # Assert: Check if the services were created and configured correctly @@ -110,3 +113,7 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers + + # Assert for the tts_service as well + assert isinstance(container.tts_service, TTSService) + assert container.tts_service == mock_tts_service diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 8d841a2..6d5cd3e 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,15 +1,18 @@ -# tests/app/api/test_routes.py +# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime +from httpx import AsyncClient, ASGITransport + # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService +from app.core.services.tts import TTSService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @@ -24,11 +27,15 @@ # Mock individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) - + + # Use AsyncMock for the TTS service since its methods are async + mock_tts_service = MagicMock(spec=TTSService) + # Create a mock ServiceContainer that holds the mocked services mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service # Mock the database session mock_db_session = MagicMock(spec=Session) @@ -80,8 +87,6 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - # Verify that chat_with_rag was called with the default model 'deepseek' - # and the default load_faiss_retriever=False mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -102,13 +107,12 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - # Verify that chat_with_rag was called with the specified model 'gemini' mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", - load_faiss_retriever=False # It should still default to False + load_faiss_retriever=False ) def test_chat_in_session_with_faiss_retriever(client): @@ -126,29 +130,25 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - # Verify that chat_with_rag was called with the correct parameters mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", - model="deepseek", # The model still defaults to deepseek - load_faiss_retriever=True # Verify that the retriever was explicitly enabled + model="deepseek", + load_faiss_retriever=True ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" test_client, mock_services = client - # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history - # Act response = test_client.get("/sessions/123/messages") - # Assert assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 @@ -163,13 +163,10 @@ def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client - # Arrange: Mock the service to return None, indicating the session wasn't found mock_services.rag_service.get_message_history.return_value = None - # Act response = test_client.get("/sessions/999/messages") - # Assert assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @@ -210,3 +207,34 @@ mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 + +# --- TTS Endpoint --- + +@pytest.mark.anyio +async def test_create_speech_stream_success(client): + """ + Tests the /speech endpoint to ensure it can successfully generate an audio stream. + """ + test_client, mock_services = client + app = test_client.app # Get the FastAPI app from the TestClient + + # Arrange: Define the text to convert and mock the service's response. + text_to_speak = "Hello, world!" + + # Define the async generator + async def mock_audio_generator(): + yield b'chunk1' + yield b'chunk2' + yield b'chunk3' + + # Properly mock the method to return the generator + mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() + + # Use AsyncClient with ASGITransport to send request to the FastAPI app + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + response = await ac.post("/speech", json={"text": text_to_speak}) + + # Assert: Check status code and content + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert response.content == b"chunk1chunk2chunk3" \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1f2be7..e1aecde 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -6,8 +6,10 @@ from app.core.retrievers.base_retriever import Retriever from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService from app.core.vector_store.faiss_store import FaissVectorStore + # This is a dependency def get_db(): db = SessionLocal() @@ -25,9 +27,10 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers - ) \ No newline at end of file + ) + self.tts_service = tts_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 6d216ec..40860aa 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,8 +1,11 @@ from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer +from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas +from starlette.concurrency import run_in_threadpool + def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -15,15 +18,12 @@ return {"status": "AI Model Hub is running!"} # --- Session Management Endpoints --- - + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) def create_session( request: schemas.SessionCreate, db: Session = Depends(get_db) ): - """ - Starts a new conversation session and returns its details. - """ try: new_session = services.rag_service.create_session( db=db, @@ -34,40 +34,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - # --- Session Management Endpoints --- @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) async def chat_in_session( session_id: int, request: schemas.ChatRequest, db: Session = Depends(get_db) ): - """ - Sends a message within an existing session and gets a contextual response. - - The 'model' and 'load_faiss_retriever' can now be specified in the request body. - """ try: response_text, model_used = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, model=request.model, - load_faiss_retriever=request.load_faiss_retriever # Pass the new parameter to the service + load_faiss_retriever=request.load_faiss_retriever ) return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) def get_session_messages(session_id: int, db: Session = Depends(get_db)): - """ - Retrieves the full message history for a specific session. - """ try: - # Note: You'll need to add a `get_message_history` method to your RAGService. messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: # Service can return None if session not found + if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) @@ -77,7 +66,6 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - # (These endpoints remain unchanged) @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): @@ -114,4 +102,26 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - return router + # --- TTS Endpoint --- + @router.post( + "/speech", + summary="Generate a speech stream from text", + tags=["TTS"], + response_description="A stream of audio bytes in WAV format", + ) + async def create_speech_stream(request: schemas.SpeechRequest): + """ + Generates an audio stream from the provided text using the TTS service. + """ + try: + # Use run_in_threadpool to turn the synchronous generator into an + # async generator that StreamingResponse can handle. + audio_stream = await run_in_threadpool( + services.tts_service.create_speech_stream, text=request.text + ) + return StreamingResponse(audio_stream, media_type="audio/wav") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + return router \ No newline at end of file diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 419aa44..31fb342 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -73,4 +73,7 @@ class MessageHistoryResponse(BaseModel): """Defines the response for retrieving a session's chat history.""" session_id: int - messages: List[Message] \ No newline at end of file + messages: List[Message] + +class SpeechRequest(BaseModel): + text: str \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a747ae6..a56fba6 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -6,12 +6,14 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config +from app.core.providers.factory import get_tts_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables from app.api.routes import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer +from app.core.services.tts import TTSService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -64,14 +66,30 @@ ) + # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. + app.state.vector_store = vector_store + # 3. Create the FaissDBRetriever, regardless of the embedder type retrievers: List[Retriever] = [ FaissDBRetriever(vector_store=vector_store), ] - # 4. Initialize the Service Container - - services = ServiceContainer(vector_store, retrievers) + # --- New TTS Initialization --- + # 4. Get the concrete TTS provider from the factory + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY + ) + + # 5. Initialize the TTSService + tts_service = TTSService(tts_provider=tts_provider) + + # 6. Initialize the Service Container with all services + services = ServiceContainer( + vector_store=vector_store, + retrievers=retrievers, + tts_service=tts_service # Pass the new TTS service instance + ) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/llm_providers.py b/ai-hub/app/core/llm_providers.py deleted file mode 100644 index 09364b3..0000000 --- a/ai-hub/app/core/llm_providers.py +++ /dev/null @@ -1,69 +0,0 @@ -# import httpx -# import logging -# import json -# from abc import ABC, abstractmethod -# from openai import OpenAI -# from typing import final -# from app.config import settings # <-- Import the centralized settings - -# # --- 1. Initialize API Clients from Central Config --- -# # All environment variable access is now gone from this file. -# deepseek_client = OpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") -# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" - - -# # --- 2. Provider Interface and Implementations (Unchanged) --- -# class LLMProvider(ABC): -# """Abstract base class ('Interface') for all LLM providers.""" -# @abstractmethod -# async def generate_response(self, prompt: str) -> str: -# """Generates a response from the LLM.""" -# pass - -# @final -# class DeepSeekProvider(LLMProvider): -# """Provider for the DeepSeek API.""" -# def __init__(self, model_name: str): -# self.model = model_name - -# async def generate_response(self, prompt: str) -> str: -# messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] -# try: -# chat_completion = deepseek_client.chat.completions.create(model=self.model, messages=messages) -# return chat_completion.choices[0].message.content -# except Exception as e: -# logging.error("DeepSeek Provider Error", exc_info=True) -# raise - -# @final -# class GeminiProvider(LLMProvider): -# """Provider for the Google Gemini API.""" -# def __init__(self, api_url: str): -# self.url = api_url - -# async def generate_response(self, prompt: str) -> str: -# payload = {"contents": [{"parts": [{"text": prompt}]}]} -# headers = {"Content-Type": "application/json"} -# try: -# async with httpx.AsyncClient() as client: -# response = await client.post(self.url, json=payload, headers=headers) -# response.raise_for_status() -# data = response.json() -# return data['candidates'][0]['content']['parts'][0]['text'] -# except Exception as e: -# logging.error("Gemini Provider Error", exc_info=True) -# raise - -# # --- 3. The Factory Function --- -# # The dictionary of providers is now built using values from the settings object. -# _providers = { -# "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME), -# "gemini": GeminiProvider(api_url=GEMINI_URL) -# } - -# def get_llm_provider(model_name: str) -> LLMProvider: -# """Factory function to get the appropriate, pre-configured LLM provider.""" -# provider = _providers.get(model_name) -# if not provider: -# raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}") -# return provider \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7cf5119..53f3651 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -2,6 +2,7 @@ from .base import LLMProvider,TTSProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider +from .tts.gemini import GeminiTTSProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -23,7 +24,7 @@ return provider def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "gemini": + if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) # Add other TTS providers here if needed raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py new file mode 100644 index 0000000..c196795 --- /dev/null +++ b/ai-hub/app/core/services/tts.py @@ -0,0 +1,25 @@ +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class TTSService: + """ + Service class for generating speech from text using a TTS provider. + """ + def __init__(self, tts_provider: TTSProvider): + """ + Initializes the TTSService with a concrete TTS provider. + """ + self.tts_provider = tts_provider + + async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: + """ + Generates a stream of audio bytes from the given text using the configured + TTS provider. + + Args: + text: The text to be converted to speech. + + Returns: + An async generator that yields chunks of audio bytes. + """ + return self.tts_provider.generate_speech(text) \ No newline at end of file diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py new file mode 100644 index 0000000..5d51010 --- /dev/null +++ b/ai-hub/integration_tests/conftest.py @@ -0,0 +1,57 @@ +import httpx +import pytest_asyncio + +BASE_URL = "http://127.0.0.1:8000" + +@pytest_asyncio.fixture(scope="session") +def base_url(): + """Fixture to provide the base URL for the tests.""" + return BASE_URL + +@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +async def http_client(): + """ + Fixture to provide an async HTTP client for all tests in the session. + A new client is created and closed properly using a try/finally block + to prevent "Event loop is closed" errors. + """ + client = httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) + try: + yield client + finally: + await client.aclose() + +@pytest_asyncio.fixture(scope="function") +async def session_id(http_client): + """ + Creates a new session before a test and cleans it up after. + Returns the session ID. + """ + payload = {"user_id": "integration_tester", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + yield session_id + # No explicit session deletion is needed for this example, + # as sessions are typically managed by a database lifecycle. + +@pytest_asyncio.fixture(scope="function") +async def document_id(http_client): + """ + Creates a new document before a test and ensures it's deleted afterward. + Returns the document ID. + """ + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + response = await http_client.post("/documents", json=doc_data) + assert response.status_code == 200 + try: + message = response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + yield document_id + + # Teardown: Delete the document after the test + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 \ No newline at end of file diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py new file mode 100644 index 0000000..c6d4362 --- /dev/null +++ b/ai-hub/integration_tests/test_documents.py @@ -0,0 +1,33 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + add_response = await http_client.post("/documents", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + list_response = await http_client.get("/documents") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py new file mode 100644 index 0000000..3ff7925 --- /dev/null +++ b/ai-hub/integration_tests/test_misc.py @@ -0,0 +1,36 @@ +import pytest +import httpx + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +# @pytest.mark.asyncio +# async def test_create_speech_stream(http_client): +# """ +# Tests the /speech endpoint for a successful audio stream response. +# """ +# print("\n--- Running test_create_speech_stream ---") +# url = "/speech" +# payload = {"text": "Hello, world!"} + +# # The `stream=True` parameter tells httpx to not read the entire response body +# # at once. We'll handle it manually to check for content. +# async with http_client.stream("POST", url, json=payload) as response: +# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" +# assert response.headers.get("content-type") == "audio/wav" + +# # Check that the response body is not empty by iterating over chunks. +# content_length = 0 +# async for chunk in response.aiter_bytes(): +# content_length += len(chunk) + +# assert content_length > 0 +# print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py new file mode 100644 index 0000000..44d63d7 --- /dev/null +++ b/ai-hub/integration_tests/test_sessions.py @@ -0,0 +1,107 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + add_doc_response = await http_client.post("/documents", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 65062f6..ef45dc0 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -4,15 +4,45 @@ # It starts the FastAPI server, runs the specified tests, and then shuts down the server. # --- Configuration --- -# Set the default path for tests to run. This will be used if no argument is provided. -DEFAULT_TEST_PATH="integration_tests/" -# You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' -TEST_PATH=${1:-$DEFAULT_TEST_PATH} +# You can define aliases for your test file paths here. +TEST_SUITES=( + "All tests" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) +TEST_PATHS=( + "integration_tests/" + "integration_tests/test_sessions.py" + "integration_tests/test_documents.py" + "integration_tests/test_misc.py" +) export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- User Interaction --- +echo "--- AI Hub Test Runner ---" +echo "Select a test suite to run:" + +# Display the menu options +for i in "${!TEST_SUITES[@]}"; do + echo "$((i+1)). ${TEST_SUITES[$i]}" +done + +# Prompt for user input +read -p "Enter the number of your choice (1-${#TEST_SUITES[@]}): " choice + +# Validate input and set the TEST_PATH +if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#TEST_SUITES[@]}" ]; then + TEST_PATH=${TEST_PATHS[$((choice-1))]} + echo "You have selected: ${TEST_SUITES[$((choice-1))]}" +else + echo "Invalid choice. Running all tests by default." + TEST_PATH=${TEST_PATHS[0]} +fi + # --- Pre-test Cleanup --- # Check for and remove old test files to ensure a clean test environment. echo "--- Checking for and removing old test files ---" @@ -62,4 +92,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE +exit $TEST_EXIT_CODE \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 65490e8..d2e0472 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,6 +9,7 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService +from app.core.services.tts import TTSService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -97,11 +98,13 @@ # The DocumentService constructor needs a .embedder attribute on the vector_store mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] + mock_tts_service = MagicMock(spec=TTSService) # Act: Instantiate the ServiceContainer container = ServiceContainer( vector_store=mock_vector_store, - retrievers=mock_retrievers + retrievers=mock_retrievers, + tts_service=mock_tts_service # Passing the mock TTS service ) # Assert: Check if the services were created and configured correctly @@ -110,3 +113,7 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers + + # Assert for the tts_service as well + assert isinstance(container.tts_service, TTSService) + assert container.tts_service == mock_tts_service diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 8d841a2..6d5cd3e 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,15 +1,18 @@ -# tests/app/api/test_routes.py +# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime +from httpx import AsyncClient, ASGITransport + # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService +from app.core.services.tts import TTSService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @@ -24,11 +27,15 @@ # Mock individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) - + + # Use AsyncMock for the TTS service since its methods are async + mock_tts_service = MagicMock(spec=TTSService) + # Create a mock ServiceContainer that holds the mocked services mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service # Mock the database session mock_db_session = MagicMock(spec=Session) @@ -80,8 +87,6 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - # Verify that chat_with_rag was called with the default model 'deepseek' - # and the default load_faiss_retriever=False mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -102,13 +107,12 @@ assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - # Verify that chat_with_rag was called with the specified model 'gemini' mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", - load_faiss_retriever=False # It should still default to False + load_faiss_retriever=False ) def test_chat_in_session_with_faiss_retriever(client): @@ -126,29 +130,25 @@ assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - # Verify that chat_with_rag was called with the correct parameters mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", - model="deepseek", # The model still defaults to deepseek - load_faiss_retriever=True # Verify that the retriever was explicitly enabled + model="deepseek", + load_faiss_retriever=True ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" test_client, mock_services = client - # Arrange: Mock the service to return a list of message objects mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history - # Act response = test_client.get("/sessions/123/messages") - # Assert assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 @@ -163,13 +163,10 @@ def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client - # Arrange: Mock the service to return None, indicating the session wasn't found mock_services.rag_service.get_message_history.return_value = None - # Act response = test_client.get("/sessions/999/messages") - # Assert assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @@ -210,3 +207,34 @@ mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 + +# --- TTS Endpoint --- + +@pytest.mark.anyio +async def test_create_speech_stream_success(client): + """ + Tests the /speech endpoint to ensure it can successfully generate an audio stream. + """ + test_client, mock_services = client + app = test_client.app # Get the FastAPI app from the TestClient + + # Arrange: Define the text to convert and mock the service's response. + text_to_speak = "Hello, world!" + + # Define the async generator + async def mock_audio_generator(): + yield b'chunk1' + yield b'chunk2' + yield b'chunk3' + + # Properly mock the method to return the generator + mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() + + # Use AsyncClient with ASGITransport to send request to the FastAPI app + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + response = await ac.post("/speech", json={"text": text_to_speak}) + + # Assert: Check status code and content + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert response.content == b"chunk1chunk2chunk3" \ No newline at end of file diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 73eb307..c0bb4b4 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -12,6 +12,7 @@ from app.api.dependencies import get_db, ServiceContainer from app.db import models from app.core.retrievers.base_retriever import Retriever +from app.core.vector_store.faiss_store import FaissVectorStore # Define a constant for the dimension to ensure consistency TEST_DIMENSION = 768 @@ -30,11 +31,10 @@ # We patch ServiceContainer directly to control its instantiation in create_app @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') -@patch('app.app.FaissVectorStore.save_index') @patch('app.app.print_config') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') # This patch is for the FaissVectorStore initialization -def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_save_index, mock_create_db, mock_service_container): +def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container): """Test the root endpoint to ensure it's running.""" # Arrange: We patch the embedder and faiss calls to prevent real logic mock_read_index.return_value = MagicMock() @@ -53,8 +53,6 @@ assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} - -# We patch ServiceContainer directly to control its instantiation in create_app @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @@ -334,3 +332,27 @@ assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999) + +# FIX: Add a new test to explicitly check the application shutdown behavior +@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index') +@patch('app.app.ServiceContainer') +@patch('app.app.create_db_and_tables') +@patch('app.app.print_config') +@patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('faiss.read_index') +def test_shutdown_saves_index(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index): + """ + Tests that the FAISS index is saved on application shutdown. + """ + # Arrange + mock_read_index.return_value = MagicMock() + mock_get_embedder.return_value = MagicMock() + + # Create the app and let the lifespan events run + app = create_app() + with TestClient(app) as client: + # Act: The lifespan shutdown event will run when the 'with' block is exited + pass + + # Assert + mock_save_index.assert_called_once()