import io
import wave
import logging
import asyncio
from typing import AsyncGenerator
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
import os
# --- Configure logging ---
logger = logging.getLogger(__name__)
# --- Helper Functions ---
def _create_wav_file(pcm_data: bytes) -> bytes:
"""
Wraps raw 16-bit PCM audio data in a WAV header.
"""
with io.BytesIO() as wav_buffer:
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(24000)
wav_file.writeframes(pcm_data)
return wav_buffer.getvalue()
# --- Define TTS Service Class ---
class TTSService:
"""
Service class for generating speech from text using a TTS provider.
This version is designed to handle both streaming and non-streaming
audio generation, splitting text into manageable chunks.
"""
MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600))
def __init__(self, tts_provider: TTSProvider):
self.tts_provider = tts_provider
async def _split_text_into_chunks(self, text: str) -> list[str]:
chunks = []
current_chunk = ""
separators = ['.', '?', '!', '\n']
sentences = []
for separator in separators:
text = text.replace(separator, f"{separator}|")
sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()]
for sentence in sentences_with_empty:
sentences.append(sentence)
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 > self.MAX_CHUNK_SIZE and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
else:
current_chunk += sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
logger.debug(f"Split text into {len(chunks)} chunks.")
return chunks
async def _generate_pcm_chunks(self, text: str) -> AsyncGenerator[bytes, None]:
chunks = await self._split_text_into_chunks(text)
for i, chunk in enumerate(chunks):
logger.info(f"Generating PCM for chunk {i+1}/{len(chunks)}: '{chunk[:30]}...'")
try:
pcm_data = await self.tts_provider.generate_speech(chunk)
yield pcm_data
except Exception as e:
logger.error(f"Error processing chunk {i+1}: {e}")
raise HTTPException(
status_code=500,
detail=f"Error generating speech for chunk {i+1}: {e}"
) from e
async def create_speech_stream(self, text: str, as_wav: bool = True) -> AsyncGenerator[bytes, None]:
async for pcm_data in self._generate_pcm_chunks(text):
if as_wav:
yield _create_wav_file(pcm_data)
else:
yield pcm_data
async def create_speech_non_stream(self, text: str) -> bytes:
chunks = await self._split_text_into_chunks(text)
semaphore = asyncio.Semaphore(3) # Limit concurrency to 3 requests
async def generate_with_limit(chunk):
retries = 3
delay = 1
async with semaphore:
for attempt in range(retries):
try:
return await self.tts_provider.generate_speech(chunk)
except HTTPException as e:
if e.status_code == 429:
logger.warning(f"429 Too Many Requests for chunk, retrying in {delay}s (attempt {attempt+1}/{retries})...")
await asyncio.sleep(delay)
delay *= 2 # exponential backoff
else:
raise
raise HTTPException(status_code=429, detail="Too many requests after retries.")
tasks = [generate_with_limit(chunk) for chunk in chunks]
try:
all_pcm_data = await asyncio.gather(*tasks)
logger.info(f"Successfully gathered audio data for all {len(chunks)} chunks.")
except Exception as e:
logger.error(f"An error occurred while gathering audio chunks: {e}")
raise HTTPException(
status_code=500,
detail=f"An error occurred while generating audio: {e}"
) from e
if not all_pcm_data:
logger.warning("No audio data was generated.")
raise HTTPException(status_code=500, detail="No audio data was generated from the TTS provider.")
concatenated_pcm = b''.join(all_pcm_data)
logger.info(f"Concatenated {len(chunks)} chunks into a single PCM stream.")
return _create_wav_file(concatenated_pcm)