Newer
Older
cortex-hub / ai-hub / app / api / routes / sessions.py
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from typing import AsyncGenerator, List
from app.db import models
from app.core.pipelines.validator import Validator

def create_sessions_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/sessions", tags=["Sessions"])

    @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session")
    def create_session(
        request: schemas.SessionCreate,
        db: Session = Depends(get_db)
    ):
        if request.user_id is None or request.provider_name is None:
            raise HTTPException(status_code=400, detail="user_id and provider_name are required to create a session.")
        try:
            new_session = services.session_service.create_session(
                db=db,
                user_id=request.user_id,
                provider_name=request.provider_name,
                feature_name=request.feature_name
            )
            return new_session
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to create session: {e}")

    @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session")
    async def chat_in_session(
        session_id: int,
        request: schemas.ChatRequest,
        db: Session = Depends(get_db)
    ):
        try:
            response_text, provider_used = await services.rag_service.chat_with_rag(
                db=db,
                session_id=session_id,
                prompt=request.prompt,
                provider_name=request.provider_name,
                load_faiss_retriever=request.load_faiss_retriever
            )
            return schemas.ChatResponse(answer=response_text, provider_used=provider_used)
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")

    @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History")
    def get_session_messages(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            return schemas.MessageHistoryResponse(session_id=session_id, messages=messages)
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
            
    @router.get("/{session_id}/tokens", response_model=schemas.SessionTokenUsageResponse, summary="Get Session Token Usage")
    def get_session_token_usage(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            combined_text = " ".join([m.content for m in messages])
            validator = Validator()
            token_count = len(validator.encoding.encode(combined_text))
            token_limit = validator.token_limit
            percentage = round((token_count / token_limit) * 100, 2) if token_limit > 0 else 0.0

            return schemas.SessionTokenUsageResponse(
                token_count=token_count,
                token_limit=token_limit,
                percentage=percentage
            )
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.get("/", response_model=List[schemas.Session], summary="Get All Chat Sessions")
    def get_sessions(
        user_id: str,
        feature_name: str = "default",
        db: Session = Depends(get_db)
    ):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name,
                models.Session.is_archived == False
            ).order_by(models.Session.created_at.desc()).all()
            return sessions
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch sessions: {e}")

    @router.get("/{session_id}", response_model=schemas.Session, summary="Get a Single Session")
    def get_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(
                models.Session.id == session_id,
                models.Session.is_archived == False
            ).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            return session
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}")

    @router.delete("/{session_id}", summary="Delete a Chat Session")
    def delete_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(models.Session.id == session_id).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            session.is_archived = True
            db.commit()
            return {"message": "Session deleted successfully."}
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete session: {e}")

    @router.delete("/", summary="Delete All Sessions for Feature")
    def delete_all_sessions(user_id: str, feature_name: str = "default", db: Session = Depends(get_db)):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name
            ).all()
            for session in sessions:
                session.is_archived = True
            db.commit()
            return {"message": "All sessions deleted successfully."}
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}")
            
    return router