Newer
Older
cortex-hub / ai-hub / app / app.py
# app/app.py
from fastapi import FastAPI
from contextlib import asynccontextmanager
from typing import List
import litellm
import logging
logger = logging.getLogger(__name__)

# Import centralized settings and other components
from app.config import settings
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.vector_store.embedder.factory import get_embedder_from_config
from app.core.providers.factory import get_tts_provider, get_stt_provider
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.db.session import create_db_and_tables
from app.api.routes.api import create_api_router
from app.utils import print_config 
from app.api.dependencies import ServiceContainer, get_db
from app.core.services.session import SessionService
from app.core.services import SessionService
from app.core.services.tts import TTSService
from app.core.services.stt import STTService # NEW: Added the missing import for STTService
from app.core.services.user import UserService
from app.core.services.workspace import WorkspaceService # NEW: Added the missing import for STTService
# Note: The llm_clients import and initialization are removed as they
# are not used in RAGService's constructor based on your services.py
# from app.core.llm_clients import DeepSeekClient, GeminiClient
from fastapi.middleware.cors import CORSMiddleware


@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Manages application startup and shutdown events.
    - On startup, it creates database tables.
    - On shutdown, it saves the FAISS index to disk.
    """
    print("Application startup...")
    print_config(settings)
    create_db_and_tables()
    yield
    print("Application shutdown...")
    # Access the vector_store from the application state to save it
    if hasattr(app.state, 'vector_store'):
        app.state.vector_store.save_index()

def create_app() -> FastAPI:
    """
    Factory function to create and configure the FastAPI application.
    This encapsulates all setup logic, making the main entry point clean.
    """
    app = FastAPI(
        # Use metadata from the central settings
        title=settings.PROJECT_NAME,
        version=settings.VERSION,
        description="A modular API to route requests to various LLMs with RAG capabilities.",
        lifespan=lifespan
    )

    logging.basicConfig(level=settings.LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s')
    logging.getLogger("dspy").setLevel(logging.DEBUG)
    
    # Global settings for LiteLLM to handle provider-specific quirks
    litellm.drop_params = True

    # --- Initialize Core Services defensively ---
    
    # RAG Components are optional for now as requested
    embedder = None
    vector_store = None
    retrievers = []
    
    try:
        # Resolve from config/settings
        if settings.EMBEDDING_PROVIDER:
            embedder = get_embedder_from_config(
                provider=settings.EMBEDDING_PROVIDER,
                dimension=settings.EMBEDDING_DIMENSION,
                model_name=settings.EMBEDDING_MODEL_NAME,
                api_key=settings.EMBEDDING_API_KEY
            )
            
            vector_store = FaissVectorStore(
                index_file_path=settings.FAISS_INDEX_PATH,
                dimension=settings.EMBEDDING_DIMENSION,
                embedder=embedder
            )
            app.state.vector_store = vector_store
            retrievers.append(FaissDBRetriever(vector_store=vector_store))
    except Exception as e:
        logger.error(f"Failed to initialize Vector Store: {e}. RAG functionality might be restricted.")

    # Voice Providers (optional fallback)
    tts_provider = None
    stt_provider = None
    try:
        if settings.TTS_PROVIDER:
            tts_provider = get_tts_provider(
                provider_name=settings.TTS_PROVIDER,
                api_key=settings.TTS_API_KEY,
                model_name=settings.TTS_MODEL_NAME,
                voice_name=settings.TTS_VOICE_NAME
            )
        if settings.STT_PROVIDER:
            stt_provider = get_stt_provider(
                provider_name=settings.STT_PROVIDER,
                api_key=settings.STT_API_KEY,
                model_name=settings.STT_MODEL_NAME
            )
    except ValueError as e:
        logger.info(f"TTS/STT will be initialized later via UI: {e}")
    except Exception as e:
        logger.warning(f"Failed to initialize TTS/STT: {e}")

    # 9. Initialize the Service Container with all initialized services
    services = ServiceContainer()
    services.with_rag_service(retrievers=retrievers)
    services.with_document_service(vector_store=vector_store)
    
    if stt_provider:
        services.with_service("stt_service", service=STTService(stt_provider=stt_provider))
    if tts_provider:
        services.with_service("tts_service", service=TTSService(tts_provider=tts_provider))
        
    services.with_service("workspace_service", service=WorkspaceService())
    services.with_service("session_service", service=SessionService())
    services.with_service("user_service", service=UserService())

    # Create and include the API router, injecting the service
    api_router = create_api_router(services=services)
    app.include_router(api_router)

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["https://ai.jerxie.com", "http://localhost:8000", "http://localhost:8080", "http://localhost:443"],
        allow_credentials=True,
        allow_methods=["*"],  # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.)
        allow_headers=["*"],  # Allows all headers
    )
    return app