diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py deleted file mode 100644 index 435ce3d..0000000 --- a/ai-hub/integration_tests/test_sessions.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -# Test prompts and data -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -@pytest.mark.asyncio -async def test_chat_in_session_lifecycle(http_client): - """ - Tests a full session lifecycle from creation to conversational memory. - This test is a single, sequential unit. - """ - print("\n--- Running test_chat_in_session_lifecycle ---") - - # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions/", json=payload) - assert response.status_code == 200 - session_id = response.json()["id"] - print(f"✅ Session created successfully with ID: {session_id}") - - # 2. First chat turn to establish context - chat_payload_1 = {"prompt": CONTEXT_PROMPT} - response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) - assert response_1.status_code == 200 - assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - - # 3. Second chat turn (follow-up) to test conversational memory - chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} - response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) - assert response_2.status_code == 200 - assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - - # 4. Cleanup (optional, but good practice if not using a test database that resets) - # The session data would typically be cleaned up by the database teardown. - -@pytest.mark.asyncio -async def test_chat_with_model_switch(http_client, session_id): - """Tests switching models within an existing session.""" - print("\n--- Running test_chat_with_model_switch ---") - - # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} - response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) - assert response_gemini.status_code == 200 - assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - - # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} - response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) - assert response_deepseek.status_code == 200 - assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -@pytest.mark.asyncio -async def test_chat_with_document_retrieval(http_client): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This test creates its own session and document for isolation. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - - # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash - add_doc_response = await http_client.post("/documents/", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", - "load_faiss_retriever": True - } - chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200 - chat_data = chat_response.json() - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await http_client.delete(f"/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py deleted file mode 100644 index 435ce3d..0000000 --- a/ai-hub/integration_tests/test_sessions.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -# Test prompts and data -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -@pytest.mark.asyncio -async def test_chat_in_session_lifecycle(http_client): - """ - Tests a full session lifecycle from creation to conversational memory. - This test is a single, sequential unit. - """ - print("\n--- Running test_chat_in_session_lifecycle ---") - - # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions/", json=payload) - assert response.status_code == 200 - session_id = response.json()["id"] - print(f"✅ Session created successfully with ID: {session_id}") - - # 2. First chat turn to establish context - chat_payload_1 = {"prompt": CONTEXT_PROMPT} - response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) - assert response_1.status_code == 200 - assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - - # 3. Second chat turn (follow-up) to test conversational memory - chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} - response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) - assert response_2.status_code == 200 - assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - - # 4. Cleanup (optional, but good practice if not using a test database that resets) - # The session data would typically be cleaned up by the database teardown. - -@pytest.mark.asyncio -async def test_chat_with_model_switch(http_client, session_id): - """Tests switching models within an existing session.""" - print("\n--- Running test_chat_with_model_switch ---") - - # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} - response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) - assert response_gemini.status_code == 200 - assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - - # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} - response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) - assert response_deepseek.status_code == 200 - assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -@pytest.mark.asyncio -async def test_chat_with_document_retrieval(http_client): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This test creates its own session and document for isolation. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - - # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash - add_doc_response = await http_client.post("/documents/", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", - "load_faiss_retriever": True - } - chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200 - chat_data = chat_response.json() - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await http_client.delete(f"/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py new file mode 100644 index 0000000..435ce3d --- /dev/null +++ b/ai-hub/integration_tests/test_sessions_api.py @@ -0,0 +1,109 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session with a trailing slash + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions/", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py deleted file mode 100644 index 435ce3d..0000000 --- a/ai-hub/integration_tests/test_sessions.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -# Test prompts and data -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -@pytest.mark.asyncio -async def test_chat_in_session_lifecycle(http_client): - """ - Tests a full session lifecycle from creation to conversational memory. - This test is a single, sequential unit. - """ - print("\n--- Running test_chat_in_session_lifecycle ---") - - # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions/", json=payload) - assert response.status_code == 200 - session_id = response.json()["id"] - print(f"✅ Session created successfully with ID: {session_id}") - - # 2. First chat turn to establish context - chat_payload_1 = {"prompt": CONTEXT_PROMPT} - response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) - assert response_1.status_code == 200 - assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - - # 3. Second chat turn (follow-up) to test conversational memory - chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} - response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) - assert response_2.status_code == 200 - assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - - # 4. Cleanup (optional, but good practice if not using a test database that resets) - # The session data would typically be cleaned up by the database teardown. - -@pytest.mark.asyncio -async def test_chat_with_model_switch(http_client, session_id): - """Tests switching models within an existing session.""" - print("\n--- Running test_chat_with_model_switch ---") - - # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} - response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) - assert response_gemini.status_code == 200 - assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - - # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} - response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) - assert response_deepseek.status_code == 200 - assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -@pytest.mark.asyncio -async def test_chat_with_document_retrieval(http_client): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This test creates its own session and document for isolation. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - - # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash - add_doc_response = await http_client.post("/documents/", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", - "load_faiss_retriever": True - } - chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200 - chat_data = chat_response.json() - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await http_client.delete(f"/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py new file mode 100644 index 0000000..435ce3d --- /dev/null +++ b/ai-hub/integration_tests/test_sessions_api.py @@ -0,0 +1,109 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session with a trailing slash + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions/", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index f3f25ac..3d48c03 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -26,23 +26,38 @@ # --- NEW Tests for TTS Provider --- def test_get_tts_provider_returns_gemini_tts_provider(): - """Tests that the factory returns a GeminiTTSProvider instance for 'google_genai'.""" - # Use a dummy key for testing - provider = get_tts_provider("google_genai", api_key="dummy_key") + """Tests that the factory returns a GeminiTTSProvider instance for 'google_gemini'.""" + # Use a valid voice from AVAILABLE_VOICES to avoid ValueError + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] + provider = get_tts_provider( + "google_gemini", + api_key="dummy_key", + voice_name=valid_voice + ) assert isinstance(provider, GeminiTTSProvider) assert provider.api_key == "dummy_key" + assert provider.voice_name == valid_voice def test_get_tts_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported TTS provider name.""" + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] with pytest.raises(ValueError, match="Unsupported TTS provider: 'unknown'"): - get_tts_provider("unknown", api_key="dummy_key") + get_tts_provider( + "unknown", + api_key="dummy_key", + voice_name=valid_voice + ) # --- NEW Tests for STT Provider --- def test_get_stt_provider_returns_google_stt_provider(): """Tests that the factory returns a GoogleSTTProvider instance for 'google_gemini'.""" - provider = get_stt_provider("google_gemini", api_key="dummy_key", model_name="dummy-model") + provider = get_stt_provider( + "google_gemini", + api_key="dummy_key", + model_name="dummy-model" + ) assert isinstance(provider, GoogleSTTProvider) assert provider.api_key == "dummy_key" assert provider.model_name == "dummy-model" @@ -50,4 +65,8 @@ def test_get_stt_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported STT provider name.""" with pytest.raises(ValueError, match="Unsupported STT provider: 'unknown'"): - get_stt_provider("unknown", api_key="dummy_key", model_name="dummy-model") \ No newline at end of file + get_stt_provider( + "unknown", + api_key="dummy_key", + model_name="dummy-model" + ) diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py deleted file mode 100644 index 435ce3d..0000000 --- a/ai-hub/integration_tests/test_sessions.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -# Test prompts and data -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -@pytest.mark.asyncio -async def test_chat_in_session_lifecycle(http_client): - """ - Tests a full session lifecycle from creation to conversational memory. - This test is a single, sequential unit. - """ - print("\n--- Running test_chat_in_session_lifecycle ---") - - # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions/", json=payload) - assert response.status_code == 200 - session_id = response.json()["id"] - print(f"✅ Session created successfully with ID: {session_id}") - - # 2. First chat turn to establish context - chat_payload_1 = {"prompt": CONTEXT_PROMPT} - response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) - assert response_1.status_code == 200 - assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - - # 3. Second chat turn (follow-up) to test conversational memory - chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} - response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) - assert response_2.status_code == 200 - assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - - # 4. Cleanup (optional, but good practice if not using a test database that resets) - # The session data would typically be cleaned up by the database teardown. - -@pytest.mark.asyncio -async def test_chat_with_model_switch(http_client, session_id): - """Tests switching models within an existing session.""" - print("\n--- Running test_chat_with_model_switch ---") - - # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} - response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) - assert response_gemini.status_code == 200 - assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - - # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} - response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) - assert response_deepseek.status_code == 200 - assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -@pytest.mark.asyncio -async def test_chat_with_document_retrieval(http_client): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This test creates its own session and document for isolation. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - - # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash - add_doc_response = await http_client.post("/documents/", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", - "load_faiss_retriever": True - } - chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200 - chat_data = chat_response.json() - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await http_client.delete(f"/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py new file mode 100644 index 0000000..435ce3d --- /dev/null +++ b/ai-hub/integration_tests/test_sessions_api.py @@ -0,0 +1,109 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session with a trailing slash + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions/", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index f3f25ac..3d48c03 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -26,23 +26,38 @@ # --- NEW Tests for TTS Provider --- def test_get_tts_provider_returns_gemini_tts_provider(): - """Tests that the factory returns a GeminiTTSProvider instance for 'google_genai'.""" - # Use a dummy key for testing - provider = get_tts_provider("google_genai", api_key="dummy_key") + """Tests that the factory returns a GeminiTTSProvider instance for 'google_gemini'.""" + # Use a valid voice from AVAILABLE_VOICES to avoid ValueError + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] + provider = get_tts_provider( + "google_gemini", + api_key="dummy_key", + voice_name=valid_voice + ) assert isinstance(provider, GeminiTTSProvider) assert provider.api_key == "dummy_key" + assert provider.voice_name == valid_voice def test_get_tts_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported TTS provider name.""" + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] with pytest.raises(ValueError, match="Unsupported TTS provider: 'unknown'"): - get_tts_provider("unknown", api_key="dummy_key") + get_tts_provider( + "unknown", + api_key="dummy_key", + voice_name=valid_voice + ) # --- NEW Tests for STT Provider --- def test_get_stt_provider_returns_google_stt_provider(): """Tests that the factory returns a GoogleSTTProvider instance for 'google_gemini'.""" - provider = get_stt_provider("google_gemini", api_key="dummy_key", model_name="dummy-model") + provider = get_stt_provider( + "google_gemini", + api_key="dummy_key", + model_name="dummy-model" + ) assert isinstance(provider, GoogleSTTProvider) assert provider.api_key == "dummy_key" assert provider.model_name == "dummy-model" @@ -50,4 +65,8 @@ def test_get_stt_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported STT provider name.""" with pytest.raises(ValueError, match="Unsupported STT provider: 'unknown'"): - get_stt_provider("unknown", api_key="dummy_key", model_name="dummy-model") \ No newline at end of file + get_stt_provider( + "unknown", + api_key="dummy_key", + model_name="dummy-model" + ) diff --git a/ai-hub/tests/core/vector_store/test_embedder_factory.py b/ai-hub/tests/core/vector_store/test_embedder_factory.py index 7413376..e7c3875 100644 --- a/ai-hub/tests/core/vector_store/test_embedder_factory.py +++ b/ai-hub/tests/core/vector_store/test_embedder_factory.py @@ -9,7 +9,7 @@ def test_get_genai_embedder(): embedder = get_embedder_from_config( - EmbeddingProvider.GOOGLE_GENAI, 768, "gemini-embedding-001", "fake_key" + EmbeddingProvider.GOOGLE_GEMINI, 768, "gemini-embedding-001", "fake_key" ) assert embedder.model_name == "gemini-embedding-001" assert embedder.api_key == "fake_key" diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a4ea604..0a70fb3 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -80,7 +80,8 @@ # 4. Get the concrete TTS provider from the factory tts_provider = get_tts_provider( provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY + api_key=settings.TTS_API_KEY, + voice_name=settings.TTS_VOICE_NAME ) # 5. Initialize the TTSService diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index f2de94e..858c8ad 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -12,12 +12,12 @@ class EmbeddingProvider(str, Enum): """An enum for supported embedding providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" MOCK = "mock" class TTSProvider(str, Enum): """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GENAI = "google_genai" + GOOGLE_GEMINI = "google_gemini" class STTProvider(str, Enum): """An enum for supported Speech-to-Text (STT) providers.""" @@ -39,12 +39,12 @@ gemini_model_name: str = "gemini-1.5-flash-latest" class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GENAI) + provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GENAI) + provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) voice_name: str = "Kore" model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index de9d8dd..99ec019 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -27,16 +27,16 @@ embedding_dimension: 768 embedding_provider: - # The provider for the embedding service. Can be "google_genai" or "mock". - provider: "google_genai" + # The provider for the embedding service. Can be "google_gemini" or "mock". + provider: "google_gemini" # The model name for the embedding service. model_name: "gemini-embedding-001" tts_provider: # The provider for the TTS service. - provider: "google_genai" + provider: "google_gemini" # The name of the voice to use for TTS. - voice_name: "Kore" + voice_name: "Zephyr" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e3aec9b..6cfbed0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -24,10 +24,10 @@ raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}") return provider -def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: - if provider_name == "google_genai": - return GeminiTTSProvider(api_key=api_key) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") +def get_tts_provider(provider_name: str, api_key: str, voice_name: str) -> TTSProvider: + if provider_name == "google_gemini": + return GeminiTTSProvider(api_key=api_key, voice_name = voice_name) + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index d958d3f..64fca80 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -3,7 +3,7 @@ from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GENAI: + if provider == EmbeddingProvider.GOOGLE_GEMINI: return GenAIEmbedder(model_name, api_key, dimension) elif provider == EmbeddingProvider.MOCK: return MockEmbedder(dimension) diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py deleted file mode 100644 index ea3ecb1..0000000 --- a/ai-hub/integration_tests/test_documents.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -@pytest.mark.asyncio -async def test_document_lifecycle(http_client): - """ - Tests the full lifecycle of a document: add, list, and delete. - This is run as a single, sequential test for a clean state. - """ - print("\n--- Running test_document_lifecycle ---") - - # 1. Add a new document - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - # Correct the URL to include the trailing slash to avoid the 307 redirect - add_response = await http_client.post("/documents/", json=doc_data) - assert add_response.status_code == 200 - try: - message = add_response.json().get("message", "") - document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {document_id}") - - # 2. List all documents and check if the new document is present - # Correct the URL to include the trailing slash - list_response = await http_client.get("/documents/") - assert list_response.status_code == 200 - ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} - assert document_id in ids_in_response - print("✅ Document list test passed.") - - # 3. Delete the document - delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 - assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") - diff --git a/ai-hub/integration_tests/test_documents_api.py b/ai-hub/integration_tests/test_documents_api.py new file mode 100644 index 0000000..ea3ecb1 --- /dev/null +++ b/ai-hub/integration_tests/test_documents_api.py @@ -0,0 +1,36 @@ +import pytest + +@pytest.mark.asyncio +async def test_document_lifecycle(http_client): + """ + Tests the full lifecycle of a document: add, list, and delete. + This is run as a single, sequential test for a clean state. + """ + print("\n--- Running test_document_lifecycle ---") + + # 1. Add a new document + doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) + assert add_response.status_code == 200 + try: + message = add_response.json().get("message", "") + document_id = int(message.split(" with ID ")[-1]) + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + print(f"✅ Document for lifecycle test created with ID: {document_id}") + + # 2. List all documents and check if the new document is present + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") + assert list_response.status_code == 200 + ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} + assert document_id in ids_in_response + print("✅ Document list test passed.") + + # 3. Delete the document + delete_response = await http_client.delete(f"/documents/{document_id}") + assert delete_response.status_code == 200 + assert delete_response.json()["document_id"] == document_id + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_misc.py b/ai-hub/integration_tests/test_misc.py deleted file mode 100644 index 552a948..0000000 --- a/ai-hub/integration_tests/test_misc.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import httpx -import wave -import io - -@pytest.mark.asyncio -async def test_root_endpoint(http_client): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - response = await http_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -@pytest.mark.asyncio -async def test_create_speech_stream(http_client): - """ - Tests the /speech endpoint for a successful audio stream response. - """ - print("\n--- Running test_create_speech_stream ---") - url = "/speech" - payload = {"text": "Hello, world!"} - - # The `stream=True` parameter tells httpx to not read the entire response body - # at once. We'll handle it manually to check for content. - async with http_client.stream("POST", url, json=payload) as response: - assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" - assert response.headers.get("content-type") == "audio/wav" - - # Check that the response body is not empty by iterating over chunks. - content_length = 0 - async for chunk in response.aiter_bytes(): - content_length += len(chunk) - - assert content_length > 0 - print("✅ TTS stream test passed.") - -@pytest.mark.asyncio -async def test_stt_transcribe_endpoint(http_client): - """ - Tests the /stt/transcribe endpoint by uploading a dummy audio file - and verifying the transcription response. - """ - print("\n--- Running test_stt_transcribe_endpoint ---") - url = "/stt/transcribe" - - # --- Use a real audio file from the integration test data --- - audio_file_path = "integration_tests/test_data/test-audio.wav" - - with open(audio_file_path, "rb") as audio_file: - files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} - - # --- Send the POST request to the endpoint --- - response = await http_client.post(url, files=files) - - # --- Assertions --- - assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" - response_json = response.json() - assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." - assert isinstance(response_json["transcript"], str), "Transcript value is not a string." - - # Assert that the transcript matches the expected text - expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." - assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" - - print("✅ STT transcription test passed.") - diff --git a/ai-hub/integration_tests/test_misc_api.py b/ai-hub/integration_tests/test_misc_api.py new file mode 100644 index 0000000..552a948 --- /dev/null +++ b/ai-hub/integration_tests/test_misc_api.py @@ -0,0 +1,69 @@ +import pytest +import httpx +import wave +import io + +@pytest.mark.asyncio +async def test_root_endpoint(http_client): + """ + Tests if the root endpoint is alive and returns the correct status message. + """ + print("\n--- Running test_root_endpoint ---") + response = await http_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} + print("✅ Root endpoint test passed.") + +@pytest.mark.asyncio +async def test_create_speech_stream(http_client): + """ + Tests the /speech endpoint for a successful audio stream response. + """ + print("\n--- Running test_create_speech_stream ---") + url = "/speech" + payload = {"text": "Hello, world!"} + + # The `stream=True` parameter tells httpx to not read the entire response body + # at once. We'll handle it manually to check for content. + async with http_client.stream("POST", url, json=payload) as response: + assert response.status_code == 200, f"Speech stream request failed. Response: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + + # Check that the response body is not empty by iterating over chunks. + content_length = 0 + async for chunk in response.aiter_bytes(): + content_length += len(chunk) + + assert content_length > 0 + print("✅ TTS stream test passed.") + +@pytest.mark.asyncio +async def test_stt_transcribe_endpoint(http_client): + """ + Tests the /stt/transcribe endpoint by uploading a dummy audio file + and verifying the transcription response. + """ + print("\n--- Running test_stt_transcribe_endpoint ---") + url = "/stt/transcribe" + + # --- Use a real audio file from the integration test data --- + audio_file_path = "integration_tests/test_data/test-audio.wav" + + with open(audio_file_path, "rb") as audio_file: + files = {'audio_file': ('test-audio.wav', audio_file, 'audio/wav')} + + # --- Send the POST request to the endpoint --- + response = await http_client.post(url, files=files) + + # --- Assertions --- + assert response.status_code == 200, f"STT request failed with status code {response.status_code}. Response: {response.text}" + response_json = response.json() + assert "transcript" in response_json, "Response JSON is missing the 'transcript' key." + assert isinstance(response_json["transcript"], str), "Transcript value is not a string." + + # Assert that the transcript matches the expected text + expected_transcript = "This audio is for integration testing of Cortex Hub, which is a wonderful project." + assert response_json["transcript"] == expected_transcript, f"Expected: '{expected_transcript}', Got: '{response_json['transcript']}'" + + print("✅ STT transcription test passed.") + diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py deleted file mode 100644 index 435ce3d..0000000 --- a/ai-hub/integration_tests/test_sessions.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest - -# Test prompts and data -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -@pytest.mark.asyncio -async def test_chat_in_session_lifecycle(http_client): - """ - Tests a full session lifecycle from creation to conversational memory. - This test is a single, sequential unit. - """ - print("\n--- Running test_chat_in_session_lifecycle ---") - - # 1. Create a new session with a trailing slash - payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions/", json=payload) - assert response.status_code == 200 - session_id = response.json()["id"] - print(f"✅ Session created successfully with ID: {session_id}") - - # 2. First chat turn to establish context - chat_payload_1 = {"prompt": CONTEXT_PROMPT} - response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) - assert response_1.status_code == 200 - assert "Satya Nadella" in response_1.json()["answer"] - assert response_1.json()["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - - # 3. Second chat turn (follow-up) to test conversational memory - chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} - response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) - assert response_2.status_code == 200 - assert "1967" in response_2.json()["answer"] - assert response_2.json()["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - - # 4. Cleanup (optional, but good practice if not using a test database that resets) - # The session data would typically be cleaned up by the database teardown. - -@pytest.mark.asyncio -async def test_chat_with_model_switch(http_client, session_id): - """Tests switching models within an existing session.""" - print("\n--- Running test_chat_with_model_switch ---") - - # Send a message to the new session with a different model - payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} - response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) - assert response_gemini.status_code == 200 - assert "Paris" in response_gemini.json()["answer"] - assert response_gemini.json()["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - - # Switch back to the original model - payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} - response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) - assert response_deepseek.status_code == 200 - assert "Pacific Ocean" in response_deepseek.json()["answer"] - assert response_deepseek.json()["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -@pytest.mark.asyncio -async def test_chat_with_document_retrieval(http_client): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This test creates its own session and document for isolation. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - - # Create a new session for this RAG test - # Corrected URL with a trailing slash - session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - # Corrected URL with a trailing slash - add_doc_response = await http_client.post("/documents/", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", - "load_faiss_retriever": True - } - chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200 - chat_data = chat_response.json() - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await http_client.delete(f"/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/integration_tests/test_sessions_api.py b/ai-hub/integration_tests/test_sessions_api.py new file mode 100644 index 0000000..435ce3d --- /dev/null +++ b/ai-hub/integration_tests/test_sessions_api.py @@ -0,0 +1,109 @@ +import pytest + +# Test prompts and data +CONTEXT_PROMPT = "Who is the CEO of Microsoft?" +FOLLOW_UP_PROMPT = "When was he born?" +RAG_DOC_TITLE = "Fictional Company History" +RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." +RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" + +@pytest.mark.asyncio +async def test_chat_in_session_lifecycle(http_client): + """ + Tests a full session lifecycle from creation to conversational memory. + This test is a single, sequential unit. + """ + print("\n--- Running test_chat_in_session_lifecycle ---") + + # 1. Create a new session with a trailing slash + payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} + response = await http_client.post("/sessions/", json=payload) + assert response.status_code == 200 + session_id = response.json()["id"] + print(f"✅ Session created successfully with ID: {session_id}") + + # 2. First chat turn to establish context + chat_payload_1 = {"prompt": CONTEXT_PROMPT} + response_1 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_1) + assert response_1.status_code == 200 + assert "Satya Nadella" in response_1.json()["answer"] + assert response_1.json()["model_used"] == "deepseek" + print("✅ Chat Turn 1 (context) test passed.") + + # 3. Second chat turn (follow-up) to test conversational memory + chat_payload_2 = {"prompt": FOLLOW_UP_PROMPT} + response_2 = await http_client.post(f"/sessions/{session_id}/chat", json=chat_payload_2) + assert response_2.status_code == 200 + assert "1967" in response_2.json()["answer"] + assert response_2.json()["model_used"] == "deepseek" + print("✅ Chat Turn 2 (follow-up) test passed.") + + # 4. Cleanup (optional, but good practice if not using a test database that resets) + # The session data would typically be cleaned up by the database teardown. + +@pytest.mark.asyncio +async def test_chat_with_model_switch(http_client, session_id): + """Tests switching models within an existing session.""" + print("\n--- Running test_chat_with_model_switch ---") + + # Send a message to the new session with a different model + payload_gemini = {"prompt": "What is the capital of France?", "model": "gemini"} + response_gemini = await http_client.post(f"/sessions/{session_id}/chat", json=payload_gemini) + assert response_gemini.status_code == 200 + assert "Paris" in response_gemini.json()["answer"] + assert response_gemini.json()["model_used"] == "gemini" + print("✅ Chat (Model Switch to Gemini) test passed.") + + # Switch back to the original model + payload_deepseek = {"prompt": "What is the largest ocean?", "model": "deepseek"} + response_deepseek = await http_client.post(f"/sessions/{session_id}/chat", json=payload_deepseek) + assert response_deepseek.status_code == 200 + assert "Pacific Ocean" in response_deepseek.json()["answer"] + assert response_deepseek.json()["model_used"] == "deepseek" + print("✅ Chat (Model Switch back to DeepSeek) test passed.") + +@pytest.mark.asyncio +async def test_chat_with_document_retrieval(http_client): + """ + Tests injecting a document and using it for retrieval-augmented generation. + This test creates its own session and document for isolation. + """ + print("\n--- Running test_chat_with_document_retrieval ---") + + # Create a new session for this RAG test + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) + assert session_response.status_code == 200 + rag_session_id = session_response.json()["id"] + + # Add a new document with specific content for retrieval + doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) + assert add_doc_response.status_code == 200 + try: + message = add_doc_response.json().get("message", "") + rag_document_id = int(message.split(" with ID ")[-1]) + print(f"Document for RAG created with ID: {rag_document_id}") + except (ValueError, IndexError): + pytest.fail("Could not parse document ID from response message.") + + try: + chat_payload = { + "prompt": RAG_PROMPT, + "document_id": rag_document_id, + "model": "deepseek", + "load_faiss_retriever": True + } + chat_response = await http_client.post(f"/sessions/{rag_session_id}/chat", json=chat_payload) + + assert chat_response.status_code == 200 + chat_data = chat_response.json() + assert "Jane Doe" in chat_data["answer"] + assert "Nexus" in chat_data["answer"] + print("✅ Chat with document retrieval test passed.") + finally: + # Clean up the document after the test + delete_response = await http_client.delete(f"/documents/{rag_document_id}") + assert delete_response.status_code == 200 + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index f3f25ac..3d48c03 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -26,23 +26,38 @@ # --- NEW Tests for TTS Provider --- def test_get_tts_provider_returns_gemini_tts_provider(): - """Tests that the factory returns a GeminiTTSProvider instance for 'google_genai'.""" - # Use a dummy key for testing - provider = get_tts_provider("google_genai", api_key="dummy_key") + """Tests that the factory returns a GeminiTTSProvider instance for 'google_gemini'.""" + # Use a valid voice from AVAILABLE_VOICES to avoid ValueError + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] + provider = get_tts_provider( + "google_gemini", + api_key="dummy_key", + voice_name=valid_voice + ) assert isinstance(provider, GeminiTTSProvider) assert provider.api_key == "dummy_key" + assert provider.voice_name == valid_voice def test_get_tts_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported TTS provider name.""" + valid_voice = GeminiTTSProvider.AVAILABLE_VOICES[0] with pytest.raises(ValueError, match="Unsupported TTS provider: 'unknown'"): - get_tts_provider("unknown", api_key="dummy_key") + get_tts_provider( + "unknown", + api_key="dummy_key", + voice_name=valid_voice + ) # --- NEW Tests for STT Provider --- def test_get_stt_provider_returns_google_stt_provider(): """Tests that the factory returns a GoogleSTTProvider instance for 'google_gemini'.""" - provider = get_stt_provider("google_gemini", api_key="dummy_key", model_name="dummy-model") + provider = get_stt_provider( + "google_gemini", + api_key="dummy_key", + model_name="dummy-model" + ) assert isinstance(provider, GoogleSTTProvider) assert provider.api_key == "dummy_key" assert provider.model_name == "dummy-model" @@ -50,4 +65,8 @@ def test_get_stt_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported STT provider name.""" with pytest.raises(ValueError, match="Unsupported STT provider: 'unknown'"): - get_stt_provider("unknown", api_key="dummy_key", model_name="dummy-model") \ No newline at end of file + get_stt_provider( + "unknown", + api_key="dummy_key", + model_name="dummy-model" + ) diff --git a/ai-hub/tests/core/vector_store/test_embedder_factory.py b/ai-hub/tests/core/vector_store/test_embedder_factory.py index 7413376..e7c3875 100644 --- a/ai-hub/tests/core/vector_store/test_embedder_factory.py +++ b/ai-hub/tests/core/vector_store/test_embedder_factory.py @@ -9,7 +9,7 @@ def test_get_genai_embedder(): embedder = get_embedder_from_config( - EmbeddingProvider.GOOGLE_GENAI, 768, "gemini-embedding-001", "fake_key" + EmbeddingProvider.GOOGLE_GEMINI, 768, "gemini-embedding-001", "fake_key" ) assert embedder.model_name == "gemini-embedding-001" assert embedder.api_key == "fake_key" diff --git a/ai-hub/tests/test_config.py b/ai-hub/tests/test_config.py index 2eca11c..b7c9647 100644 --- a/ai-hub/tests/test_config.py +++ b/ai-hub/tests/test_config.py @@ -27,7 +27,7 @@ "url": "postgresql://user:pass@host/dbname" }, "tts_provider": { - "provider": "google_genai", + "provider": "google_gemini", "voice_name": "Laomedeia", "model_name": "tts-model-from-yaml", "api_key": "tts-api-from-yaml" @@ -161,7 +161,7 @@ monkeypatch.setenv("GEMINI_API_KEY", "mock_key") settings = Settings() - assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI + assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GEMINI assert settings.TTS_VOICE_NAME == "Laomedeia" assert settings.TTS_MODEL_NAME == "tts-model-from-yaml" assert settings.TTS_API_KEY == "tts-api-from-yaml" @@ -170,7 +170,7 @@ def test_tts_settings_from_env(monkeypatch, tmp_config_file, clear_all_env): """Tests that TTS environment variables override the YAML file.""" monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - monkeypatch.setenv("TTS_PROVIDER", "google_genai") + monkeypatch.setenv("TTS_PROVIDER", "google_gemini") monkeypatch.setenv("TTS_VOICE_NAME", "Zephyr") monkeypatch.setenv("TTS_MODEL_NAME", "env-tts-model") monkeypatch.setenv("TTS_API_KEY", "env_tts_key") @@ -178,7 +178,7 @@ monkeypatch.setenv("GEMINI_API_KEY", "mock_key") settings = Settings() - assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI + assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GEMINI assert settings.TTS_VOICE_NAME == "Zephyr" assert settings.TTS_MODEL_NAME == "env-tts-model" assert settings.TTS_API_KEY == "env_tts_key" @@ -192,7 +192,7 @@ settings = Settings() - assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI + assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GEMINI assert settings.TTS_VOICE_NAME == "Kore" assert settings.TTS_MODEL_NAME == "gemini-2.5-flash-preview-tts" assert settings.TTS_API_KEY == "fallback_gemini_key"