Newer
Older
cortex-hub / ai-hub / app / config.py
import os
import yaml
from enum import Enum
from typing import Optional
from dotenv import load_dotenv
from pydantic import BaseModel, Field, SecretStr

# Load environment variables from .env file
load_dotenv()

# --- 1. Define the Configuration Schema ---

class ApplicationSettings(BaseModel):
    project_name: str = "Cortex Hub"
    version: str = "1.0.0"
    log_level: str = "INFO"
    super_admins: list[str] = Field(default_factory=list)

class OIDCSettings(BaseModel):
    client_id: str = ""
    client_secret: str = ""
    server_url: str = ""
    redirect_uri: str = ""

class DatabaseSettings(BaseModel):
    mode: str = "sqlite"
    url: Optional[str] = None
    local_path: str = "data/ai_hub.db"

class ProviderSettings(BaseModel):
    """Generic structure to hold any provider-specific config (api_key, model, etc.)"""
    active_provider: Optional[str] = None
    providers: dict[str, dict] = Field(default_factory=dict)
    # Compatibility for top-level keys
    api_key: Optional[SecretStr] = None
    model_name: Optional[str] = None
    voice_name: Optional[str] = None

class EmbeddingProviderSettings(BaseModel):
    provider: str = "google_gemini"
    model_name: str = "models/text-embedding-004"
    api_key: Optional[SecretStr] = None

class LLMProvidersSettings(BaseModel):
    """Holds shared API keys and per-provider overrides."""
    providers: dict[str, dict] = Field(default_factory=dict)

class VectorStoreSettings(BaseModel):
    index_path: str = "data/faiss_index.bin"
    embedding_dimension: int = 768

class AppConfig(BaseModel):
    """Top-level Pydantic model for application configuration."""
    application: ApplicationSettings = Field(default_factory=ApplicationSettings)
    database: DatabaseSettings = Field(default_factory=DatabaseSettings)
    llm_providers: LLMProvidersSettings = Field(default_factory=LLMProvidersSettings)
    vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings)
    embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings)
    tts_provider: ProviderSettings = Field(default_factory=ProviderSettings)
    stt_provider: ProviderSettings = Field(default_factory=ProviderSettings)
    oidc: OIDCSettings = Field(default_factory=OIDCSettings)


# --- 2. Create the Final Settings Object ---
class Settings:
    """
    Holds all application settings, validated and structured by Pydantic.
    Priority Order: Environment Variables > YAML File > Pydantic Defaults
    """
    def __init__(self):
        config_path = os.getenv("CONFIG_PATH", "config.yaml")
        yaml_data = {}
        if os.path.exists(config_path):
            print(f"✅ Loading configuration from {config_path}")
            with open(config_path, 'r') as f:
                yaml_data = yaml.safe_load(f) or {}
        else:
            print(f"⚠️ '{config_path}' not found. Using defaults and environment variables.")
        
        config_from_pydantic = AppConfig.parse_obj(yaml_data)

        def get_from_yaml(keys):
            d = yaml_data
            for key in keys:
                d = d.get(key) if isinstance(d, dict) else None
            return d

        self.PROJECT_NAME: str = os.getenv("PROJECT_NAME") or \
                                 get_from_yaml(["application", "project_name"]) or \
                                 config_from_pydantic.application.project_name
        self.VERSION: str = config_from_pydantic.application.version
        self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \
                                 get_from_yaml(["application", "log_level"]) or \
                                 config_from_pydantic.application.log_level
        self.SUPER_ADMINS: list[str] = get_from_yaml(["application", "super_admins"]) or \
                                        config_from_pydantic.application.super_admins

        # --- OIDC Settings ---
        self.OIDC_CLIENT_ID: str = os.getenv("OIDC_CLIENT_ID") or \
                                   get_from_yaml(["oidc", "client_id"]) or \
                                   config_from_pydantic.oidc.client_id
        self.OIDC_CLIENT_SECRET: str = os.getenv("OIDC_CLIENT_SECRET") or \
                                       get_from_yaml(["oidc", "client_secret"]) or \
                                       config_from_pydantic.oidc.client_secret
        self.OIDC_SERVER_URL: str = os.getenv("OIDC_SERVER_URL") or \
                                     get_from_yaml(["oidc", "server_url"]) or \
                                     config_from_pydantic.oidc.server_url
        self.OIDC_REDIRECT_URI: str = os.getenv("OIDC_REDIRECT_URI") or \
                                       get_from_yaml(["oidc", "redirect_uri"]) or \
                                       config_from_pydantic.oidc.redirect_uri

        # --- Database Settings ---
        self.DB_MODE: str = os.getenv("DB_MODE") or \
                                get_from_yaml(["database", "mode"]) or \
                                config_from_pydantic.database.mode

        local_db_path = os.getenv("LOCAL_DB_PATH") or \
                                 get_from_yaml(["database", "local_path"]) or \
                                 config_from_pydantic.database.local_path
        external_db_url = os.getenv("DATABASE_URL") or \
                                 get_from_yaml(["database", "url"]) or \
                                 config_from_pydantic.database.url

        if self.DB_MODE == "sqlite":
            normalized_path = local_db_path.lstrip("./")
            self.DATABASE_URL: str = f"sqlite:///./{normalized_path}" if normalized_path else "sqlite:///./data/ai_hub.db"
        else:
            self.DATABASE_URL: str = external_db_url or "sqlite:///./data/ai_hub.db" 
        
        # --- Agnostic Provider Resolution ---
        # We store everything in a flat map for the legacy settings getters, 
        # but also provide a dynamic map.
        
        # 1. Resolve LLM Providers
        self.LLM_PROVIDERS = config_from_pydantic.llm_providers.providers or {}
        # Support legacy environment variables and merge them into the providers map
        for env_key, env_val in os.environ.items():
            if env_key.endswith("_API_KEY") and not any(x in env_key for x in ["TTS", "STT", "EMBEDDING"]):
                provider_id = env_key.replace("_API_KEY", "").lower()
                if provider_id not in self.LLM_PROVIDERS:
                    self.LLM_PROVIDERS[provider_id] = {}
                self.LLM_PROVIDERS[provider_id]["api_key"] = env_val
            if env_key.endswith("_MODEL_NAME") and not any(x in env_key for x in ["TTS", "STT", "EMBEDDING"]):
                provider_id = env_key.replace("_MODEL_NAME", "").lower()
                if provider_id not in self.LLM_PROVIDERS:
                    self.LLM_PROVIDERS[provider_id] = {}
                self.LLM_PROVIDERS[provider_id]["model"] = env_val

        # Explicit legacy fallback helpers (still useful for factory.py initial state)
        self.DEEPSEEK_API_KEY = self.LLM_PROVIDERS.get("deepseek", {}).get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        self.GEMINI_API_KEY = self.LLM_PROVIDERS.get("gemini", {}).get("api_key") or os.getenv("GEMINI_API_KEY")
        self.OPENAI_API_KEY = self.LLM_PROVIDERS.get("openai", {}).get("api_key") or os.getenv("OPENAI_API_KEY")
        
        self.DEEPSEEK_MODEL_NAME = self.LLM_PROVIDERS.get("deepseek", {}).get("model") or \
                                    get_from_yaml(["llm_providers", "deepseek_model_name"]) or "deepseek-chat"
        self.GEMINI_MODEL_NAME = self.LLM_PROVIDERS.get("gemini", {}).get("model") or \
                                   get_from_yaml(["llm_providers", "gemini_model_name"]) or "gemini-1.5-flash-latest"

        # 2. Resolve Vector / Embedding
        self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \
                                             get_from_yaml(["vector_store", "index_path"]) or \
                                             config_from_pydantic.vector_store.index_path
        dimension_str = os.getenv("EMBEDDING_DIMENSION") or \
                                 get_from_yaml(["vector_store", "embedding_dimension"]) or \
                                 config_from_pydantic.vector_store.embedding_dimension
        self.EMBEDDING_DIMENSION: int = int(dimension_str)

        self.EMBEDDING_PROVIDER: str = os.getenv("EMBEDDING_PROVIDER") or \
                                        get_from_yaml(["embedding_provider", "provider"]) or \
                                        config_from_pydantic.embedding_provider.provider
        self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \
                                             get_from_yaml(["embedding_provider", "model_name"]) or \
                                             config_from_pydantic.embedding_provider.model_name
        self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \
                                                 get_from_yaml(["embedding_provider", "api_key"]) or \
                                                 self.GEMINI_API_KEY

        # 3. Resolve TTS (Agnostic)
        self.TTS_PROVIDER: str = os.getenv("TTS_PROVIDER") or \
                                 get_from_yaml(["tts_provider", "provider"]) or \
                                 config_from_pydantic.tts_provider.active_provider or "google_gemini"
        
        # Legacy back-compat fields
        self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \
                                     get_from_yaml(["tts_provider", "voice_name"]) or \
                                     config_from_pydantic.tts_provider.voice_name or "Kore"
        self.TTS_MODEL_NAME: str = os.getenv("TTS_MODEL_NAME") or \
                                     get_from_yaml(["tts_provider", "model_name"]) or \
                                     config_from_pydantic.tts_provider.model_name or "gemini-2.5-flash-preview-tts"
        self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \
                                          get_from_yaml(["tts_provider", "api_key"]) or \
                                          self.GEMINI_API_KEY

        # 4. Resolve STT (Agnostic)
        self.STT_PROVIDER: str = os.getenv("STT_PROVIDER") or \
                                 get_from_yaml(["stt_provider", "provider"]) or \
                                 config_from_pydantic.stt_provider.active_provider or "google_gemini"
        
        self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \
                                     get_from_yaml(["stt_provider", "model_name"]) or \
                                     config_from_pydantic.stt_provider.model_name or "gemini-2.5-flash"
        self.STT_API_KEY: Optional[str] = os.getenv("STT_API_KEY") or \
                                          get_from_yaml(["stt_provider", "api_key"]) or \
                                          self.GEMINI_API_KEY
        
    def save_to_yaml(self):
        """Saves current settings back to config.yaml."""
        import yaml
        config_path = os.getenv("CONFIG_PATH", "config.yaml")

        def get_val(v):
            if hasattr(v, "get_secret_value"):
                return v.get_secret_value()
            return v

        # Build data dictionary by mapping current class attributes back to YAML structure
        # This keeps the sync logic centralized in this class
        data = {
            "application": {
                "project_name": self.PROJECT_NAME,
                "version": self.VERSION,
                "log_level": self.LOG_LEVEL,
                "super_admins": self.SUPER_ADMINS
            },
            "database": {
                "mode": self.DB_MODE,
                "local_path": self.DATABASE_URL.replace("sqlite:///./", "") if "sqlite" in self.DATABASE_URL else "data/ai_hub.db"
            },
            "vector_store": {
                "index_path": self.FAISS_INDEX_PATH,
                "embedding_dimension": self.EMBEDDING_DIMENSION
            },
            "oidc": {
                "client_id": self.OIDC_CLIENT_ID,
                "client_secret": self.OIDC_CLIENT_SECRET,
                "server_url": self.OIDC_SERVER_URL,
                "redirect_uri": self.OIDC_REDIRECT_URI
            }
        }
        
        # Ensure directories exist
        os.makedirs(os.path.dirname(os.path.abspath(config_path)), exist_ok=True)
        
        with open(config_path, 'w') as f:
            yaml.dump(data, f, sort_keys=False, default_flow_style=False)
        print(f"🏠 Configuration synchronized to {config_path}")

# Instantiate the single settings object for the application
settings = Settings()