Newer
Older
cortex-hub / ai-hub / app / core / services / tts.py
import io
import wave
import logging
import asyncio
from typing import AsyncGenerator
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
import os

# --- Configure logging ---
logger = logging.getLogger(__name__)

# --- Define TTS Service Class ---
class TTSService:
    """
    Service class for generating speech from text using a TTS provider.
    This version is designed to handle both streaming and non-streaming
    audio generation, splitting text into manageable chunks.
    """
    
    # Use an environment variable or a default value for the max chunk size
    MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 200))

    def __init__(self, tts_provider: TTSProvider):
        """
        Initializes the TTSService with a concrete TTS provider.
        """
        self.tts_provider = tts_provider

    async def _split_text_into_chunks(self, text: str) -> list[str]:
        """
        Splits the input text into chunks based on a maximum size and
        period delimiters, ensuring no chunk exceeds the limit.
        """
        chunks = []
        current_chunk = ""
        # Use a list of punctuation to split sentences more effectively
        separators = ['.', '?', '!', '\n']
        sentences = []
        
        # Split text by multiple delimiters
        for separator in separators:
            text = text.replace(separator, f"{separator}|")
        sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()]

        # Re-join sentences with their delimiters, so we don't lose them
        for sentence in sentences_with_empty:
            sentences.append(sentence)

        for sentence in sentences:
            # Add the sentence and check if it exceeds the chunk size.
            if len(current_chunk) + len(sentence) + 1 > self.MAX_CHUNK_SIZE and current_chunk:
                chunks.append(current_chunk.strip())
                current_chunk = sentence + " "
            else:
                current_chunk += sentence + " "

        if current_chunk:
            chunks.append(current_chunk.strip())
        
        logger.debug(f"Split text into {len(chunks)} chunks.")
        return chunks

    async def create_speech_stream(self, text: str) -> AsyncGenerator[bytes, None]:
        """
        Generates a stream of complete, playable WAV files for each text chunk.
        This provides a streaming-like experience even with a non-streaming backend
        by sending each chunk as soon as it is generated.
        """
        chunks = await self._split_text_into_chunks(text)

        for i, chunk in enumerate(chunks):
            logger.info(f"Processing chunk {i+1}/{len(chunks)} for streaming...")

            try:
                # Get the raw PCM audio data for this chunk
                pcm_data = await self.tts_provider.generate_speech(chunk)

                # Wrap the PCM data in a WAV header to make it a playable file
                with io.BytesIO() as wav_buffer:
                    with wave.open(wav_buffer, 'wb') as wav_file:
                        wav_file.setnchannels(1)
                        wav_file.setsampwidth(2)
                        wav_file.setframerate(24000) 
                        wav_file.writeframes(pcm_data)
                    
                    # Yield a complete, playable WAV file for the chunk
                    yield wav_buffer.getvalue()

            except Exception as e:
                logger.error(f"Error processing chunk {i+1}: {e}")
                raise HTTPException(
                    status_code=500,
                    detail=f"Error generating speech for chunk {i+1}: {e}"
                ) from e

    async def create_speech_non_stream(self, text: str) -> bytes:
        """
        Generates a complete audio file from the given text, splitting it
        into chunks and concatenating the audio into a single WAV file.
        All chunks are processed concurrently for speed.
        """
        chunks = await self._split_text_into_chunks(text)
        
        all_pcm_data = []
        
        # Create a list of tasks for each chunk to run them concurrently.
        tasks = [self.tts_provider.generate_speech(chunk) for chunk in chunks]
        
        try:
            # Gather the results from all tasks. This will run all API calls
            # to the TTS provider concurrently.
            all_pcm_data = await asyncio.gather(*tasks)
            logger.info(f"Successfully gathered audio data for all {len(chunks)} chunks.")
        except Exception as e:
            logger.error(f"An error occurred while gathering audio chunks: {e}")
            raise HTTPException(
                status_code=500,
                detail=f"An error occurred while generating audio: {e}"
            ) from e

        if not all_pcm_data:
            logger.warning("No audio data was generated.")
            raise HTTPException(status_code=500, detail="No audio data was generated from the TTS provider.")

        # Concatenate all the raw PCM data into a single stream
        concatenated_pcm = b''.join(all_pcm_data)
        logger.info(f"Concatenated {len(chunks)} chunks into a single PCM stream.")

        # Wrap the complete PCM stream in a single WAV container
        with io.BytesIO() as wav_buffer:
            with wave.open(wav_buffer, 'wb') as wav_file:
                wav_file.setnchannels(1)
                wav_file.setsampwidth(2)
                # The Gemini API returns 24kHz audio, adjust if using a different provider
                wav_file.setframerate(24000) 
                wav_file.writeframes(concatenated_pcm)
            
            return wav_buffer.getvalue()