diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py index 0803877..e54e16e 100644 --- a/ai-hub/app/api/routes/sessions.py +++ b/ai-hub/app/api/routes/sessions.py @@ -42,7 +42,8 @@ session_id=session_id, prompt=request.prompt, provider_name=request.provider_name, - load_faiss_retriever=request.load_faiss_retriever + load_faiss_retriever=request.load_faiss_retriever, + user_service=services.user_service ) return schemas.ChatResponse(answer=response_text, provider_used=provider_used, message_id=message_id) except Exception as e: diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index e0ded0d..78e6179 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -86,10 +86,13 @@ if not modelName: modelName = settings.LLM_PROVIDERS.get(provider_name, {}).get("model") if not modelName: - if provider_name == "gemini": modelName = settings.GEMINI_MODEL_NAME - elif provider_name == "deepseek": modelName = settings.DEEPSEEK_MODEL_NAME - else: - raise ValueError(f"No model name provided for '{provider_name}'.") + # Fallback: check base type if user-selected instance is missing from global settings + if provider_name == "gemini": modelName = settings.GEMINI_MODEL_NAME + elif provider_name == "deepseek": modelName = settings.DEEPSEEK_MODEL_NAME + elif "gemini" in provider_name.lower(): modelName = settings.GEMINI_MODEL_NAME + elif "deepseek" in provider_name.lower(): modelName = settings.DEEPSEEK_MODEL_NAME + else: + raise ValueError(f"No model name provided for '{provider_name}'.") # Extract base type (e.g. 'gemini_2' -> 'gemini') litellm_providers = [p.value for p in litellm.LlmProviders] diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 3eb1987..9af3909 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -24,8 +24,9 @@ session_id: int, prompt: str, provider_name: str, - load_faiss_retriever: bool = False - ) -> Tuple[str, str]: + load_faiss_retriever: bool = False, + user_service = None + ) -> Tuple[str, str, int]: """ Processes a user prompt within a session, saves the chat history, and returns a response. """ @@ -49,18 +50,26 @@ # Keep provider_name in sync with the model actually being used if session.provider_name != provider_name: session.provider_name = provider_name - + db.commit() - # Fetch user preferences for overrides - api_key_override = None - model_name_override = "" + # Resolve provider: User Prefs > System Settings llm_prefs = {} user = session.user if user and user.preferences: llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {}) - api_key_override = llm_prefs.get("api_key") - model_name_override = llm_prefs.get("model", "") + + # System Settings Fallback + if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service: + system_prefs = user_service.get_system_settings(db) + system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(provider_name, {}) + if system_provider_prefs: + merged = system_provider_prefs.copy() + if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v}) + llm_prefs = merged + + api_key_override = llm_prefs.get("api_key") + model_name_override = llm_prefs.get("model", "") # Get the appropriate LLM provider with all extra prefs passed as kwargs kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}