diff --git a/.gitignore b/.gitignore index 816e9e6..e70abdd 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ @eaDir/ **/.DS_Store ai-hub/app/config.yaml -**/config.yaml \ No newline at end of file +**/config.yaml +data/audio/* +data/* \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py index 732bd2f..0803877 100644 --- a/ai-hub/app/api/routes/sessions.py +++ b/ai-hub/app/api/routes/sessions.py @@ -1,10 +1,13 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Response +from fastapi.responses import FileResponse from sqlalchemy.orm import Session from app.api.dependencies import ServiceContainer, get_db from app.api import schemas -from typing import AsyncGenerator, List +from typing import AsyncGenerator, List, Optional from app.db import models from app.core.pipelines.validator import Validator +import os +import shutil def create_sessions_router(services: ServiceContainer) -> APIRouter: router = APIRouter(prefix="/sessions", tags=["Sessions"]) @@ -34,14 +37,14 @@ db: Session = Depends(get_db) ): try: - response_text, provider_used = await services.rag_service.chat_with_rag( + response_text, provider_used, message_id = await services.rag_service.chat_with_rag( db=db, session_id=session_id, prompt=request.prompt, provider_name=request.provider_name, load_faiss_retriever=request.load_faiss_retriever ) - return schemas.ChatResponse(answer=response_text, provider_used=provider_used) + return schemas.ChatResponse(answer=response_text, provider_used=provider_used, message_id=message_id) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") @@ -52,7 +55,16 @@ if messages is None: raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + # Enhance messages with audio availability + enhanced_messages = [] + for m in messages: + msg_dict = schemas.Message.model_validate(m).model_dump() + if m.audio_path and os.path.exists(m.audio_path): + msg_dict["has_audio"] = True + msg_dict["audio_url"] = f"/sessions/messages/{m.id}/audio" + enhanced_messages.append(msg_dict) + + return schemas.MessageHistoryResponse(session_id=session_id, messages=enhanced_messages) except HTTPException: raise except Exception as e: @@ -112,6 +124,29 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}") + @router.patch("/{session_id}", response_model=schemas.Session, summary="Update a Chat Session") + def update_session(session_id: int, session_update: schemas.SessionUpdate, db: Session = Depends(get_db)): + try: + session = db.query(models.Session).filter( + models.Session.id == session_id, + models.Session.is_archived == False + ).first() + if not session: + raise HTTPException(status_code=404, detail="Session not found.") + + if session_update.title is not None: + session.title = session_update.title + if session_update.provider_name is not None: + session.provider_name = session_update.provider_name + + db.commit() + db.refresh(session) + return session + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update session: {e}") + @router.delete("/{session_id}", summary="Delete a Chat Session") def delete_session(session_id: int, db: Session = Depends(get_db)): try: @@ -139,5 +174,46 @@ return {"message": "All sessions deleted successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}") + + @router.post("/messages/{message_id}/audio", summary="Upload audio for a specific message") + async def upload_message_audio(message_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)): + try: + message = db.query(models.Message).filter(models.Message.id == message_id).first() + if not message: + raise HTTPException(status_code=404, detail="Message not found.") + # Create data directory if not exists + audio_dir = "/app/data/audio" + os.makedirs(audio_dir, exist_ok=True) + + # Save file + file_path = f"{audio_dir}/message_{message_id}.wav" + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Update database + message.audio_path = file_path + db.commit() + + return {"message": "Audio uploaded successfully.", "audio_path": file_path} + except Exception as e: + print(f"Error uploading audio: {e}") + raise HTTPException(status_code=500, detail=f"Failed to upload audio: {e}") + + @router.get("/messages/{message_id}/audio", summary="Get audio for a specific message") + async def get_message_audio(message_id: int, db: Session = Depends(get_db)): + try: + message = db.query(models.Message).filter(models.Message.id == message_id).first() + if not message or not message.audio_path: + raise HTTPException(status_code=404, detail="Audio not found for this message.") + + if not os.path.exists(message.audio_path): + raise HTTPException(status_code=404, detail="Audio file missing on disk.") + + return FileResponse(message.audio_path, media_type="audio/wav") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get audio: {e}") + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py index 94334a8..eda6404 100644 --- a/ai-hub/app/api/routes/tts.py +++ b/ai-hub/app/api/routes/tts.py @@ -39,24 +39,39 @@ from app.config import settings active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER active_prefs = prefs.get("providers", {}).get(active_provider, {}) - if active_prefs: - from app.core.providers.factory import get_tts_provider - kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]} - provider_override = get_tts_provider( - provider_name=active_provider, - api_key=active_prefs.get("api_key"), - model_name=active_prefs.get("model", ""), - voice_name=active_prefs.get("voice", ""), - **kwargs - ) + from app.core.providers.factory import get_tts_provider + kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]} + provider_override = get_tts_provider( + provider_name=active_provider, + api_key=active_prefs.get("api_key"), + model_name=active_prefs.get("model", ""), + voice_name=active_prefs.get("voice", ""), + **kwargs + ) if stream: # Pre-flight: generate first chunk before streaming to catch errors cleanly # If we send StreamingResponse and then fail, the browser sees a network error # instead of a meaningful error message. - chunks = await services.tts_service._split_text_into_chunks(request.text) + # Split into first chunk for latency, then send entire rest for smoothness + all_text = request.text + separators = ['.', '?', '!', '\n', '。', '?', '!', ','] + + # Find first separator within a reasonable limit + split_idx = -1 + chars_to_scan = min(len(all_text) - 1, 400) + for i in range(chars_to_scan, 50, -1): + if all_text[i] in separators: + split_idx = i + 1 + break + + if split_idx != -1 and split_idx < len(all_text): + chunks = [all_text[:split_idx], all_text[split_idx:]] + else: + chunks = [all_text] + provider = provider_override or services.tts_service.default_tts_provider - if not chunks: + if not chunks or not chunks[0].strip(): raise HTTPException(status_code=400, detail="No text to synthesize.") # Test first chunk synchronously to validate the provider works @@ -69,23 +84,49 @@ yield _create_wav_file(first_pcm) else: yield first_pcm - # Then stream the remaining chunks - for chunk in chunks[1:]: - try: - pcm = await provider.generate_speech(chunk) - if pcm: - if as_wav: - from app.core.services.tts import _create_wav_file - yield _create_wav_file(pcm) - else: - yield pcm - except Exception as e: - import logging - logging.getLogger(__name__).error(f"TTS chunk error: {e}") - break # Stop cleanly rather than crashing the stream + # Then stream the remaining chunks using parallel fetching but sequential yielding + import asyncio + semaphore = asyncio.Semaphore(3) # Limit concurrent external requests + + async def fetch_chunk(text_chunk): + retries = 3 + delay = 1.0 + for attempt in range(retries): + try: + async with semaphore: + return await provider.generate_speech(text_chunk) + except Exception as e: + error_str = str(e) + if "No audio in response" in error_str or "finishReason" in error_str: + import logging + logging.getLogger(__name__).error(f"TTS chunk blocked by provider formatting/safety: {e}") + return None + + if attempt == retries - 1: + import logging + logging.getLogger(__name__).error(f"TTS chunk failed after {retries} attempts: {e}") + return None + await asyncio.sleep(delay) + delay *= 2 + + # Start all tasks concurrently + tasks = [asyncio.create_task(fetch_chunk(chunk)) for chunk in chunks[1:]] + + for task in tasks: + pcm = await task + if pcm: + if as_wav: + from app.core.services.tts import _create_wav_file + yield _create_wav_file(pcm) + else: + yield pcm media_type = "audio/wav" if as_wav else "audio/pcm" - return StreamingResponse(full_stream(), media_type=media_type) + return StreamingResponse( + full_stream(), + media_type=media_type, + headers={"X-TTS-Chunk-Count": str(len(chunks))} + ) else: # The non-streaming function only returns WAV, so this part remains the same diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index a4faaa7..47f34f4 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -248,6 +248,13 @@ # Load system defaults from DB if needed system_prefs = services.user_service.get_system_settings(db) + + system_statuses = system_prefs.get("statuses", {}) + user_statuses = prefs_dict.get("statuses", {}) + + def is_provider_healthy(section: str, provider_id: str) -> bool: + status_key = f"{section}_{provider_id}" + return user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success" user_providers = llm_prefs.get("providers", {}) if not user_providers: @@ -264,7 +271,7 @@ llm_providers_effective = {} for p, p_p in user_providers.items(): - if p_p: + if p_p and is_provider_healthy("llm", p): llm_providers_effective[p] = { "api_key": mask_key(p_p.get("api_key")), "model": p_p.get("model") @@ -286,7 +293,7 @@ tts_providers_effective = {} for p, p_p in user_tts_providers.items(): - if p_p: + if p_p and is_provider_healthy("tts", p): tts_providers_effective[p] = { "api_key": mask_key(p_p.get("api_key")), "model": p_p.get("model"), @@ -308,7 +315,7 @@ stt_providers_effective = {} for p, p_p in user_stt_providers.items(): - if p_p: + if p_p and is_provider_healthy("stt", p): stt_providers_effective[p] = { "api_key": mask_key(p_p.get("api_key")), "model": p_p.get("model") @@ -470,8 +477,8 @@ from app.config import settings as global_settings # Sync LLM - if prefs.llm and prefs.llm.get("providers"): - global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {})) + if prefs.llm and "providers" in prefs.llm: + global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {})) # Sync TTS if prefs.tts and prefs.tts.get("active_provider"): @@ -699,7 +706,7 @@ voice_name=req.voice or "", **kwargs ) - await provider.generate_speech("Test") + await provider.generate_speech("Hello there. We are testing this thing.") return schemas.VerifyProviderResponse(success=True, message="Connection successful!") except Exception as e: logger.error(f"TTS verification failed for {req.provider_name}: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 357ed08..3719e34 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -88,6 +88,7 @@ """Defines the shape of a successful response from the /chat endpoint.""" answer: str provider_used: str + message_id: Optional[int] = None # --- Document Schemas --- class DocumentCreate(BaseModel): @@ -131,6 +132,10 @@ provider_name: str = "deepseek" feature_name: Optional[str] = "default" +class SessionUpdate(BaseModel): + title: Optional[str] = None + provider_name: Optional[str] = None + class Session(BaseModel): """Defines the shape of a session object returned by the API.""" id: int @@ -143,12 +148,17 @@ class Message(BaseModel): """Defines the shape of a single message within a session's history.""" + id: int # The sender can only be one of two roles. sender: Literal["user", "assistant"] # The text content of the message. content: str # The timestamp for when the message was created. created_at: datetime + # URL to the saved audio file + audio_url: Optional[str] = None + # Whether audio exists for this message + has_audio: bool = False # Enables creating this schema from a SQLAlchemy database object. model_config = ConfigDict(from_attributes=True) diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 230d127..a6ec166 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -14,6 +14,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables +from app.db.migrate import run_migrations from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer, get_db @@ -39,6 +40,7 @@ print("Application startup...") print_config(settings) create_db_and_tables() + run_migrations() yield print("Application shutdown...") # Access the vector_store from the application state to save it diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index d54d7cf..bc04c20 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -154,7 +154,7 @@ self.DEEPSEEK_MODEL_NAME = self.LLM_PROVIDERS.get("deepseek", {}).get("model") or \ get_from_yaml(["llm_providers", "deepseek_model_name"]) or "deepseek-chat" self.GEMINI_MODEL_NAME = self.LLM_PROVIDERS.get("gemini", {}).get("model") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-1.5-flash-latest" + get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-2.5-flash" # 2. Resolve Vector / Embedding self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ diff --git a/ai-hub/app/core/providers/tts/gcloud_tts.py b/ai-hub/app/core/providers/tts/gcloud_tts.py index a87f9fc..0c034b4 100644 --- a/ai-hub/app/core/providers/tts/gcloud_tts.py +++ b/ai-hub/app/core/providers/tts/gcloud_tts.py @@ -78,7 +78,8 @@ "name": voice_to_use }, "audioConfig": { - "audioEncoding": "LINEAR16" + "audioEncoding": "LINEAR16", + "sampleRateHertz": 24000 } } diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index a59f644..c5353d9 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -3,6 +3,8 @@ import httpx import base64 import logging +import asyncio +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, retry_if_exception from app.core.providers.base import TTSProvider from fastapi import HTTPException @@ -11,6 +13,17 @@ logger = logging.getLogger(__name__) +def is_retryable_exception(exception): + """Check if the exception is one we should retry on.""" + if isinstance(exception, httpx.TimeoutException): + return True + if isinstance(exception, httpx.NetworkError): + return True + if isinstance(exception, HTTPException): + # Retry on 429 (Too Many Requests) or 5xx (Server Errors) + return exception.status_code == 429 or 500 <= exception.status_code < 600 + return False + class GeminiTTSProvider(TTSProvider): """TTS provider using Gemini's audio responseModalities via Google AI Studio.""" @@ -29,7 +42,7 @@ # Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id model_id = raw_model.split("/")[-1] # Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts" - if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts"): + if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts", "gemini-2.5-flash"): model_id = "gemini-2.5-flash-preview-tts" logger.info(f"Normalised model name to: {model_id}") @@ -57,6 +70,13 @@ logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}") logger.debug(f" endpoint: {self.api_url[:80]}...") + @retry( + retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)) | retry_if_exception(is_retryable_exception), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + reraise=True, + before_sleep=lambda retry_state: logger.warning(f"Retrying Gemini TTS request (attempt {retry_state.attempt_number})...") + ) async def generate_speech(self, text: str) -> bytes: logger.debug(f"TTS generate_speech: '{text[:60]}...'") @@ -64,9 +84,6 @@ # The dedicated TTS models require a system instruction to produce only audio json_data = { - "system_instruction": { - "parts": [{"text": "You are a text-to-speech system. Convert the user text to speech audio only. Do not generate any text response."}] - }, "contents": [{"role": "user", "parts": [{"text": text}]}], "generationConfig": { "responseModalities": ["AUDIO"], @@ -86,7 +103,7 @@ logger.debug(f"Calling: {self.api_url}") try: - async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: response = await client.post(self.api_url, headers=headers, json=json_data) logger.debug(f"Response status: {response.status_code}") @@ -99,7 +116,15 @@ msg = err.get("message", body[:200]) except Exception: msg = body[:200] - raise HTTPException(status_code=response.status_code, detail=f"Gemini TTS error: {msg}") + + # Check if we should retry (429 or 5xx) + status_code = response.status_code + if status_code == 429 or 500 <= status_code < 600: + # tenacity will catch this if we configure it to retry on HTTPException + raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}") + else: + # Non-retryable error + raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}") resp_data = response.json() audio_fragments = [] @@ -124,11 +149,13 @@ logger.debug(f"TTS returned {len(result)} PCM bytes") return result - except HTTPException: + except (httpx.TimeoutException, httpx.NetworkError) as e: + logger.error(f"Gemini TTS request ({type(e).__name__}) after 60s") + # tenacity will catch this and retry raise - except httpx.TimeoutException: - logger.error("Gemini TTS request timed out after 30s") - raise HTTPException(status_code=504, detail="Gemini TTS request timed out.") + except HTTPException: + # tenacity might catch this if it's 429 or 5xx + raise except Exception as e: logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}") raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 9d88bd1..3eb1987 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -55,6 +55,7 @@ # Fetch user preferences for overrides api_key_override = None model_name_override = "" + llm_prefs = {} user = session.user if user and user.preferences: llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {}) @@ -93,7 +94,7 @@ db.commit() db.refresh(assistant_message) - return answer_text, provider_name + return answer_text, provider_name, assistant_message.id def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index 015b1c4..d9b5929 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -31,7 +31,7 @@ audio generation, splitting text into manageable chunks. """ - MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600)) + MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 500)) def __init__(self, tts_provider: TTSProvider): self.default_tts_provider = tts_provider diff --git a/ai-hub/app/db/migrate.py b/ai-hub/app/db/migrate.py new file mode 100644 index 0000000..625b5c3 --- /dev/null +++ b/ai-hub/app/db/migrate.py @@ -0,0 +1,45 @@ +import logging +from app.db.session import engine +from sqlalchemy import text, inspect + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def run_migrations(): + """ + Checks for missing columns and adds them if necessary. + This is a simple idempotent migration system to handle schema updates. + """ + logger.info("Starting database migrations...") + + with engine.connect() as conn: + inspector = inspect(engine) + if not inspector.has_table("messages"): + logger.info("Table 'messages' does not exist, skipping migrations (will be handled by Base.metadata.create_all).") + return + + columns = [c["name"] for c in inspector.get_columns("messages")] + + # List of (column_name, column_type) to ensure existence + required_columns = [ + ("audio_path", "TEXT"), + ("model_response_time", "INTEGER"), + ("token_count", "INTEGER") + ] + + for col_name, col_type in required_columns: + if col_name not in columns: + logger.info(f"Adding column '{col_name}' to 'messages' table...") + try: + conn.execute(text(f"ALTER TABLE messages ADD COLUMN {col_name} {col_type}")) + conn.commit() + logger.info(f"Successfully added '{col_name}'.") + except Exception as e: + logger.error(f"Failed to add column '{col_name}': {e}") + else: + logger.info(f"Column '{col_name}' already exists in 'messages'.") + + logger.info("Database migrations complete.") + +if __name__ == "__main__": + run_migrations() diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index f7bee7f..f94df2b 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -136,6 +136,8 @@ # A JSON field to store unstructured metadata about the message, such as tool calls. # This column has been renamed from 'metadata' to avoid a conflict. message_metadata = Column(JSON, nullable=True) + # Path to the generated audio file for this message, if any. + audio_path = Column(String, nullable=True) # Relationship back to the parent Session. diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 5b0ee90..37a2533 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -17,4 +17,5 @@ dspy aioresponses python-multipart -PyJWT \ No newline at end of file +PyJWT +tenacity \ No newline at end of file diff --git a/remote_deploy.sh b/remote_deploy.sh index 2784a66..78b2a71 100755 --- a/remote_deploy.sh +++ b/remote_deploy.sh @@ -1,12 +1,18 @@ #!/bin/bash # Description: Automates deployment from the local environment to the production host 192.168.68.113 -HOST="192.168.68.113" -USER="axieyangb" -PASS="a6163484a" +HOST="${REMOTE_HOST:-192.168.68.113}" +USER="${REMOTE_USER:-axieyangb}" +PASS="${REMOTE_PASS:-MySecurePassword}" REMOTE_TMP="/tmp/cortex-hub/" REMOTE_PROJ="/home/coder/project/cortex-hub" +if [ "$PASS" = "MySecurePassword" ]; then + echo "Error: Please set the REMOTE_PASS environment variable before deploying." + echo "Example: REMOTE_PASS='your_password' ./remote_deploy.sh" + exit 1 +fi + echo "Checking if sshpass is installed..." if ! command -v sshpass &> /dev/null; then echo "sshpass could not be found, installing..." diff --git a/ui/client-app/src/components/ChatArea.js b/ui/client-app/src/components/ChatArea.js index 1b5b4c6..e5172cc 100644 --- a/ui/client-app/src/components/ChatArea.js +++ b/ui/client-app/src/components/ChatArea.js @@ -2,7 +2,7 @@ import ChatWindow from "./ChatWindow"; import './ChatArea.css'; -const ChatArea = ({ chatHistory, onSendMessage, isProcessing }) => { +const ChatArea = ({ chatHistory, onSendMessage, isProcessing, featureName = "default" }) => { const [inputValue, setInputValue] = useState(""); const inputRef = useRef(null); const chatScrollRef = useRef(null); @@ -33,7 +33,7 @@
{/* Scrollable ChatWindow */}
- +
{/* Sticky Input */} @@ -51,11 +51,10 @@ diff --git a/ui/client-app/src/components/ChatWindow.js b/ui/client-app/src/components/ChatWindow.js index 1f53477..c2f2401 100644 --- a/ui/client-app/src/components/ChatWindow.js +++ b/ui/client-app/src/components/ChatWindow.js @@ -4,12 +4,78 @@ import FileListComponent from "./FileList"; import DiffViewer from "./DiffViewer"; import CodeChangePlan from "./CodeChangePlan"; -import { FaRegCopy, FaCopy } from 'react-icons/fa'; // Import the copy icon +import { FaRegCopy, FaCopy, FaVolumeUp, FaPlay, FaPause, FaDownload, FaSyncAlt } from 'react-icons/fa'; // Import the icons // Individual message component -const ChatMessage = ({ message }) => { +const ChatMessage = ({ message, index, onSynthesize, featureName = "default", activePlayingId, onPlayStateChange }) => { const [selectedFile, setSelectedFile] = useState(null); const [isReasoningExpanded, setIsReasoningExpanded] = useState(false); + const [audioUrl, setAudioUrl] = useState(null); + const [isPlaying, setIsPlaying] = useState(false); + const audioRef = useRef(null); + const isVoiceChat = featureName === "voice_chat"; + + // Unique ID for this message's audio + const currentMsgId = message.id || `msg-${index}`; + + useEffect(() => { + if (message.audioBlob) { + const url = URL.createObjectURL(message.audioBlob); + setAudioUrl(url); + return () => URL.revokeObjectURL(url); + } + }, [message.audioBlob]); + + // Handle exclusive playback: stop if someone else starts playing + useEffect(() => { + if (activePlayingId && activePlayingId !== currentMsgId && isPlaying) { + if (audioRef.current) { + audioRef.current.pause(); + setIsPlaying(false); + } + } + }, [activePlayingId, currentMsgId, isPlaying]); + + // Stop audio on unmount + useEffect(() => { + return () => { + if (audioRef.current) { + audioRef.current.pause(); + audioRef.current.src = ""; // Clear source to ensure it stops immediately + } + }; + }, []); + + const handlePlayPause = () => { + if (audioRef.current) { + if (isPlaying) { + audioRef.current.pause(); + onPlayStateChange(null); + } else { + audioRef.current.play(); + onPlayStateChange(currentMsgId); + } + setIsPlaying(!isPlaying); + } + }; + + const handleDownload = () => { + if (audioUrl) { + const a = document.createElement("a"); + a.href = audioUrl; + a.download = `voice_chat_${Date.now()}.wav`; + a.click(); + } + }; + + const handleReplay = () => { + if (audioRef.current) { + audioRef.current.currentTime = 0; + audioRef.current.play(); + setIsPlaying(true); + onPlayStateChange(currentMsgId); + } + }; const toggleReasoning = () => { setIsReasoningExpanded(!isReasoningExpanded); @@ -33,8 +99,15 @@ } } }; - const assistantMessageClasses = `p-4 rounded-lg shadow-md max-w-3xl bg-gray-200 dark:bg-gray-700 text-gray-900 dark:text-gray-100 assistant-message mr-auto`; - const userMessageClasses = `max-w-md p-4 rounded-lg shadow-md bg-indigo-500 text-white ml-auto`; + const assistantMessageClasses = `p-3 pb-2 rounded-2xl shadow-sm max-w-[85%] bg-gray-200 dark:bg-gray-800 text-gray-900 dark:text-gray-100 assistant-message mr-auto border border-gray-300 dark:border-gray-700/50`; + const userMessageClasses = `max-w-[80%] p-3 pb-2 rounded-2xl shadow-sm bg-indigo-600 text-white ml-auto`; + + const formatTime = (iso) => { + if (!iso) return ''; + try { + return new Date(iso).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }); + } catch { return ''; } + }; return (
@@ -55,11 +128,73 @@
)} {message.text} - {message.isPureAnswer && ( -
- {/* Horizontal line */} -
-
+ {(message.isPureAnswer || !message.isUser) && ( +
+ {/* Horizontal line - only for voice chat to separate from voice controls */} + {isVoiceChat && ( +
+ )} +
+ + {/* Audio Controls - strictly limited to voice chat feature */} + {isVoiceChat && (message.audioBlob ? ( +
+
+ ) : (!message.isUser && (message.isFromHistory || (message.audioProgress && message.audioProgress > 0)) && ( +
+ {message.isFromHistory && !message.audioProgress ? ( + + ) : (message.audioProgress && ( + <> +
+ Audio generating {message.audioProgress || 0}%... + + ))} +
+ )))} + + {/* Timestamp */} + + {formatTime(message.timestamp)} + + {/* Copy Icon - positioned above the bottom line */}
@@ -88,8 +222,16 @@ }; // Main ChatWindow component with dynamic height calculation -const ChatWindow = ({ chatHistory, maxHeight }) => { +const ChatWindow = ({ chatHistory, maxHeight, onSynthesize, featureName, isStreamingPlaying, onAudioPlay }) => { const containerRef = useRef(null); + const [activePlayingId, setActivePlayingId] = useState(null); + + useEffect(() => { + // If a new stream starts playing, stop any ongoing historical audio + if (isStreamingPlaying) { + setActivePlayingId(null); + } + }, [isStreamingPlaying]); useEffect(() => { if (containerRef.current) { @@ -108,7 +250,19 @@ key={index} className={`flex ${message.isUser ? "justify-end" : "justify-start"} w-full`} > - + { + setActivePlayingId(id); + if (id && onAudioPlay) { + onAudioPlay(); // Notify parent to stop streaming (to prevent overlap) + } + }} + />
))}
diff --git a/ui/client-app/src/components/SessionSidebar.js b/ui/client-app/src/components/SessionSidebar.js index 911210c..93f419a 100644 --- a/ui/client-app/src/components/SessionSidebar.js +++ b/ui/client-app/src/components/SessionSidebar.js @@ -7,7 +7,7 @@ } from '../services/apiService'; import './SessionSidebar.css'; -const SessionSidebar = ({ featureName, currentSessionId, onSwitchSession, onNewSession }) => { +const SessionSidebar = ({ featureName, currentSessionId, onSwitchSession, onNewSession, refreshTick }) => { const [isOpen, setIsOpen] = useState(false); const [sessions, setSessions] = useState([]); const [tokenHoverData, setTokenHoverData] = useState({}); @@ -15,7 +15,7 @@ useEffect(() => { if (isOpen) fetchSessions(); - }, [isOpen, featureName, currentSessionId]); + }, [isOpen, featureName, currentSessionId, refreshTick]); const fetchSessions = async () => { setIsLoading(true); @@ -100,16 +100,19 @@ sessions.map(s => { const isActive = Number(currentSessionId) === s.id; const td = tokenHoverData[s.id]; - const tooltip = td - ? `Context: ${td.token_count.toLocaleString()} / ${td.token_limit.toLocaleString()} tokens (${td.percentage}%)` - : 'Hover to load token usage'; - // Derive a display title: prefer session.title, fall back gracefully const displayTitle = s.title && s.title !== 'New Chat Session' ? s.title : `Session #${s.id}`; + const configInfo = s.provider_name ? `AI Model: ${s.provider_name}` : 'Default AI Configuration'; + const usageInfo = td + ? `Context: ${td.token_count.toLocaleString()} / ${td.token_limit.toLocaleString()} tokens (${td.percentage}%)` + : 'Hover to load token usage stats'; + + const tooltip = `${displayTitle}\n---\n${configInfo}\n${usageInfo}`; + return (
{ - const [chatHistory, setChatHistory] = useState([ - { - text: "Hello! I'm an AI assistant. How can I help you today?", - isUser: false, - }, - ]); + const [chatHistory, setChatHistory] = useState([]); const [status, setStatus] = useState("Click the microphone to start recording."); const [isBusy, setIsBusy] = useState(false); const [isRecording, setIsRecording] = useState(false); @@ -62,6 +60,17 @@ const lastRequestTimeRef = useRef(0); const streamRef = useRef(null); + const [isStreamingPlaying, setIsStreamingPlaying] = useState(false); + + /** + * Stops any currently playing streaming audio. + */ + const stopStreamingPlayback = useCallback(() => { + stopAllPlayingAudio(playingSourcesRef, audioContextRef, playbackTimeRef); + setIsStreamingPlaying(false); + setIsBusy(false); + }, []); + const fetchTokenUsage = useCallback(async () => { if (!sessionIdRef.current) return; try { @@ -98,10 +107,25 @@ try { const messagesData = await getSessionMessages(currentSessionId); if (messagesData && messagesData.messages && messagesData.messages.length > 0) { - const formattedHistory = messagesData.messages.map((msg) => ({ - isUser: msg.sender === "user", - text: msg.content, - })); + const formattedHistoryPromises = messagesData.messages.map(async (msg) => { + let audioBlob = null; + if (msg.has_audio) { + try { + audioBlob = await fetchMessageAudio(msg.id); + } catch (e) { + console.warn(`Failed to fetch audio for message ${msg.id}`, e); + } + } + return { + id: msg.id, + isUser: msg.sender === "user", + text: msg.content, + timestamp: msg.created_at, + isFromHistory: true, + audioBlob: audioBlob + }; + }); + const formattedHistory = await Promise.all(formattedHistoryPromises); setChatHistory(formattedHistory); } } catch (historyErr) { @@ -175,8 +199,9 @@ return () => { // Pass the refs to the utility function here stopAllMediaStreams(vadStreamRef, mediaRecorderRef, scriptProcessorRef, streamRef); + stopStreamingPlayback(); }; - }, []); + }, [stopStreamingPlayback]); // New useEffect hook to automatically scroll to the bottom of the chat history // The fix: `chatContainerRef` is now included in the dependency array. @@ -186,20 +211,32 @@ } }, [chatHistory, chatContainerRef]); - const addMessage = (text, isUser) => { - setChatHistory((prevHistory) => [...prevHistory, { text, isUser }]); + const addMessage = (text, isUser, id = null) => { + setChatHistory((prevHistory) => [...prevHistory, { + text, + isUser, + id, + timestamp: new Date().toISOString() + }]); }; /** * Plays a stream of audio chunks using the Web Audio API by fetching them from the API. * This is the orchestrator that uses the stateless streamSpeech API function. * @param {string} text - The text to be synthesized by the TTS service. + * @param {number} messageId - The ID of the message to associated the audio with. */ - const playStreamingAudio = async (text) => { + const playStreamingAudio = async (text, messageId = null) => { setIsBusy(true); + setIsStreamingPlaying(true); setStatus("Streaming audio..."); - // Pass the refs to the utility function - stopAllPlayingAudio(playingSourcesRef, audioContextRef, playbackTimeRef); + + // Stop any existing playback + stopStreamingPlayback(); + setIsBusy(true); // stopStreamingPlayback sets it to false, we want it true during this process + + // Track chunks to store in history + const accumulatedChunks = []; try { if (!audioContextRef.current) { @@ -209,12 +246,27 @@ const audioContext = audioContextRef.current; - const onChunkReceived = (rawFloat32Data) => { - // This is the callback that receives processed audio data from apiService. - // It's responsible for using the Web Audio API to play the sound. + const onChunkReceived = (rawFloat32Data, totalChunks, currentChunkIndex) => { + // Collect for storage + accumulatedChunks.push(new Float32Array(rawFloat32Data)); + + // Update UI progress + if (totalChunks > 0) { + const progress = Math.round((currentChunkIndex / totalChunks) * 100); + setChatHistory(prev => { + const next = [...prev]; + for (let i = next.length - 1; i >= 0; i--) { + if (!next[i].isUser && !next[i].audioBlob) { + next[i].audioProgress = progress; + break; + } + } + return next; + }); + } const float32Resampled = resampleBuffer( rawFloat32Data, - 24000, // The model's sample rate is hardcoded to 24000 + 24000, audioContext.sampleRate ); const audioBuffer = audioContext.createBuffer( @@ -224,6 +276,16 @@ ); audioBuffer.copyToChannel(float32Resampled, 0); + // Apply a very short fade-in and fade-out (2ms) to eliminate "clicks" at segment boundaries + const fadeSamps = Math.floor(audioContext.sampleRate * 0.002); + const chanData = audioBuffer.getChannelData(0); + if (chanData.length > fadeSamps * 2) { + for (let i = 0; i < fadeSamps; i++) { + chanData[i] *= (i / fadeSamps); + chanData[chanData.length - 1 - i] *= (i / fadeSamps); + } + } + const source = audioContext.createBufferSource(); source.buffer = audioBuffer; source.connect(audioContext.destination); @@ -242,16 +304,53 @@ }; }; - const onStreamDone = () => { - // This callback is triggered when the stream finishes. + const onStreamDone = async () => { console.log("TTS Stream complete."); + setIsStreamingPlaying(false); + if (accumulatedChunks.length > 0) { + // Concatenate all chunks and save the blob + const totalLen = accumulatedChunks.reduce((acc, c) => acc + c.length, 0); + const result = new Float32Array(totalLen); + let offset = 0; + for (const c of accumulatedChunks) { + result.set(c, offset); + offset += c.length; + } + // resample to standard 44.1k for download/blob stability + const finalPcm = resampleBuffer(result, 24000, 44100); + const wavBlob = encodeWAV(finalPcm, 44100); + + // Upload to persistent storage if messageId is available + if (messageId) { + try { + console.log(`Uploading audio for message ${messageId}...`); + await uploadMessageAudio(messageId, wavBlob); + } catch (uploadErr) { + console.warn("Failed to upload persistent audio", uploadErr); + } + } + + // Post-update: find the last AI message and attach this blob + setChatHistory(prev => { + const next = [...prev]; + // Find the latest assistant message that matches this text (or just the latest) + for (let i = next.length - 1; i >= 0; i--) { + if (!next[i].isUser && !next[i].audioBlob) { + next[i].audioBlob = wavBlob; + if (messageId) next[i].id = messageId; + break; + } + } + return next; + }); + } }; - // Call the stateless API function, passing the UI-related callbacks await streamSpeech(text, onChunkReceived, onStreamDone, localActivePrefs.tts); } catch (err) { console.error("Failed to stream speech:", err); + setIsStreamingPlaying(false); setStatus(`Error: Failed to stream speech. ${err.message}`); setErrorMessage(`Failed to stream speech: ${err.message}`); setShowErrorModal(true); @@ -268,6 +367,80 @@ } }; + /** + * Specifically for manual replay/synthesis of any message (including history) + */ + const synthesizeMessageAudio = async (index, text) => { + if (isBusy) return; + const accumulatedChunks = []; + + if (chatHistory[index]?.audioBlob) return; + + setIsBusy(true); + try { + if (!audioContextRef.current) { + audioContextRef.current = new (window.AudioContext || window.webkitAudioContext)(); + playbackTimeRef.current = audioContextRef.current.currentTime; + } + const audioContext = audioContextRef.current; + + const onData = (rawFloat32Data, total, current) => { + accumulatedChunks.push(new Float32Array(rawFloat32Data)); + if (total > 0) { + const progress = Math.round((current / total) * 100); + setChatHistory(prev => { + const next = [...prev]; + if (next[index]) next[index].audioProgress = progress; + return next; + }); + } + + const float32Resampled = resampleBuffer(rawFloat32Data, 24000, audioContext.sampleRate); + const audioBuffer = audioContext.createBuffer(1, float32Resampled.length, audioContext.sampleRate); + audioBuffer.copyToChannel(float32Resampled, 0); + const source = audioContext.createBufferSource(); + source.buffer = audioBuffer; + source.connect(audioContext.destination); + const startTime = Math.max(playbackTimeRef.current, audioContext.currentTime); + source.start(startTime); + playbackTimeRef.current = startTime + audioBuffer.duration; + playingSourcesRef.current.push(source); + }; + + const onDone = async () => { + if (accumulatedChunks.length > 0) { + const totalLen = accumulatedChunks.reduce((acc, c) => acc + c.length, 0); + const result = new Float32Array(totalLen); + let offset = 0; + for (const c of accumulatedChunks) { + result.set(c, offset); + offset += c.length; + } + const finalPcm = resampleBuffer(result, 24000, 44100); + const wavBlob = encodeWAV(finalPcm, 44100); + + const messageId = chatHistory[index]?.id; + if (messageId) { + try { + await uploadMessageAudio(messageId, wavBlob); + } catch (e) { console.warn("Upload failed during manual synthesis", e); } + } + + setChatHistory(prev => { + const next = [...prev]; + if (next[index]) next[index].audioBlob = wavBlob; + return next; + }); + } + }; + + await streamSpeech(text, onData, onDone, localActivePrefs.tts); + } catch (err) { + console.error("Manual synthesis failed", err); + } finally { + setIsBusy(false); + } + }; const processConversation = async (audioBlob) => { console.log("Processing conversation..."); @@ -291,11 +464,11 @@ addMessage(userText, true); setStatus("AI is thinking..."); - const aiText = await chatWithAI(sessionId, userText, localActivePrefs.llm || "gemini"); - addMessage(aiText, false); + const aiResponse = await chatWithAI(sessionId, userText, localActivePrefs.llm || "gemini"); + addMessage(aiResponse.answer, false, aiResponse.message_id); fetchTokenUsage(); - await playStreamingAudio(aiText); + await playStreamingAudio(aiResponse.answer, aiResponse.message_id); } catch (error) { console.error("Conversation processing failed:", error); setStatus(`Error: ${error.message}`); @@ -391,7 +564,7 @@ const timeSinceLastRequest = Date.now() - lastRequestTimeRef.current; const isCooldownPassed = timeSinceLastRequest > AUTO_MODE_COOLDOWN_MS; - if (isVoiceDetected) { + if (isVoiceDetected && !isBusy) { if (silenceTimeoutRef.current) { clearTimeout(silenceTimeoutRef.current); silenceTimeoutRef.current = null; @@ -484,12 +657,7 @@ }; const handleNewSession = async () => { - setChatHistory([ - { - text: "Hello! I'm an AI assistant. How can I help you today?", - isUser: false, - }, - ]); + setChatHistory([]); localStorage.removeItem("sessionId_voice_chat"); setIsBusy(true); @@ -531,10 +699,25 @@ const messagesData = await getSessionMessages(targetSessionId); if (messagesData && messagesData.messages) { - const mappedHistory = messagesData.messages.map(msg => ({ - text: msg.content, - isUser: msg.sender === 'user' - })); + const mappedHistoryPromises = messagesData.messages.map(async (msg) => { + let audioBlob = null; + if (msg.has_audio) { + try { + audioBlob = await fetchMessageAudio(msg.id); + } catch (e) { + console.warn(`Failed to fetch audio for message ${msg.id} during switch`, e); + } + } + return { + id: msg.id, + text: msg.content, + isUser: msg.sender === 'user', + timestamp: msg.created_at, + isFromHistory: true, + audioBlob: audioBlob + }; + }); + const mappedHistory = await Promise.all(mappedHistoryPromises); setChatHistory(mappedHistory); } fetchTokenUsage(); @@ -567,6 +750,9 @@ handleSwitchSession, setShowErrorModal, setErrorMessage, + synthesizeMessageAudio, + isStreamingPlaying, + stopStreamingPlayback }; }; diff --git a/ui/client-app/src/pages/CodingAssistantPage.js b/ui/client-app/src/pages/CodingAssistantPage.js index ac1d4fe..241920e 100644 --- a/ui/client-app/src/pages/CodingAssistantPage.js +++ b/ui/client-app/src/pages/CodingAssistantPage.js @@ -8,6 +8,7 @@ // A custom hook to manage WebSocket connection and state import useCodeAssistant from "../hooks/useCodeAssistant"; +import { updateSession } from "../services/apiService"; const CodeAssistantPage = () => { // Reference for the main container to manage scrolling @@ -40,6 +41,19 @@ const [isPanelExpanded, setIsPanelExpanded] = useState(false); const [showConfigModal, setShowConfigModal] = useState(false); + const [sidebarRefreshTick, setSidebarRefreshTick] = useState(0); + + const handleSaveQuickConfig = async () => { + try { + if (sessionId && localActiveLLM) { + await updateSession(sessionId, { provider_name: localActiveLLM }); + setSidebarRefreshTick(t => t + 1); + } + setShowConfigModal(false); + } catch (e) { + console.error("Failed to update session configs:", e); + } + }; // Scroll to the bottom of the page when new content is added useEffect(() => { @@ -55,6 +69,7 @@ currentSessionId={sessionId} onSwitchSession={handleSwitchSession} onNewSession={() => handleSendChat("/new")} + refreshTick={sidebarRefreshTick} /> {/* Main content area */} @@ -127,7 +142,7 @@
{/* Note: ChatArea component needs to be implemented with a