diff --git a/.gitignore b/.gitignore index af87667..816e9e6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ ai-hub/ai_payloads/* ai-hub/.env.prod @eaDir/ -**/.DS_Store \ No newline at end of file +**/.DS_Store +ai-hub/app/config.yaml +**/config.yaml \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py index 6fc2af1..732bd2f 100644 --- a/ai-hub/app/api/routes/sessions.py +++ b/ai-hub/app/api/routes/sessions.py @@ -97,6 +97,21 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch sessions: {e}") + @router.get("/{session_id}", response_model=schemas.Session, summary="Get a Single Session") + def get_session(session_id: int, db: Session = Depends(get_db)): + try: + session = db.query(models.Session).filter( + models.Session.id == session_id, + models.Session.is_archived == False + ).first() + if not session: + raise HTTPException(status_code=404, detail="Session not found.") + return session + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}") + @router.delete("/{session_id}", summary="Delete a Chat Session") def delete_session(session_id: int, db: Session = Depends(get_db)): try: diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py index d99220e..a0fec90 100644 --- a/ai-hub/app/api/routes/stt.py +++ b/ai-hub/app/api/routes/stt.py @@ -1,6 +1,8 @@ import logging from fastapi import APIRouter, HTTPException, UploadFile, File, Depends -from app.api.dependencies import ServiceContainer +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api.routes.user import get_current_user_id from app.api import schemas from app.core.services.stt import STTService @@ -20,7 +22,10 @@ response_model=schemas.STTResponse ) async def transcribe_audio_to_text( - audio_file: UploadFile = File(...) + audio_file: UploadFile = File(...), + provider_name: str | None = None, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) ): """ Transcribes an uploaded audio file into text using the configured STT service. @@ -40,9 +45,29 @@ try: # Read the audio bytes from the uploaded file audio_bytes = await audio_file.read() + + provider_override = None + if user_id: + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + prefs = user.preferences.get("stt", {}) if user and user.preferences else {} + from app.config import settings + active_provider = provider_name or prefs.get("active_provider") or settings.STT_PROVIDER + active_prefs = prefs.get("providers", {}).get(active_provider, {}) + if active_prefs: + from app.core.providers.factory import get_stt_provider + kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model"]} + provider_override = get_stt_provider( + provider_name=active_provider, + api_key=active_prefs.get("api_key"), + model_name=active_prefs.get("model", ""), + **kwargs + ) # Use the STT service to get the transcript - transcript = await services.stt_service.transcribe(audio_bytes) + transcript = await services.stt_service.transcribe( + audio_bytes, + provider_override=provider_override + ) # Return the transcript in a simple JSON response return schemas.STTResponse(transcript=transcript) diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py index e305ee2..a6c9c4f 100644 --- a/ai-hub/app/api/routes/tts.py +++ b/ai-hub/app/api/routes/tts.py @@ -1,8 +1,10 @@ -from fastapi import APIRouter, HTTPException, Query, Response +from fastapi import APIRouter, HTTPException, Query, Response, Depends from fastapi.responses import StreamingResponse -from app.api.dependencies import ServiceContainer -from app.api import schemas from typing import AsyncGenerator +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api.routes.user import get_current_user_id +from app.api import schemas def create_tts_router(services: ServiceContainer) -> APIRouter: router = APIRouter(prefix="/speech", tags=["TTS"]) @@ -21,24 +23,75 @@ as_wav: bool = Query( True, description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true." - ) + ), + provider_name: str = Query( + None, + description="Optional session-level override for the TTS provider" + ), + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) ): try: + provider_override = None + if user_id: + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + prefs = user.preferences.get("tts", {}) if user and user.preferences else {} + from app.config import settings + active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER + active_prefs = prefs.get("providers", {}).get(active_provider, {}) + if active_prefs: + from app.core.providers.factory import get_tts_provider + kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]} + provider_override = get_tts_provider( + provider_name=active_provider, + api_key=active_prefs.get("api_key"), + model_name=active_prefs.get("model", ""), + voice_name=active_prefs.get("voice", ""), + **kwargs + ) + if stream: - # Pass the new as_wav parameter to the streaming function - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text, - as_wav=as_wav - ) + # Pre-flight: generate first chunk before streaming to catch errors cleanly + # If we send StreamingResponse and then fail, the browser sees a network error + # instead of a meaningful error message. + chunks = await services.tts_service._split_text_into_chunks(request.text) + provider = provider_override or services.tts_service.default_tts_provider + if not chunks: + raise HTTPException(status_code=400, detail="No text to synthesize.") - # Dynamically set the media_type based on the as_wav flag + # Test first chunk synchronously to validate the provider works + first_pcm = await provider.generate_speech(chunks[0]) + + async def full_stream(): + # Yield the already-generated first chunk + if as_wav: + from app.core.services.tts import _create_wav_file + yield _create_wav_file(first_pcm) + else: + yield first_pcm + # Then stream the remaining chunks + for chunk in chunks[1:]: + try: + pcm = await provider.generate_speech(chunk) + if pcm: + if as_wav: + from app.core.services.tts import _create_wav_file + yield _create_wav_file(pcm) + else: + yield pcm + except Exception as e: + import logging + logging.getLogger(__name__).error(f"TTS chunk error: {e}") + break # Stop cleanly rather than crashing the stream + media_type = "audio/wav" if as_wav else "audio/pcm" - - return StreamingResponse(audio_stream_generator, media_type=media_type) + return StreamingResponse(full_stream(), media_type=media_type) + else: # The non-streaming function only returns WAV, so this part remains the same audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text + text=request.text, + provider_override=provider_override ) return Response(content=audio_bytes, media_type="audio/wav") @@ -48,5 +101,37 @@ raise HTTPException( status_code=500, detail=f"Failed to generate speech: {e}" ) - + + @router.get( + "/voices", + summary="List available Google Cloud TTS voices", + response_description="A list of voice names" + ) + async def list_voices( + api_key: str = Query(None, description="Optional API key override"), + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + from app.config import settings + import httpx + key_to_use = api_key or settings.TTS_API_KEY or settings.GEMINI_API_KEY + if not key_to_use: + return [] + + url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}" + try: + async with httpx.AsyncClient(timeout=10) as client: + res = await client.get(url) + if res.status_code == 200: + data = res.json() + voices = data.get('voices', []) + # Filter for english or interesting ones + names = [v['name'] for v in voices] + return sorted(names) + return [] + except Exception as e: + import logging + logging.getLogger(__name__).error(f"Failed to fetch voices: {e}") + return [] + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index 523fcd0..bb9d513 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -1,12 +1,13 @@ -from fastapi import APIRouter, HTTPException, Depends, Header, Query, Request +from fastapi import APIRouter, HTTPException, Depends, Header, Query, Request, UploadFile, File from fastapi.responses import RedirectResponse as redirect from sqlalchemy.orm import Session from app.db import models from typing import Optional, Annotated import logging import os -import requests +import httpx import jwt +import urllib.parse # Correctly import from your application's schemas and dependencies from app.api.dependencies import ServiceContainer, get_db @@ -56,17 +57,17 @@ # For simplicity, we will pass it as a query parameter in the callback. # A more robust solution would use a state parameter. - # The OIDC provider must redirect to a URL known to the backend. - # So we redirect to a backend endpoint, which in turn redirects to the frontend. - auth_url = ( - f"{OIDC_AUTHORIZATION_URL}?" - f"response_type=code&" - f"scope=openid%20profile%20email&" - f"client_id={OIDC_CLIENT_ID}&" - f"redirect_uri={OIDC_REDIRECT_URI}&" - f"state={frontend_callback_uri}" # Pass the frontend URI in the state parameter - ) - logger.debug(f"Redirecting to OIDC authorization URL: {auth_url}") + # Use urllib.parse.urlencode to properly encode parameters + params = { + "response_type": "code", + "scope": "openid profile email", + "client_id": OIDC_CLIENT_ID, + "redirect_uri": OIDC_REDIRECT_URI, + "state": frontend_callback_uri or "" + } + + auth_url = f"{OIDC_AUTHORIZATION_URL}?{urllib.parse.urlencode(params)}" + logger.info(f"Redirecting to OIDC authorization URL: {auth_url}") return redirect(url=auth_url) @router.get("/login/callback", summary="Handle OIDC Login Callback") @@ -81,9 +82,10 @@ tokens, and then redirects the user back to the frontend with the user data or a session token. """ - logger.debug(f"Received callback with authorization code: {code}") + logger.info(f"Received callback with authorization code: {code[:10]}... and state: {state}") try: + logger.info(f"Exchanging code for tokens at: {OIDC_TOKEN_URL}") # Step 1: Exchange the authorization code for an access token and an ID token token_data = { "grant_type": "authorization_code", @@ -92,51 +94,61 @@ "client_id": OIDC_CLIENT_ID, "client_secret": OIDC_CLIENT_SECRET, } - token_response = requests.post(OIDC_TOKEN_URL, data=token_data) - token_response.raise_for_status() - response_json = token_response.json() + async with httpx.AsyncClient() as client: + logger.debug(f"Sending POST to {OIDC_TOKEN_URL} with data keys: {list(token_data.keys())}") + token_response = await client.post(OIDC_TOKEN_URL, data=token_data, timeout=30.0) + token_response.raise_for_status() + response_json = token_response.json() + + logger.info("Successfully received tokens from OIDC provider.") id_token = response_json.get("id_token") if not id_token: - logger.error("Error: ID token not found.") - raise HTTPException(status_code=400, detail="Failed to get ID token.") + logger.error("Error: ID token not found in the response.") + raise HTTPException(status_code=400, detail="Failed to get ID token from OIDC provider.") # Step 2: Decode the ID token to get user information + logger.info("Decoding ID token...") decoded_id_token = jwt.decode(id_token, options={"verify_signature": False}) oidc_id = decoded_id_token.get("sub") email = decoded_id_token.get("email") - username = decoded_id_token.get("name") + # Dex and others often use 'name' for the full name, or 'preferred_username' + username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email - if not all([oidc_id, email, username]): - logger.error("Error: Essential user data missing.") - raise HTTPException(status_code=400, detail="Essential user data missing.") + logger.info(f"User decoded: email={email}, oidc_id={oidc_id}") + + if not all([oidc_id, email]): + logger.error(f"Error: Essential user data missing. oidc_id={oidc_id}, email={email}") + raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).") # Step 3: Save the user and get their unique ID + logger.info("Saving user to database...") user_id = services.user_service.save_user( db=db, oidc_id=oidc_id, email=email, username=username ) + logger.info(f"User saved/updated successfully with internal ID: {user_id}") - # Step 4: Redirect back to the frontend, passing the user_id or a session token - # Note: This is a simplification. A real app would set a secure HTTP-only cookie. - # We are passing the user_id as a query parameter for demonstration. + # Step 4: Redirect back to the frontend frontend_redirect_url = f"{state}?user_id={user_id}" + logger.info(f"Redirecting back to frontend: {frontend_redirect_url}") return redirect(url=frontend_redirect_url) - except requests.exceptions.RequestException as e: - logger.error(f"Token exchange error: {e}") - if hasattr(e, 'response') and e.response is not None: - logger.error(f"Token exchange response body: {e.response.text}") + except httpx.HTTPStatusError as e: + logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}") + raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"OIDC Token exchange request error: {e}") raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}") except jwt.JWTDecodeError as e: logger.error(f"ID token decode error: {e}") - raise HTTPException(status_code=400, detail="Failed to decode ID token.") + raise HTTPException(status_code=400, detail="Failed to decode ID token from OIDC provider.") except Exception as e: - logger.exception(f"An unexpected error occurred during OAuth callback: {e}") + logger.exception(f"An unexpected error occurred during OIDC callback: {e}") raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") @router.get("/me", response_model=schemas.UserStatus, summary="Get Current User Status") @@ -165,6 +177,374 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + @router.get("/me/config", response_model=schemas.ConfigResponse, summary="Get Current User Preferences") + async def get_user_config( + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Gets user specific preferences (LLM, TTS, STT config overrides).""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + prefs_dict = user.preferences or {} + + # Calculate effective config + from app.config import settings + + def mask_key(k): + if not k: return None + if len(k) <= 8: return "****" + return k[:4] + "*" * (len(k)-8) + k[-4:] + + llm_prefs = prefs_dict.get("llm", {}) + tts_prefs = prefs_dict.get("tts", {}) + stt_prefs = prefs_dict.get("stt", {}) + + user_providers = llm_prefs.get("providers", {}) + + if not user_providers: + # Day zero: fall back to evaluating defaults from env/yaml + llm_providers_effective = { + "deepseek": { + "api_key": mask_key(settings.DEEPSEEK_API_KEY), + "model": settings.DEEPSEEK_MODEL_NAME + }, + "gemini": { + "api_key": mask_key(settings.GEMINI_API_KEY), + "model": settings.GEMINI_MODEL_NAME + }, + "openai": { + "api_key": mask_key(settings.OPENAI_API_KEY), + "model": "gpt-4" + } + } + else: + # User has configured providers, only return what they have (plus masking) + llm_providers_effective = {} + for p, prefs in user_providers.items(): + llm_providers_effective[p] = { + "api_key": mask_key(prefs.get("api_key")), + "model": prefs.get("model") + } + + user_tts_providers = tts_prefs.get("providers", {}) + if not user_tts_providers: + tts_providers_effective = { + settings.TTS_PROVIDER: { + "api_key": mask_key(settings.TTS_API_KEY), + "model": settings.TTS_MODEL_NAME, + "voice": settings.TTS_VOICE_NAME + } + } + else: + tts_providers_effective = {} + for p, prefs in user_tts_providers.items(): + tts_providers_effective[p] = { + "api_key": mask_key(prefs.get("api_key")), + "model": prefs.get("model"), + "voice": prefs.get("voice") + } + + user_stt_providers = stt_prefs.get("providers", {}) + if not user_stt_providers: + stt_providers_effective = { + settings.STT_PROVIDER: { + "api_key": mask_key(settings.STT_API_KEY), + "model": settings.STT_MODEL_NAME + } + } + else: + stt_providers_effective = {} + for p, prefs in user_stt_providers.items(): + stt_providers_effective[p] = { + "api_key": mask_key(prefs.get("api_key")), + "model": prefs.get("model") + } + + effective = { + "llm": { + "active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), None)) or "deepseek", + "providers": llm_providers_effective + }, + "tts": { + "active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), None)) or settings.TTS_PROVIDER, + "providers": tts_providers_effective + }, + "stt": { + "active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), None)) or settings.STT_PROVIDER, + "providers": stt_providers_effective + } + } + # Ensure we mask the preferences dict we send back to the user + def mask_section_prefs(section_dict): + import copy + masked_dict = copy.deepcopy(section_dict) + providers = masked_dict.get("providers", {}) + for p_name, p_data in providers.items(): + if p_data.get("api_key"): + p_data["api_key"] = mask_key(p_data["api_key"]) + return masked_dict + + return schemas.ConfigResponse( + preferences=schemas.UserPreferences( + llm=mask_section_prefs(llm_prefs), + tts=mask_section_prefs(tts_prefs), + stt=mask_section_prefs(stt_prefs) + ), + effective=effective + ) + + @router.put("/me/config", response_model=schemas.UserPreferences, summary="Update Current User Preferences") + async def update_user_config( + prefs: schemas.UserPreferences, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Updates user specific preferences.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # When saving, if the api_key contains ****, we must retain the old one from the DB + old_prefs = user.preferences or {} + + def preserve_masked_keys(section_name, new_section): + if not new_section or "providers" not in new_section: + return + old_section = old_prefs.get(section_name, {}).get("providers", {}) + for p_name, p_data in new_section["providers"].items(): + if p_data.get("api_key") and "***" in p_data["api_key"]: + if p_name in old_section: + p_data["api_key"] = old_section[p_name].get("api_key") + + if prefs.llm: preserve_masked_keys("llm", prefs.llm) + if prefs.tts: preserve_masked_keys("tts", prefs.tts) + if prefs.stt: preserve_masked_keys("stt", prefs.stt) + + user.preferences = {"llm": prefs.llm, "tts": prefs.tts, "stt": prefs.stt} + from sqlalchemy.orm.attributes import flag_modified + flag_modified(user, "preferences") + + # --- Day 2 Sync: Update Global Settings and YAML --- + # Update our global settings object to match the user's new preferences + # and then persist them to config.yaml if this is the "primary" user or admin. + from app.config import settings as global_settings + + # Sync LLM + if prefs.llm and prefs.llm.get("providers"): + global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {})) + + # Sync TTS + if prefs.tts and prefs.tts.get("active_provider"): + p_name = prefs.tts["active_provider"] + p_data = prefs.tts.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.TTS_PROVIDER = p_name + global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME + global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME + global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY + + # Sync STT + if prefs.stt and prefs.stt.get("active_provider"): + p_name = prefs.stt["active_provider"] + p_data = prefs.stt.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.STT_PROVIDER = p_name + global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME + global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY + + # Write to config.yaml + try: + global_settings.save_to_yaml() + except Exception as ey: + logger.error(f"Failed to sync settings to YAML: {ey}") + + logger.info(f"Saving updated preferences for user {user_id}: {list(user.preferences.keys())}") + db.add(user) + db.commit() + db.refresh(user) + + return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {})) + + @router.get("/me/config/models", response_model=list[schemas.ModelInfoResponse], summary="Get Models for Provider") + async def get_provider_models(provider_name: str, section: str = "llm"): + import litellm + from fastapi.concurrency import run_in_threadpool + + def fetch_models(): + try: + models = litellm.models_by_provider.get(provider_name, []) + out = [] + for m in models: + try: + info = litellm.get_model_info(m) + if "error" not in info: + mode = info.get("mode") + is_valid = False + if section == "llm": + is_valid = mode in ["chat", "text-completion", "custom", None] + elif section == "tts": + is_valid = mode == "audio_speech" + elif section == "stt": + is_valid = mode == "audio_transcription" + elif section == "image": + is_valid = mode == "image_generation" + else: + is_valid = True + + if is_valid: + out.append({ + "model_name": m, + "max_tokens": info.get("max_tokens"), + "max_input_tokens": info.get("max_input_tokens") + }) + except Exception: + pass + return out + except Exception: + return [] + + results = await run_in_threadpool(fetch_models) + return results + + @router.get("/me/config/providers", response_model=list[str], summary="Get All Valid Providers per Section") + async def get_all_providers(section: str = "llm"): + import litellm + from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers + + if section == "llm": + return ["general"] + [p.value for p in litellm.LlmProviders] + elif section == "tts": + # For TTS, combine registry + possibly litellm providers that support audio_speech + return ["general"] + get_registered_tts_providers() + ["openai"] + elif section == "stt": + return ["general"] + get_registered_stt_providers() + ["openai"] + elif section == "image": + return ["general"] + [p.value for p in litellm.LlmProviders] + return ["general"] + [p.value for p in litellm.LlmProviders] + + @router.post("/me/config/verify_llm", response_model=schemas.VerifyProviderResponse) + async def verify_llm( + req: schemas.VerifyProviderRequest, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + from app.core.providers.factory import get_llm_provider + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + actual_key = req.api_key + try: + llm_prefs = {} + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if user and user.preferences: + llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {}) + + kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} + llm = get_llm_provider( + provider_name=req.provider_name, + model_name=req.model or "", + api_key_override=actual_key, + **kwargs + ) + res = llm("Hello") + return schemas.VerifyProviderResponse(success=True, message="Connection successful!") + except Exception as e: + return schemas.VerifyProviderResponse(success=False, message=str(e)) + + @router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse) + async def verify_tts( + req: schemas.VerifyProviderRequest, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + from app.core.providers.factory import get_tts_provider + from app.config import settings + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + actual_key = req.api_key + try: + tts_prefs = {} + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if user and user.preferences: + tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) + + # Resolve the real key: masked/absent → DB key → system TTS key → system Gemini key + def is_masked(k): + return not k or k in ("None", "none", "") or "*" in str(k) + + if is_masked(actual_key): + db_key = tts_prefs.get("api_key") + if not is_masked(db_key): + actual_key = db_key + logger.debug(f"verify_tts: using DB key for {req.provider_name}") + elif not is_masked(settings.TTS_API_KEY): + actual_key = settings.TTS_API_KEY + logger.debug(f"verify_tts: using system TTS_API_KEY for {req.provider_name}") + elif not is_masked(settings.GEMINI_API_KEY): + actual_key = settings.GEMINI_API_KEY + logger.debug(f"verify_tts: using system GEMINI_API_KEY for {req.provider_name}") + + logger.info(f"verify_tts: provider={req.provider_name}, model={req.model}, key_prefix={str(actual_key)[:8] if actual_key else 'NONE'}") + + kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]} + tts = get_tts_provider( + provider_name=req.provider_name, + api_key=actual_key, + model_name=req.model or "", + voice_name=req.voice or "", + **kwargs + ) + await tts.generate_speech("Testing.") + return schemas.VerifyProviderResponse(success=True, message="Connection successful!") + except Exception as e: + logger.error(f"TTS verification failed for {req.provider_name}: {e}") + return schemas.VerifyProviderResponse(success=False, message=str(e)) + + @router.post("/me/config/verify_stt", response_model=schemas.VerifyProviderResponse) + async def verify_stt( + req: schemas.VerifyProviderRequest, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + from app.core.providers.factory import get_stt_provider + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + actual_key = req.api_key + from app.config import settings + def is_masked(k): + return not k or k in ("None", "none", "") or "*" in str(k) + + if is_masked(actual_key): + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if user and user.preferences: + prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) + db_key = prefs.get("api_key") + if not is_masked(db_key): + actual_key = db_key + if is_masked(actual_key) and not is_masked(settings.STT_API_KEY): + actual_key = settings.STT_API_KEY + if is_masked(actual_key) and not is_masked(settings.GEMINI_API_KEY): + actual_key = settings.GEMINI_API_KEY + try: + stt = get_stt_provider(provider_name=req.provider_name, model_name=req.model or "", api_key=actual_key) + import io + import wave + wav_io = io.BytesIO() + with wave.open(wav_io, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(16000) + wav_file.writeframes(b'\x00\x00' * 16000) # 1 second of silence + res = await stt.transcribe_audio(wav_io.getvalue()) + # Empty transcript is expected for silent audio — connection works fine + return schemas.VerifyProviderResponse(success=True, message="Connection successful!") + except Exception as e: + return schemas.VerifyProviderResponse(success=False, message=str(e)) + @router.post("/logout", summary="Log Out the Current User") async def logout(): """ @@ -172,4 +552,192 @@ """ return {"message": "Logged out successfully"} + @router.get("/me/config/export", summary="Export Configurations to YAML") + async def export_user_config_yaml( + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Exports the effective user configuration as a YAML file.""" + from fastapi.responses import PlainTextResponse + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + prefs_dict = user.preferences or {} + from app.config import settings + import yaml + + llm_prefs = prefs_dict.get("llm", {}) + tts_prefs = prefs_dict.get("tts", {}) + stt_prefs = prefs_dict.get("stt", {}) + + llm_providers_export = {} + user_providers = llm_prefs.get("providers", {}) + if not user_providers: + # Fallback to system defaults if no user config exists + llm_providers_export = { + "deepseek_api_key": settings.DEEPSEEK_API_KEY, + "deepseek_model_name": settings.DEEPSEEK_MODEL_NAME, + "gemini_api_key": settings.GEMINI_API_KEY, + "gemini_model_name": settings.GEMINI_MODEL_NAME, + "openai_api_key": settings.OPENAI_API_KEY + } + else: + for p, p_data in user_providers.items(): + llm_providers_export[f"{p}_api_key"] = p_data.get("api_key") + llm_providers_export[f"{p}_model_name"] = p_data.get("model") + + def get_provider_export(section_prefs, fallback_provider, fallback_model, fallback_api_key, fallback_voice=None): + active_p = section_prefs.get("active_provider") + providers = section_prefs.get("providers", {}) + if active_p and active_p in providers: + p_data = providers[active_p] + return { + "provider": active_p, + "model_name": p_data.get("model"), + "voice_name": p_data.get("voice"), + "api_key": p_data.get("api_key") + } + # Fallback to system settings + return { + "provider": fallback_provider, + "model_name": fallback_model, + "voice_name": fallback_voice, + "api_key": fallback_api_key + } + + yaml_data = { + "application": { + "project_name": settings.PROJECT_NAME, + "version": settings.VERSION, + "log_level": settings.LOG_LEVEL + }, + "database": { + "mode": settings.DB_MODE, + }, + "llm_providers": llm_providers_export, + "embedding_provider": { + "provider": settings.EMBEDDING_PROVIDER, + "model_name": settings.EMBEDDING_MODEL_NAME, + "api_key": settings.EMBEDDING_API_KEY + }, + "tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME), + "stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY) + } + + # Filter out None values recursively + def remove_none(obj): + if isinstance(obj, dict): + return {k: remove_none(v) for k, v in obj.items() if v is not None} + return obj + + clean_yaml_data = remove_none(yaml_data) + yaml_str = yaml.dump(clean_yaml_data, sort_keys=False, default_flow_style=False) + + return PlainTextResponse( + content=yaml_str, + media_type="application/x-yaml", + headers={"Content-Disposition": "attachment; filename=\"config.yaml\""} + ) + + @router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML") + async def import_user_config_yaml( + file: UploadFile = File(...), + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Imports user configuration from a YAML file.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + content = await file.read() + try: + import yaml + data = yaml.safe_load(content) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid YAML file: {e}") + + # Reverse mapping: YAML -> UserPreferences structure + new_llm = { "providers": {}, "active_provider": None } + new_tts = { "providers": {}, "active_provider": None } + new_stt = { "providers": {}, "active_provider": None } + + # --- LLM --- + llm_data = data.get("llm_providers", {}) + if "providers" in llm_data: # Structured + new_llm["providers"] = llm_data["providers"] + else: # Flattened (as exported) + for k, v in llm_data.items(): + if k.endswith("_api_key"): + p = k.replace("_api_key", "") + if p not in new_llm["providers"]: new_llm["providers"][p] = {} + new_llm["providers"][p]["api_key"] = v + elif k.endswith("_model_name"): + p = k.replace("_model_name", "") + if p not in new_llm["providers"]: new_llm["providers"][p] = {} + new_llm["providers"][p]["model"] = v + + if new_llm["providers"]: + new_llm["active_provider"] = next(iter(new_llm["providers"]), None) + + # --- TTS --- + tts_data = data.get("tts_provider", {}) + if tts_data: + p = tts_data.get("provider") or "google_gemini" + new_tts["active_provider"] = p + new_tts["providers"][p] = { + "api_key": tts_data.get("api_key"), + "model": tts_data.get("model_name"), + "voice": tts_data.get("voice_name") + } + + # --- STT --- + stt_data = data.get("stt_provider", {}) + if stt_data: + p = stt_data.get("provider") or "google_gemini" + new_stt["active_provider"] = p + new_stt["providers"][p] = { + "api_key": stt_data.get("api_key"), + "model": stt_data.get("model_name") + } + + user.preferences = { "llm": new_llm, "tts": new_tts, "stt": new_stt } + from sqlalchemy.orm.attributes import flag_modified + flag_modified(user, "preferences") + + # --- Day 2 Sync --- + from app.config import settings as global_settings + if new_llm.get("providers"): + global_settings.LLM_PROVIDERS.update(new_llm["providers"]) + if new_tts.get("active_provider"): + p = new_tts["active_provider"] + p_data = new_tts["providers"].get(p, {}) + if p_data: + global_settings.TTS_PROVIDER = p + global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME + global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME + global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY + if new_stt.get("active_provider"): + p = new_stt["active_provider"] + p_data = new_stt["providers"].get(p, {}) + if p_data: + global_settings.STT_PROVIDER = p + global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME + global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY + + try: + global_settings.save_to_yaml() + except Exception as ey: + logger.error(f"Failed to sync settings to YAML on import: {ey}") + + db.add(user) + db.commit() + db.refresh(user) + return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {})) + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 8cd4770..630b178 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -17,6 +17,18 @@ is_logged_in: bool = Field(True, description="Indicates if the user is currently authenticated.") is_anonymous: bool = Field(False, description="Indicates if the user is an anonymous user.") +# --- General Schemas --- +class UserPreferences(BaseModel): + """Schema for user-specific LLM, TTS, STT preferences.""" + llm: dict = Field(default_factory=dict) + tts: dict = Field(default_factory=dict) + stt: dict = Field(default_factory=dict) + +class ConfigResponse(BaseModel): + """Schema for returning user preferences alongside effective settings.""" + preferences: UserPreferences + effective: dict = Field(default_factory=dict) + # --- Chat Schemas --- class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" @@ -71,7 +83,7 @@ class SessionCreate(BaseModel): """Defines the shape for starting a new conversation session.""" user_id: str - provider_name: Literal["deepseek", "gemini"] = "deepseek" + provider_name: str = "deepseek" feature_name: Optional[str] = "default" class Session(BaseModel): @@ -114,3 +126,18 @@ class STTResponse(BaseModel): """Defines the shape of a successful response from the /stt endpoint.""" transcript: str + +class VerifyProviderRequest(BaseModel): + provider_name: str + api_key: Optional[str] = None + model: Optional[str] = None + voice: Optional[str] = None + +class VerifyProviderResponse(BaseModel): + success: bool + message: str + +class ModelInfoResponse(BaseModel): + model_name: str + max_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index cad3153..230d127 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -2,7 +2,9 @@ from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List +import litellm import logging +logger = logging.getLogger(__name__) # Import centralized settings and other components from app.config import settings @@ -58,56 +60,69 @@ logging.basicConfig(level=settings.LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s') logging.getLogger("dspy").setLevel(logging.DEBUG) + + # Global settings for LiteLLM to handle provider-specific quirks + litellm.drop_params = True - # --- Initialize Core Services using settings --- + # --- Initialize Core Services defensively --- - # 1. Use the new, more flexible factory function to create the embedder instance - # This decouples the application from a specific embedding provider. - embedder = get_embedder_from_config( - provider=settings.EMBEDDING_PROVIDER, - dimension=settings.EMBEDDING_DIMENSION, - model_name=settings.EMBEDDING_MODEL_NAME, - api_key=settings.EMBEDDING_API_KEY - ) - - # 2. Initialize the FaissVectorStore with the chosen embedder - vector_store = FaissVectorStore( - index_file_path=settings.FAISS_INDEX_PATH, - dimension=settings.EMBEDDING_DIMENSION, - embedder=embedder - ) + # RAG Components are optional for now as requested + embedder = None + vector_store = None + retrievers = [] - # CRITICAL FIX: Assign the vector_store to the app state so it can be saved on shutdown. - app.state.vector_store = vector_store - - # 3. Create the FaissDBRetriever, regardless of the embedder type - retrievers: List[Retriever] = [ - FaissDBRetriever(vector_store=vector_store), - ] - - # --- New TTS Initialization --- - # 4. Get the concrete TTS provider from the factory - tts_provider = get_tts_provider( - provider_name=settings.TTS_PROVIDER, - api_key=settings.TTS_API_KEY, - model_name = settings.TTS_MODEL_NAME, - voice_name=settings.TTS_VOICE_NAME - ) + try: + # Resolve from config/settings + if settings.EMBEDDING_PROVIDER: + embedder = get_embedder_from_config( + provider=settings.EMBEDDING_PROVIDER, + dimension=settings.EMBEDDING_DIMENSION, + model_name=settings.EMBEDDING_MODEL_NAME, + api_key=settings.EMBEDDING_API_KEY + ) + + vector_store = FaissVectorStore( + index_file_path=settings.FAISS_INDEX_PATH, + dimension=settings.EMBEDDING_DIMENSION, + embedder=embedder + ) + app.state.vector_store = vector_store + retrievers.append(FaissDBRetriever(vector_store=vector_store)) + except Exception as e: + logger.error(f"Failed to initialize Vector Store: {e}. RAG functionality might be restricted.") - # 6. Get the concrete STT provider from the factory - stt_provider = get_stt_provider( - provider_name=settings.STT_PROVIDER, - api_key=settings.STT_API_KEY, - model_name=settings.STT_MODEL_NAME - ) + # Voice Providers (optional fallback) + tts_provider = None + stt_provider = None + try: + if settings.TTS_PROVIDER: + tts_provider = get_tts_provider( + provider_name=settings.TTS_PROVIDER, + api_key=settings.TTS_API_KEY, + model_name=settings.TTS_MODEL_NAME, + voice_name=settings.TTS_VOICE_NAME + ) + if settings.STT_PROVIDER: + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + except ValueError as e: + logger.info(f"TTS/STT will be initialized later via UI: {e}") + except Exception as e: + logger.warning(f"Failed to initialize TTS/STT: {e}") - # 9. Initialize the Service Container with all services - # This replaces the previous, redundant initialization + # 9. Initialize the Service Container with all initialized services services = ServiceContainer() services.with_rag_service(retrievers=retrievers) services.with_document_service(vector_store=vector_store) - services.with_service("stt_service",service=STTService(stt_provider=stt_provider)) - services.with_service("tts_service",service=TTSService(tts_provider=tts_provider)) + + if stt_provider: + services.with_service("stt_service", service=STTService(stt_provider=stt_provider)) + if tts_provider: + services.with_service("tts_service", service=TTSService(tts_provider=tts_provider)) + services.with_service("workspace_service", service=WorkspaceService()) services.with_service("session_service", service=SessionService()) services.with_service("user_service", service=UserService()) diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 390c990..c4fd19d 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -10,21 +10,6 @@ # --- 1. Define the Configuration Schema --- -class EmbeddingProvider(str, Enum): - """An enum for supported embedding providers.""" - GOOGLE_GEMINI = "google_gemini" - MOCK = "mock" - -class TTSProvider(str, Enum): - """An enum for supported Text-to-Speech (TTS) providers.""" - GOOGLE_GEMINI = "google_gemini" - GCLOUD_TTS = "gcloud_tts" # NEW: Add Google Cloud TTS as a supported provider - -class STTProvider(str, Enum): - """An enum for supported Speech-to-Text (STT) providers.""" - GOOGLE_GEMINI = "google_gemini" - OPENAI = "openai" - class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -35,26 +20,23 @@ url: Optional[str] = None local_path: str = "data/ai_hub.db" -class LLMProviderSettings(BaseModel): - deepseek_model_name: str = "deepseek-chat" - gemini_model_name: str = "gemini-1.5-flash-latest" +class ProviderSettings(BaseModel): + """Generic structure to hold any provider-specific config (api_key, model, etc.)""" + active_provider: Optional[str] = None + providers: dict[str, dict] = Field(default_factory=dict) + # Compatibility for top-level keys + api_key: Optional[SecretStr] = None + model_name: Optional[str] = None + voice_name: Optional[str] = None class EmbeddingProviderSettings(BaseModel): - provider: EmbeddingProvider = Field(default=EmbeddingProvider.GOOGLE_GEMINI) + provider: str = "google_gemini" model_name: str = "models/text-embedding-004" api_key: Optional[SecretStr] = None -class TTSProviderSettings(BaseModel): - provider: TTSProvider = Field(default=TTSProvider.GOOGLE_GEMINI) - # The default values are kept as originally requested - voice_name: str = "Kore" - model_name: str = "gemini-2.5-flash-preview-tts" - api_key: Optional[SecretStr] = None - -class STTProviderSettings(BaseModel): - provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) - model_name: str = "gemini-2.5-flash" - api_key: Optional[SecretStr] = None +class LLMProvidersSettings(BaseModel): + """Holds shared API keys and per-provider overrides.""" + providers: dict[str, dict] = Field(default_factory=dict) class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" @@ -64,11 +46,11 @@ """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) - llm_providers: LLMProviderSettings = Field(default_factory=LLMProviderSettings) + llm_providers: LLMProvidersSettings = Field(default_factory=LLMProvidersSettings) vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) - tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) - stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) + tts_provider: ProviderSettings = Field(default_factory=ProviderSettings) + stt_provider: ProviderSettings = Field(default_factory=ProviderSettings) # --- 2. Create the Final Settings Object --- @@ -121,19 +103,36 @@ else: self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" - # --- API Keys & Models --- - self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") - self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") - self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") - - self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name - self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + # --- Agnostic Provider Resolution --- + # We store everything in a flat map for the legacy settings getters, + # but also provide a dynamic map. - # --- Vector Store Settings --- + # 1. Resolve LLM Providers + self.LLM_PROVIDERS = config_from_pydantic.llm_providers.providers or {} + # Support legacy environment variables and merge them into the providers map + for env_key, env_val in os.environ.items(): + if env_key.endswith("_API_KEY") and not any(x in env_key for x in ["TTS", "STT", "EMBEDDING"]): + provider_id = env_key.replace("_API_KEY", "").lower() + if provider_id not in self.LLM_PROVIDERS: + self.LLM_PROVIDERS[provider_id] = {} + self.LLM_PROVIDERS[provider_id]["api_key"] = env_val + if env_key.endswith("_MODEL_NAME") and not any(x in env_key for x in ["TTS", "STT", "EMBEDDING"]): + provider_id = env_key.replace("_MODEL_NAME", "").lower() + if provider_id not in self.LLM_PROVIDERS: + self.LLM_PROVIDERS[provider_id] = {} + self.LLM_PROVIDERS[provider_id]["model"] = env_val + + # Explicit legacy fallback helpers (still useful for factory.py initial state) + self.DEEPSEEK_API_KEY = self.LLM_PROVIDERS.get("deepseek", {}).get("api_key") or os.getenv("DEEPSEEK_API_KEY") + self.GEMINI_API_KEY = self.LLM_PROVIDERS.get("gemini", {}).get("api_key") or os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY = self.LLM_PROVIDERS.get("openai", {}).get("api_key") or os.getenv("OPENAI_API_KEY") + + self.DEEPSEEK_MODEL_NAME = self.LLM_PROVIDERS.get("deepseek", {}).get("model") or \ + get_from_yaml(["llm_providers", "deepseek_model_name"]) or "deepseek-chat" + self.GEMINI_MODEL_NAME = self.LLM_PROVIDERS.get("gemini", {}).get("model") or \ + get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-1.5-flash-latest" + + # 2. Resolve Vector / Embedding self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ get_from_yaml(["vector_store", "index_path"]) or \ config_from_pydantic.vector_store.index_path @@ -142,69 +141,97 @@ config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) - # --- Embedding Provider Settings --- - embedding_provider_env = os.getenv("EMBEDDING_PROVIDER") - if embedding_provider_env: - embedding_provider_env = embedding_provider_env.lower() - - self.EMBEDDING_PROVIDER: EmbeddingProvider = EmbeddingProvider(embedding_provider_env or \ - get_from_yaml(["embedding_provider", "provider"]) or \ - config_from_pydantic.embedding_provider.provider) + self.EMBEDDING_PROVIDER: str = os.getenv("EMBEDDING_PROVIDER") or \ + get_from_yaml(["embedding_provider", "provider"]) or \ + config_from_pydantic.embedding_provider.provider self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ get_from_yaml(["embedding_provider", "model_name"]) or \ config_from_pydantic.embedding_provider.model_name - self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ get_from_yaml(["embedding_provider", "api_key"]) or \ self.GEMINI_API_KEY - # --- TTS Provider Settings --- - tts_provider_env = os.getenv("TTS_PROVIDER") - if tts_provider_env: - tts_provider_env = tts_provider_env.lower() - - self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + # 3. Resolve TTS (Agnostic) + self.TTS_PROVIDER: str = os.getenv("TTS_PROVIDER") or \ + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.active_provider or "google_gemini" + + # Legacy back-compat fields self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ - config_from_pydantic.tts_provider.voice_name - + config_from_pydantic.tts_provider.voice_name or "Kore" self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \ get_from_yaml(["tts_provider", "model_name"]) or \ - config_from_pydantic.tts_provider.model_name + config_from_pydantic.tts_provider.model_name or "gemini-2.5-flash-preview-tts" + self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ + get_from_yaml(["tts_provider", "api_key"]) or \ + self.GEMINI_API_KEY - # API Key logic for TTS - tts_api_key_env = os.getenv("TTS_API_KEY") or get_from_yaml(["tts_provider", "api_key"]) + # 4. Resolve STT (Agnostic) + self.STT_PROVIDER: str = os.getenv("STT_PROVIDER") or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.active_provider or "google_gemini" - if tts_api_key_env: - self.TTS_API_KEY: Optional[str] = tts_api_key_env - else: - # If no specific TTS key is set, use the Gemini key as a fallback - self.TTS_API_KEY: Optional[str] = self.GEMINI_API_KEY - - # --- STT Provider Settings --- - stt_provider_env = os.getenv("STT_PROVIDER") - if stt_provider_env: - stt_provider_env = stt_provider_env.lower() - - self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ - get_from_yaml(["stt_provider", "provider"]) or \ - config_from_pydantic.stt_provider.provider) self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ get_from_yaml(["stt_provider", "model_name"]) or \ - config_from_pydantic.stt_provider.model_name - - # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. - explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + config_from_pydantic.stt_provider.model_name or "gemini-2.5-flash" + self.STT_API_KEY: Optional[str] = os.getenv("STT_API_KEY") or \ + get_from_yaml(["stt_provider", "api_key"]) or \ + self.GEMINI_API_KEY - if explicit_stt_api_key: - self.STT_API_KEY: Optional[str] = explicit_stt_api_key - elif self.STT_PROVIDER == STTProvider.OPENAI: - self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY - else: - # Fallback for Google Gemini or other providers - self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + def save_to_yaml(self): + """Saves current settings back to config.yaml.""" + import yaml + config_path = os.getenv("CONFIG_PATH", "config.yaml") + + def get_val(v): + if hasattr(v, "get_secret_value"): + return v.get_secret_value() + return v + + # Build data dictionary by mapping current class attributes back to YAML structure + # This keeps the sync logic centralized in this class + data = { + "application": { + "project_name": self.PROJECT_NAME, + "version": self.VERSION, + "log_level": self.LOG_LEVEL + }, + "database": { + "mode": self.DB_MODE, + "local_path": self.DATABASE_URL.replace("sqlite:///./", "") if "sqlite" in self.DATABASE_URL else "data/ai_hub.db" + }, + "llm_providers": { + "providers": self.LLM_PROVIDERS + }, + "vector_store": { + "index_path": self.FAISS_INDEX_PATH, + "embedding_dimension": self.EMBEDDING_DIMENSION + }, + "embedding_provider": { + "provider": self.EMBEDDING_PROVIDER, + "model_name": self.EMBEDDING_MODEL_NAME, + "api_key": get_val(self.EMBEDDING_API_KEY) + }, + "tts_provider": { + "provider": self.TTS_PROVIDER, + "model_name": self.TTS_MODEL_NAME, + "voice_name": self.TTS_VOICE_NAME, + "api_key": get_val(self.TTS_API_KEY) + }, + "stt_provider": { + "provider": self.STT_PROVIDER, + "model_name": self.STT_MODEL_NAME, + "api_key": get_val(self.STT_API_KEY) + } + } + + # Ensure directories exist + os.makedirs(os.path.dirname(os.path.abspath(config_path)), exist_ok=True) + + with open(config_path, 'w') as f: + yaml.dump(data, f, sort_keys=False, default_flow_style=False) + print(f"🏠 Configuration synchronized to {config_path}") # Instantiate the single settings object for the application settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml deleted file mode 100644 index bacb1fb..0000000 --- a/ai-hub/app/config.yaml +++ /dev/null @@ -1,51 +0,0 @@ -# All non-key settings that can be checked into version control. -# API keys are still managed via environment variables for security. - -application: - # The log level for the application. Set to DEBUG for verbose output. - log_level: "INFO" - -database: - # The database mode. Set to "sqlite" for a local file, or "postgresql" - # for a remote server (requires DATABASE_URL to be set). - mode: "sqlite" - - # When using SQLite mode, specify the local database file path here. - # This path is relative to the project root and defaults to "./data/ai_hub.db". - local_path: "data/ai_hub.db" - -llm_providers: - # The default model name for the DeepSeek LLM provider. - deepseek_model_name: "deepseek-chat" - # The default model name for the Gemini LLM provider. - gemini_model_name: "gemini-2.0-flash" - -vector_store: - # The file path to save and load the FAISS index. - index_path: "data/faiss_index.bin" - # The dimension of the embedding vectors used by the FAISS index. - embedding_dimension: 768 - -embedding_provider: - # The provider for the embedding service. Can be "google_gemini" or "mock". - provider: "google_gemini" - # The model name for the embedding service. - model_name: "gemini-embedding-001" - -tts_provider: - # The provider for the TTS service. - # Check more at https://cloud.google.com/text-to-speech - provider: "gcloud_tts" - # The name of the voice to use for TTS. - voice_name: "en-US-Chirp3-HD-Achernar" - # The model name for the TTS service. - model_name: "gemini-2.5-pro-preview-tts" - -# The provider for the Speech-to-Text (STT) service. -stt_provider: - # The provider can be "google_gemini" or "openai". - provider: "google_gemini" - # The model name for the STT service. - # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). - # For "openai" this would be a Whisper model (e.g., "whisper-1"). - model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 7707338..94edb41 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,6 +1,8 @@ from app.config import settings from .base import TTSProvider, STTProvider -from .llm.general import GeneralProvider # Assuming GeneralProvider is now in this file or imported +from .llm.general import GeneralProvider +from .tts.general import GeneralTTSProvider +from .stt.general import GeneralSTTProvider from .tts.gemini import GeminiTTSProvider from .tts.gcloud_tts import GCloudTTSProvider from .stt.gemini import GoogleSTTProvider @@ -25,30 +27,83 @@ "gemini": settings.GEMINI_MODEL_NAME } +_tts_registry = { + "google_gemini": GeminiTTSProvider, + "gcloud_tts": GCloudTTSProvider +} + +_stt_registry = { + "google_gemini": GoogleSTTProvider +} + +def get_registered_tts_providers(): + return list(_tts_registry.keys()) + +def get_registered_stt_providers(): + return list(_stt_registry.keys()) + # --- 3. The Factory Functions --- -def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None) -> BaseLM: +def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> BaseLM: """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt.""" - providerKey = _llm_providers.get(provider_name) - if not providerKey: - raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") + providerKey = api_key_override or _llm_providers.get(provider_name) modelName = model_name - if modelName == "": + if not modelName: modelName = _llm_models.get(provider_name) if not modelName: - raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") + raise ValueError(f"No model name provided for '{provider_name}'.") - # Pass the optional system_prompt to the GeneralProvider constructor - return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key=providerKey, system_prompt=system_prompt) + full_model = f'{provider_name}/{modelName}' if '/' not in modelName else modelName + + # Pass the optional system_prompt and kwargs to the GeneralProvider constructor + return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs) -def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider: - if provider_name == "google_gemini": - return GeminiTTSProvider(api_key=api_key, model_name=model_name, voice_name=voice_name) - elif provider_name == "gcloud_tts": - return GCloudTTSProvider(api_key=api_key, voice_name=voice_name) - raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini', 'gcloud_tts']") +def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider: + def is_masked(k): + return not k or k in ("None", "none", "") or "*" in str(k) -def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: - if provider_name == "google_gemini": - return GoogleSTTProvider(api_key=api_key, model_name=model_name) - raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file + actual_key = api_key + if is_masked(actual_key): + if not is_masked(settings.TTS_API_KEY): + actual_key = settings.TTS_API_KEY + elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY): + actual_key = settings.GEMINI_API_KEY + elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")): + actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key") + + provider_cls = _tts_registry.get(provider_name) + if provider_cls: + if provider_name == "gcloud_tts": + return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs) + return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs) + + # Fallback to General LiteLLM implementation + full_model = model_name + if "/" not in full_model and provider_name not in ["google_gemini", "gcloud_tts"]: + full_model = f"{provider_name}/{model_name}" + + return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs) + +def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider: + def is_masked(k): + return not k or k in ("None", "none", "") or "*" in str(k) + + actual_key = api_key + if is_masked(actual_key): + if not is_masked(settings.STT_API_KEY): + actual_key = settings.STT_API_KEY + elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY): + actual_key = settings.GEMINI_API_KEY + elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")): + actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key") + + provider_cls = _stt_registry.get(provider_name) + if provider_cls: + return provider_cls(api_key=actual_key, model_name=model_name, **kwargs) + + # Fallback to General LiteLLM implementation + full_model = model_name + if "/" not in full_model and provider_name not in ["google_gemini"]: + full_model = f"{provider_name}/{model_name}" + + return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs) \ No newline at end of file diff --git a/ai-hub/app/core/providers/llm/general.py b/ai-hub/app/core/providers/llm/general.py index 565f67c..c195c04 100644 --- a/ai-hub/app/core/providers/llm/general.py +++ b/ai-hub/app/core/providers/llm/general.py @@ -6,10 +6,16 @@ self.model_name = model_name self.api_key = api_key self.system_prompt = system_prompt - # Call the parent constructor + + # Determine max tokens dynamically via LiteLLM info max_tokens = 8000 - if model_name.startswith("gemini"): - max_tokens = 10000000 + try: + info = litellm.get_model_info(model_name) + if info and "max_tokens" in info: + max_tokens = info["max_tokens"] + except: + pass + super().__init__(model=model_name, max_tokens=max_tokens, **kwargs) def _prepare_messages(self, prompt=None, messages=None): diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py index 0bc57d6..9ca1bc7 100644 --- a/ai-hub/app/core/providers/stt/gemini.py +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -1,8 +1,7 @@ import os import aiohttp -import asyncio +import base64 import logging -import mimetypes from typing import Optional from fastapi import HTTPException from app.core.providers.base import STTProvider @@ -10,116 +9,107 @@ # Configure logging logger = logging.getLogger(__name__) + class GoogleSTTProvider(STTProvider): - """Concrete STT provider for Google Gemini API.""" + """Concrete STT provider for Google Gemini API using inline audio data.""" - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "gemini-2.5-flash" - ): - self.api_key = api_key or os.getenv("GEMINI_API_KEY") + def __init__(self, api_key: Optional[str] = None, model_name: str = 'gemini-1.5-flash', **kwargs): + self.api_key = api_key or os.getenv('GEMINI_API_KEY') if not self.api_key: - raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + raise ValueError('GEMINI_API_KEY environment variable not set or provided.') - self.model_name = model_name - self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" - self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + clean_model = model_name or 'gemini-1.5-flash' + model_id = clean_model.split('/')[-1] + self.model_name = model_id - logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + # Use v1beta — the only endpoint that supports audio inline_data with Gemini 2.x + self.api_url = ( + f'https://generativelanguage.googleapis.com/v1beta/models/' + f'{model_id}:generateContent?key={self.api_key}' + ) + + logger.debug(f"Initialized GoogleSTTProvider: model={self.model_name}") + + def _detect_mime(self, data: bytes) -> str: + """Sniff the audio byte signature to determine the real MIME type.""" + if data[:4] == b'RIFF': + return 'audio/wav' + elif data[:4] == b'\x1aE\xdf\xa3': + return 'audio/webm' + elif data[:3] == b'ID3' or (len(data) > 1 and data[:2] == b'\xff\xfb'): + return 'audio/mpeg' + elif data[:4] == b'OggS': + return 'audio/ogg' + elif len(data) > 8 and data[4:8] == b'ftyp': + return 'audio/mp4' + elif len(data) > 1 and data[:2] == b'\x1a\x45': + return 'audio/webm' + # Default: browsers record as webm + return 'audio/webm' async def transcribe_audio(self, audio_data: bytes) -> str: + """Transcribes audio using Gemini's inline_data approach (no Files API needed).""" logger.debug("Starting transcription process.") - mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" - num_bytes = len(audio_data) - logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + mime_type = self._detect_mime(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {len(audio_data)} bytes.") - try: - async with aiohttp.ClientSession() as session: - # Step 1: Start resumable upload - logger.debug("Starting resumable upload...") - start_headers = { - "x-goog-api-key": self.api_key, - "X-Goog-Upload-Protocol": "resumable", - "X-Goog-Upload-Command": "start", - "X-Goog-Upload-Header-Content-Length": str(num_bytes), - "X-Goog-Upload-Header-Content-Type": mime_type, - "Content-Type": "application/json", - } - start_payload = {"file": {"display_name": "AUDIO"}} + # Encode audio as base64 for inline submission + audio_b64 = base64.b64encode(audio_data).decode('utf-8') - async with session.post( - self.upload_url_base, - headers=start_headers, - json=start_payload - ) as resp: - logger.debug(f"Upload start response status: {resp.status}") - resp.raise_for_status() - upload_url = resp.headers.get("X-Goog-Upload-URL") - if not upload_url: - raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") - logger.debug(f"Received upload URL: {upload_url}") - - # Step 2: Upload the file - logger.debug("Uploading audio file...") - upload_headers = { - "Content-Length": str(num_bytes), - "X-Goog-Upload-Offset": "0", - "X-Goog-Upload-Command": "upload, finalize", - } - async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: - logger.debug(f"File upload response status: {resp.status}") - resp.raise_for_status() - file_info = await resp.json() - - file_name = file_info["file"]["name"].split("/")[-1] - file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" - logger.debug(f"Uploaded file URI: {file_uri}") - - # Step 3: Request transcription - logger.debug("Requesting transcription from Gemini API...") - transcription_headers = { - "x-goog-api-key": self.api_key, - "Content-Type": "application/json", - } - transcription_payload = { - "contents": [ + payload = { + "contents": [ + { + "role": "user", + "parts": [ { - "parts": [ - { - "fileData": { - "mimeType": mime_type, - "fileUri": file_uri - } - }, - {"text": "Transcribe this audio file."} - ] - } + "inline_data": { + "mime_type": mime_type, + "data": audio_b64 + } + }, + {"text": "Transcribe this audio. Return only the spoken words, nothing else."} ] } + ] + } - async with session.post( - self.api_url, - headers=transcription_headers, - json=transcription_payload - ) as resp: - logger.debug(f"Transcription request status: {resp.status}") - resp.raise_for_status() - data = await resp.json() + headers = {"Content-Type": "application/json"} + logger.debug(f"Sending inline audio to: {self.api_url}") - # Step 4: Extract text + try: + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(self.api_url, headers=headers, json=payload) as response: + logger.debug(f"Transcription response status: {response.status}") + if not response.ok: + body = await response.text() + logger.error(f"STT API error {response.status}: {body}") + raise HTTPException( + status_code=500, + detail=f"API failed ({response.status}): {body[:300]}" + ) + data = await response.json() + try: - transcript = data["candidates"][0]["content"]["parts"][0]["text"] - logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") - return transcript + candidate = data["candidates"][0] + parts = candidate.get("content", {}).get("parts", []) + if not parts: + # Gemini returns no parts for silent/empty audio - that's fine + logger.debug("Gemini returned no transcript parts (likely silence).") + return "" + transcript = parts[0].get("text", "") + logger.debug(f"Transcript: '{transcript[:80]}'") + return transcript.strip() except (KeyError, IndexError) as e: - logger.error(f"Malformed API response: {e}. Full response: {data}") - raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + logger.error(f"Malformed API response: {e}. Full: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini STT.") except aiohttp.ClientError as e: - logger.error(f"Aiohttp client error occurred: {e}") + logger.error(f"Network error during STT: {e}") raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except HTTPException: + raise except Exception as e: - logger.error(f"Unexpected error occurred during transcription: {e}") + logger.error(f"Unexpected STT error: {e}") raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/providers/stt/general.py b/ai-hub/app/core/providers/stt/general.py index 3c1c556..c05b4fc 100644 --- a/ai-hub/app/core/providers/stt/general.py +++ b/ai-hub/app/core/providers/stt/general.py @@ -1,62 +1,30 @@ -import os import litellm -import logging import io -from typing import Optional -from fastapi import HTTPException from app.core.providers.base import STTProvider -# Configure logging -logger = logging.getLogger(__name__) - class GeneralSTTProvider(STTProvider): - """Concrete General STT provider using litellm for Whisper transcription.""" - - def __init__( - self, - api_key: str, - model_name: str = "" - ): - if not api_key: - raise ValueError("API_KEY for general STT provider not set or provided.") - if not model_name: - raise ValueError("model_name for general STT provider not set or provided") - self.api_key = api_key + """General Speech-to-Text provider using LiteLLM.""" + def __init__(self, model_name: str, api_key: str, **kwargs): self.model_name = model_name - - logger.debug(f"Initialized GeneralSTTProvider with model: {self.model_name}") + self.api_key = api_key + self.kwargs = kwargs async def transcribe_audio(self, audio_data: bytes) -> str: - """ - Transcribes audio using the litellm Whisper transcription endpoint. - """ - logger.debug("Starting transcription process using litellm.transcription().") - + """Transcribes audio data using LiteLLM atranscription.""" try: - # Wrap audio bytes in a BytesIO object to mimic a file - audio_file = io.BytesIO(audio_data) - audio_file.name = "input.wav" # Required by some clients (like Whisper) - - # Call litellm.transcription (sync function, use thread executor) - import asyncio - loop = asyncio.get_event_loop() - response = await loop.run_in_executor( - None, - lambda: litellm.transcription(model=self.model_name, file=audio_file, api_key=self.api_key) + # We must pass file-like object for the LiteLLM (OpenAI) underlying handler + buffer = io.BytesIO(audio_data) + buffer.name = "audio.wav" + + response = await litellm.atranscription( + model=self.model_name, + file=buffer, + api_key=self.api_key, + **self.kwargs ) - - # Extract text - transcript = response.get("text", "") - logger.debug(f"Transcription succeeded. Text: '{transcript[:50]}...'") - return transcript - - except litellm.exceptions.AuthenticationError as e: - logger.error(f"LiteLLM authentication error: {e.message}") - raise HTTPException(status_code=401, detail="Authentication failed: Invalid API key.") - except litellm.exceptions.APIError as e: - logger.error(f"LiteLLM API error occurred: {e}") - status_code = getattr(e, "status_code", 500) - raise HTTPException(status_code=status_code, detail=f"API request failed: {e.message}") + + # Response object has 'text' for generic transcription + return getattr(response, "text", "") + except Exception as e: - logger.error(f"Unexpected error during transcription: {e}", exc_info=True) - raise HTTPException(status_code=500, detail="Failed to transcribe audio due to an unexpected error.") + raise RuntimeError(f"Failed to transcribe audio with LiteLLM for model '{self.model_name}': {e}") diff --git a/ai-hub/app/core/providers/tts/gcloud_tts.py b/ai-hub/app/core/providers/tts/gcloud_tts.py index 5b24bbb..a87f9fc 100644 --- a/ai-hub/app/core/providers/tts/gcloud_tts.py +++ b/ai-hub/app/core/providers/tts/gcloud_tts.py @@ -1,11 +1,10 @@ import os -import aiohttp +import httpx import asyncio import base64 import logging from typing import AsyncGenerator from app.core.providers.base import TTSProvider -from aiohttp import ClientResponseError from fastapi import HTTPException # Configure logging @@ -16,44 +15,34 @@ class GCloudTTSProvider(TTSProvider): # English voices # English voices + # English voices (Studio and Neural2 are the highest quality available in standard Cloud TTS) AVAILABLE_VOICES_EN = [ - "en-US-Chirp3-HD-Achernar", "en-US-Chirp3-HD-Achird", "en-US-Chirp3-HD-Algenib", - "en-US-Chirp3-HD-Algieba", "en-US-Chirp3-HD-Alnilam", "en-US-Chirp3-HD-Aoede", - "en-US-Chirp3-HD-Autonoe", "en-US-Chirp3-HD-Callirrhoe", "en-US-Chirp3-HD-Charon", - "en-US-Chirp3-HD-Despina", "en-US-Chirp3-HD-Enceladus", "en-US-Chirp3-HD-Erinome", - "en-US-Chirp3-HD-Fenrir", "en-US-Chirp3-HD-Gacrux", "en-US-Chirp3-HD-Iapetus", - "en-US-Chirp3-HD-Kore", "en-US-Chirp3-HD-Laomedeia", "en-US-Chirp3-HD-Leda", - "en-US-Chirp3-HD-Orus", "en-US-Chirp3-HD-Puck", "en-US-Chirp3-HD-Pulcherrima", - "en-US-Chirp3-HD-Rasalgethi", "en-US-Chirp3-HD-Sadachbia", "en-US-Chirp3-HD-Sadaltager", - "en-US-Chirp3-HD-Schedar", "en-US-Chirp3-HD-Sulafat", "en-US-Chirp3-HD-Umbriel", - "en-US-Chirp3-HD-Vindemiatrix", "en-US-Chirp3-HD-Zephyr", "en-US-Chirp3-HD-Zubenelgenubi" + "en-US-Studio-A", "en-US-Studio-O", "en-US-Neural2-A", + "en-US-Neural2-C", "en-US-Neural2-D", "en-US-Neural2-E", + "en-US-Neural2-F", "en-US-Neural2-G", "en-US-Neural2-H", + "en-US-Neural2-I", "en-US-Neural2-J", "en-US-Wavenet-A", + "en-US-Wavenet-B", "en-US-Wavenet-C", "en-US-Wavenet-D", + "en-US-Wavenet-E", "en-US-Wavenet-F", "en-US-Wavenet-G", + "en-US-Wavenet-H", "en-US-Wavenet-I", "en-US-Wavenet-J" ] - DEFAULT_VOICE_EN = "en-US-Chirp3-HD-Kore" + DEFAULT_VOICE_EN = "en-US-Studio-O" # Chinese voices AVAILABLE_VOICES_CMN = [ - "cmn-CN-Chirp3-HD-Achernar", "cmn-CN-Chirp3-HD-Achird", "cmn-CN-Chirp3-HD-Algenib", - "cmn-CN-Chirp3-HD-Algieba", "cmn-CN-Chirp3-HD-Alnilam", "cmn-CN-Chirp3-HD-Aoede", - "cmn-CN-Chirp3-HD-Autonoe", "cmn-CN-Chirp3-HD-Callirrhoe", "cmn-CN-Chirp3-HD-Charon", - "cmn-CN-Chirp3-HD-Despina", "cmn-CN-Chirp3-HD-Enceladus", "cmn-CN-Chirp3-HD-Erinome", - "cmn-CN-Chirp3-HD-Fenrir", "cmn-CN-Chirp3-HD-Gacrux", "cmn-CN-Chirp3-HD-Iapetus", - "cmn-CN-Chirp3-HD-Kore", "cmn-CN-Chirp3-HD-Laomedeia", "cmn-CN-Chirp3-HD-Leda", - "cmn-CN-Chirp3-HD-Orus", "cmn-CN-Chirp3-HD-Puck", "cmn-CN-Chirp3-HD-Pulcherrima", - "cmn-CN-Chirp3-HD-Rasalgethi", "cmn-CN-Chirp3-HD-Sadachbia", "cmn-CN-Chirp3-HD-Sadaltager", - "cmn-CN-Chirp3-HD-Schedar", "cmn-CN-Chirp3-HD-Sulafat", "cmn-CN-Chirp3-HD-Umbriel", - "cmn-CN-Chirp3-HD-Vindemiatrix", "cmn-CN-Chirp3-HD-Zephyr", "cmn-CN-Chirp3-HD-Zubenelgenubi" + "cmn-CN-Wavenet-A", "cmn-CN-Wavenet-B", "cmn-CN-Wavenet-C", + "cmn-CN-Wavenet-D", "cmn-CN-Standard-A", "cmn-CN-Standard-B", + "cmn-CN-Standard-C", "cmn-CN-Standard-D" ] - DEFAULT_VOICE_CMN = "cmn-CN-Chirp3-HD-Achernar" + DEFAULT_VOICE_CMN = "cmn-CN-Wavenet-A" - def __init__(self, api_key: str, voice_name: str = DEFAULT_VOICE_EN): - all_voices = self.AVAILABLE_VOICES_EN + self.AVAILABLE_VOICES_CMN - if voice_name not in all_voices: - raise ValueError(f"Invalid voice name: {voice_name}. Choose from {all_voices}") - + def __init__(self, api_key: str, model_name: str = "", voice_name: str = DEFAULT_VOICE_EN, **kwargs): self.api_key = api_key + # For typical TTS, it's just text:synthesize. Some new models might use different endpoints, but let's stick to standard. self.api_url = f"https://texttospeech.googleapis.com/v1/text:synthesize?key={self.api_key}" - self.voice_name = voice_name - logger.debug(f"Initialized GCloudTTSProvider with voice: {self.voice_name}") + self.voice_name = voice_name or self.DEFAULT_VOICE_EN + self.model_name = model_name + logger.debug(f"Initialized GCloudTTSProvider with voice: {self.voice_name}, model: {self.model_name}") + def _detect_language(self, text: str) -> str: # Simple heuristic: count Chinese characters vs. total chars @@ -63,22 +52,21 @@ return "en-US" async def generate_speech(self, text: str) -> bytes: - language = self._detect_language(text) - logger.debug(f"Detected language '{language}' for text: '{text[:50]}...'") + voice_to_use = self.voice_name - if language == "cmn-CN": - valid_voices = self.AVAILABLE_VOICES_CMN - default_voice = self.DEFAULT_VOICE_CMN + # Extract languageCode directly from voice name (e.g., 'en-US-Journey-F' -> 'en-US') + # Google API requires languageCode to exactly match the voice's language prefix + parts = voice_to_use.split("-") + if len(parts) >= 2: + language = f"{parts[0]}-{parts[1]}" else: - language = "en-US" - valid_voices = self.AVAILABLE_VOICES_EN - default_voice = self.DEFAULT_VOICE_EN + language = self._detect_language(text) - if self.voice_name not in valid_voices: - logger.warning(f"Voice '{self.voice_name}' not compatible with language '{language}'. Using default voice '{default_voice}'.") - voice_to_use = default_voice - else: - voice_to_use = self.voice_name + logger.debug(f"Using language '{language}' for voice: '{voice_to_use}'") + + # By default, use the user's voice_name as-is without restricting to the hardcoded list. + # This allows testing new models like 'chirp-3.0-generate-001' or 'gemini-2.5-flash-tts'. + voice_to_use = self.voice_name headers = { "Content-Type": "application/json" @@ -93,37 +81,46 @@ "audioEncoding": "LINEAR16" } } + + # Some Google Cloud STT/TTS models require the model parameter in the request body for specific endpoints, + # but for standard text:synthesize just using the name is usually enough. + # If model_name is provided, let's include it in the voice config if it acts as a model override, + # or at the top level (as seen in some v1beta documentation). + # We'll just attach it to voice config if present. + if self.model_name and self.model_name not in voice_to_use: + json_data["voice"]["name"] = self.model_name # Often the voice name IS the model name for newer ones + logger.debug(f"API Request URL: {self.api_url}") logger.debug(f"Request Payload: {json_data}") try: - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - logger.debug(f"Received API response with status code: {response.status}") - response.raise_for_status() + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(self.api_url, headers=headers, json=json_data) + + if response.status_code != 200: + logger.error(f"GCloud TTS API error ({response.status_code}): {response.text}") + if response.status_code == 429: + raise HTTPException(status_code=429, detail="Rate limit exceeded. Please try again later.") + raise HTTPException(status_code=500, detail=f"API request failed ({response.status_code}): {response.text}") - response_json = await response.json() - logger.debug("Successfully parsed API response JSON.") + response_json = response.json() + logger.debug("Successfully parsed API response JSON.") - audio_base64 = response_json.get('audioContent') - if not audio_base64: - raise KeyError("audioContent key not found in the response.") + audio_base64 = response_json.get('audioContent') + if not audio_base64: + raise KeyError("audioContent key not found in the response.") - audio_bytes = base64.b64decode(audio_base64) - logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") + audio_bytes = base64.b64decode(audio_base64) + logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") - return audio_bytes - except ClientResponseError as e: - if e.status == 429: - logger.error("Rate limit exceeded on Cloud TTS API.") - raise HTTPException(status_code=429, detail="Rate limit exceeded. Please try again later.") - else: - logger.error(f"Aiohttp client error occurred: {e}") - raise HTTPException(status_code=500, detail=f"API request failed: {e}") - except KeyError as e: - logger.error(f"Key error in API response: {e}. Full response: {await response.json()}") - raise HTTPException(status_code=500, detail="Malformed API response from Cloud TTS.") + return audio_bytes + + except HTTPException: + raise + except httpx.RequestError as e: + logger.error(f"Network error communicating with Google Cloud TTS: {e}") + raise HTTPException(status_code=500, detail=f"Network error: {e}") except Exception as e: logger.error(f"An unexpected error occurred during speech generation: {e}") raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") diff --git a/ai-hub/app/core/providers/tts/gemini.py b/ai-hub/app/core/providers/tts/gemini.py index 89d59e1..a59f644 100644 --- a/ai-hub/app/core/providers/tts/gemini.py +++ b/ai-hub/app/core/providers/tts/gemini.py @@ -1,53 +1,73 @@ import os -import aiohttp -import asyncio +import json +import httpx import base64 import logging -from typing import AsyncGenerator from app.core.providers.base import TTSProvider -from aiohttp import ClientResponseError from fastapi import HTTPException # Configure logging logger = logging.getLogger(__name__) -# New concrete class for Gemini TTS with the corrected voice list + class GeminiTTSProvider(TTSProvider): - # Class attribute with the corrected list of available voices + """TTS provider using Gemini's audio responseModalities via Google AI Studio.""" + AVAILABLE_VOICES = [ - "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", - "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus", - "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome", - "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam", - "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi", + "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", + "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus", + "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome", + "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam", + "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi", "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat" ] - def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"): - if voice_name not in self.AVAILABLE_VOICES: - raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}") - + def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash-preview-tts", + voice_name: str = "Kore", **kwargs): + raw_model = model_name or "gemini-2.5-flash-preview-tts" + # Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id + model_id = raw_model.split("/")[-1] + # Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts" + if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts"): + model_id = "gemini-2.5-flash-preview-tts" + logger.info(f"Normalised model name to: {model_id}") + + # Route to Vertex AI ONLY when the key is a Vertex service-account key (starting with "AQ.") + # AI Studio keys start with "AIza" and must use the generativelanguage endpoint. + is_vertex_key = bool(api_key) and api_key.startswith("AQ.") + + if is_vertex_key: + self.api_url = ( + f"https://us-central1-aiplatform.googleapis.com/v1/publishers/google/" + f"models/{model_id}:streamGenerateContent" + ) + self.is_vertex = True + else: + # Google AI Studio — v1beta is required for audio responseModalities + self.api_url = ( + f"https://generativelanguage.googleapis.com/v1beta/models/" + f"{model_id}:generateContent?key={api_key}" + ) + self.is_vertex = False + self.api_key = api_key - # The API URL is now a f-string that includes the configurable model name - self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" self.voice_name = voice_name - self.model_name = model_name - logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}") + self.model_name = model_id + logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}") + logger.debug(f" endpoint: {self.api_url[:80]}...") async def generate_speech(self, text: str) -> bytes: - logger.debug(f"Starting speech generation for text: '{text[:50]}...'") - - headers = { - "x-goog-api-key": self.api_key, - "Content-Type": "application/json" - } + logger.debug(f"TTS generate_speech: '{text[:60]}...'") + + headers = {"Content-Type": "application/json"} + + # The dedicated TTS models require a system instruction to produce only audio json_data = { - "contents": [{ - "parts": [{ - "text": text - }] - }], + "system_instruction": { + "parts": [{"text": "You are a text-to-speech system. Convert the user text to speech audio only. Do not generate any text response."}] + }, + "contents": [{"role": "user", "parts": [{"text": text}]}], "generationConfig": { "responseModalities": ["AUDIO"], "speechConfig": { @@ -57,41 +77,58 @@ } } } - }, - # The model is now configurable via the instance variable - "model": self.model_name + } } - - logger.debug(f"API Request URL: {self.api_url}") - logger.debug(f"Request Headers: {headers}") - logger.debug(f"Request Payload: {json_data}") + + if not self.is_vertex: + headers["x-goog-api-key"] = self.api_key + + logger.debug(f"Calling: {self.api_url}") try: - async with aiohttp.ClientSession() as session: - async with session.post(self.api_url, headers=headers, json=json_data) as response: - logger.debug(f"Received API response with status code: {response.status}") - response.raise_for_status() - - response_json = await response.json() - logger.debug("Successfully parsed API response JSON.") - - inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data'] - logger.debug("Successfully extracted audio data from JSON response.") - - audio_bytes = base64.b64decode(inline_data) - logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.") - - return audio_bytes - except ClientResponseError as e: - if e.status == 429: - logger.error("Rate limit exceeded on Gemini TTS API.") - raise HTTPException(status_code=429, detail="Rate limit exceeded. Please try again later.") - else: - logger.error(f"Aiohttp client error occurred: {e}") - raise HTTPException(status_code=500, detail=f"API request failed: {e}") - except KeyError as e: - logger.error(f"Key error in API response: {e}. Full response: {response_json}") - raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.post(self.api_url, headers=headers, json=json_data) + + logger.debug(f"Response status: {response.status_code}") + + if response.status_code != 200: + body = response.text + logger.error(f"TTS API error {response.status_code}: {body[:300]}") + try: + err = response.json().get("error", {}) + msg = err.get("message", body[:200]) + except Exception: + msg = body[:200] + raise HTTPException(status_code=response.status_code, detail=f"Gemini TTS error: {msg}") + + resp_data = response.json() + audio_fragments = [] + + # Handle both list (streamGenerateContent) and single object (generateContent) + segments = resp_data if isinstance(resp_data, list) else [resp_data] + for segment in segments: + candidates = segment.get("candidates", []) + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + for part in parts: + inline = part.get("inlineData", {}) + data = inline.get("data") + if data: + audio_fragments.append(base64.b64decode(data)) + + if not audio_fragments: + logger.error(f"No audio in response. Full response: {json.dumps(resp_data)[:500]}") + raise HTTPException(status_code=500, detail="No audio data in Gemini TTS response.") + + result = b"".join(audio_fragments) + logger.debug(f"TTS returned {len(result)} PCM bytes") + return result + + except HTTPException: + raise + except httpx.TimeoutException: + logger.error("Gemini TTS request timed out after 30s") + raise HTTPException(status_code=504, detail="Gemini TTS request timed out.") except Exception as e: - logger.error(f"An unexpected error occurred during speech generation: {e}") + logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}") raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}") \ No newline at end of file diff --git a/ai-hub/app/core/providers/tts/general.py b/ai-hub/app/core/providers/tts/general.py new file mode 100644 index 0000000..43d4d2c --- /dev/null +++ b/ai-hub/app/core/providers/tts/general.py @@ -0,0 +1,43 @@ +import litellm +import asyncio +from typing import AsyncGenerator +from app.core.providers.base import TTSProvider + +class GeneralTTSProvider(TTSProvider): + """General Text-to-Speech provider using LiteLLM.""" + def __init__(self, model_name: str, api_key: str, voice_name: str = "alloy", **kwargs): + self.model_name = model_name + self.api_key = api_key + self.voice_name = voice_name + self.kwargs = kwargs + + async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]: + """Generates speech using LiteLLM aspeech.""" + try: + # Note: litellm.aspeech returns a response object with 'stream' or 'content' + # Depending on underlying provider (OpenAI supports streaming) + response = await litellm.aspeech( + model=self.model_name, + input=text, + voice=self.voice_name, + api_key=self.api_key, + drop_params=True, + **self.kwargs + ) + + # LiteLLM's implementation often returns raw bytes for simple TTS + # or an Httpx response that we can iterate over + if hasattr(response, "content"): + yield response.content + elif hasattr(response, "iter_bytes"): + async for chunk in response.iter_bytes(): + yield chunk + elif isinstance(response, bytes): + yield response + else: + # Fallback to direct attribute + content = getattr(response, "content", b"") + yield content + + except Exception as e: + raise RuntimeError(f"Failed to generate speech with LiteLLM for model '{self.model_name}': {e}") diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 49f70c3..9d88bd1 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -52,8 +52,23 @@ db.commit() - # Get the appropriate LLM provider - llm_provider = get_llm_provider(provider_name) + # Fetch user preferences for overrides + api_key_override = None + model_name_override = "" + user = session.user + if user and user.preferences: + llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {}) + api_key_override = llm_prefs.get("api_key") + model_name_override = llm_prefs.get("model", "") + + # Get the appropriate LLM provider with all extra prefs passed as kwargs + kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} + llm_provider = get_llm_provider( + provider_name, + model_name=model_name_override, + api_key_override=api_key_override, + **kwargs + ) # Configure retrievers for the pipeline context_chunks = [] diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py index e1d558b..7325f2d 100644 --- a/ai-hub/app/core/services/stt.py +++ b/ai-hub/app/core/services/stt.py @@ -14,12 +14,13 @@ """ Initializes the STTService with a concrete STT provider. """ - self.stt_provider = stt_provider + self.default_stt_provider = stt_provider - async def transcribe(self, audio_bytes: bytes) -> str: + async def transcribe(self, audio_bytes: bytes, provider_override: STTProvider = None) -> str: """ Transcribes the provided audio bytes into text using the STT provider. """ + provider = provider_override or self.default_stt_provider logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") if not audio_bytes: @@ -27,12 +28,9 @@ raise HTTPException(status_code=400, detail="No audio data provided.") try: - transcript = await self.stt_provider.transcribe_audio(audio_bytes) - if not transcript: - logger.warning("STT provider returned an empty transcript.") - raise HTTPException(status_code=500, detail="Failed to transcribe audio.") - - logger.info(f"Successfully transcribed audio. Transcript length: {len(transcript)} characters.") + transcript = await provider.transcribe_audio(audio_bytes) + # Note: empty transcript is valid (e.g. silent audio), we let the caller decide + logger.info(f"Transcribed audio. Length: {len(transcript)} characters.") return transcript except HTTPException: diff --git a/ai-hub/app/core/services/tts.py b/ai-hub/app/core/services/tts.py index d658e6c..015b1c4 100644 --- a/ai-hub/app/core/services/tts.py +++ b/ai-hub/app/core/services/tts.py @@ -34,17 +34,21 @@ MAX_CHUNK_SIZE = int(os.getenv("TTS_MAX_CHUNK_SIZE", 600)) def __init__(self, tts_provider: TTSProvider): - self.tts_provider = tts_provider + self.default_tts_provider = tts_provider async def _split_text_into_chunks(self, text: str) -> list[str]: chunks = [] current_chunk = "" - separators = ['.', '?', '!', '\n'] + # Adding Chinese punctuation marks: 。 ? ! , + separators = ['.', '?', '!', '\n', '。', '?', '!', ','] sentences = [] + # Split text by separators, keeping the separator with the sentence + temp_text = text for separator in separators: - text = text.replace(separator, f"{separator}|") - sentences_with_empty = [s.strip() for s in text.split('|') if s.strip()] + temp_text = temp_text.replace(separator, f"{separator}|") + + sentences_with_empty = [s.strip() for s in temp_text.split('|') if s.strip()] for sentence in sentences_with_empty: sentences.append(sentence) @@ -62,31 +66,37 @@ logger.debug(f"Split text into {len(chunks)} chunks.") return chunks - async def _generate_pcm_chunks(self, text: str) -> AsyncGenerator[bytes, None]: + async def _generate_pcm_chunks(self, text: str, provider_override: TTSProvider = None) -> AsyncGenerator[bytes, None]: chunks = await self._split_text_into_chunks(text) + provider = provider_override or self.default_tts_provider for i, chunk in enumerate(chunks): logger.info(f"Generating PCM for chunk {i+1}/{len(chunks)}: '{chunk[:30]}...'") try: - pcm_data = await self.tts_provider.generate_speech(chunk) - yield pcm_data + pcm_data = await provider.generate_speech(chunk) + if pcm_data: + yield pcm_data + except HTTPException as e: + logger.error(f"HTTPException on chunk {i+1}: {e.detail}") + raise # Re-raise so StreamingResponse returns 500 cleanly except Exception as e: - logger.error(f"Error processing chunk {i+1}: {e}") + logger.error(f"Unexpected error on chunk {i+1}: {e}") raise HTTPException( status_code=500, detail=f"Error generating speech for chunk {i+1}: {e}" ) from e - async def create_speech_stream(self, text: str, as_wav: bool = True) -> AsyncGenerator[bytes, None]: - async for pcm_data in self._generate_pcm_chunks(text): + async def create_speech_stream(self, text: str, as_wav: bool = True, provider_override: TTSProvider = None) -> AsyncGenerator[bytes, None]: + async for pcm_data in self._generate_pcm_chunks(text, provider_override=provider_override): if as_wav: yield _create_wav_file(pcm_data) else: yield pcm_data - async def create_speech_non_stream(self, text: str) -> bytes: + async def create_speech_non_stream(self, text: str, provider_override: TTSProvider = None) -> bytes: chunks = await self._split_text_into_chunks(text) semaphore = asyncio.Semaphore(3) # Limit concurrency to 3 requests + provider = provider_override or self.default_tts_provider async def generate_with_limit(chunk): retries = 3 @@ -94,7 +104,7 @@ async with semaphore: for attempt in range(retries): try: - return await self.tts_provider.generate_speech(chunk) + return await provider.generate_speech(chunk) except HTTPException as e: if e.status_code == 429: logger.warning(f"429 Too Many Requests for chunk, retrying in {delay}s (attempt {attempt+1}/{retries})...") diff --git a/ai-hub/app/core/services/workspace.py b/ai-hub/app/core/services/workspace.py index c14e1d5..3ced008 100644 --- a/ai-hub/app/core/services/workspace.py +++ b/ai-hub/app/core/services/workspace.py @@ -767,7 +767,21 @@ file_request = await self._get_or_create_file_request(session_id, path, prompt) await self.send_command(websocket, "list_directory", data={"request_id": str(file_request.id)}) return - llm_provider = get_llm_provider(provider_name) + # Fetch user preferences for overrides + api_key_override = None + model_name_override = "" + user = session.user + if user and user.preferences: + llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {}) + api_key_override = llm_prefs.get("api_key") + model_name_override = llm_prefs.get("model", "") + + # Get the appropriate LLM provider + llm_provider = get_llm_provider( + provider_name, + model_name=model_name_override, + api_key_override=api_key_override + ) chat = DspyRagPipeline() with dspy.context(lm=llm_provider): answer_text = await chat(question=prompt, history=session.messages, context_chunks=[]) diff --git a/ai-hub/app/core/vector_store/embedder/factory.py b/ai-hub/app/core/vector_store/embedder/factory.py index 64fca80..14ab9e8 100644 --- a/ai-hub/app/core/vector_store/embedder/factory.py +++ b/ai-hub/app/core/vector_store/embedder/factory.py @@ -1,11 +1,12 @@ -from app.config import EmbeddingProvider from .genai import GenAIEmbedder from .mock import MockEmbedder def get_embedder_from_config(provider, dimension, model_name, api_key): - if provider == EmbeddingProvider.GOOGLE_GEMINI: + + # Use lowercase strings to match the settings values + if provider in ["google_gemini", "gemini"]: return GenAIEmbedder(model_name, api_key, dimension) - elif provider == EmbeddingProvider.MOCK: + elif provider == "mock": return MockEmbedder(dimension) else: raise ValueError(f"Unsupported embedding provider: {provider}") diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index cf8d4eb..46d9c5c 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -27,6 +27,8 @@ username = Column(String, nullable=True) # Timestamp for when the user account was created. created_at = Column(DateTime, default=datetime.utcnow) + # User's preferences/settings (e.g. LLM/TTS/STT configs) + preferences = Column(JSON, default={}, nullable=True) # Defines a one-to-many relationship with the Session table. # 'back_populates' creates a link back to the User model from the Session model. diff --git a/ai-hub/app/db/session.py b/ai-hub/app/db/session.py index 6c90abe..417de48 100644 --- a/ai-hub/app/db/session.py +++ b/ai-hub/app/db/session.py @@ -14,6 +14,13 @@ # This prevents errors from connections that have been timed out by the DB server. engine_args["pool_pre_ping"] = True +import os + +# Create the data directory if it doesn't exist to prevent sqlite3.OperationalError +if settings.DATABASE_URL and settings.DATABASE_URL.startswith("sqlite:///"): + db_path = settings.DATABASE_URL.replace("sqlite:///", "") + os.makedirs(os.path.dirname(db_path), exist_ok=True) + # Create the SQLAlchemy engine using the centralized URL and determined arguments engine = create_engine(settings.DATABASE_URL, **engine_args) diff --git a/ai-hub/app/utils.py b/ai-hub/app/utils.py index a61449c..c17c6b1 100644 --- a/ai-hub/app/utils.py +++ b/ai-hub/app/utils.py @@ -21,7 +21,9 @@ continue if any(k in key for k in sensitive_keywords): - if isinstance(value, str) and len(value) > 8: + if not value: + masked_value = "Not Set" + elif isinstance(value, str) and len(value) > 8: masked_value = f"{value[:4]}...{value[-4:]}" else: masked_value = "***" diff --git a/docker-compose.yml b/docker-compose.yml index 3aa3a12..90990ce 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,8 +6,12 @@ build: ./ai-hub container_name: ai_hub_service restart: unless-stopped - env_file: - - ai-hub/.env.prod + environment: + - PATH_PREFIX=/api/v1 + - OIDC_CLIENT_ID=cortex-server + - OIDC_CLIENT_SECRET=aYc2j1lYUUZXkBFFUndnleZI + - OIDC_SERVER_URL=https://auth.jerxie.com + - OIDC_REDIRECT_URI=https://ai.jerxie.com/api/v1/users/login/callback volumes: # Mount the named volume to the /app/data directory in the container - ai_hub_data:/app/data:rw @@ -22,12 +26,13 @@ build: ./ui/client-app container_name: ai_frontend_service restart: unless-stopped - env_file: - - ui/client-app/.env.prod + + environment: + - REACT_APP_API_BASE_URL=https://ai.jerxie.com/api/v1 + - PORT=8000 ports: - # Map host port 8080 to container port 80 (Nginx default) - # This avoids a port conflict with the AI hub - - "8003:443" + # Map host port 8003 to container port 8000 + - "8003:8000" # Define the named volume for the AI hub's data volumes: diff --git a/local_deployment.sh b/local_deployment.sh index 41f83f5..278f298 100644 --- a/local_deployment.sh +++ b/local_deployment.sh @@ -18,9 +18,6 @@ # Set the absolute path to your project directory. PROJECT_DIR="/home/coder/project/cortex-hub" -# Set the name of your production environment file (if you are using one). -# If you are not using a separate env file, you can leave this variable empty. -PROD_ENV_FILE="ai-hub/.env" # --- Script Execution --- echo "🚀 Starting AI Hub deployment process..." @@ -40,11 +37,7 @@ # Navigate to the project directory. Exit if the directory doesn't exist. cd "$PROJECT_DIR" || { echo "Error: Project directory '$PROJECT_DIR' not found. Exiting."; exit 1; } -# Check for the existence of the production environment file, if specified. -if [ -n "$PROD_ENV_FILE" ] && [ ! -f "$PROD_ENV_FILE" ]; then - echo "Warning: Production environment file '$PROD_ENV_FILE' not found." - echo "The script will proceed, but some services may not have the correct environment variables." -fi + # Stop and remove any existing containers to ensure a clean deployment. echo "🛑 Stopping and removing old Docker containers and networks..." diff --git a/ui/client-app/src/App.js b/ui/client-app/src/App.js index 2097fad..ec2a2e9 100644 --- a/ui/client-app/src/App.js +++ b/ui/client-app/src/App.js @@ -5,7 +5,8 @@ import VoiceChatPage from "./pages/VoiceChatPage"; import CodingAssistantPage from "./pages/CodingAssistantPage"; import LoginPage from "./pages/LoginPage"; -import { getUserStatus, logout } from "./services/apiService"; +import SettingsPage from "./pages/SettingsPage"; +import { getUserStatus, logout } from "./services/apiService"; const Icon = ({ path, onClick, className }) => ( { const urlParams = new URLSearchParams(window.location.search); @@ -117,6 +118,8 @@ return ; case "coding-assistant": return ; + case "settings": + return ; case "login": return ; default: diff --git a/ui/client-app/src/components/ChatWindow.js b/ui/client-app/src/components/ChatWindow.js index d006e30..1f53477 100644 --- a/ui/client-app/src/components/ChatWindow.js +++ b/ui/client-app/src/components/ChatWindow.js @@ -4,7 +4,7 @@ import FileListComponent from "./FileList"; import DiffViewer from "./DiffViewer"; import CodeChangePlan from "./CodeChangePlan"; -import { FaRegCopy,FaCopy } from 'react-icons/fa'; // Import the copy icon +import { FaRegCopy, FaCopy } from 'react-icons/fa'; // Import the copy icon // Individual message component const ChatMessage = ({ message }) => { @@ -33,7 +33,7 @@ } } }; - const assistantMessageClasses = `p-4 rounded-lg shadow-md w-full bg-gray-200 dark:bg-gray-700 text-gray-900 dark:text-gray-100 assistant-message`; + const assistantMessageClasses = `p-4 rounded-lg shadow-md max-w-3xl bg-gray-200 dark:bg-gray-700 text-gray-900 dark:text-gray-100 assistant-message mr-auto`; const userMessageClasses = `max-w-md p-4 rounded-lg shadow-md bg-indigo-500 text-white ml-auto`; return ( @@ -47,9 +47,8 @@ {isReasoningExpanded ? "Hide Reasoning ▲" : "Show Reasoning ▼"}
{message.reasoning}
@@ -61,19 +60,19 @@ {/* Horizontal line */}
- {/* Copy Icon - positioned above the bottom line */} + {/* Copy Icon - positioned above the bottom line */} + {/* Solid icon (initially hidden) */} + + +
)} diff --git a/ui/client-app/src/components/Navbar.js b/ui/client-app/src/components/Navbar.js index 4cd1d98..3d8db68 100644 --- a/ui/client-app/src/components/Navbar.js +++ b/ui/client-app/src/components/Navbar.js @@ -12,15 +12,13 @@ return (