Newer
Older
cortex-hub / ai-hub / app / core / services / tts.py
import io
import wave
import logging
import asyncio
from typing import AsyncGenerator
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
import os

# --- Configure logging ---
logger = logging.getLogger(__name__)

# --- Helper Functions ---
def _create_wav_file(pcm_data: bytes) -> bytes:
    """
    Wraps raw 16-bit PCM audio data in a WAV header.
    """
    with io.BytesIO() as wav_buffer:
        with wave.open(wav_buffer, 'wb') as wav_file:
            wav_file.setnchannels(1)
            wav_file.setsampwidth(2)
            wav_file.setframerate(24000) 
            wav_file.writeframes(pcm_data)
        return wav_buffer.getvalue()

# --- Define TTS Service Class ---
class TTSService:
    """
    Service class for generating speech from text using a TTS provider.
    This version is designed to handle both streaming and non-streaming
    audio generation, splitting text into manageable chunks.
    """
    
    MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600))

    def __init__(self, tts_provider: TTSProvider):
        self.tts_provider = tts_provider

    async def _split_text_into_chunks(self, text: str) -> list[str]:
        chunks = []
        current_chunk = ""
        separators = ['.', '?', '!', '\n']
        sentences = []
        
        for separator in separators:
            text = text.replace(separator, f"{separator}|")
        sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()]

        for sentence in sentences_with_empty:
            sentences.append(sentence)

        for sentence in sentences:
            if len(current_chunk) + len(sentence) + 1 > self.MAX_CHUNK_SIZE and current_chunk:
                chunks.append(current_chunk.strip())
                current_chunk = sentence + " "
            else:
                current_chunk += sentence + " "

        if current_chunk:
            chunks.append(current_chunk.strip())
        
        logger.debug(f"Split text into {len(chunks)} chunks.")
        return chunks

    async def _generate_pcm_chunks(self, text: str) -> AsyncGenerator[bytes, None]:
        chunks = await self._split_text_into_chunks(text)
        
        for i, chunk in enumerate(chunks):
            logger.info(f"Generating PCM for chunk {i+1}/{len(chunks)}: '{chunk[:30]}...'")
            try:
                pcm_data = await self.tts_provider.generate_speech(chunk)
                yield pcm_data
            except Exception as e:
                logger.error(f"Error processing chunk {i+1}: {e}")
                raise HTTPException(
                    status_code=500,
                    detail=f"Error generating speech for chunk {i+1}: {e}"
                ) from e

    async def create_speech_stream(self, text: str, as_wav: bool = True) -> AsyncGenerator[bytes, None]:
        async for pcm_data in self._generate_pcm_chunks(text):
            if as_wav:
                yield _create_wav_file(pcm_data)
            else:
                yield pcm_data

    async def create_speech_non_stream(self, text: str) -> bytes:
        chunks = await self._split_text_into_chunks(text)
        semaphore = asyncio.Semaphore(3)  # Limit concurrency to 3 requests

        async def generate_with_limit(chunk):
            retries = 3
            delay = 1
            async with semaphore:
                for attempt in range(retries):
                    try:
                        return await self.tts_provider.generate_speech(chunk)
                    except HTTPException as e:
                        if e.status_code == 429:
                            logger.warning(f"429 Too Many Requests for chunk, retrying in {delay}s (attempt {attempt+1}/{retries})...")
                            await asyncio.sleep(delay)
                            delay *= 2  # exponential backoff
                        else:
                            raise
                raise HTTPException(status_code=429, detail="Too many requests after retries.")

        tasks = [generate_with_limit(chunk) for chunk in chunks]

        try:
            all_pcm_data = await asyncio.gather(*tasks)
            logger.info(f"Successfully gathered audio data for all {len(chunks)} chunks.")
        except Exception as e:
            logger.error(f"An error occurred while gathering audio chunks: {e}")
            raise HTTPException(
                status_code=500,
                detail=f"An error occurred while generating audio: {e}"
            ) from e

        if not all_pcm_data:
            logger.warning("No audio data was generated.")
            raise HTTPException(status_code=500, detail="No audio data was generated from the TTS provider.")

        concatenated_pcm = b''.join(all_pcm_data)
        logger.info(f"Concatenated {len(chunks)} chunks into a single PCM stream.")

        return _create_wav_file(concatenated_pcm)