from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from typing import AsyncGenerator, List
from app.db import models
from app.core.pipelines.validator import Validator
def create_sessions_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/sessions", tags=["Sessions"])
@router.post("/", response_model=schemas.Session, summary="Create a New Chat Session")
def create_session(
request: schemas.SessionCreate,
db: Session = Depends(get_db)
):
if request.user_id is None or request.provider_name is None:
raise HTTPException(status_code=400, detail="user_id and provider_name are required to create a session.")
try:
new_session = services.session_service.create_session(
db=db,
user_id=request.user_id,
provider_name=request.provider_name,
feature_name=request.feature_name
)
return new_session
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create session: {e}")
@router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session")
async def chat_in_session(
session_id: int,
request: schemas.ChatRequest,
db: Session = Depends(get_db)
):
try:
response_text, provider_used = await services.rag_service.chat_with_rag(
db=db,
session_id=session_id,
prompt=request.prompt,
provider_name=request.provider_name,
load_faiss_retriever=request.load_faiss_retriever
)
return schemas.ChatResponse(answer=response_text, provider_used=provider_used)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")
@router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History")
def get_session_messages(session_id: int, db: Session = Depends(get_db)):
try:
messages = services.rag_service.get_message_history(db=db, session_id=session_id)
if messages is None:
raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")
return schemas.MessageHistoryResponse(session_id=session_id, messages=messages)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
@router.get("/{session_id}/tokens", response_model=schemas.SessionTokenUsageResponse, summary="Get Session Token Usage")
def get_session_token_usage(session_id: int, db: Session = Depends(get_db)):
try:
messages = services.rag_service.get_message_history(db=db, session_id=session_id)
if messages is None:
raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")
combined_text = " ".join([m.content for m in messages])
validator = Validator()
token_count = len(validator.encoding.encode(combined_text))
token_limit = validator.token_limit
percentage = round((token_count / token_limit) * 100, 2) if token_limit > 0 else 0.0
return schemas.SessionTokenUsageResponse(
token_count=token_count,
token_limit=token_limit,
percentage=percentage
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
@router.get("/", response_model=List[schemas.Session], summary="Get All Chat Sessions")
def get_sessions(
user_id: str,
feature_name: str = "default",
db: Session = Depends(get_db)
):
try:
sessions = db.query(models.Session).filter(
models.Session.user_id == user_id,
models.Session.feature_name == feature_name,
models.Session.is_archived == False
).order_by(models.Session.created_at.desc()).all()
return sessions
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch sessions: {e}")
@router.get("/{session_id}", response_model=schemas.Session, summary="Get a Single Session")
def get_session(session_id: int, db: Session = Depends(get_db)):
try:
session = db.query(models.Session).filter(
models.Session.id == session_id,
models.Session.is_archived == False
).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found.")
return session
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}")
@router.delete("/{session_id}", summary="Delete a Chat Session")
def delete_session(session_id: int, db: Session = Depends(get_db)):
try:
session = db.query(models.Session).filter(models.Session.id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found.")
session.is_archived = True
db.commit()
return {"message": "Session deleted successfully."}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete session: {e}")
@router.delete("/", summary="Delete All Sessions for Feature")
def delete_all_sessions(user_id: str, feature_name: str = "default", db: Session = Depends(get_db)):
try:
sessions = db.query(models.Session).filter(
models.Session.user_id == user_id,
models.Session.feature_name == feature_name
).all()
for session in sessions:
session.is_archived = True
db.commit()
return {"message": "All sessions deleted successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}")
return router