import io import wave import logging import asyncio from typing import AsyncGenerator from app.core.providers.base import TTSProvider from fastapi import HTTPException import os # --- Configure logging --- logger = logging.getLogger(__name__) # --- Define TTS Service Class --- class TTSService: """ Service class for generating speech from text using a TTS provider. This version is designed to handle both streaming and non-streaming audio generation, splitting text into manageable chunks. """ # Use an environment variable or a default value for the max chunk size MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 200)) def __init__(self, tts_provider: TTSProvider): """ Initializes the TTSService with a concrete TTS provider. """ self.tts_provider = tts_provider async def _split_text_into_chunks(self, text: str) -> list[str]: """ Splits the input text into chunks based on a maximum size and period delimiters, ensuring no chunk exceeds the limit. """ chunks = [] current_chunk = "" # Use a list of punctuation to split sentences more effectively separators = ['.', '?', '!', '\n'] sentences = [] # Split text by multiple delimiters for separator in separators: text = text.replace(separator, f"{separator}|") sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()] # Re-join sentences with their delimiters, so we don't lose them for sentence in sentences_with_empty: sentences.append(sentence) for sentence in sentences: # Add the sentence and check if it exceeds the chunk size. if len(current_chunk) + len(sentence) + 1 > self.MAX_CHUNK_SIZE and current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " else: current_chunk += sentence + " " if current_chunk: chunks.append(current_chunk.strip()) logger.debug(f"Split text into {len(chunks)} chunks.") return chunks async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]: """ Generates a stream of complete, playable WAV files for each text chunk. This provides a streaming-like experience even with a non-streaming backend by sending each chunk as soon as it is generated. """ chunks = await self._split_text_into_chunks(text) for i, chunk in enumerate(chunks): logger.info(f"Processing chunk {i+1}/{len(chunks)} for streaming...") try: # Get the raw PCM audio data for this chunk pcm_data = await self.tts_provider.generate_speech(chunk) # Wrap the PCM data in a WAV header to make it a playable file with io.BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(24000) wav_file.writeframes(pcm_data) # Yield a complete, playable WAV file for the chunk yield wav_buffer.getvalue() except Exception as e: logger.error(f"Error processing chunk {i+1}: {e}") raise HTTPException( status_code=500, detail=f"Error generating speech for chunk {i+1}: {e}" ) from e async def create_speech_non_stream(self, text: str) -> bytes: """ Generates a complete audio file from the given text, splitting it into chunks and concatenating the audio into a single WAV file. All chunks are processed concurrently for speed. """ chunks = await self._split_text_into_chunks(text) all_pcm_data = [] # Create a list of tasks for each chunk to run them concurrently. tasks = [self.tts_provider.generate_speech(chunk) for chunk in chunks] try: # Gather the results from all tasks. This will run all API calls # to the TTS provider concurrently. all_pcm_data = await asyncio.gather(*tasks) logger.info(f"Successfully gathered audio data for all {len(chunks)} chunks.") except Exception as e: logger.error(f"An error occurred while gathering audio chunks: {e}") raise HTTPException( status_code=500, detail=f"An error occurred while generating audio: {e}" ) from e if not all_pcm_data: logger.warning("No audio data was generated.") raise HTTPException(status_code=500, detail="No audio data was generated from the TTS provider.") # Concatenate all the raw PCM data into a single stream concatenated_pcm = b''.join(all_pcm_data) logger.info(f"Concatenated {len(chunks)} chunks into a single PCM stream.") # Wrap the complete PCM stream in a single WAV container with io.BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) # The Gemini API returns 24kHz audio, adjust if using a different provider wav_file.setframerate(24000) wav_file.writeframes(concatenated_pcm) return wav_buffer.getvalue()