Newer
Older
cortex-hub / ai-hub / app / api / routes / sessions.py
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Response
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from typing import AsyncGenerator, List, Optional
from app.db import models
from app.core.orchestration.validator import Validator
import os
import shutil

def create_sessions_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/sessions", tags=["Sessions"])

    @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session")
    def create_session(
        request: schemas.SessionCreate,
        db: Session = Depends(get_db)
    ):
        if request.user_id is None or request.provider_name is None:
            raise HTTPException(status_code=400, detail="user_id and provider_name are required to create a session.")
        try:
            new_session = services.session_service.create_session(
                db=db,
                user_id=request.user_id,
                provider_name=request.provider_name,
                feature_name=request.feature_name,
                stt_provider_name=request.stt_provider_name,
                tts_provider_name=request.tts_provider_name
            )

            # M3: Auto-attach user's default nodes from preferences
            new_session = services.session_service.auto_attach_default_nodes(db, new_session, request)

            return new_session
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to create session: {e}")


    @router.post("/{session_id}/chat", summary="Send a Message in a Session (Streaming)")
    async def chat_in_session(
        session_id: int,
        request: schemas.ChatRequest,
        db: Session = Depends(get_db)
    ):
        """
        Streams AI response using Server-Sent Events (SSE).
        Yields tokens, reasoning, and tool executions in real-time.
        """
        # Reset cancellation flag on fresh request
        session = db.query(models.Session).filter(models.Session.id == session_id).first()
        if session:
            session.is_cancelled = False
            db.commit()

        from fastapi.responses import StreamingResponse
        import json

        async def event_generator():
            try:
                # chat_with_rag is now an AsyncGenerator
                async for event in services.rag_service.chat_with_rag(
                    db=db,
                    session_id=session_id,
                    prompt=request.prompt,
                    provider_name=request.provider_name,
                    load_faiss_retriever=request.load_faiss_retriever,
                    user_service=services.user_service
                ):
                    yield f"data: {json.dumps(event)}\n\n"
            except Exception as e:
                import logging
                logging.exception("Error in chat streaming")
                yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"

        return StreamingResponse(event_generator(), media_type="text/event-stream")


    @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History")
    def get_session_messages(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            # Enhance messages with audio availability
            enhanced_messages = []
            for m in messages:
                msg_dict = schemas.Message.model_validate(m).model_dump()
                if m.audio_path and os.path.exists(m.audio_path):
                    msg_dict["has_audio"] = True
                    msg_dict["audio_url"] = f"/sessions/messages/{m.id}/audio"
                enhanced_messages.append(msg_dict)

            return schemas.MessageHistoryResponse(session_id=session_id, messages=enhanced_messages)
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.post("/{session_id}/clear-history", summary="Clear Chat History (Preserve Session)")
    def clear_session_history(session_id: int, db: Session = Depends(get_db)):
        """
        Deletes all messages for a session but preserves the session itself
        (node attachments, workspace ID, sync config, title all remain intact).
        Useful in Swarm Control to start a fresh chat without losing the file sync workspace.
        """
        try:
            session = db.query(models.Session).filter(models.Session.id == session_id).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            
            deleted = db.query(models.Message).filter(models.Message.session_id == session_id).delete()
            db.commit()
            return {"message": f"Cleared {deleted} messages. Session and workspace preserved."}
        except HTTPException:
            raise
        except Exception as e:
            db.rollback()
            raise HTTPException(status_code=500, detail=f"Failed to clear history: {e}")


    @router.get("/{session_id}/tokens", response_model=schemas.SessionTokenUsageResponse, summary="Get Session Token Usage")
    def get_session_token_usage(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(models.Session.id == session_id).first()
            if not session:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            combined_text = " ".join([m.content for m in messages])
            
            # Resolve dynamic token limit from model info
            from app.core.providers.factory import get_model_limit
            token_limit = get_model_limit(session.provider_name)
            
            validator = Validator(token_limit=token_limit)
            token_count = len(validator.encoding.encode(combined_text))
            percentage = round((token_count / token_limit) * 100, 2) if token_limit > 0 else 0.0

            return schemas.SessionTokenUsageResponse(
                token_count=token_count,
                token_limit=token_limit,
                percentage=percentage
            )
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.get("/", response_model=List[schemas.Session], summary="Get All Chat Sessions")
    def get_sessions(
        user_id: str,
        feature_name: str = "default",
        db: Session = Depends(get_db)
    ):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name,
                models.Session.is_archived == False
            ).order_by(models.Session.created_at.desc()).all()
            return sessions
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch sessions: {e}")

    @router.get("/{session_id}", response_model=schemas.Session, summary="Get a Single Session")
    def get_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(
                models.Session.id == session_id,
                models.Session.is_archived == False
            ).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            return session
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}")

    @router.patch("/{session_id}", response_model=schemas.Session, summary="Update a Chat Session")
    def update_session(session_id: int, session_update: schemas.SessionUpdate, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(
                models.Session.id == session_id,
                models.Session.is_archived == False
            ).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            
            if session_update.title is not None:
                session.title = session_update.title
            if session_update.provider_name is not None:
                session.provider_name = session_update.provider_name
            if session_update.stt_provider_name is not None:
                session.stt_provider_name = session_update.stt_provider_name
            if session_update.tts_provider_name is not None:
                session.tts_provider_name = session_update.tts_provider_name
                
            db.commit()
            db.refresh(session)
            return session
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to update session: {e}")

    @router.delete("/{session_id}", summary="Delete a Chat Session")
    def delete_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(models.Session.id == session_id).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            session.is_archived = True
            db.commit()
            return {"message": "Session deleted successfully."}
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete session: {e}")

    @router.delete("/", summary="Delete All Sessions for Feature")
    def delete_all_sessions(user_id: str, feature_name: str = "default", db: Session = Depends(get_db)):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name
            ).all()
            for session in sessions:
                session.is_archived = True
            db.commit()
            return {"message": "All sessions deleted successfully."}
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}")

    @router.post("/messages/{message_id}/audio", summary="Upload audio for a specific message")
    async def upload_message_audio(message_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)):
        try:
            message = db.query(models.Message).filter(models.Message.id == message_id).first()
            if not message:
                raise HTTPException(status_code=404, detail="Message not found.")
            
            # Create data directory if not exists
            audio_dir = "/app/data/audio"
            os.makedirs(audio_dir, exist_ok=True)
            
            # Save file
            file_path = f"{audio_dir}/message_{message_id}.wav"
            with open(file_path, "wb") as buffer:
                shutil.copyfileobj(file.file, buffer)
            
            # Update database
            message.audio_path = file_path
            db.commit()
            
            return {"message": "Audio uploaded successfully.", "audio_path": file_path}
        except Exception as e:
            print(f"Error uploading audio: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to upload audio: {e}")

    @router.get("/messages/{message_id}/audio", summary="Get audio for a specific message")
    async def get_message_audio(message_id: int, db: Session = Depends(get_db)):
        try:
            message = db.query(models.Message).filter(models.Message.id == message_id).first()
            if not message or not message.audio_path:
                raise HTTPException(status_code=404, detail="Audio not found for this message.")
            
            if not os.path.exists(message.audio_path):
                 raise HTTPException(status_code=404, detail="Audio file missing on disk.")

            return FileResponse(message.audio_path, media_type="audio/wav")
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to get audio: {e}")

    # ==================================================================
    # M3: Session ↔ Node Attachment
    # ==================================================================

    @router.post("/{session_id}/nodes", response_model=schemas.SessionNodeStatusResponse,
                 summary="Attach Nodes to Session")
    def attach_nodes_to_session(
        session_id: int,
        request: schemas.NodeAttachRequest,
        db: Session = Depends(get_db)
    ):
        """
        Attach one or more Agent Nodes to a chat session.
        """
        response = services.session_service.attach_nodes(db, session_id, request)
        if not response:
            raise HTTPException(status_code=404, detail="Session not found.")
        return response

    @router.delete("/{session_id}/nodes/{node_id}", summary="Detach Node from Session")
    def detach_node_from_session(
        session_id: int,
        node_id: str,
        db: Session = Depends(get_db)
    ):
        session = db.query(models.Session).filter(
            models.Session.id == session_id,
            models.Session.is_archived == False
        ).first()
        if not session:
            raise HTTPException(status_code=404, detail="Session not found.")

        nodes = list(session.attached_node_ids or [])
        if node_id not in nodes:
            raise HTTPException(status_code=404, detail=f"Node '{node_id}' not attached to this session.")

        nodes.remove(node_id)
        session.attached_node_ids = nodes

        status = dict(session.node_sync_status or {})
        status.pop(node_id, None)
        session.node_sync_status = status

        db.commit()

        # Purge workspace on the detached node
        orchestrator = getattr(services, "orchestrator", None)
        if orchestrator:
            orchestrator.assistant.clear_workspace(node_id, session.sync_workspace_id)

        return {"message": f"Node '{node_id}' detached from session {session_id}."}

    @router.get("/{session_id}/nodes", response_model=schemas.SessionNodeStatusResponse,
                summary="Get Session Node Status")
    def get_session_nodes(session_id: int, db: Session = Depends(get_db)):
        """
        Returns all nodes attached to a session and their current sync status.
        Merges persisted sync_status with live connection state from the registry.
        """
        session = db.query(models.Session).filter(
            models.Session.id == session_id,
            models.Session.is_archived == False
        ).first()
        if not session:
            raise HTTPException(status_code=404, detail="Session not found.")

        registry = services.node_registry_service
        sync_status = session.node_sync_status or {}
        entries = []
        for nid in (session.attached_node_ids or []):
            live = registry.get_node(nid)
            persisted = sync_status.get(nid, {})
            # If node is live and was previously pending, show as 'connected'
            if live and persisted.get("status") == "pending":
                status_val = "connected"
            else:
                status_val = persisted.get("status", "pending")

            entries.append(schemas.NodeSyncStatusEntry(
                node_id=nid,
                status=status_val,
                last_sync=persisted.get("last_sync"),
                error=persisted.get("error"),
            ))

        return schemas.SessionNodeStatusResponse(
            session_id=session_id,
            sync_workspace_id=session.sync_workspace_id,
            nodes=entries,
            sync_config=session.sync_config or {},
        )


    @router.post("/{session_id}/cancel", summary="Cancel Running AI Task")
    def cancel_session_task(
        session_id: int,
        db: Session = Depends(get_db)
    ):
        session = db.query(models.Session).filter(models.Session.id == session_id).first()
        if not session:
            raise HTTPException(status_code=404, detail="Session not found.")
        
        session.is_cancelled = True
        db.commit()
        return {"message": "Cancellation request sent (Watchdog will interrupt on next turn)."}

    return router