import io import wave import logging import asyncio from typing import AsyncGenerator from app.core.providers.base import TTSProvider from fastapi import HTTPException import os # --- Configure logging --- logger = logging.getLogger(__name__) # --- Helper Functions --- def _create_wav_file(pcm_data: bytes) -> bytes: """ Wraps raw 16-bit PCM audio data in a WAV header. """ with io.BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(24000) wav_file.writeframes(pcm_data) return wav_buffer.getvalue() # --- Define TTS Service Class --- class TTSService: """ Service class for generating speech from text using a TTS provider. This version is designed to handle both streaming and non-streaming audio generation, splitting text into manageable chunks. """ MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600)) def __init__(self, tts_provider: TTSProvider): self.tts_provider = tts_provider async def _split_text_into_chunks(self, text: str) -> list[str]: chunks = [] current_chunk = "" separators = ['.', '?', '!', '\n'] sentences = [] for separator in separators: text = text.replace(separator, f"{separator}|") sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()] for sentence in sentences_with_empty: sentences.append(sentence) for sentence in sentences: if len(current_chunk) + len(sentence) + 1 > self.MAX_CHUNK_SIZE and current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " else: current_chunk += sentence + " " if current_chunk: chunks.append(current_chunk.strip()) logger.debug(f"Split text into {len(chunks)} chunks.") return chunks async def _generate_pcm_chunks(self, text: str) -> AsyncGenerator[bytes, None]: chunks = await self._split_text_into_chunks(text) for i, chunk in enumerate(chunks): logger.info(f"Generating PCM for chunk {i+1}/{len(chunks)}: '{chunk[:30]}...'") try: pcm_data = await self.tts_provider.generate_speech(chunk) yield pcm_data except Exception as e: logger.error(f"Error processing chunk {i+1}: {e}") raise HTTPException( status_code=500, detail=f"Error generating speech for chunk {i+1}: {e}" ) from e async def create_speech_stream(self, text: str, as_wav: bool = True) -> AsyncGenerator[bytes, None]: async for pcm_data in self._generate_pcm_chunks(text): if as_wav: yield _create_wav_file(pcm_data) else: yield pcm_data async def create_speech_non_stream(self, text: str) -> bytes: chunks = await self._split_text_into_chunks(text) semaphore = asyncio.Semaphore(3) # Limit concurrency to 3 requests async def generate_with_limit(chunk): retries = 3 delay = 1 async with semaphore: for attempt in range(retries): try: return await self.tts_provider.generate_speech(chunk) except HTTPException as e: if e.status_code == 429: logger.warning(f"429 Too Many Requests for chunk, retrying in {delay}s (attempt {attempt+1}/{retries})...") await asyncio.sleep(delay) delay *= 2 # exponential backoff else: raise raise HTTPException(status_code=429, detail="Too many requests after retries.") tasks = [generate_with_limit(chunk) for chunk in chunks] try: all_pcm_data = await asyncio.gather(*tasks) logger.info(f"Successfully gathered audio data for all {len(chunks)} chunks.") except Exception as e: logger.error(f"An error occurred while gathering audio chunks: {e}") raise HTTPException( status_code=500, detail=f"An error occurred while generating audio: {e}" ) from e if not all_pcm_data: logger.warning("No audio data was generated.") raise HTTPException(status_code=500, detail="No audio data was generated from the TTS provider.") concatenated_pcm = b''.join(all_pcm_data) logger.info(f"Concatenated {len(chunks)} chunks into a single PCM stream.") return _create_wav_file(concatenated_pcm)