diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a6ec166..dced2c7 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -120,10 +120,8 @@ services.with_rag_service(retrievers=retrievers) services.with_document_service(vector_store=vector_store) - if stt_provider: - services.with_service("stt_service", service=STTService(stt_provider=stt_provider)) - if tts_provider: - services.with_service("tts_service", service=TTSService(tts_provider=tts_provider)) + services.with_service("stt_service", service=STTService(stt_provider=stt_provider)) + services.with_service("tts_service", service=TTSService(tts_provider=tts_provider)) services.with_service("workspace_service", service=WorkspaceService()) services.with_service("session_service", service=SessionService()) diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py index 7325f2d..fed6830 100644 --- a/ai-hub/app/core/services/stt.py +++ b/ai-hub/app/core/services/stt.py @@ -10,9 +10,9 @@ Service class for transcribing audio into text using an STT provider. """ - def __init__(self, stt_provider: STTProvider): + def __init__(self, stt_provider: STTProvider = None): """ - Initializes the STTService with a concrete STT provider. + Initializes the STTService with an optional STT provider. """ self.default_stt_provider = stt_provider @@ -21,6 +21,8 @@ Transcribes the provided audio bytes into text using the STT provider. """ provider = provider_override or self.default_stt_provider + if not provider: + raise HTTPException(status_code=400, detail="No active STT provider is configured.") logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") if not audio_bytes: diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index d9b5929..5711d88 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -33,7 +33,7 @@ MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 500)) - def __init__(self, tts_provider: TTSProvider): + def __init__(self, tts_provider: TTSProvider = None): self.default_tts_provider = tts_provider async def _split_text_into_chunks(self, text: str) -> list[str]: @@ -69,6 +69,8 @@ async def _generate_pcm_chunks(self, text: str, provider_override: TTSProvider = None) -> AsyncGenerator[bytes, None]: chunks = await self._split_text_into_chunks(text) provider = provider_override or self.default_tts_provider + if not provider: + raise HTTPException(status_code=400, detail="No active TTS provider is configured.") for i, chunk in enumerate(chunks): logger.info(f"Generating PCM for chunk {i+1}/{len(chunks)}: '{chunk[:30]}...'") @@ -97,6 +99,8 @@ chunks = await self._split_text_into_chunks(text) semaphore = asyncio.Semaphore(3) # Limit concurrency to 3 requests provider = provider_override or self.default_tts_provider + if not provider: + raise HTTPException(status_code=400, detail="No active TTS provider is configured.") async def generate_with_limit(chunk): retries = 3