Newer
Older
cortex-hub / ai-hub / app / api / routes.py
from fastapi import APIRouter, HTTPException, Depends, Query
from fastapi.responses import Response, StreamingResponse
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from typing import AsyncGenerator

def create_api_router(services: ServiceContainer) -> APIRouter:
    """
    Creates and returns an APIRouter with all the application's endpoints.
    """
    router = APIRouter()

    @router.get("/", summary="Check Service Status", tags=["General"])
    def read_root():
        return {"status": "AI Model Hub is running!"}

    # --- Session Management Endpoints ---

    @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"])
    def create_session(
        request: schemas.SessionCreate,
        db: Session = Depends(get_db)
    ):
        try:
            new_session = services.rag_service.create_session(
                db=db,
                user_id=request.user_id,
                model=request.model
            )
            return new_session
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to create session: {e}")

    @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"])
    async def chat_in_session(
        session_id: int,
        request: schemas.ChatRequest,
        db: Session = Depends(get_db)
    ):
        try:
            response_text, model_used = await services.rag_service.chat_with_rag(
                db=db,
                session_id=session_id,
                prompt=request.prompt,
                model=request.model,
                load_faiss_retriever=request.load_faiss_retriever
            )
            return schemas.ChatResponse(answer=response_text, model_used=model_used)
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")

    @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"])
    def get_session_messages(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            return schemas.MessageHistoryResponse(session_id=session_id, messages=messages)
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    # --- Document Management Endpoints ---

    @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"])
    def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)):
        try:
            doc_data = doc.model_dump()
            document_id = services.document_service.add_document(db=db, doc_data=doc_data)
            return schemas.DocumentResponse(
                message=f"Document '{doc.title}' added successfully with ID {document_id}"
            )
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"])
    def get_documents(db: Session = Depends(get_db)):
        try:
            documents_from_db = services.document_service.get_all_documents(db=db)
            return {"documents": documents_from_db}
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"])
    def delete_document(document_id: int, db: Session = Depends(get_db)):
        try:
            deleted_id = services.document_service.delete_document(db=db, document_id=document_id)
            if deleted_id is None:
                raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.")

            return schemas.DocumentDeleteResponse(
                message="Document deleted successfully",
                document_id=deleted_id
            )
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    # --- Consolidated Speech Endpoint ---

    @router.post(
        "/speech",
        summary="Generate speech from text",
        tags=["TTS"],
        response_description="Audio bytes in WAV format, either as a complete file or a stream.",
    )
    async def create_speech_response(
        request: schemas.SpeechRequest,
        stream: bool = Query(
            False,
            description="If true, returns a streamed audio response. Otherwise, returns a complete file."
        )
    ):
        """
        Generates an audio file or a streaming audio response from the provided text.
        By default, it returns a complete audio file.
        To get a streaming response, set the 'stream' query parameter to 'true'.
        """
        try:
            if stream:
                # Use the streaming service method
                audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream(
                    text=request.text
                )
                return StreamingResponse(audio_stream_generator, media_type="audio/wav")
            else:
                # Use the non-streaming service method
                audio_bytes = await services.tts_service.create_speech_non_stream(
                    text=request.text
                )
                return Response(content=audio_bytes, media_type="audio/wav")

        except HTTPException:
            raise  # Re-raise existing HTTPException
        except Exception as e:
            raise HTTPException(
                status_code=500, detail=f"Failed to generate speech: {e}"
            )

    return router