# app/app.py
import asyncio
import logging
import litellm
from fastapi import FastAPI
from contextlib import asynccontextmanager
from typing import List
logger = logging.getLogger(__name__)
# Import centralized settings and other components
from app.config import settings
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.vector_store.embedder.factory import get_embedder_from_config
from app.core.providers.factory import get_tts_provider, get_stt_provider
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.db.session import create_db_and_tables
from app.db.migrate import run_migrations
from app.api.routes.api import create_api_router
from app.utils import print_config
from app.api.dependencies import ServiceContainer, get_db
from app.core.services.session import SessionService
from app.core.services.tts import TTSService
from app.core.services.stt import STTService # NEW: Added the missing import for STTService
from app.core.services.user import UserService
from app.core.services.prompt import PromptService
from app.core.services.tool import ToolService
from app.core.services.node_registry import NodeRegistryService
# Note: The llm_clients import and initialization are removed as they
# are not used in RAGService's constructor based on your services.py
# from app.core.llm_clients import DeepSeekClient, GeminiClient
from fastapi.middleware.cors import CORSMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages application startup and shutdown events.
- On startup, it creates database tables.
- On shutdown, it saves the FAISS index to disk.
"""
print("Application startup...")
print_config(settings)
create_db_and_tables()
run_migrations()
# --- Reset Node Statuses (Zombie Fix) ---
try:
app.state.services.node_registry_service.reset_all_statuses()
except Exception as e:
logger.warning(f"Failed to reset node statuses: {e}")
# --- Start gRPC Orchestrator (M6) ---
try:
from app.core.grpc.services.grpc_server import serve_grpc
registry = app.state.services.node_registry_service
server, orchestrator = serve_grpc(registry, port=50051)
app.state.grpc_server = server
app.state.orchestrator = orchestrator
app.state.services.with_service("orchestrator", orchestrator)
logger.info("[M6] Agent Orchestrator gRPC server started on port 50051.")
# Launch periodic Ghost Mirror cleanup
asyncio.create_task(_periodic_mirror_cleanup(orchestrator))
except Exception as e:
logger.error(f"[M6] Failed to start gRPC server: {e}")
# --- Bootstrap System Skills ---
try:
from app.core.skills.bootstrap import bootstrap_system_skills
# Use the context manager to ensure session is closed
from app.db.session import get_db_session
with get_db_session() as db:
bootstrap_system_skills(db)
except Exception as e:
logger.error(f"Failed to bootstrap system skills: {e}")
yield
print("Application shutdown...")
# --- Stop gRPC Orchestrator ---
if hasattr(app.state, 'grpc_server'):
logger.info("[M6] Stopping gRPC server...")
app.state.grpc_server.stop(0)
# Access the vector_store from the application state to save it
if hasattr(app.state, 'vector_store'):
app.state.vector_store.save_index()
async def _periodic_mirror_cleanup(orchestrator):
"""Periodically purges orphaned ghost mirror directories from the server."""
await asyncio.sleep(10) # Initial delay to let DB settle
while True:
try:
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
# Fetch all unique sync_workspace_ids currently in DB
sessions = db.query(models.Session).filter(models.Session.sync_workspace_id != None).all()
active_ids = [s.sync_workspace_id for s in sessions]
if hasattr(orchestrator, 'mirror'):
orchestrator.mirror.purge_orphaned(active_ids)
else:
logger.warning("[๐๐งน] Orchestrator missing mirror manager during cleanup pass.")
except Exception as e:
logger.error(f"[๐๐งน] Ghost Mirror periodic cleanup fail: {e}")
await asyncio.sleep(3600) # Run every hour
def create_app() -> FastAPI:
"""
Factory function to create and configure the FastAPI application.
This encapsulates all setup logic, making the main entry point clean.
"""
app = FastAPI(
# Use metadata from the central settings
title=settings.PROJECT_NAME,
version=settings.VERSION,
description="A modular API to route requests to various LLMs with RAG capabilities.",
lifespan=lifespan
)
logging.basicConfig(level=settings.LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s')
# Global settings for LiteLLM to handle provider-specific quirks
litellm.drop_params = True
# --- Initialize Core Services defensively ---
# RAG Components are optional for now as requested
embedder = None
vector_store = None
retrievers = []
try:
# Resolve from config/settings
if settings.EMBEDDING_PROVIDER:
embedder = get_embedder_from_config(
provider=settings.EMBEDDING_PROVIDER,
dimension=settings.EMBEDDING_DIMENSION,
model_name=settings.EMBEDDING_MODEL_NAME,
api_key=settings.EMBEDDING_API_KEY
)
vector_store = FaissVectorStore(
index_file_path=settings.FAISS_INDEX_PATH,
dimension=settings.EMBEDDING_DIMENSION,
embedder=embedder
)
app.state.vector_store = vector_store
retrievers.append(FaissDBRetriever(vector_store=vector_store))
except Exception as e:
logger.error(f"Failed to initialize Vector Store: {e}. RAG functionality might be restricted.")
# Voice Providers (optional fallback)
tts_provider = None
stt_provider = None
try:
if settings.TTS_PROVIDER:
tts_provider = get_tts_provider(
provider_name=settings.TTS_PROVIDER,
api_key=settings.TTS_API_KEY,
model_name=settings.TTS_MODEL_NAME,
voice_name=settings.TTS_VOICE_NAME
)
if settings.STT_PROVIDER:
stt_provider = get_stt_provider(
provider_name=settings.STT_PROVIDER,
api_key=settings.STT_API_KEY,
model_name=settings.STT_MODEL_NAME
)
except ValueError as e:
logger.info(f"TTS/STT will be initialized later via UI: {e}")
except Exception as e:
logger.warning(f"Failed to initialize TTS/STT: {e}")
prompt_service = PromptService()
# 9. Initialize the Service Container with all initialized services
services = ServiceContainer()
services.with_document_service(vector_store=vector_store)
node_registry_service = NodeRegistryService()
services.with_service("node_registry_service", service=node_registry_service)
tool_service = ToolService(services=services)
services.with_service("tool_service", service=tool_service)
services.with_rag_service(
retrievers=retrievers,
prompt_service=prompt_service,
tool_service=tool_service,
node_registry_service=node_registry_service
)
services.with_service("stt_service", service=STTService(stt_provider=stt_provider))
services.with_service("tts_service", service=TTSService(tts_provider=tts_provider))
services.with_service("prompt_service", service=prompt_service)
services.with_service("session_service", service=SessionService())
services.with_service("user_service", service=UserService())
app.state.services = services
# Create and include the API router, injecting the service
api_router = create_api_router(services=services)
app.include_router(api_router)
import os
cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:8000,http://localhost:8080,http://localhost:3000").split(",")
hub_url = os.getenv("HUB_PUBLIC_URL")
if hub_url and hub_url not in cors_origins:
cors_origins.append(hub_url)
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # Allows all headers
)
return app