diff --git a/.gitignore b/.gitignore
index 816e9e6..e70abdd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,6 @@
@eaDir/
**/.DS_Store
ai-hub/app/config.yaml
-**/config.yaml
\ No newline at end of file
+**/config.yaml
+data/audio/*
+data/*
\ No newline at end of file
diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py
index 732bd2f..0803877 100644
--- a/ai-hub/app/api/routes/sessions.py
+++ b/ai-hub/app/api/routes/sessions.py
@@ -1,10 +1,13 @@
-from fastapi import APIRouter, HTTPException, Depends
+from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Response
+from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
-from typing import AsyncGenerator, List
+from typing import AsyncGenerator, List, Optional
from app.db import models
from app.core.pipelines.validator import Validator
+import os
+import shutil
def create_sessions_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/sessions", tags=["Sessions"])
@@ -34,14 +37,14 @@
db: Session = Depends(get_db)
):
try:
- response_text, provider_used = await services.rag_service.chat_with_rag(
+ response_text, provider_used, message_id = await services.rag_service.chat_with_rag(
db=db,
session_id=session_id,
prompt=request.prompt,
provider_name=request.provider_name,
load_faiss_retriever=request.load_faiss_retriever
)
- return schemas.ChatResponse(answer=response_text, provider_used=provider_used)
+ return schemas.ChatResponse(answer=response_text, provider_used=provider_used, message_id=message_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")
@@ -52,7 +55,16 @@
if messages is None:
raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")
- return schemas.MessageHistoryResponse(session_id=session_id, messages=messages)
+ # Enhance messages with audio availability
+ enhanced_messages = []
+ for m in messages:
+ msg_dict = schemas.Message.model_validate(m).model_dump()
+ if m.audio_path and os.path.exists(m.audio_path):
+ msg_dict["has_audio"] = True
+ msg_dict["audio_url"] = f"/sessions/messages/{m.id}/audio"
+ enhanced_messages.append(msg_dict)
+
+ return schemas.MessageHistoryResponse(session_id=session_id, messages=enhanced_messages)
except HTTPException:
raise
except Exception as e:
@@ -112,6 +124,29 @@
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}")
+ @router.patch("/{session_id}", response_model=schemas.Session, summary="Update a Chat Session")
+ def update_session(session_id: int, session_update: schemas.SessionUpdate, db: Session = Depends(get_db)):
+ try:
+ session = db.query(models.Session).filter(
+ models.Session.id == session_id,
+ models.Session.is_archived == False
+ ).first()
+ if not session:
+ raise HTTPException(status_code=404, detail="Session not found.")
+
+ if session_update.title is not None:
+ session.title = session_update.title
+ if session_update.provider_name is not None:
+ session.provider_name = session_update.provider_name
+
+ db.commit()
+ db.refresh(session)
+ return session
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Failed to update session: {e}")
+
@router.delete("/{session_id}", summary="Delete a Chat Session")
def delete_session(session_id: int, db: Session = Depends(get_db)):
try:
@@ -139,5 +174,46 @@
return {"message": "All sessions deleted successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}")
+
+ @router.post("/messages/{message_id}/audio", summary="Upload audio for a specific message")
+ async def upload_message_audio(message_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)):
+ try:
+ message = db.query(models.Message).filter(models.Message.id == message_id).first()
+ if not message:
+ raise HTTPException(status_code=404, detail="Message not found.")
+ # Create data directory if not exists
+ audio_dir = "/app/data/audio"
+ os.makedirs(audio_dir, exist_ok=True)
+
+ # Save file
+ file_path = f"{audio_dir}/message_{message_id}.wav"
+ with open(file_path, "wb") as buffer:
+ shutil.copyfileobj(file.file, buffer)
+
+ # Update database
+ message.audio_path = file_path
+ db.commit()
+
+ return {"message": "Audio uploaded successfully.", "audio_path": file_path}
+ except Exception as e:
+ print(f"Error uploading audio: {e}")
+ raise HTTPException(status_code=500, detail=f"Failed to upload audio: {e}")
+
+ @router.get("/messages/{message_id}/audio", summary="Get audio for a specific message")
+ async def get_message_audio(message_id: int, db: Session = Depends(get_db)):
+ try:
+ message = db.query(models.Message).filter(models.Message.id == message_id).first()
+ if not message or not message.audio_path:
+ raise HTTPException(status_code=404, detail="Audio not found for this message.")
+
+ if not os.path.exists(message.audio_path):
+ raise HTTPException(status_code=404, detail="Audio file missing on disk.")
+
+ return FileResponse(message.audio_path, media_type="audio/wav")
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Failed to get audio: {e}")
+
return router
\ No newline at end of file
diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py
index 94334a8..eda6404 100644
--- a/ai-hub/app/api/routes/tts.py
+++ b/ai-hub/app/api/routes/tts.py
@@ -39,24 +39,39 @@
from app.config import settings
active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER
active_prefs = prefs.get("providers", {}).get(active_provider, {})
- if active_prefs:
- from app.core.providers.factory import get_tts_provider
- kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
- provider_override = get_tts_provider(
- provider_name=active_provider,
- api_key=active_prefs.get("api_key"),
- model_name=active_prefs.get("model", ""),
- voice_name=active_prefs.get("voice", ""),
- **kwargs
- )
+ from app.core.providers.factory import get_tts_provider
+ kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
+ provider_override = get_tts_provider(
+ provider_name=active_provider,
+ api_key=active_prefs.get("api_key"),
+ model_name=active_prefs.get("model", ""),
+ voice_name=active_prefs.get("voice", ""),
+ **kwargs
+ )
if stream:
# Pre-flight: generate first chunk before streaming to catch errors cleanly
# If we send StreamingResponse and then fail, the browser sees a network error
# instead of a meaningful error message.
- chunks = await services.tts_service._split_text_into_chunks(request.text)
+ # Split into first chunk for latency, then send entire rest for smoothness
+ all_text = request.text
+ separators = ['.', '?', '!', '\n', '。', '?', '!', ',']
+
+ # Find first separator within a reasonable limit
+ split_idx = -1
+ chars_to_scan = min(len(all_text) - 1, 400)
+ for i in range(chars_to_scan, 50, -1):
+ if all_text[i] in separators:
+ split_idx = i + 1
+ break
+
+ if split_idx != -1 and split_idx < len(all_text):
+ chunks = [all_text[:split_idx], all_text[split_idx:]]
+ else:
+ chunks = [all_text]
+
provider = provider_override or services.tts_service.default_tts_provider
- if not chunks:
+ if not chunks or not chunks[0].strip():
raise HTTPException(status_code=400, detail="No text to synthesize.")
# Test first chunk synchronously to validate the provider works
@@ -69,23 +84,49 @@
yield _create_wav_file(first_pcm)
else:
yield first_pcm
- # Then stream the remaining chunks
- for chunk in chunks[1:]:
- try:
- pcm = await provider.generate_speech(chunk)
- if pcm:
- if as_wav:
- from app.core.services.tts import _create_wav_file
- yield _create_wav_file(pcm)
- else:
- yield pcm
- except Exception as e:
- import logging
- logging.getLogger(__name__).error(f"TTS chunk error: {e}")
- break # Stop cleanly rather than crashing the stream
+ # Then stream the remaining chunks using parallel fetching but sequential yielding
+ import asyncio
+ semaphore = asyncio.Semaphore(3) # Limit concurrent external requests
+
+ async def fetch_chunk(text_chunk):
+ retries = 3
+ delay = 1.0
+ for attempt in range(retries):
+ try:
+ async with semaphore:
+ return await provider.generate_speech(text_chunk)
+ except Exception as e:
+ error_str = str(e)
+ if "No audio in response" in error_str or "finishReason" in error_str:
+ import logging
+ logging.getLogger(__name__).error(f"TTS chunk blocked by provider formatting/safety: {e}")
+ return None
+
+ if attempt == retries - 1:
+ import logging
+ logging.getLogger(__name__).error(f"TTS chunk failed after {retries} attempts: {e}")
+ return None
+ await asyncio.sleep(delay)
+ delay *= 2
+
+ # Start all tasks concurrently
+ tasks = [asyncio.create_task(fetch_chunk(chunk)) for chunk in chunks[1:]]
+
+ for task in tasks:
+ pcm = await task
+ if pcm:
+ if as_wav:
+ from app.core.services.tts import _create_wav_file
+ yield _create_wav_file(pcm)
+ else:
+ yield pcm
media_type = "audio/wav" if as_wav else "audio/pcm"
- return StreamingResponse(full_stream(), media_type=media_type)
+ return StreamingResponse(
+ full_stream(),
+ media_type=media_type,
+ headers={"X-TTS-Chunk-Count": str(len(chunks))}
+ )
else:
# The non-streaming function only returns WAV, so this part remains the same
diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py
index a4faaa7..47f34f4 100644
--- a/ai-hub/app/api/routes/user.py
+++ b/ai-hub/app/api/routes/user.py
@@ -248,6 +248,13 @@
# Load system defaults from DB if needed
system_prefs = services.user_service.get_system_settings(db)
+
+ system_statuses = system_prefs.get("statuses", {})
+ user_statuses = prefs_dict.get("statuses", {})
+
+ def is_provider_healthy(section: str, provider_id: str) -> bool:
+ status_key = f"{section}_{provider_id}"
+ return user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success"
user_providers = llm_prefs.get("providers", {})
if not user_providers:
@@ -264,7 +271,7 @@
llm_providers_effective = {}
for p, p_p in user_providers.items():
- if p_p:
+ if p_p and is_provider_healthy("llm", p):
llm_providers_effective[p] = {
"api_key": mask_key(p_p.get("api_key")),
"model": p_p.get("model")
@@ -286,7 +293,7 @@
tts_providers_effective = {}
for p, p_p in user_tts_providers.items():
- if p_p:
+ if p_p and is_provider_healthy("tts", p):
tts_providers_effective[p] = {
"api_key": mask_key(p_p.get("api_key")),
"model": p_p.get("model"),
@@ -308,7 +315,7 @@
stt_providers_effective = {}
for p, p_p in user_stt_providers.items():
- if p_p:
+ if p_p and is_provider_healthy("stt", p):
stt_providers_effective[p] = {
"api_key": mask_key(p_p.get("api_key")),
"model": p_p.get("model")
@@ -470,8 +477,8 @@
from app.config import settings as global_settings
# Sync LLM
- if prefs.llm and prefs.llm.get("providers"):
- global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {}))
+ if prefs.llm and "providers" in prefs.llm:
+ global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
# Sync TTS
if prefs.tts and prefs.tts.get("active_provider"):
@@ -699,7 +706,7 @@
voice_name=req.voice or "",
**kwargs
)
- await provider.generate_speech("Test")
+ await provider.generate_speech("Hello there. We are testing this thing.")
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
logger.error(f"TTS verification failed for {req.provider_name}: {e}")
diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py
index 357ed08..3719e34 100644
--- a/ai-hub/app/api/schemas.py
+++ b/ai-hub/app/api/schemas.py
@@ -88,6 +88,7 @@
"""Defines the shape of a successful response from the /chat endpoint."""
answer: str
provider_used: str
+ message_id: Optional[int] = None
# --- Document Schemas ---
class DocumentCreate(BaseModel):
@@ -131,6 +132,10 @@
provider_name: str = "deepseek"
feature_name: Optional[str] = "default"
+class SessionUpdate(BaseModel):
+ title: Optional[str] = None
+ provider_name: Optional[str] = None
+
class Session(BaseModel):
"""Defines the shape of a session object returned by the API."""
id: int
@@ -143,12 +148,17 @@
class Message(BaseModel):
"""Defines the shape of a single message within a session's history."""
+ id: int
# The sender can only be one of two roles.
sender: Literal["user", "assistant"]
# The text content of the message.
content: str
# The timestamp for when the message was created.
created_at: datetime
+ # URL to the saved audio file
+ audio_url: Optional[str] = None
+ # Whether audio exists for this message
+ has_audio: bool = False
# Enables creating this schema from a SQLAlchemy database object.
model_config = ConfigDict(from_attributes=True)
diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py
index 230d127..a6ec166 100644
--- a/ai-hub/app/app.py
+++ b/ai-hub/app/app.py
@@ -14,6 +14,7 @@
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.db.session import create_db_and_tables
+from app.db.migrate import run_migrations
from app.api.routes.api import create_api_router
from app.utils import print_config
from app.api.dependencies import ServiceContainer, get_db
@@ -39,6 +40,7 @@
print("Application startup...")
print_config(settings)
create_db_and_tables()
+ run_migrations()
yield
print("Application shutdown...")
# Access the vector_store from the application state to save it
diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py
index d54d7cf..bc04c20 100644
--- a/ai-hub/app/config.py
+++ b/ai-hub/app/config.py
@@ -154,7 +154,7 @@
self.DEEPSEEK_MODEL_NAME = self.LLM_PROVIDERS.get("deepseek", {}).get("model") or \
get_from_yaml(["llm_providers", "deepseek_model_name"]) or "deepseek-chat"
self.GEMINI_MODEL_NAME = self.LLM_PROVIDERS.get("gemini", {}).get("model") or \
- get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-1.5-flash-latest"
+ get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-2.5-flash"
# 2. Resolve Vector / Embedding
self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \
diff --git a/ai-hub/app/core/providers/tts/gcloud_tts.py b/ai-hub/app/core/providers/tts/gcloud_tts.py
index a87f9fc..0c034b4 100644
--- a/ai-hub/app/core/providers/tts/gcloud_tts.py
+++ b/ai-hub/app/core/providers/tts/gcloud_tts.py
@@ -78,7 +78,8 @@
"name": voice_to_use
},
"audioConfig": {
- "audioEncoding": "LINEAR16"
+ "audioEncoding": "LINEAR16",
+ "sampleRateHertz": 24000
}
}
diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py
index a59f644..c5353d9 100644
--- a/ai-hub/app/core/providers/tts/gemini.py
+++ b/ai-hub/app/core/providers/tts/gemini.py
@@ -3,6 +3,8 @@
import httpx
import base64
import logging
+import asyncio
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, retry_if_exception
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
@@ -11,6 +13,17 @@
logger = logging.getLogger(__name__)
+def is_retryable_exception(exception):
+ """Check if the exception is one we should retry on."""
+ if isinstance(exception, httpx.TimeoutException):
+ return True
+ if isinstance(exception, httpx.NetworkError):
+ return True
+ if isinstance(exception, HTTPException):
+ # Retry on 429 (Too Many Requests) or 5xx (Server Errors)
+ return exception.status_code == 429 or 500 <= exception.status_code < 600
+ return False
+
class GeminiTTSProvider(TTSProvider):
"""TTS provider using Gemini's audio responseModalities via Google AI Studio."""
@@ -29,7 +42,7 @@
# Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id
model_id = raw_model.split("/")[-1]
# Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts"
- if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts"):
+ if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts", "gemini-2.5-flash"):
model_id = "gemini-2.5-flash-preview-tts"
logger.info(f"Normalised model name to: {model_id}")
@@ -57,6 +70,13 @@
logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}")
logger.debug(f" endpoint: {self.api_url[:80]}...")
+ @retry(
+ retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)) | retry_if_exception(is_retryable_exception),
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=1, max=10),
+ reraise=True,
+ before_sleep=lambda retry_state: logger.warning(f"Retrying Gemini TTS request (attempt {retry_state.attempt_number})...")
+ )
async def generate_speech(self, text: str) -> bytes:
logger.debug(f"TTS generate_speech: '{text[:60]}...'")
@@ -64,9 +84,6 @@
# The dedicated TTS models require a system instruction to produce only audio
json_data = {
- "system_instruction": {
- "parts": [{"text": "You are a text-to-speech system. Convert the user text to speech audio only. Do not generate any text response."}]
- },
"contents": [{"role": "user", "parts": [{"text": text}]}],
"generationConfig": {
"responseModalities": ["AUDIO"],
@@ -86,7 +103,7 @@
logger.debug(f"Calling: {self.api_url}")
try:
- async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
response = await client.post(self.api_url, headers=headers, json=json_data)
logger.debug(f"Response status: {response.status_code}")
@@ -99,7 +116,15 @@
msg = err.get("message", body[:200])
except Exception:
msg = body[:200]
- raise HTTPException(status_code=response.status_code, detail=f"Gemini TTS error: {msg}")
+
+ # Check if we should retry (429 or 5xx)
+ status_code = response.status_code
+ if status_code == 429 or 500 <= status_code < 600:
+ # tenacity will catch this if we configure it to retry on HTTPException
+ raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")
+ else:
+ # Non-retryable error
+ raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")
resp_data = response.json()
audio_fragments = []
@@ -124,11 +149,13 @@
logger.debug(f"TTS returned {len(result)} PCM bytes")
return result
- except HTTPException:
+ except (httpx.TimeoutException, httpx.NetworkError) as e:
+ logger.error(f"Gemini TTS request ({type(e).__name__}) after 60s")
+ # tenacity will catch this and retry
raise
- except httpx.TimeoutException:
- logger.error("Gemini TTS request timed out after 30s")
- raise HTTPException(status_code=504, detail="Gemini TTS request timed out.")
+ except HTTPException:
+ # tenacity might catch this if it's 429 or 5xx
+ raise
except Exception as e:
logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")
\ No newline at end of file
diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py
index 9d88bd1..3eb1987 100644
--- a/ai-hub/app/core/services/rag.py
+++ b/ai-hub/app/core/services/rag.py
@@ -55,6 +55,7 @@
# Fetch user preferences for overrides
api_key_override = None
model_name_override = ""
+ llm_prefs = {}
user = session.user
if user and user.preferences:
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {})
@@ -93,7 +94,7 @@
db.commit()
db.refresh(assistant_message)
- return answer_text, provider_name
+ return answer_text, provider_name, assistant_message.id
def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
"""
diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py
index 015b1c4..d9b5929 100644
--- a/ai-hub/app/core/services/tts.py
+++ b/ai-hub/app/core/services/tts.py
@@ -31,7 +31,7 @@
audio generation, splitting text into manageable chunks.
"""
- MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600))
+ MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 500))
def __init__(self, tts_provider: TTSProvider):
self.default_tts_provider = tts_provider
diff --git a/ai-hub/app/db/migrate.py b/ai-hub/app/db/migrate.py
new file mode 100644
index 0000000..625b5c3
--- /dev/null
+++ b/ai-hub/app/db/migrate.py
@@ -0,0 +1,45 @@
+import logging
+from app.db.session import engine
+from sqlalchemy import text, inspect
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def run_migrations():
+ """
+ Checks for missing columns and adds them if necessary.
+ This is a simple idempotent migration system to handle schema updates.
+ """
+ logger.info("Starting database migrations...")
+
+ with engine.connect() as conn:
+ inspector = inspect(engine)
+ if not inspector.has_table("messages"):
+ logger.info("Table 'messages' does not exist, skipping migrations (will be handled by Base.metadata.create_all).")
+ return
+
+ columns = [c["name"] for c in inspector.get_columns("messages")]
+
+ # List of (column_name, column_type) to ensure existence
+ required_columns = [
+ ("audio_path", "TEXT"),
+ ("model_response_time", "INTEGER"),
+ ("token_count", "INTEGER")
+ ]
+
+ for col_name, col_type in required_columns:
+ if col_name not in columns:
+ logger.info(f"Adding column '{col_name}' to 'messages' table...")
+ try:
+ conn.execute(text(f"ALTER TABLE messages ADD COLUMN {col_name} {col_type}"))
+ conn.commit()
+ logger.info(f"Successfully added '{col_name}'.")
+ except Exception as e:
+ logger.error(f"Failed to add column '{col_name}': {e}")
+ else:
+ logger.info(f"Column '{col_name}' already exists in 'messages'.")
+
+ logger.info("Database migrations complete.")
+
+if __name__ == "__main__":
+ run_migrations()
diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py
index f7bee7f..f94df2b 100644
--- a/ai-hub/app/db/models.py
+++ b/ai-hub/app/db/models.py
@@ -136,6 +136,8 @@
# A JSON field to store unstructured metadata about the message, such as tool calls.
# This column has been renamed from 'metadata' to avoid a conflict.
message_metadata = Column(JSON, nullable=True)
+ # Path to the generated audio file for this message, if any.
+ audio_path = Column(String, nullable=True)
# Relationship back to the parent Session.
diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt
index 5b0ee90..37a2533 100644
--- a/ai-hub/requirements.txt
+++ b/ai-hub/requirements.txt
@@ -17,4 +17,5 @@
dspy
aioresponses
python-multipart
-PyJWT
\ No newline at end of file
+PyJWT
+tenacity
\ No newline at end of file
diff --git a/remote_deploy.sh b/remote_deploy.sh
index 2784a66..78b2a71 100755
--- a/remote_deploy.sh
+++ b/remote_deploy.sh
@@ -1,12 +1,18 @@
#!/bin/bash
# Description: Automates deployment from the local environment to the production host 192.168.68.113
-HOST="192.168.68.113"
-USER="axieyangb"
-PASS="a6163484a"
+HOST="${REMOTE_HOST:-192.168.68.113}"
+USER="${REMOTE_USER:-axieyangb}"
+PASS="${REMOTE_PASS:-MySecurePassword}"
REMOTE_TMP="/tmp/cortex-hub/"
REMOTE_PROJ="/home/coder/project/cortex-hub"
+if [ "$PASS" = "MySecurePassword" ]; then
+ echo "Error: Please set the REMOTE_PASS environment variable before deploying."
+ echo "Example: REMOTE_PASS='your_password' ./remote_deploy.sh"
+ exit 1
+fi
+
echo "Checking if sshpass is installed..."
if ! command -v sshpass &> /dev/null; then
echo "sshpass could not be found, installing..."
diff --git a/ui/client-app/src/components/ChatArea.js b/ui/client-app/src/components/ChatArea.js
index 1b5b4c6..e5172cc 100644
--- a/ui/client-app/src/components/ChatArea.js
+++ b/ui/client-app/src/components/ChatArea.js
@@ -2,7 +2,7 @@
import ChatWindow from "./ChatWindow";
import './ChatArea.css';
-const ChatArea = ({ chatHistory, onSendMessage, isProcessing }) => {
+const ChatArea = ({ chatHistory, onSendMessage, isProcessing, featureName = "default" }) => {
const [inputValue, setInputValue] = useState("");
const inputRef = useRef(null);
const chatScrollRef = useRef(null);
@@ -33,7 +33,7 @@
{/* Scrollable ChatWindow */}
-
+
{/* Sticky Input */}
@@ -51,11 +51,10 @@
diff --git a/ui/client-app/src/components/ChatWindow.js b/ui/client-app/src/components/ChatWindow.js
index 1f53477..c2f2401 100644
--- a/ui/client-app/src/components/ChatWindow.js
+++ b/ui/client-app/src/components/ChatWindow.js
@@ -4,12 +4,78 @@
import FileListComponent from "./FileList";
import DiffViewer from "./DiffViewer";
import CodeChangePlan from "./CodeChangePlan";
-import { FaRegCopy, FaCopy } from 'react-icons/fa'; // Import the copy icon
+import { FaRegCopy, FaCopy, FaVolumeUp, FaPlay, FaPause, FaDownload, FaSyncAlt } from 'react-icons/fa'; // Import the icons
// Individual message component
-const ChatMessage = ({ message }) => {
+const ChatMessage = ({ message, index, onSynthesize, featureName = "default", activePlayingId, onPlayStateChange }) => {
const [selectedFile, setSelectedFile] = useState(null);
const [isReasoningExpanded, setIsReasoningExpanded] = useState(false);
+ const [audioUrl, setAudioUrl] = useState(null);
+ const [isPlaying, setIsPlaying] = useState(false);
+ const audioRef = useRef(null);
+ const isVoiceChat = featureName === "voice_chat";
+
+ // Unique ID for this message's audio
+ const currentMsgId = message.id || `msg-${index}`;
+
+ useEffect(() => {
+ if (message.audioBlob) {
+ const url = URL.createObjectURL(message.audioBlob);
+ setAudioUrl(url);
+ return () => URL.revokeObjectURL(url);
+ }
+ }, [message.audioBlob]);
+
+ // Handle exclusive playback: stop if someone else starts playing
+ useEffect(() => {
+ if (activePlayingId && activePlayingId !== currentMsgId && isPlaying) {
+ if (audioRef.current) {
+ audioRef.current.pause();
+ setIsPlaying(false);
+ }
+ }
+ }, [activePlayingId, currentMsgId, isPlaying]);
+
+ // Stop audio on unmount
+ useEffect(() => {
+ return () => {
+ if (audioRef.current) {
+ audioRef.current.pause();
+ audioRef.current.src = ""; // Clear source to ensure it stops immediately
+ }
+ };
+ }, []);
+
+ const handlePlayPause = () => {
+ if (audioRef.current) {
+ if (isPlaying) {
+ audioRef.current.pause();
+ onPlayStateChange(null);
+ } else {
+ audioRef.current.play();
+ onPlayStateChange(currentMsgId);
+ }
+ setIsPlaying(!isPlaying);
+ }
+ };
+
+ const handleDownload = () => {
+ if (audioUrl) {
+ const a = document.createElement("a");
+ a.href = audioUrl;
+ a.download = `voice_chat_${Date.now()}.wav`;
+ a.click();
+ }
+ };
+
+ const handleReplay = () => {
+ if (audioRef.current) {
+ audioRef.current.currentTime = 0;
+ audioRef.current.play();
+ setIsPlaying(true);
+ onPlayStateChange(currentMsgId);
+ }
+ };
const toggleReasoning = () => {
setIsReasoningExpanded(!isReasoningExpanded);
@@ -33,8 +99,15 @@
}
}
};
- const assistantMessageClasses = `p-4 rounded-lg shadow-md max-w-3xl bg-gray-200 dark:bg-gray-700 text-gray-900 dark:text-gray-100 assistant-message mr-auto`;
- const userMessageClasses = `max-w-md p-4 rounded-lg shadow-md bg-indigo-500 text-white ml-auto`;
+ const assistantMessageClasses = `p-3 pb-2 rounded-2xl shadow-sm max-w-[85%] bg-gray-200 dark:bg-gray-800 text-gray-900 dark:text-gray-100 assistant-message mr-auto border border-gray-300 dark:border-gray-700/50`;
+ const userMessageClasses = `max-w-[80%] p-3 pb-2 rounded-2xl shadow-sm bg-indigo-600 text-white ml-auto`;
+
+ const formatTime = (iso) => {
+ if (!iso) return '';
+ try {
+ return new Date(iso).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
+ } catch { return ''; }
+ };
return (
@@ -55,11 +128,73 @@
)}
{message.text}
- {message.isPureAnswer && (
-
- {/* Horizontal line */}
-
-
+ {(message.isPureAnswer || !message.isUser) && (
+
+ {/* Horizontal line - only for voice chat to separate from voice controls */}
+ {isVoiceChat && (
+
+ )}
+
+
+ {/* Audio Controls - strictly limited to voice chat feature */}
+ {isVoiceChat && (message.audioBlob ? (
+
+
+ ) : (!message.isUser && (message.isFromHistory || (message.audioProgress && message.audioProgress > 0)) && (
+
+ {message.isFromHistory && !message.audioProgress ? (
+
+ ) : (message.audioProgress && (
+ <>
+
+ Audio generating {message.audioProgress || 0}%...
+ >
+ ))}
+
+ )))}
+
+ {/* Timestamp */}
+
+ {formatTime(message.timestamp)}
+
+
{/* Copy Icon - positioned above the bottom line */}
@@ -88,8 +222,16 @@
};
// Main ChatWindow component with dynamic height calculation
-const ChatWindow = ({ chatHistory, maxHeight }) => {
+const ChatWindow = ({ chatHistory, maxHeight, onSynthesize, featureName, isStreamingPlaying, onAudioPlay }) => {
const containerRef = useRef(null);
+ const [activePlayingId, setActivePlayingId] = useState(null);
+
+ useEffect(() => {
+ // If a new stream starts playing, stop any ongoing historical audio
+ if (isStreamingPlaying) {
+ setActivePlayingId(null);
+ }
+ }, [isStreamingPlaying]);
useEffect(() => {
if (containerRef.current) {
@@ -108,7 +250,19 @@
key={index}
className={`flex ${message.isUser ? "justify-end" : "justify-start"} w-full`}
>
-
+
{
+ setActivePlayingId(id);
+ if (id && onAudioPlay) {
+ onAudioPlay(); // Notify parent to stop streaming (to prevent overlap)
+ }
+ }}
+ />
))}
diff --git a/ui/client-app/src/components/SessionSidebar.js b/ui/client-app/src/components/SessionSidebar.js
index 911210c..93f419a 100644
--- a/ui/client-app/src/components/SessionSidebar.js
+++ b/ui/client-app/src/components/SessionSidebar.js
@@ -7,7 +7,7 @@
} from '../services/apiService';
import './SessionSidebar.css';
-const SessionSidebar = ({ featureName, currentSessionId, onSwitchSession, onNewSession }) => {
+const SessionSidebar = ({ featureName, currentSessionId, onSwitchSession, onNewSession, refreshTick }) => {
const [isOpen, setIsOpen] = useState(false);
const [sessions, setSessions] = useState([]);
const [tokenHoverData, setTokenHoverData] = useState({});
@@ -15,7 +15,7 @@
useEffect(() => {
if (isOpen) fetchSessions();
- }, [isOpen, featureName, currentSessionId]);
+ }, [isOpen, featureName, currentSessionId, refreshTick]);
const fetchSessions = async () => {
setIsLoading(true);
@@ -100,16 +100,19 @@
sessions.map(s => {
const isActive = Number(currentSessionId) === s.id;
const td = tokenHoverData[s.id];
- const tooltip = td
- ? `Context: ${td.token_count.toLocaleString()} / ${td.token_limit.toLocaleString()} tokens (${td.percentage}%)`
- : 'Hover to load token usage';
-
// Derive a display title: prefer session.title, fall back gracefully
const displayTitle = s.title &&
s.title !== 'New Chat Session'
? s.title
: `Session #${s.id}`;
+ const configInfo = s.provider_name ? `AI Model: ${s.provider_name}` : 'Default AI Configuration';
+ const usageInfo = td
+ ? `Context: ${td.token_count.toLocaleString()} / ${td.token_limit.toLocaleString()} tokens (${td.percentage}%)`
+ : 'Hover to load token usage stats';
+
+ const tooltip = `${displayTitle}\n---\n${configInfo}\n${usageInfo}`;
+
return (
{
- const [chatHistory, setChatHistory] = useState([
- {
- text: "Hello! I'm an AI assistant. How can I help you today?",
- isUser: false,
- },
- ]);
+ const [chatHistory, setChatHistory] = useState([]);
const [status, setStatus] = useState("Click the microphone to start recording.");
const [isBusy, setIsBusy] = useState(false);
const [isRecording, setIsRecording] = useState(false);
@@ -62,6 +60,17 @@
const lastRequestTimeRef = useRef(0);
const streamRef = useRef(null);
+ const [isStreamingPlaying, setIsStreamingPlaying] = useState(false);
+
+ /**
+ * Stops any currently playing streaming audio.
+ */
+ const stopStreamingPlayback = useCallback(() => {
+ stopAllPlayingAudio(playingSourcesRef, audioContextRef, playbackTimeRef);
+ setIsStreamingPlaying(false);
+ setIsBusy(false);
+ }, []);
+
const fetchTokenUsage = useCallback(async () => {
if (!sessionIdRef.current) return;
try {
@@ -98,10 +107,25 @@
try {
const messagesData = await getSessionMessages(currentSessionId);
if (messagesData && messagesData.messages && messagesData.messages.length > 0) {
- const formattedHistory = messagesData.messages.map((msg) => ({
- isUser: msg.sender === "user",
- text: msg.content,
- }));
+ const formattedHistoryPromises = messagesData.messages.map(async (msg) => {
+ let audioBlob = null;
+ if (msg.has_audio) {
+ try {
+ audioBlob = await fetchMessageAudio(msg.id);
+ } catch (e) {
+ console.warn(`Failed to fetch audio for message ${msg.id}`, e);
+ }
+ }
+ return {
+ id: msg.id,
+ isUser: msg.sender === "user",
+ text: msg.content,
+ timestamp: msg.created_at,
+ isFromHistory: true,
+ audioBlob: audioBlob
+ };
+ });
+ const formattedHistory = await Promise.all(formattedHistoryPromises);
setChatHistory(formattedHistory);
}
} catch (historyErr) {
@@ -175,8 +199,9 @@
return () => {
// Pass the refs to the utility function here
stopAllMediaStreams(vadStreamRef, mediaRecorderRef, scriptProcessorRef, streamRef);
+ stopStreamingPlayback();
};
- }, []);
+ }, [stopStreamingPlayback]);
// New useEffect hook to automatically scroll to the bottom of the chat history
// The fix: `chatContainerRef` is now included in the dependency array.
@@ -186,20 +211,32 @@
}
}, [chatHistory, chatContainerRef]);
- const addMessage = (text, isUser) => {
- setChatHistory((prevHistory) => [...prevHistory, { text, isUser }]);
+ const addMessage = (text, isUser, id = null) => {
+ setChatHistory((prevHistory) => [...prevHistory, {
+ text,
+ isUser,
+ id,
+ timestamp: new Date().toISOString()
+ }]);
};
/**
* Plays a stream of audio chunks using the Web Audio API by fetching them from the API.
* This is the orchestrator that uses the stateless streamSpeech API function.
* @param {string} text - The text to be synthesized by the TTS service.
+ * @param {number} messageId - The ID of the message to associated the audio with.
*/
- const playStreamingAudio = async (text) => {
+ const playStreamingAudio = async (text, messageId = null) => {
setIsBusy(true);
+ setIsStreamingPlaying(true);
setStatus("Streaming audio...");
- // Pass the refs to the utility function
- stopAllPlayingAudio(playingSourcesRef, audioContextRef, playbackTimeRef);
+
+ // Stop any existing playback
+ stopStreamingPlayback();
+ setIsBusy(true); // stopStreamingPlayback sets it to false, we want it true during this process
+
+ // Track chunks to store in history
+ const accumulatedChunks = [];
try {
if (!audioContextRef.current) {
@@ -209,12 +246,27 @@
const audioContext = audioContextRef.current;
- const onChunkReceived = (rawFloat32Data) => {
- // This is the callback that receives processed audio data from apiService.
- // It's responsible for using the Web Audio API to play the sound.
+ const onChunkReceived = (rawFloat32Data, totalChunks, currentChunkIndex) => {
+ // Collect for storage
+ accumulatedChunks.push(new Float32Array(rawFloat32Data));
+
+ // Update UI progress
+ if (totalChunks > 0) {
+ const progress = Math.round((currentChunkIndex / totalChunks) * 100);
+ setChatHistory(prev => {
+ const next = [...prev];
+ for (let i = next.length - 1; i >= 0; i--) {
+ if (!next[i].isUser && !next[i].audioBlob) {
+ next[i].audioProgress = progress;
+ break;
+ }
+ }
+ return next;
+ });
+ }
const float32Resampled = resampleBuffer(
rawFloat32Data,
- 24000, // The model's sample rate is hardcoded to 24000
+ 24000,
audioContext.sampleRate
);
const audioBuffer = audioContext.createBuffer(
@@ -224,6 +276,16 @@
);
audioBuffer.copyToChannel(float32Resampled, 0);
+ // Apply a very short fade-in and fade-out (2ms) to eliminate "clicks" at segment boundaries
+ const fadeSamps = Math.floor(audioContext.sampleRate * 0.002);
+ const chanData = audioBuffer.getChannelData(0);
+ if (chanData.length > fadeSamps * 2) {
+ for (let i = 0; i < fadeSamps; i++) {
+ chanData[i] *= (i / fadeSamps);
+ chanData[chanData.length - 1 - i] *= (i / fadeSamps);
+ }
+ }
+
const source = audioContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(audioContext.destination);
@@ -242,16 +304,53 @@
};
};
- const onStreamDone = () => {
- // This callback is triggered when the stream finishes.
+ const onStreamDone = async () => {
console.log("TTS Stream complete.");
+ setIsStreamingPlaying(false);
+ if (accumulatedChunks.length > 0) {
+ // Concatenate all chunks and save the blob
+ const totalLen = accumulatedChunks.reduce((acc, c) => acc + c.length, 0);
+ const result = new Float32Array(totalLen);
+ let offset = 0;
+ for (const c of accumulatedChunks) {
+ result.set(c, offset);
+ offset += c.length;
+ }
+ // resample to standard 44.1k for download/blob stability
+ const finalPcm = resampleBuffer(result, 24000, 44100);
+ const wavBlob = encodeWAV(finalPcm, 44100);
+
+ // Upload to persistent storage if messageId is available
+ if (messageId) {
+ try {
+ console.log(`Uploading audio for message ${messageId}...`);
+ await uploadMessageAudio(messageId, wavBlob);
+ } catch (uploadErr) {
+ console.warn("Failed to upload persistent audio", uploadErr);
+ }
+ }
+
+ // Post-update: find the last AI message and attach this blob
+ setChatHistory(prev => {
+ const next = [...prev];
+ // Find the latest assistant message that matches this text (or just the latest)
+ for (let i = next.length - 1; i >= 0; i--) {
+ if (!next[i].isUser && !next[i].audioBlob) {
+ next[i].audioBlob = wavBlob;
+ if (messageId) next[i].id = messageId;
+ break;
+ }
+ }
+ return next;
+ });
+ }
};
- // Call the stateless API function, passing the UI-related callbacks
await streamSpeech(text, onChunkReceived, onStreamDone, localActivePrefs.tts);
} catch (err) {
console.error("Failed to stream speech:", err);
+ setIsStreamingPlaying(false);
setStatus(`Error: Failed to stream speech. ${err.message}`);
setErrorMessage(`Failed to stream speech: ${err.message}`);
setShowErrorModal(true);
@@ -268,6 +367,80 @@
}
};
+ /**
+ * Specifically for manual replay/synthesis of any message (including history)
+ */
+ const synthesizeMessageAudio = async (index, text) => {
+ if (isBusy) return;
+ const accumulatedChunks = [];
+
+ if (chatHistory[index]?.audioBlob) return;
+
+ setIsBusy(true);
+ try {
+ if (!audioContextRef.current) {
+ audioContextRef.current = new (window.AudioContext || window.webkitAudioContext)();
+ playbackTimeRef.current = audioContextRef.current.currentTime;
+ }
+ const audioContext = audioContextRef.current;
+
+ const onData = (rawFloat32Data, total, current) => {
+ accumulatedChunks.push(new Float32Array(rawFloat32Data));
+ if (total > 0) {
+ const progress = Math.round((current / total) * 100);
+ setChatHistory(prev => {
+ const next = [...prev];
+ if (next[index]) next[index].audioProgress = progress;
+ return next;
+ });
+ }
+
+ const float32Resampled = resampleBuffer(rawFloat32Data, 24000, audioContext.sampleRate);
+ const audioBuffer = audioContext.createBuffer(1, float32Resampled.length, audioContext.sampleRate);
+ audioBuffer.copyToChannel(float32Resampled, 0);
+ const source = audioContext.createBufferSource();
+ source.buffer = audioBuffer;
+ source.connect(audioContext.destination);
+ const startTime = Math.max(playbackTimeRef.current, audioContext.currentTime);
+ source.start(startTime);
+ playbackTimeRef.current = startTime + audioBuffer.duration;
+ playingSourcesRef.current.push(source);
+ };
+
+ const onDone = async () => {
+ if (accumulatedChunks.length > 0) {
+ const totalLen = accumulatedChunks.reduce((acc, c) => acc + c.length, 0);
+ const result = new Float32Array(totalLen);
+ let offset = 0;
+ for (const c of accumulatedChunks) {
+ result.set(c, offset);
+ offset += c.length;
+ }
+ const finalPcm = resampleBuffer(result, 24000, 44100);
+ const wavBlob = encodeWAV(finalPcm, 44100);
+
+ const messageId = chatHistory[index]?.id;
+ if (messageId) {
+ try {
+ await uploadMessageAudio(messageId, wavBlob);
+ } catch (e) { console.warn("Upload failed during manual synthesis", e); }
+ }
+
+ setChatHistory(prev => {
+ const next = [...prev];
+ if (next[index]) next[index].audioBlob = wavBlob;
+ return next;
+ });
+ }
+ };
+
+ await streamSpeech(text, onData, onDone, localActivePrefs.tts);
+ } catch (err) {
+ console.error("Manual synthesis failed", err);
+ } finally {
+ setIsBusy(false);
+ }
+ };
const processConversation = async (audioBlob) => {
console.log("Processing conversation...");
@@ -291,11 +464,11 @@
addMessage(userText, true);
setStatus("AI is thinking...");
- const aiText = await chatWithAI(sessionId, userText, localActivePrefs.llm || "gemini");
- addMessage(aiText, false);
+ const aiResponse = await chatWithAI(sessionId, userText, localActivePrefs.llm || "gemini");
+ addMessage(aiResponse.answer, false, aiResponse.message_id);
fetchTokenUsage();
- await playStreamingAudio(aiText);
+ await playStreamingAudio(aiResponse.answer, aiResponse.message_id);
} catch (error) {
console.error("Conversation processing failed:", error);
setStatus(`Error: ${error.message}`);
@@ -391,7 +564,7 @@
const timeSinceLastRequest = Date.now() - lastRequestTimeRef.current;
const isCooldownPassed = timeSinceLastRequest > AUTO_MODE_COOLDOWN_MS;
- if (isVoiceDetected) {
+ if (isVoiceDetected && !isBusy) {
if (silenceTimeoutRef.current) {
clearTimeout(silenceTimeoutRef.current);
silenceTimeoutRef.current = null;
@@ -484,12 +657,7 @@
};
const handleNewSession = async () => {
- setChatHistory([
- {
- text: "Hello! I'm an AI assistant. How can I help you today?",
- isUser: false,
- },
- ]);
+ setChatHistory([]);
localStorage.removeItem("sessionId_voice_chat");
setIsBusy(true);
@@ -531,10 +699,25 @@
const messagesData = await getSessionMessages(targetSessionId);
if (messagesData && messagesData.messages) {
- const mappedHistory = messagesData.messages.map(msg => ({
- text: msg.content,
- isUser: msg.sender === 'user'
- }));
+ const mappedHistoryPromises = messagesData.messages.map(async (msg) => {
+ let audioBlob = null;
+ if (msg.has_audio) {
+ try {
+ audioBlob = await fetchMessageAudio(msg.id);
+ } catch (e) {
+ console.warn(`Failed to fetch audio for message ${msg.id} during switch`, e);
+ }
+ }
+ return {
+ id: msg.id,
+ text: msg.content,
+ isUser: msg.sender === 'user',
+ timestamp: msg.created_at,
+ isFromHistory: true,
+ audioBlob: audioBlob
+ };
+ });
+ const mappedHistory = await Promise.all(mappedHistoryPromises);
setChatHistory(mappedHistory);
}
fetchTokenUsage();
@@ -567,6 +750,9 @@
handleSwitchSession,
setShowErrorModal,
setErrorMessage,
+ synthesizeMessageAudio,
+ isStreamingPlaying,
+ stopStreamingPlayback
};
};
diff --git a/ui/client-app/src/pages/CodingAssistantPage.js b/ui/client-app/src/pages/CodingAssistantPage.js
index ac1d4fe..241920e 100644
--- a/ui/client-app/src/pages/CodingAssistantPage.js
+++ b/ui/client-app/src/pages/CodingAssistantPage.js
@@ -8,6 +8,7 @@
// A custom hook to manage WebSocket connection and state
import useCodeAssistant from "../hooks/useCodeAssistant";
+import { updateSession } from "../services/apiService";
const CodeAssistantPage = () => {
// Reference for the main container to manage scrolling
@@ -40,6 +41,19 @@
const [isPanelExpanded, setIsPanelExpanded] = useState(false);
const [showConfigModal, setShowConfigModal] = useState(false);
+ const [sidebarRefreshTick, setSidebarRefreshTick] = useState(0);
+
+ const handleSaveQuickConfig = async () => {
+ try {
+ if (sessionId && localActiveLLM) {
+ await updateSession(sessionId, { provider_name: localActiveLLM });
+ setSidebarRefreshTick(t => t + 1);
+ }
+ setShowConfigModal(false);
+ } catch (e) {
+ console.error("Failed to update session configs:", e);
+ }
+ };
// Scroll to the bottom of the page when new content is added
useEffect(() => {
@@ -55,6 +69,7 @@
currentSessionId={sessionId}
onSwitchSession={handleSwitchSession}
onNewSession={() => handleSendChat("/new")}
+ refreshTick={sidebarRefreshTick}
/>
{/* Main content area */}
@@ -127,7 +142,7 @@
{/* Note: ChatArea component needs to be implemented with a
@@ -215,7 +230,7 @@
Cancel