diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 2f0c0aa..49b30f7 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -36,4 +36,6 @@ # The provider for the TTS service. provider: "google_genai" # The name of the voice to use for TTS. - voice_name: "Kore" \ No newline at end of file + voice_name: "Kore" + # The model name for the TTS service. + model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 2f0c0aa..49b30f7 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -36,4 +36,6 @@ # The provider for the TTS service. provider: "google_genai" # The name of the voice to use for TTS. - voice_name: "Kore" \ No newline at end of file + voice_name: "Kore" + # The model name for the TTS service. + model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 98b191f..176af5d 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -17,13 +17,15 @@ "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat" ] - def __init__(self, api_key: str, voice_name: str = "Kore"): + def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"): if voice_name not in self.AVAILABLE_VOICES: raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}") self.api_key = api_key - self.api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" + # The API URL is now a f-string that includes the configurable model name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name + self.model_name = model_name async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: headers = { @@ -46,7 +48,8 @@ } } }, - "model": "gemini-2.5-flash-preview-tts" + # The model is now configurable via the instance variable + "model": self.model_name } async with aiohttp.ClientSession() as session: @@ -57,4 +60,4 @@ inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] audio_bytes = base64.b64decode(inline_data) - yield audio_bytes \ No newline at end of file + yield audio_bytes diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 2f0c0aa..49b30f7 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -36,4 +36,6 @@ # The provider for the TTS service. provider: "google_genai" # The name of the voice to use for TTS. - voice_name: "Kore" \ No newline at end of file + voice_name: "Kore" + # The model name for the TTS service. + model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 98b191f..176af5d 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -17,13 +17,15 @@ "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat" ] - def __init__(self, api_key: str, voice_name: str = "Kore"): + def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"): if voice_name not in self.AVAILABLE_VOICES: raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}") self.api_key = api_key - self.api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" + # The API URL is now a f-string that includes the configurable model name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name + self.model_name = model_name async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: headers = { @@ -46,7 +48,8 @@ } } }, - "model": "gemini-2.5-flash-preview-tts" + # The model is now configurable via the instance variable + "model": self.model_name } async with aiohttp.ClientSession() as session: @@ -57,4 +60,4 @@ inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] audio_bytes = base64.b64decode(inline_data) - yield audio_bytes \ No newline at end of file + yield audio_bytes diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 8e7a120..65062f6 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,10 +9,23 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} -DB_MODE=sqlite +export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- Pre-test Cleanup --- +# Check for and remove old test files to ensure a clean test environment. +echo "--- Checking for and removing old test files ---" +if [ -f "$LOCAL_DB_PATH" ]; then + echo "Removing old database file: $LOCAL_DB_PATH" + rm "$LOCAL_DB_PATH" +fi +if [ -f "$FAISS_INDEX_PATH" ]; then + echo "Removing old FAISS index file: $FAISS_INDEX_PATH" + rm "$FAISS_INDEX_PATH" +fi +echo "Cleanup complete." + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background @@ -49,4 +62,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE \ No newline at end of file +exit $TEST_EXIT_CODE diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 2f0c0aa..49b30f7 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -36,4 +36,6 @@ # The provider for the TTS service. provider: "google_genai" # The name of the voice to use for TTS. - voice_name: "Kore" \ No newline at end of file + voice_name: "Kore" + # The model name for the TTS service. + model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 98b191f..176af5d 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -17,13 +17,15 @@ "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat" ] - def __init__(self, api_key: str, voice_name: str = "Kore"): + def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"): if voice_name not in self.AVAILABLE_VOICES: raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}") self.api_key = api_key - self.api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" + # The API URL is now a f-string that includes the configurable model name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name + self.model_name = model_name async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: headers = { @@ -46,7 +48,8 @@ } } }, - "model": "gemini-2.5-flash-preview-tts" + # The model is now configurable via the instance variable + "model": self.model_name } async with aiohttp.ClientSession() as session: @@ -57,4 +60,4 @@ inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] audio_bytes = base64.b64decode(inline_data) - yield audio_bytes \ No newline at end of file + yield audio_bytes diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 8e7a120..65062f6 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,10 +9,23 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} -DB_MODE=sqlite +export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- Pre-test Cleanup --- +# Check for and remove old test files to ensure a clean test environment. +echo "--- Checking for and removing old test files ---" +if [ -f "$LOCAL_DB_PATH" ]; then + echo "Removing old database file: $LOCAL_DB_PATH" + rm "$LOCAL_DB_PATH" +fi +if [ -f "$FAISS_INDEX_PATH" ]; then + echo "Removing old FAISS index file: $FAISS_INDEX_PATH" + rm "$FAISS_INDEX_PATH" +fi +echo "Cleanup complete." + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background @@ -49,4 +62,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE \ No newline at end of file +exit $TEST_EXIT_CODE diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py index 02ca463..47023ed 100644 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ b/ai-hub/tests/core/providers/tts/test_gemini.py @@ -4,8 +4,7 @@ import base64 from aioresponses import aioresponses from app.core.providers.tts.gemini import GeminiTTSProvider - -# Note: The mock_aioresponse fixture is not needed and has been removed. +from app.core.providers.base import TTSProvider @pytest.mark.asyncio async def test_generate_speech_success(): @@ -14,6 +13,7 @@ """ api_key = "test_api_key" text_to_speak = "Hello, world!" + model_name = "gemini-2.5-flash-preview-tts" # Create a dummy base64 encoded audio response dummy_audio_bytes = b"This is a test audio stream." @@ -33,7 +33,7 @@ } # Configure aioresponses to intercept the API call and return our mock data - tts_provider = GeminiTTSProvider(api_key=api_key) + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) with aioresponses() as m: m.post( tts_provider.api_url, @@ -52,16 +52,34 @@ assert len(audio_chunks) == 1 assert audio_chunks[0] == dummy_audio_bytes -# The other tests for __init__ are not affected and can remain as they are. def test_init_with_valid_voice_name(): + """ + Tests that initialization succeeds with a valid voice name. + """ api_key = "test_api_key" voice_name = "Zephyr" tts_provider = GeminiTTSProvider(api_key=api_key, voice_name=voice_name) assert tts_provider.api_key == api_key assert tts_provider.voice_name == voice_name + assert tts_provider.model_name == "gemini-2.5-flash-preview-tts" + assert "gemini-2.5-flash-preview-tts" in tts_provider.api_url def test_init_with_invalid_voice_name(): + """ + Tests that initialization fails with an invalid voice name. + """ api_key = "test_api_key" invalid_voice_name = "InvalidVoice" with pytest.raises(ValueError, match="Invalid voice name"): - GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) \ No newline at end of file + GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) + +def test_init_with_custom_model_name(): + """ + Tests that the provider can be initialized with a custom model name. + """ + api_key = "test_api_key" + custom_model_name = "gemini-tts-beta" + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) + assert tts_provider.model_name == custom_model_name + assert custom_model_name in tts_provider.api_url + diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index ac5cbb5..5c04467 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -46,10 +46,11 @@ model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -# New settings class for TTS providers +# New settings class for TTS providers, now with a configurable model_name class TTSProviderSettings(BaseModel): provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) voice_name: str = "Kore" + model_name: str = "gemini-2.5-flash-preview-tts" # Added configurable model name api_key: Optional[SecretStr] = None class VectorStoreSettings(BaseModel): @@ -96,13 +97,13 @@ self.VERSION: str = config_from_pydantic.application.version self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ - get_from_yaml(["application", "log_level"]) or \ - config_from_pydantic.application.log_level + get_from_yaml(["application", "log_level"]) or \ + config_from_pydantic.application.log_level # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ - get_from_yaml(["database", "mode"]) or \ - config_from_pydantic.database.mode + get_from_yaml(["database", "mode"]) or \ + config_from_pydantic.database.mode # Get local path for SQLite, from env/yaml/pydantic local_db_path = os.getenv("LOCAL_DB_PATH") or \ @@ -130,13 +131,13 @@ config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path - + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path + dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ get_from_yaml(["vector_store", "embedding_dimension"]) or \ config_from_pydantic.vector_store.embedding_dimension @@ -148,8 +149,8 @@ embedding_provider_env = embedding_provider_env.lower() self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ @@ -167,12 +168,17 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) - + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) + self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ - get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name + get_from_yaml(["tts_provider", "voice_name"]) or \ + config_from_pydantic.tts_provider.voice_name + + # Added the new configurable model name + self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ + get_from_yaml(["tts_provider", "model_name"]) or \ + config_from_pydantic.tts_provider.model_name tts_api_key_env = os.getenv("TTS_API_KEY") tts_api_key_yaml = get_from_yaml(["tts_provider", "api_key"]) diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 2f0c0aa..49b30f7 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -36,4 +36,6 @@ # The provider for the TTS service. provider: "google_genai" # The name of the voice to use for TTS. - voice_name: "Kore" \ No newline at end of file + voice_name: "Kore" + # The model name for the TTS service. + model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 98b191f..176af5d 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -17,13 +17,15 @@ "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat" ] - def __init__(self, api_key: str, voice_name: str = "Kore"): + def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"): if voice_name not in self.AVAILABLE_VOICES: raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}") self.api_key = api_key - self.api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" + # The API URL is now a f-string that includes the configurable model name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name + self.model_name = model_name async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: headers = { @@ -46,7 +48,8 @@ } } }, - "model": "gemini-2.5-flash-preview-tts" + # The model is now configurable via the instance variable + "model": self.model_name } async with aiohttp.ClientSession() as session: @@ -57,4 +60,4 @@ inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] audio_bytes = base64.b64decode(inline_data) - yield audio_bytes \ No newline at end of file + yield audio_bytes diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index 8e7a120..65062f6 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -9,10 +9,23 @@ # You can override the default with a command-line argument, e.g., './run_integration_tests.sh tests/test_app.py' TEST_PATH=${1:-$DEFAULT_TEST_PATH} -DB_MODE=sqlite +export DB_MODE=sqlite export LOCAL_DB_PATH="data/integration_test_ai_hub.db" export FAISS_INDEX_PATH="data/integration_test_faiss_index.bin" +# --- Pre-test Cleanup --- +# Check for and remove old test files to ensure a clean test environment. +echo "--- Checking for and removing old test files ---" +if [ -f "$LOCAL_DB_PATH" ]; then + echo "Removing old database file: $LOCAL_DB_PATH" + rm "$LOCAL_DB_PATH" +fi +if [ -f "$FAISS_INDEX_PATH" ]; then + echo "Removing old FAISS index file: $FAISS_INDEX_PATH" + rm "$FAISS_INDEX_PATH" +fi +echo "Cleanup complete." + echo "--- Starting AI Hub Server for Tests ---" # Start the uvicorn server in the background @@ -49,4 +62,4 @@ # The 'trap' will automatically call the cleanup function now. # Exit with the same code as the test script (0 for success, non-zero for failure). -exit $TEST_EXIT_CODE \ No newline at end of file +exit $TEST_EXIT_CODE diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py index 02ca463..47023ed 100644 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ b/ai-hub/tests/core/providers/tts/test_gemini.py @@ -4,8 +4,7 @@ import base64 from aioresponses import aioresponses from app.core.providers.tts.gemini import GeminiTTSProvider - -# Note: The mock_aioresponse fixture is not needed and has been removed. +from app.core.providers.base import TTSProvider @pytest.mark.asyncio async def test_generate_speech_success(): @@ -14,6 +13,7 @@ """ api_key = "test_api_key" text_to_speak = "Hello, world!" + model_name = "gemini-2.5-flash-preview-tts" # Create a dummy base64 encoded audio response dummy_audio_bytes = b"This is a test audio stream." @@ -33,7 +33,7 @@ } # Configure aioresponses to intercept the API call and return our mock data - tts_provider = GeminiTTSProvider(api_key=api_key) + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) with aioresponses() as m: m.post( tts_provider.api_url, @@ -52,16 +52,34 @@ assert len(audio_chunks) == 1 assert audio_chunks[0] == dummy_audio_bytes -# The other tests for __init__ are not affected and can remain as they are. def test_init_with_valid_voice_name(): + """ + Tests that initialization succeeds with a valid voice name. + """ api_key = "test_api_key" voice_name = "Zephyr" tts_provider = GeminiTTSProvider(api_key=api_key, voice_name=voice_name) assert tts_provider.api_key == api_key assert tts_provider.voice_name == voice_name + assert tts_provider.model_name == "gemini-2.5-flash-preview-tts" + assert "gemini-2.5-flash-preview-tts" in tts_provider.api_url def test_init_with_invalid_voice_name(): + """ + Tests that initialization fails with an invalid voice name. + """ api_key = "test_api_key" invalid_voice_name = "InvalidVoice" with pytest.raises(ValueError, match="Invalid voice name"): - GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) \ No newline at end of file + GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) + +def test_init_with_custom_model_name(): + """ + Tests that the provider can be initialized with a custom model name. + """ + api_key = "test_api_key" + custom_model_name = "gemini-tts-beta" + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) + assert tts_provider.model_name == custom_model_name + assert custom_model_name in tts_provider.api_url + diff --git a/ai-hub/tests/test_config.py b/ai-hub/tests/test_config.py index 5687698..4ded74f 100644 --- a/ai-hub/tests/test_config.py +++ b/ai-hub/tests/test_config.py @@ -31,6 +31,7 @@ "tts_provider": { "provider": "google_genai", "voice_name": "Laomedeia", + "model_name": "tts-model-from-yaml", # Added configurable model name "api_key": "tts-api-from-yaml" } } @@ -51,6 +52,7 @@ monkeypatch.delenv("EMBEDDING_API_KEY", raising=False) monkeypatch.delenv("TTS_PROVIDER", raising=False) monkeypatch.delenv("TTS_VOICE_NAME", raising=False) + monkeypatch.delenv("TTS_MODEL_NAME", raising=False) # Added for the new setting monkeypatch.delenv("TTS_API_KEY", raising=False) monkeypatch.delenv("DB_MODE", raising=False) monkeypatch.delenv("LOCAL_DB_PATH", raising=False) @@ -151,6 +153,7 @@ settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Laomedeia" + assert settings.TTS_MODEL_NAME == "tts-model-from-yaml" # Test for new model name assert settings.TTS_API_KEY == "tts-api-from-yaml" @@ -160,6 +163,7 @@ # Explicitly set all TTS env vars for this test monkeypatch.setenv("TTS_PROVIDER", "google_genai") monkeypatch.setenv("TTS_VOICE_NAME", "Zephyr") + monkeypatch.setenv("TTS_MODEL_NAME", "env-tts-model") # Added for the new setting monkeypatch.setenv("TTS_API_KEY", "env_tts_key") monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") monkeypatch.setenv("GEMINI_API_KEY", "mock_key") @@ -167,6 +171,7 @@ settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Zephyr" + assert settings.TTS_MODEL_NAME == "env-tts-model" # Assert new setting is loaded assert settings.TTS_API_KEY == "env_tts_key" @@ -178,4 +183,5 @@ assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Kore" + assert settings.TTS_MODEL_NAME == "gemini-2.5-flash-preview-tts" # Assert default value assert settings.TTS_API_KEY is None