Newer
Older
cortex-hub / ai-hub / app / core / services / session.py
import uuid
import logging
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.exc import SQLAlchemyError
from app.db import models
from app.api import schemas

logger = logging.getLogger(__name__)

class SessionService:
    def __init__(self, services=None):
        self.services = services

    def create_session(
        self, 
        db: Session, 
        user_id: str, 
        provider_name: str, 
        feature_name: str = "default",
        stt_provider_name: str = None,
        tts_provider_name: str = None
    ) -> models.Session:
        try:
            new_session = models.Session(
                user_id=user_id, 
                provider_name=provider_name, 
                stt_provider_name=stt_provider_name,
                tts_provider_name=tts_provider_name,
                feature_name=feature_name, 
                title=f"New Chat Session"
            )
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    def auto_attach_default_nodes(self, db: Session, session: models.Session, request: schemas.SessionCreate):
        user = db.query(models.User).filter(models.User.id == request.user_id).first()
        if not user:
            return session
            
        node_prefs = (user.preferences or {}).get("nodes", {})
        default_nodes = node_prefs.get("default_node_ids", [])
        node_config = node_prefs.get("data_source", {"source": "empty"})

        if request.feature_name == "swarm_control" or default_nodes:
            session.sync_workspace_id = f"session-{session.id}-{uuid.uuid4().hex[:8]}"
            
            try:
                if self.services and hasattr(self.services, "orchestrator") and self.services.orchestrator.mirror:
                    self.services.orchestrator.mirror.get_workspace_path(session.sync_workspace_id)
            except Exception as mirror_err:
                logger.error(f"Failed to pre-initialize server mirror: {mirror_err}")

        if default_nodes:
            session.attached_node_ids = list(default_nodes)
            session.node_sync_status = {
                nid: {"status": "pending", "last_sync": None}
                for nid in default_nodes
            }
            db.commit()
            db.refresh(session)

            registry = getattr(self.services, "node_registry_service", None)
            orchestrator = getattr(self.services, "orchestrator", None)
            
            try:
                assistant = orchestrator.assistant if orchestrator else None
                source = node_config.get("source", "empty")
                path = node_config.get("path", "")

                for nid in default_nodes:
                    if registry:
                        try:
                            registry.emit(nid, "info", {
                                "message": f"Auto-attached to session {session.id}",
                                "workspace_id": session.sync_workspace_id,
                            })
                        except Exception: pass

                    if assistant:
                        try:
                            if source == "server":
                                assistant.push_workspace(nid, session.sync_workspace_id)
                            elif source == "empty":
                                assistant.push_workspace(nid, session.sync_workspace_id)
                                assistant.control_sync(nid, session.sync_workspace_id, action="START")
                                assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK")
                            elif source == "node_local":
                                assistant.request_manifest(nid, session.sync_workspace_id, path=path or ".")
                                assistant.control_sync(nid, session.sync_workspace_id, action="START", path=path or ".")
                        except Exception as sync_err:
                            logger.error(f"Failed to trigger sync for node {nid}: {sync_err}")
            except Exception as e:
                logger.error(f"Failed to initialize orchestrator sync: {e}")
        return session

    def attach_nodes(self, db: Session, session_id: int, request: schemas.NodeAttachRequest) -> schemas.SessionNodeStatusResponse:
        session = db.query(models.Session).filter(
            models.Session.id == session_id,
            models.Session.is_archived == False
        ).first()
        if not session:
            return None

        if not session.sync_workspace_id:
            session.sync_workspace_id = f"session-{session_id}-{uuid.uuid4().hex[:8]}"

        old_node_ids = set(session.attached_node_ids or [])
        new_node_ids = set(request.node_ids)
        detached_nodes = old_node_ids - new_node_ids
        
        session.attached_node_ids = list(request.node_ids)

        sync_status = dict(session.node_sync_status or {})
        registry = getattr(self.services, "node_registry_service", None)
        
        for nid in new_node_ids:
            if nid not in sync_status:
                sync_status[nid] = {"status": "pending", "last_sync": None}

            if registry:
                try:
                    registry.emit(
                        nid, "info",
                        {"message": f"Attached to session {session_id}", "workspace_id": session.sync_workspace_id},
                    )
                except Exception:
                    pass

        for nid in detached_nodes:
            sync_status.pop(nid, None)

        session.node_sync_status = sync_status
        flag_modified(session, "attached_node_ids")
        flag_modified(session, "node_sync_status")
        db.commit()
        db.refresh(session)

        orchestrator = getattr(self.services, "orchestrator", None)
        if not orchestrator:
            logger.warning("Orchestrator not found in ServiceContainer; cannot trigger sync.")
        else:
            try:
                assistant = orchestrator.assistant
                config = request.config or schemas.NodeWorkspaceConfig(source="empty")
                old_config = session.sync_config or {}
                
                strategy_changed = False
                if old_config and (config.source != old_config.get("source") or \
                   config.path != old_config.get("path") or \
                   config.source_node_id != old_config.get("source_node_id")):
                    strategy_changed = True

                session.sync_config = config.model_dump()
                db.commit()
                
                if strategy_changed:
                    for nid in old_node_ids:
                        assistant.clear_workspace(nid, session.sync_workspace_id)
                    if getattr(orchestrator, "mirror", None):
                        orchestrator.mirror.purge(session.sync_workspace_id)
                else:
                    for nid in detached_nodes:
                        assistant.clear_workspace(nid, session.sync_workspace_id)
                
                for nid in request.node_ids:
                    if config.source == "server":
                        assistant.push_workspace(nid, session.sync_workspace_id)
                        assistant.control_sync(nid, session.sync_workspace_id, action="LOCK")
                    elif config.source == "empty":
                        assistant.push_workspace(nid, session.sync_workspace_id)
                        assistant.control_sync(nid, session.sync_workspace_id, action="START")
                        assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK")
                    elif config.source == "node_local":
                        if config.source_node_id == nid:
                            assistant.request_manifest(nid, session.sync_workspace_id, path=config.path or ".")
                            assistant.control_sync(nid, session.sync_workspace_id, action="START", path=config.path or ".")
                            assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK")
                        else:
                            assistant.control_sync(nid, session.sync_workspace_id, action="START")
                            assistant.control_sync(nid, session.sync_workspace_id, action="LOCK")
                            assistant.push_workspace(nid, session.sync_workspace_id)
                        
                    if config.read_only_node_ids and nid in config.read_only_node_ids:
                        assistant.control_sync(nid, session.sync_workspace_id, action="LOCK")
                        
            except Exception as e:
                logger.error(f"Failed to trigger session node sync: {e}")

        return schemas.SessionNodeStatusResponse(
            session_id=session_id,
            sync_workspace_id=session.sync_workspace_id,
            nodes=[
                schemas.NodeSyncStatusEntry(
                    node_id=nid,
                    status=sync_status.get(nid, {}).get("status", "pending"),
                    last_sync=sync_status.get(nid, {}).get("last_sync"),
                )
                for nid in session.attached_node_ids
            ],
            sync_config=session.sync_config or {}
        )