diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index c196795..b5887eb 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -22,4 +22,19 @@ Returns: An async generator that yields chunks of audio bytes. """ - return self.tts_provider.generate_speech(text) \ No newline at end of file + return self.tts_provider.generate_speech(text) + + async def create_speech_non_stream(self, text: str) -> bytes: + """ + Generates a complete audio file from the given text without streaming. + + Args: + text: The text to be converted to speech. + + Returns: + The complete audio file as bytes. + """ + # Awaiting the coroutine is necessary to get the result. + # The previous version was missing this 'await'. + audio_data = await self.tts_provider.generate_speech(text) + return audio_data \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index c196795..b5887eb 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -22,4 +22,19 @@ Returns: An async generator that yields chunks of audio bytes. """ - return self.tts_provider.generate_speech(text) \ No newline at end of file + return self.tts_provider.generate_speech(text) + + async def create_speech_non_stream(self, text: str) -> bytes: + """ + Generates a complete audio file from the given text without streaming. + + Args: + text: The text to be converted to speech. + + Returns: + The complete audio file as bytes. + """ + # Awaiting the coroutine is necessary to get the result. + # The previous version was missing this 'await'. + audio_data = await self.tts_provider.generate_speech(text) + return audio_data \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py index 3ff7925..effc619 100644 --- a/ai-hub/integration_tests/test_misc.py +++ b/ai-hub/integration_tests/test_misc.py @@ -12,25 +12,25 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# @pytest.mark.asyncio -# async def test_create_speech_stream(http_client): -# """ -# Tests the /speech endpoint for a successful audio stream response. -# """ -# print("\n--- Running test_create_speech_stream ---") -# url = "/speech" -# payload = {"text": "Hello, world!"} +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} -# # The `stream=True` parameter tells httpx to not read the entire response body -# # at once. We'll handle it manually to check for content. -# async with http_client.stream("POST", url, json=payload) as response: -# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" -# assert response.headers.get("content-type") == "audio/wav" + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" -# # Check that the response body is not empty by iterating over chunks. -# content_length = 0 -# async for chunk in response.aiter_bytes(): -# content_length += len(chunk) + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) -# assert content_length > 0 -# print("✅ TTS stream test passed.") \ No newline at end of file + assert content_length > 0 + print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index c196795..b5887eb 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -22,4 +22,19 @@ Returns: An async generator that yields chunks of audio bytes. """ - return self.tts_provider.generate_speech(text) \ No newline at end of file + return self.tts_provider.generate_speech(text) + + async def create_speech_non_stream(self, text: str) -> bytes: + """ + Generates a complete audio file from the given text without streaming. + + Args: + text: The text to be converted to speech. + + Returns: + The complete audio file as bytes. + """ + # Awaiting the coroutine is necessary to get the result. + # The previous version was missing this 'await'. + audio_data = await self.tts_provider.generate_speech(text) + return audio_data \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py index 3ff7925..effc619 100644 --- a/ai-hub/integration_tests/test_misc.py +++ b/ai-hub/integration_tests/test_misc.py @@ -12,25 +12,25 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# @pytest.mark.asyncio -# async def test_create_speech_stream(http_client): -# """ -# Tests the /speech endpoint for a successful audio stream response. -# """ -# print("\n--- Running test_create_speech_stream ---") -# url = "/speech" -# payload = {"text": "Hello, world!"} +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} -# # The `stream=True` parameter tells httpx to not read the entire response body -# # at once. We'll handle it manually to check for content. -# async with http_client.stream("POST", url, json=payload) as response: -# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" -# assert response.headers.get("content-type") == "audio/wav" + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" -# # Check that the response body is not empty by iterating over chunks. -# content_length = 0 -# async for chunk in response.aiter_bytes(): -# content_length += len(chunk) + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) -# assert content_length > 0 -# print("✅ TTS stream test passed.") \ No newline at end of file + assert content_length > 0 + print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 6d5cd3e..5e26ac0 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,55 +1,77 @@ -# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock -from fastapi import FastAPI +from fastapi import FastAPI, Response from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime from httpx import AsyncClient, ASGITransport - +import asyncio # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService -from app.core.services.tts import TTSService +from app.core.services.tts import TTSService from app.api.routes import create_api_router -from app.db import models # Import your SQLAlchemy models +from app.db import models @pytest.fixture def client(): """ - Pytest fixture to create a TestClient with a fully mocked environment, - including a mock ServiceContainer. + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints. """ test_app = FastAPI() - - # Mock individual services + mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) - - # Use AsyncMock for the TTS service since its methods are async mock_tts_service = MagicMock(spec=TTSService) - - # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service - - # Mock the database session + mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - # Pass the mock ServiceContainer to the router factory api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - # Return the test client and the mock services for assertion - yield TestClient(test_app), mock_services + test_client = TestClient(test_app) + + yield test_client, mock_services + +@pytest.fixture +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services # --- General Endpoint --- @@ -67,9 +89,9 @@ test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) mock_services.rag_service.create_session.return_value = mock_session - + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - + assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() @@ -81,12 +103,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -101,12 +123,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -121,15 +143,15 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - + response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) - + assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -146,9 +168,9 @@ models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history - + response = test_client.get("/sessions/123/messages") - + assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 @@ -156,7 +178,7 @@ assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( - db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) @@ -164,9 +186,9 @@ """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None - + response = test_client.get("/sessions/999/messages") - + assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @@ -207,34 +229,49 @@ mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 - -# --- TTS Endpoint --- -@pytest.mark.anyio -async def test_create_speech_stream_success(client): - """ - Tests the /speech endpoint to ensure it can successfully generate an audio stream. - """ - test_client, mock_services = client - app = test_client.app # Get the FastAPI app from the TestClient - - # Arrange: Define the text to convert and mock the service's response. - text_to_speak = "Hello, world!" - - # Define the async generator - async def mock_audio_generator(): - yield b'chunk1' - yield b'chunk2' - yield b'chunk3' - - # Properly mock the method to return the generator - mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() - - # Use AsyncClient with ASGITransport to send request to the FastAPI app - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - response = await ac.post("/speech", json={"text": text_to_speak}) +@pytest.mark.asyncio +async def test_create_speech_response(async_client): + """Test the /speech endpoint returns audio bytes.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes = b"fake wav audio bytes" - # Assert: Check status code and content + # The route handler calls `create_speech_non_stream`, not `create_speech_stream` + # It's an async function, so we must use AsyncMock + mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) + + response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" - assert response.content == b"chunk1chunk2chunk3" \ No newline at end of file + assert response.content == mock_audio_bytes + + mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") + +# New test to cover the streaming endpoint +@pytest.mark.asyncio +async def test_create_speech_stream_response(async_client): + """Test the new /speech/stream endpoint returns a streaming response.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # This async generator mock correctly simulates the streaming service + async def mock_async_generator(): + for chunk in mock_audio_bytes_chunks: + yield chunk + + # We mock `create_speech_stream` with a MagicMock returning the async generator + mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) + + response = await test_client.post("/speech/stream", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + # Read the streamed content and verify it matches the mocked chunks + streamed_content = b"" + async for chunk in response.aiter_bytes(): + streamed_content += chunk + + assert streamed_content == b"".join(mock_audio_bytes_chunks) + mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index c196795..b5887eb 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -22,4 +22,19 @@ Returns: An async generator that yields chunks of audio bytes. """ - return self.tts_provider.generate_speech(text) \ No newline at end of file + return self.tts_provider.generate_speech(text) + + async def create_speech_non_stream(self, text: str) -> bytes: + """ + Generates a complete audio file from the given text without streaming. + + Args: + text: The text to be converted to speech. + + Returns: + The complete audio file as bytes. + """ + # Awaiting the coroutine is necessary to get the result. + # The previous version was missing this 'await'. + audio_data = await self.tts_provider.generate_speech(text) + return audio_data \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py index 3ff7925..effc619 100644 --- a/ai-hub/integration_tests/test_misc.py +++ b/ai-hub/integration_tests/test_misc.py @@ -12,25 +12,25 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# @pytest.mark.asyncio -# async def test_create_speech_stream(http_client): -# """ -# Tests the /speech endpoint for a successful audio stream response. -# """ -# print("\n--- Running test_create_speech_stream ---") -# url = "/speech" -# payload = {"text": "Hello, world!"} +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} -# # The `stream=True` parameter tells httpx to not read the entire response body -# # at once. We'll handle it manually to check for content. -# async with http_client.stream("POST", url, json=payload) as response: -# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" -# assert response.headers.get("content-type") == "audio/wav" + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" -# # Check that the response body is not empty by iterating over chunks. -# content_length = 0 -# async for chunk in response.aiter_bytes(): -# content_length += len(chunk) + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) -# assert content_length > 0 -# print("✅ TTS stream test passed.") \ No newline at end of file + assert content_length > 0 + print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 6d5cd3e..5e26ac0 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,55 +1,77 @@ -# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock -from fastapi import FastAPI +from fastapi import FastAPI, Response from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime from httpx import AsyncClient, ASGITransport - +import asyncio # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService -from app.core.services.tts import TTSService +from app.core.services.tts import TTSService from app.api.routes import create_api_router -from app.db import models # Import your SQLAlchemy models +from app.db import models @pytest.fixture def client(): """ - Pytest fixture to create a TestClient with a fully mocked environment, - including a mock ServiceContainer. + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints. """ test_app = FastAPI() - - # Mock individual services + mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) - - # Use AsyncMock for the TTS service since its methods are async mock_tts_service = MagicMock(spec=TTSService) - - # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service - - # Mock the database session + mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - # Pass the mock ServiceContainer to the router factory api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - # Return the test client and the mock services for assertion - yield TestClient(test_app), mock_services + test_client = TestClient(test_app) + + yield test_client, mock_services + +@pytest.fixture +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services # --- General Endpoint --- @@ -67,9 +89,9 @@ test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) mock_services.rag_service.create_session.return_value = mock_session - + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - + assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() @@ -81,12 +103,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -101,12 +123,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -121,15 +143,15 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - + response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) - + assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -146,9 +168,9 @@ models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history - + response = test_client.get("/sessions/123/messages") - + assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 @@ -156,7 +178,7 @@ assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( - db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) @@ -164,9 +186,9 @@ """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None - + response = test_client.get("/sessions/999/messages") - + assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @@ -207,34 +229,49 @@ mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 - -# --- TTS Endpoint --- -@pytest.mark.anyio -async def test_create_speech_stream_success(client): - """ - Tests the /speech endpoint to ensure it can successfully generate an audio stream. - """ - test_client, mock_services = client - app = test_client.app # Get the FastAPI app from the TestClient - - # Arrange: Define the text to convert and mock the service's response. - text_to_speak = "Hello, world!" - - # Define the async generator - async def mock_audio_generator(): - yield b'chunk1' - yield b'chunk2' - yield b'chunk3' - - # Properly mock the method to return the generator - mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() - - # Use AsyncClient with ASGITransport to send request to the FastAPI app - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - response = await ac.post("/speech", json={"text": text_to_speak}) +@pytest.mark.asyncio +async def test_create_speech_response(async_client): + """Test the /speech endpoint returns audio bytes.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes = b"fake wav audio bytes" - # Assert: Check status code and content + # The route handler calls `create_speech_non_stream`, not `create_speech_stream` + # It's an async function, so we must use AsyncMock + mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) + + response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" - assert response.content == b"chunk1chunk2chunk3" \ No newline at end of file + assert response.content == mock_audio_bytes + + mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") + +# New test to cover the streaming endpoint +@pytest.mark.asyncio +async def test_create_speech_stream_response(async_client): + """Test the new /speech/stream endpoint returns a streaming response.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # This async generator mock correctly simulates the streaming service + async def mock_async_generator(): + for chunk in mock_audio_bytes_chunks: + yield chunk + + # We mock `create_speech_stream` with a MagicMock returning the async generator + mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) + + response = await test_client.post("/speech/stream", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + # Read the streamed content and verify it matches the mocked chunks + streamed_content = b"" + async for chunk in response.aiter_bytes(): + streamed_content += chunk + + assert streamed_content == b"".join(mock_audio_bytes_chunks) + mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py index 47023ed..5a5e1b2 100644 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ b/ai-hub/tests/core/providers/tts/test_gemini.py @@ -1,3 +1,4 @@ +# Fixed test file import pytest import aiohttp import asyncio @@ -14,11 +15,11 @@ api_key = "test_api_key" text_to_speak = "Hello, world!" model_name = "gemini-2.5-flash-preview-tts" - + # Create a dummy base64 encoded audio response dummy_audio_bytes = b"This is a test audio stream." dummy_base64_data = base64.b64encode(dummy_audio_bytes).decode('utf-8') - + # The mocked JSON response from the API mock_response_json = { "candidates": [{ @@ -31,7 +32,7 @@ } }] } - + # Configure aioresponses to intercept the API call and return our mock data tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) with aioresponses() as m: @@ -41,16 +42,12 @@ payload=mock_response_json, repeat=True ) - - # Call the method under test - audio_stream = tts_provider.generate_speech(text_to_speak) - - # Iterate through the async generator to get the data - audio_chunks = [chunk async for chunk in audio_stream] - - # Assert that the list of chunks is not empty and contains the expected data - assert len(audio_chunks) == 1 - assert audio_chunks[0] == dummy_audio_bytes + + # Call the method under test, now awaiting the coroutine + audio_data = await tts_provider.generate_speech(text_to_speak) + + # Assert that the returned data is correct + assert audio_data == dummy_audio_bytes def test_init_with_valid_voice_name(): """ @@ -81,5 +78,4 @@ custom_model_name = "gemini-tts-beta" tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) assert tts_provider.model_name == custom_model_name - assert custom_model_name in tts_provider.api_url - + assert custom_model_name in tts_provider.api_url \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 40860aa..220282f 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,11 +1,11 @@ +# Fixed routes.py from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer from app.api.dependencies import get_db from app.api import schemas -from starlette.concurrency import run_in_threadpool - +from typing import AsyncGenerator def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -26,8 +26,8 @@ ): try: new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, + db=db, + user_id=request.user_id, model=request.model ) return new_session @@ -58,7 +58,7 @@ messages = services.rag_service.get_message_history(db=db, session_id=session_id) if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) except HTTPException: raise @@ -66,7 +66,7 @@ raise HTTPException(status_code=500, detail=f"An error occurred: {e}") # --- Document Management Endpoints --- - + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: @@ -92,7 +92,7 @@ deleted_id = services.document_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - + return schemas.DocumentDeleteResponse( message="Document deleted successfully", document_id=deleted_id @@ -102,26 +102,50 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- TTS Endpoint --- @router.post( "/speech", - summary="Generate a speech stream from text", + summary="Generate speech from text", tags=["TTS"], - response_description="A stream of audio bytes in WAV format", + response_description="Audio bytes in WAV format", ) - async def create_speech_stream(request: schemas.SpeechRequest): + async def create_speech_response(request: schemas.SpeechRequest): """ - Generates an audio stream from the provided text using the TTS service. + Generates an audio file from the provided text using the TTS service + and returns it as a complete response. """ try: - # Use run_in_threadpool to turn the synchronous generator into an - # async generator that StreamingResponse can handle. - audio_stream = await run_in_threadpool( - services.tts_service.create_speech_stream, text=request.text + # Await the coroutine that returns the complete audio data + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text ) - return StreamingResponse(audio_stream, media_type="audio/wav") + + # Return a standard FastAPI Response with the complete audio bytes. + return Response(content=audio_bytes, media_type="audio/wav") + except Exception as e: + # Catch exceptions from the TTS service raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) + + # Add a streaming endpoint as a new feature + @router.post( + "/speech/stream", + summary="Generate speech from text with streaming", + tags=["TTS"], + response_description="Audio bytes in WAV format (streaming)", + ) + async def create_speech_stream_response(request: schemas.SpeechRequest) -> StreamingResponse: + """ + Generates an audio stream from the provided text and streams it back. + """ + try: + # The service method returns an async generator + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to stream speech: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 5c04467..65c69c9 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -5,24 +5,18 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, SecretStr +# Load environment variables from .env file load_dotenv() # --- 1. Define the Configuration Schema --- -# Define an Enum for supported embedding providers class EmbeddingProvider(str, Enum): - """ - An enum to represent the supported embedding providers. - This helps in type-checking and ensures only valid providers are used. - """ + """An enum for supported embedding providers.""" GOOGLE_GENAI = "google_genai" MOCK = "mock" -# New Enum for supported TTS providers class TTSProvider(str, Enum): - """ - An enum to represent the supported Text-to-Speech (TTS) providers. - """ + """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" class ApplicationSettings(BaseModel): @@ -31,26 +25,23 @@ log_level: str = "INFO" class DatabaseSettings(BaseModel): - mode: str = "sqlite" # "sqlite" or "postgresql" - url: Optional[str] = None # Used if mode != "sqlite" - local_path: str = "data/ai_hub.db" # Used if mode == "sqlite" + mode: str = "sqlite" + url: Optional[str] = None + local_path: str = "data/ai_hub.db" class LLMProviderSettings(BaseModel): deepseek_model_name: str = "deepseek-chat" gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - # Add a new 'provider' field to specify the embedding service provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) - # Changed the default to match the test suite model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name + model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -58,12 +49,12 @@ embedding_dimension: int = 768 class AppConfig(BaseModel): + """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - # Add the new TTS provider settings to the main config tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) @@ -95,96 +86,84 @@ get_from_yaml(["application", "project_name"]) or \ config_from_pydantic.application.project_name self.VERSION: str = config_from_pydantic.application.version - self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode - # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path - - # Get external DB URL, from env/yaml/pydantic + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ - get_from_yaml(["database", "url"]) or \ - config_from_pydantic.database.url + get_from_yaml(["database", "url"]) or \ + config_from_pydantic.database.url if self.DB_MODE == "sqlite": - # Ensure path does not have duplicate ./ prefix normalized_path = local_db_path.lstrip("./") self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db" else: - self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" # fallback if no URL provided + self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys --- + # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name + # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # New embedding provider settings + # --- Embedding Provider Settings --- embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") if embedding_provider_env: embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) - + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - api_key_env = os.getenv("EMBEDDING_API_KEY") - api_key_yaml = get_from_yaml(["embedding_provider", "api_key"]) - api_key_pydantic = config_from_pydantic.embedding_provider.api_key.get_secret_value() if config_from_pydantic.embedding_provider.api_key else None - - self.EMBEDDING_API_KEY: Optional[str] = api_key_env or api_key_yaml or api_key_pydantic + # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # --- New TTS Provider Settings --- + # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") if tts_provider_env: tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - - # Added the new configurable model name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ - get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name - tts_api_key_env = os.getenv("TTS_API_KEY") - tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) - tts_api_key_pydantic = config_from_pydantic.tts_provider.api_key.get_secret_value() if config_from_pydantic.tts_provider.api_key else None - - self.TTS_API_KEY: Optional[str] = tts_api_key_env or tts_api_key_yaml or tts_api_key_pydantic + # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY # Instantiate the single settings object for the application settings = Settings() diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 176af5d..264557c 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -2,9 +2,13 @@ import aiohttp import asyncio import base64 +import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider +# Configure logging +logger = logging.getLogger(__name__) + # New concrete class for Gemini TTS with the corrected voice list class GeminiTTSProvider(TTSProvider): # Class attribute with the corrected list of available voices @@ -26,8 +30,11 @@ self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name self.model_name = model_name + logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") - async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + async def generate_speech(self, text: str) -> bytes: + logger.debug(f"Starting speech generation for text: '{text[:50]}...'") + headers = { "x-goog-api-key": self.api_key, "Content-Type": "application/json" @@ -51,13 +58,33 @@ # The model is now configurable via the instance variable "model": self.model_name } + + logger.debug(f"API Request URL: {self.api_url}") + logger.debug(f"Request Headers: {headers}") + logger.debug(f"Request Payload: {json_data}") - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - response.raise_for_status() - response_json = await response.json() - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - audio_bytes = base64.b64decode(inline_data) - - yield audio_bytes + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, headers=headers, json=json_data) as response: + logger.debug(f"Received API response with status code: {response.status}") + response.raise_for_status() + + response_json = await response.json() + logger.debug("Successfully parsed API response JSON.") + + inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] + logger.debug("Successfully extracted audio data from JSON response.") + + audio_bytes = base64.b64decode(inline_data) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + + return audio_bytes + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except KeyError as e: + logger.error(f"Key error in API response: {e}. Full response: {response_json}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + except Exception as e: + logger.error(f"An unexpected error occurred during speech generation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index c196795..b5887eb 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -22,4 +22,19 @@ Returns: An async generator that yields chunks of audio bytes. """ - return self.tts_provider.generate_speech(text) \ No newline at end of file + return self.tts_provider.generate_speech(text) + + async def create_speech_non_stream(self, text: str) -> bytes: + """ + Generates a complete audio file from the given text without streaming. + + Args: + text: The text to be converted to speech. + + Returns: + The complete audio file as bytes. + """ + # Awaiting the coroutine is necessary to get the result. + # The previous version was missing this 'await'. + audio_data = await self.tts_provider.generate_speech(text) + return audio_data \ No newline at end of file diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py index 3ff7925..effc619 100644 --- a/ai-hub/integration_tests/test_misc.py +++ b/ai-hub/integration_tests/test_misc.py @@ -12,25 +12,25 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# @pytest.mark.asyncio -# async def test_create_speech_stream(http_client): -# """ -# Tests the /speech endpoint for a successful audio stream response. -# """ -# print("\n--- Running test_create_speech_stream ---") -# url = "/speech" -# payload = {"text": "Hello, world!"} +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} -# # The `stream=True` parameter tells httpx to not read the entire response body -# # at once. We'll handle it manually to check for content. -# async with http_client.stream("POST", url, json=payload) as response: -# assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" -# assert response.headers.get("content-type") == "audio/wav" + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" -# # Check that the response body is not empty by iterating over chunks. -# content_length = 0 -# async for chunk in response.aiter_bytes(): -# content_length += len(chunk) + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) -# assert content_length > 0 -# print("✅ TTS stream test passed.") \ No newline at end of file + assert content_length > 0 + print("✅ TTS stream test passed.") \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 6d5cd3e..5e26ac0 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,55 +1,77 @@ -# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock -from fastapi import FastAPI +from fastapi import FastAPI, Response from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime from httpx import AsyncClient, ASGITransport - +import asyncio # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService -from app.core.services.tts import TTSService +from app.core.services.tts import TTSService from app.api.routes import create_api_router -from app.db import models # Import your SQLAlchemy models +from app.db import models @pytest.fixture def client(): """ - Pytest fixture to create a TestClient with a fully mocked environment, - including a mock ServiceContainer. + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints. """ test_app = FastAPI() - - # Mock individual services + mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) - - # Use AsyncMock for the TTS service since its methods are async mock_tts_service = MagicMock(spec=TTSService) - - # Create a mock ServiceContainer that holds the mocked services + mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service - - # Mock the database session + mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session - # Pass the mock ServiceContainer to the router factory api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) - # Return the test client and the mock services for assertion - yield TestClient(test_app), mock_services + test_client = TestClient(test_app) + + yield test_client, mock_services + +@pytest.fixture +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services # --- General Endpoint --- @@ -67,9 +89,9 @@ test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) mock_services.rag_service.create_session.return_value = mock_session - + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - + assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() @@ -81,12 +103,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -101,12 +123,12 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - + assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -121,15 +143,15 @@ """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - + response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) - + assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - + mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, @@ -146,9 +168,9 @@ models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history - + response = test_client.get("/sessions/123/messages") - + assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 @@ -156,7 +178,7 @@ assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( - db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) @@ -164,9 +186,9 @@ """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None - + response = test_client.get("/sessions/999/messages") - + assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." @@ -207,34 +229,49 @@ mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 - -# --- TTS Endpoint --- -@pytest.mark.anyio -async def test_create_speech_stream_success(client): - """ - Tests the /speech endpoint to ensure it can successfully generate an audio stream. - """ - test_client, mock_services = client - app = test_client.app # Get the FastAPI app from the TestClient - - # Arrange: Define the text to convert and mock the service's response. - text_to_speak = "Hello, world!" - - # Define the async generator - async def mock_audio_generator(): - yield b'chunk1' - yield b'chunk2' - yield b'chunk3' - - # Properly mock the method to return the generator - mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() - - # Use AsyncClient with ASGITransport to send request to the FastAPI app - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - response = await ac.post("/speech", json={"text": text_to_speak}) +@pytest.mark.asyncio +async def test_create_speech_response(async_client): + """Test the /speech endpoint returns audio bytes.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes = b"fake wav audio bytes" - # Assert: Check status code and content + # The route handler calls `create_speech_non_stream`, not `create_speech_stream` + # It's an async function, so we must use AsyncMock + mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) + + response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" - assert response.content == b"chunk1chunk2chunk3" \ No newline at end of file + assert response.content == mock_audio_bytes + + mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") + +# New test to cover the streaming endpoint +@pytest.mark.asyncio +async def test_create_speech_stream_response(async_client): + """Test the new /speech/stream endpoint returns a streaming response.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # This async generator mock correctly simulates the streaming service + async def mock_async_generator(): + for chunk in mock_audio_bytes_chunks: + yield chunk + + # We mock `create_speech_stream` with a MagicMock returning the async generator + mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) + + response = await test_client.post("/speech/stream", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + # Read the streamed content and verify it matches the mocked chunks + streamed_content = b"" + async for chunk in response.aiter_bytes(): + streamed_content += chunk + + assert streamed_content == b"".join(mock_audio_bytes_chunks) + mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py index 47023ed..5a5e1b2 100644 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ b/ai-hub/tests/core/providers/tts/test_gemini.py @@ -1,3 +1,4 @@ +# Fixed test file import pytest import aiohttp import asyncio @@ -14,11 +15,11 @@ api_key = "test_api_key" text_to_speak = "Hello, world!" model_name = "gemini-2.5-flash-preview-tts" - + # Create a dummy base64 encoded audio response dummy_audio_bytes = b"This is a test audio stream." dummy_base64_data = base64.b64encode(dummy_audio_bytes).decode('utf-8') - + # The mocked JSON response from the API mock_response_json = { "candidates": [{ @@ -31,7 +32,7 @@ } }] } - + # Configure aioresponses to intercept the API call and return our mock data tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) with aioresponses() as m: @@ -41,16 +42,12 @@ payload=mock_response_json, repeat=True ) - - # Call the method under test - audio_stream = tts_provider.generate_speech(text_to_speak) - - # Iterate through the async generator to get the data - audio_chunks = [chunk async for chunk in audio_stream] - - # Assert that the list of chunks is not empty and contains the expected data - assert len(audio_chunks) == 1 - assert audio_chunks[0] == dummy_audio_bytes + + # Call the method under test, now awaiting the coroutine + audio_data = await tts_provider.generate_speech(text_to_speak) + + # Assert that the returned data is correct + assert audio_data == dummy_audio_bytes def test_init_with_valid_voice_name(): """ @@ -81,5 +78,4 @@ custom_model_name = "gemini-tts-beta" tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) assert tts_provider.model_name == custom_model_name - assert custom_model_name in tts_provider.api_url - + assert custom_model_name in tts_provider.api_url \ No newline at end of file diff --git a/ai-hub/tests/test_config.py b/ai-hub/tests/test_config.py index 4ded74f..f3835f4 100644 --- a/ai-hub/tests/test_config.py +++ b/ai-hub/tests/test_config.py @@ -178,7 +178,10 @@ def test_tts_settings_defaults(monkeypatch, clear_all_env): """Tests that TTS settings fall back to Pydantic defaults if no env or YAML are present.""" monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") - monkeypatch.setenv("GEMINI_API_KEY", "mock_key") + # We remove the line below that sets GEMINI_API_KEY. + # The clear_all_env fixture already ensures no env vars are set initially. + # settings = Settings() will be able to fall back to None for the API key. + settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI