Newer
Older
cortex-hub / ai-hub / app / core / services / session.py
import os
import uuid
import logging
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.exc import SQLAlchemyError
from typing import Dict, List, Optional, Any
from app.db import models
from app.api import schemas

logger = logging.getLogger(__name__)

class SessionService:
    def __init__(self, services=None):
        self.services = services

    def _mount_skills_to_workspace(self, db: Session, session: models.Session):
        """Standardizes workspace skill availability by linking local skill files."""
        if not session.sync_workspace_id or not self.services: return
        try:
            orchestrator = getattr(self.services, "orchestrator", None)
            tool_service = getattr(self.services, "tool_service", None)
            if not orchestrator or not orchestrator.mirror or not tool_service: return
            
            workspace_path = orchestrator.mirror.get_workspace_path(session.sync_workspace_id)
            skills_dir = os.path.join(workspace_path, ".skills")
            os.makedirs(skills_dir, exist_ok=True)
            
            tools = tool_service.get_available_tools(db, user_id=session.user_id, feature=session.feature_name)
            valid_tool_names = {t["function"]["name"] for t in tools}
            
            from app.core.skills.fs_loader import fs_loader
            from app.config import settings
            
            for fs_skill in fs_loader.get_all_skills():
                skill_name = fs_skill.get("name")
                if skill_name in valid_tool_names:
                    self._create_skill_symlink(fs_skill, skills_dir, settings.DATA_DIR)
        except Exception as e:
            logger.error(f"Failed to mount skills to workspace: {e}")

    def _create_skill_symlink(self, fs_skill: dict, skills_dir: str, data_dir: str):
        """Internal helper to safely create skill symlinks."""
        skill_name = fs_skill.get("name")
        feature = fs_skill.get("features", ["chat"])[0]
        skill_id = fs_skill.get("id", "").replace("fs-", "")
        skill_path = os.path.join(data_dir, "skills", feature, skill_id)
        link_path = os.path.join(skills_dir, skill_name)
        
        if os.path.exists(skill_path) and not os.path.exists(link_path):
            try: os.symlink(skill_path, link_path, target_is_directory=True)
            except OSError: pass

    def create_session(
        self, 
        db: Session, 
        user_id: str, 
        provider_name: str,
        model_name: str = None,
        feature_name: str = "default",
        stt_provider_name: str = None,
        tts_provider_name: str = None
    ) -> models.Session:
        try:
            new_session = models.Session(
                user_id=user_id, 
                provider_name=provider_name, 
                model_name=model_name,
                stt_provider_name=stt_provider_name,
                tts_provider_name=tts_provider_name,
                feature_name=feature_name, 
                title=f"New Chat Session"
            )
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    def auto_attach_default_nodes(self, db: Session, session: models.Session, request: schemas.SessionCreate):
        """Automatically attaches a user's default nodes to a new session."""
        user = db.query(models.User).filter(models.User.id == request.user_id).first()
        if not user: return session
            
        node_prefs = (user.preferences or {}).get("nodes", {})
        default_nodes = node_prefs.get("default_node_ids", [])
        node_config = node_prefs.get("data_source", {"source": "empty"})

        if request.feature_name == "swarm_control" or default_nodes:
            session.sync_workspace_id = f"session-{session.id}-{uuid.uuid4().hex[:8]}"
            db.commit()
            db.refresh(session)
            
        if default_nodes:
            session.attached_node_ids = list(default_nodes)
            session.node_sync_status = {nid: {"status": "pending", "last_sync": None} for nid in default_nodes}
            db.commit()

            orchestrator = getattr(self.services, "orchestrator", None)
            if orchestrator and orchestrator.assistant:
                for nid in default_nodes:
                    try:
                        self._trigger_orchestrator_sync(orchestrator.assistant, nid, session.sync_workspace_id, node_config)
                        if self.services.node_registry_service:
                            self.services.node_registry_service.emit(nid, "info", {
                                "message": f"Auto-attached to session {session.id}", "workspace_id": session.sync_workspace_id,
                            })
                    except Exception as e: logger.error(f"Auto-attach sync failed for {nid}: {e}")

        self._mount_skills_to_workspace(db, session)
        return session

    def attach_nodes(self, db: Session, session_id: int, request: schemas.NodeAttachRequest) -> schemas.SessionNodeStatusResponse:
        session = db.query(models.Session).filter(
            models.Session.id == session_id,
            models.Session.is_archived == False
        ).first()
        if not session:
            return None

        if not session.sync_workspace_id:
            session.sync_workspace_id = f"session-{session_id}-{uuid.uuid4().hex[:8]}"

        old_node_ids = set(session.attached_node_ids or [])
        new_node_ids = set(request.node_ids)
        detached_nodes = old_node_ids - new_node_ids
        
        session.attached_node_ids = list(request.node_ids)

        sync_status = dict(session.node_sync_status or {})
        registry = getattr(self.services, "node_registry_service", None)
        
        for nid in new_node_ids:
            if nid not in sync_status:
                sync_status[nid] = {"status": "pending", "last_sync": None}

            if registry:
                try:
                    registry.emit(
                        nid, "info",
                        {"message": f"Attached to session {session_id}", "workspace_id": session.sync_workspace_id},
                    )
                except Exception:
                    pass

        for nid in detached_nodes:
            sync_status.pop(nid, None)

        session.node_sync_status = sync_status
        flag_modified(session, "attached_node_ids")
        flag_modified(session, "node_sync_status")
        db.commit()
        db.refresh(session)

        orchestrator = getattr(self.services, "orchestrator", None)
        if not orchestrator:
            logger.warning("Orchestrator not found in ServiceContainer; cannot trigger sync.")
        else:
            try:
                assistant = orchestrator.assistant
                session.sync_config = config.model_dump()
                db.commit()
                
                if strategy_changed:
                    for nid in old_node_ids: assistant.clear_workspace(nid, session.sync_workspace_id)
                    if getattr(orchestrator, "mirror", None): orchestrator.mirror.purge(session.sync_workspace_id)
                else:
                    for nid in detached_nodes: assistant.clear_workspace(nid, session.sync_workspace_id)
                
                for nid in request.node_ids:
                    try: self._trigger_orchestrator_sync(assistant, nid, session.sync_workspace_id, session.sync_config)
                    except Exception as e: logger.error(f"Manual sync failed for {nid}: {e}")
                        
            except Exception as e:
                logger.error(f"Failed to trigger session node sync: {e}")

        self._mount_skills_to_workspace(db, session)
        return schemas.SessionNodeStatusResponse(
            session_id=session_id,
            sync_workspace_id=session.sync_workspace_id,
            nodes=[
                schemas.NodeSyncStatusEntry(
                    node_id=nid,
                    status=sync_status.get(nid, {}).get("status", "pending"),
                    last_sync=sync_status.get(nid, {}).get("last_sync"),
                )
                for nid in session.attached_node_ids
            ],
            sync_config=session.sync_config or {}
        )

    def get_token_usage(self, db: Session, session_id: int) -> schemas.SessionTokenUsageResponse:
        """Centralized token counter with effective LLM model resolution."""
        session = db.query(models.Session).filter(models.Session.id == session_id).first()
        if not session: raise HTTPException(status_code=404, detail="Session not found.")

        messages = self.services.rag_service.get_message_history(db=db, session_id=session_id)
        combined_text = " ".join([m.content for m in messages])
        
        # Resolve effective LLM model
        user = session.user
        if not user:
             from app.config import settings
             admin_email = settings.SUPER_ADMINS[0] if settings.SUPER_ADMINS else None
             user = db.query(models.User).filter(models.User.email == admin_email).first()

        provider, provider_name = self.services.preference_service.resolve_llm_provider(db, user, session.provider_name, session.model_name)
        
        from app.core.providers.factory import get_model_limit
        try:
            # Fallback model name if provider object is complex
            m_name = getattr(provider, "model", "") if hasattr(provider, "model") else (session.model_name or "")
            token_limit = get_model_limit(provider_name, model_name=m_name)
        except:
            token_limit = 100000 

        from app.core.orchestration.validator import Validator
        validator = Validator(token_limit=token_limit)
        token_count = validator.get_token_count(combined_text)
        percentage = round((token_count / token_limit) * 100, 2)

        return schemas.SessionTokenUsageResponse(token_count=token_count, token_limit=token_limit, percentage=percentage)

    def archive_session(self, db: Session, session_id: int):
        """Archives a session and purges associated node workspaces."""
        session = db.query(models.Session).filter(models.Session.id == session_id).first()
        if not session: raise HTTPException(status_code=404, detail="Session not found.")
        if session.is_locked: raise HTTPException(status_code=403, detail="Session is locked.")

        session.is_archived = True
        wid = session.sync_workspace_id
        db.commit()
        if wid: self._broadcast_workspace_purge([wid])

    def archive_all_feature_sessions(self, db: Session, user_id: str, feature_name: str) -> int:
        """Archives all non-locked sessions for a specific feature and user."""
        sessions = db.query(models.Session).filter(
            models.Session.user_id == user_id, models.Session.feature_name == feature_name,
            models.Session.is_archived == False, models.Session.is_locked == False
        ).all()
        
        wids = [s.sync_workspace_id for s in sessions if s.sync_workspace_id]
        count = len(sessions)
        for s in sessions: s.is_archived = True
        db.commit()
        
        if wids: self._broadcast_workspace_purge(wids)
        return count

    def _broadcast_workspace_purge(self, workspace_ids: List[str]):
        """Helper to send PURGE commands to all active nodes and clean up Hub local mirror."""
        import shutil
        from app.config import settings
        from app.protos import agent_pb2

        orchestrator = getattr(self.services, "orchestrator", None)
        if not orchestrator: return

        live_nodes = orchestrator.registry.list_nodes()
        for nid in live_nodes:
            live = orchestrator.registry.get_node(nid)
            if not live: continue
            for wid in workspace_ids:
                try:
                    live.send_message(agent_pb2.ServerTaskMessage(
                        file_sync=agent_pb2.FileSyncMessage(
                            session_id=wid,
                            control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE)
                        )
                    ), priority=0)
                except: pass

        for wid in workspace_ids:
            path = os.path.join(settings.DATA_DIR, "mirrors", wid)
            if os.path.exists(path):
                shutil.rmtree(path, ignore_errors=True)

    def _trigger_orchestrator_sync(self, assistant, nid, workspace_id, config):
        """Unified sync dispatcher based on configured source strategy."""
        source = config.get("source", "empty")
        path = config.get("path", ".")
        source_nid = config.get("source_node_id")
        read_only = (nid in config.get("read_only_node_ids", [])) if config.get("read_only_node_ids") else False

        if source == "server":
            assistant.push_workspace(nid, workspace_id)
            assistant.control_sync(nid, workspace_id, action="LOCK")
        elif source == "empty":
            assistant.push_workspace(nid, workspace_id)
            assistant.control_sync(nid, workspace_id, action="START")
            assistant.control_sync(nid, workspace_id, action="UNLOCK")
        elif source == "node_local":
            if source_nid == nid:
                assistant.request_manifest(nid, workspace_id, path=path)
                assistant.control_sync(nid, workspace_id, action="START", path=path)
                assistant.control_sync(nid, workspace_id, action="UNLOCK")
            else:
                assistant.control_sync(nid, workspace_id, action="START")
                assistant.control_sync(nid, workspace_id, action="LOCK")
                assistant.push_workspace(nid, workspace_id)
                
        if read_only:
            assistant.control_sync(nid, workspace_id, action="LOCK")