diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/tests/api/routes/test_documents.py b/ai-hub/tests/api/routes/test_documents.py new file mode 100644 index 0000000..670b375 --- /dev/null +++ b/ai-hub/tests/api/routes/test_documents.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock +from datetime import datetime +from app.db import models + +def test_add_document_success(client): + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 + + # Update the payload to include the default values from the Pydantic model + doc_payload = { + "title": "Test Doc", + "text": "Content here", + "source_url": None, # Add these based on your schema's defaults + "author": None, + "user_id": "default_user", + } + + # The payload sent to the endpoint is just title and text + request_payload = {"title": "Test Doc", "text": "Content here"} + + response = test_client.post("/documents", json=request_payload) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" + + # Assert against the full payload that the router will generate and pass to the service + mock_services.document_service.add_document.assert_called_once_with( + db=mock_services.document_service.add_document.call_args.kwargs['db'], + doc_data=doc_payload + ) + +def test_get_documents_success(client): + test_client, mock_services = client + + # Ensure mock attributes are valid types for the schema + mock_docs = [ + MagicMock( + spec=models.Document, + id=1, + title="Doc One", + status="ready", + created_at=datetime.now(), + source_url="http://example.com/doc1", # Explicitly set a string value + text="text one", + author="author one" + ), + MagicMock( + spec=models.Document, + id=2, + title="Doc Two", + status="processing", + created_at=datetime.now(), + source_url=None, # Or explicitly set to None if that's a valid case + text="text two", + author="author two" + ) + ] + + mock_services.document_service.get_all_documents.return_value = mock_docs + + response = test_client.get("/documents") + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["documents"]) == 2 + # You can also add more specific assertions + assert response_data["documents"][0]["title"] == "Doc One" + assert response_data["documents"][1]["status"] == "processing" + + mock_services.document_service.get_all_documents.assert_called_once() + +def test_delete_document_success(client): + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 + response = test_client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=42 + ) + +def test_delete_document_not_found(client): + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None + response = test_client.delete("/documents/999") + assert response.status_code == 404 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=999 + ) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/tests/api/routes/test_documents.py b/ai-hub/tests/api/routes/test_documents.py new file mode 100644 index 0000000..670b375 --- /dev/null +++ b/ai-hub/tests/api/routes/test_documents.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock +from datetime import datetime +from app.db import models + +def test_add_document_success(client): + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 + + # Update the payload to include the default values from the Pydantic model + doc_payload = { + "title": "Test Doc", + "text": "Content here", + "source_url": None, # Add these based on your schema's defaults + "author": None, + "user_id": "default_user", + } + + # The payload sent to the endpoint is just title and text + request_payload = {"title": "Test Doc", "text": "Content here"} + + response = test_client.post("/documents", json=request_payload) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" + + # Assert against the full payload that the router will generate and pass to the service + mock_services.document_service.add_document.assert_called_once_with( + db=mock_services.document_service.add_document.call_args.kwargs['db'], + doc_data=doc_payload + ) + +def test_get_documents_success(client): + test_client, mock_services = client + + # Ensure mock attributes are valid types for the schema + mock_docs = [ + MagicMock( + spec=models.Document, + id=1, + title="Doc One", + status="ready", + created_at=datetime.now(), + source_url="http://example.com/doc1", # Explicitly set a string value + text="text one", + author="author one" + ), + MagicMock( + spec=models.Document, + id=2, + title="Doc Two", + status="processing", + created_at=datetime.now(), + source_url=None, # Or explicitly set to None if that's a valid case + text="text two", + author="author two" + ) + ] + + mock_services.document_service.get_all_documents.return_value = mock_docs + + response = test_client.get("/documents") + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["documents"]) == 2 + # You can also add more specific assertions + assert response_data["documents"][0]["title"] == "Doc One" + assert response_data["documents"][1]["status"] == "processing" + + mock_services.document_service.get_all_documents.assert_called_once() + +def test_delete_document_success(client): + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 + response = test_client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=42 + ) + +def test_delete_document_not_found(client): + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None + response = test_client.delete("/documents/999") + assert response.status_code == 404 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=999 + ) \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_general.py b/ai-hub/tests/api/routes/test_general.py new file mode 100644 index 0000000..cd4466e --- /dev/null +++ b/ai-hub/tests/api/routes/test_general.py @@ -0,0 +1,6 @@ +def test_read_root(client): + """Tests the root endpoint for a successful response.""" + test_client, _ = client + response = test_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/tests/api/routes/test_documents.py b/ai-hub/tests/api/routes/test_documents.py new file mode 100644 index 0000000..670b375 --- /dev/null +++ b/ai-hub/tests/api/routes/test_documents.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock +from datetime import datetime +from app.db import models + +def test_add_document_success(client): + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 + + # Update the payload to include the default values from the Pydantic model + doc_payload = { + "title": "Test Doc", + "text": "Content here", + "source_url": None, # Add these based on your schema's defaults + "author": None, + "user_id": "default_user", + } + + # The payload sent to the endpoint is just title and text + request_payload = {"title": "Test Doc", "text": "Content here"} + + response = test_client.post("/documents", json=request_payload) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" + + # Assert against the full payload that the router will generate and pass to the service + mock_services.document_service.add_document.assert_called_once_with( + db=mock_services.document_service.add_document.call_args.kwargs['db'], + doc_data=doc_payload + ) + +def test_get_documents_success(client): + test_client, mock_services = client + + # Ensure mock attributes are valid types for the schema + mock_docs = [ + MagicMock( + spec=models.Document, + id=1, + title="Doc One", + status="ready", + created_at=datetime.now(), + source_url="http://example.com/doc1", # Explicitly set a string value + text="text one", + author="author one" + ), + MagicMock( + spec=models.Document, + id=2, + title="Doc Two", + status="processing", + created_at=datetime.now(), + source_url=None, # Or explicitly set to None if that's a valid case + text="text two", + author="author two" + ) + ] + + mock_services.document_service.get_all_documents.return_value = mock_docs + + response = test_client.get("/documents") + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["documents"]) == 2 + # You can also add more specific assertions + assert response_data["documents"][0]["title"] == "Doc One" + assert response_data["documents"][1]["status"] == "processing" + + mock_services.document_service.get_all_documents.assert_called_once() + +def test_delete_document_success(client): + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 + response = test_client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=42 + ) + +def test_delete_document_not_found(client): + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None + response = test_client.delete("/documents/999") + assert response.status_code == 404 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=999 + ) \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_general.py b/ai-hub/tests/api/routes/test_general.py new file mode 100644 index 0000000..cd4466e --- /dev/null +++ b/ai-hub/tests/api/routes/test_general.py @@ -0,0 +1,6 @@ +def test_read_root(client): + """Tests the root endpoint for a successful response.""" + test_client, _ = client + response = test_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_sessions.py b/ai-hub/tests/api/routes/test_sessions.py new file mode 100644 index 0000000..10334d2 --- /dev/null +++ b/ai-hub/tests/api/routes/test_sessions.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock, AsyncMock +from datetime import datetime +from app.db import models + +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" + test_client, mock_services = client + mock_session = MagicMock(spec=models.Session) + mock_session.id = 1 + mock_session.user_id = "test_user" + mock_session.model_name = "gemini" + mock_session.title = "New Chat" + mock_session.created_at = datetime.now() + mock_services.rag_service.create_session.return_value = mock_session + + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json()["id"] == 1 + mock_services.rag_service.create_session.assert_called_once() + +def test_chat_in_session_success(client): + """ + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", + load_faiss_retriever=True + ) + +def test_get_session_messages_success(client): + """Tests retrieving the message history for a session.""" + test_client, mock_services = client + mock_history = [ + MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()), + MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now()) + ] + mock_services.rag_service.get_message_history.return_value = mock_history + + response = test_client.get("/sessions/123/messages") + + assert response.status_code == 200 + response_data = response.json() + assert response_data["session_id"] == 123 + assert len(response_data["messages"]) == 2 + assert response_data["messages"][0]["sender"] == "user" + assert response_data["messages"][1]["content"] == "Hi there!" + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) + +def test_get_session_messages_not_found(client): + """Tests retrieving messages for a session that does not exist.""" + test_client, mock_services = client + mock_services.rag_service.get_message_history.return_value = None + + response = test_client.get("/sessions/999/messages") + + assert response.status_code == 404 + assert response.json()["detail"] == "Session with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/tests/api/routes/test_documents.py b/ai-hub/tests/api/routes/test_documents.py new file mode 100644 index 0000000..670b375 --- /dev/null +++ b/ai-hub/tests/api/routes/test_documents.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock +from datetime import datetime +from app.db import models + +def test_add_document_success(client): + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 + + # Update the payload to include the default values from the Pydantic model + doc_payload = { + "title": "Test Doc", + "text": "Content here", + "source_url": None, # Add these based on your schema's defaults + "author": None, + "user_id": "default_user", + } + + # The payload sent to the endpoint is just title and text + request_payload = {"title": "Test Doc", "text": "Content here"} + + response = test_client.post("/documents", json=request_payload) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" + + # Assert against the full payload that the router will generate and pass to the service + mock_services.document_service.add_document.assert_called_once_with( + db=mock_services.document_service.add_document.call_args.kwargs['db'], + doc_data=doc_payload + ) + +def test_get_documents_success(client): + test_client, mock_services = client + + # Ensure mock attributes are valid types for the schema + mock_docs = [ + MagicMock( + spec=models.Document, + id=1, + title="Doc One", + status="ready", + created_at=datetime.now(), + source_url="http://example.com/doc1", # Explicitly set a string value + text="text one", + author="author one" + ), + MagicMock( + spec=models.Document, + id=2, + title="Doc Two", + status="processing", + created_at=datetime.now(), + source_url=None, # Or explicitly set to None if that's a valid case + text="text two", + author="author two" + ) + ] + + mock_services.document_service.get_all_documents.return_value = mock_docs + + response = test_client.get("/documents") + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["documents"]) == 2 + # You can also add more specific assertions + assert response_data["documents"][0]["title"] == "Doc One" + assert response_data["documents"][1]["status"] == "processing" + + mock_services.document_service.get_all_documents.assert_called_once() + +def test_delete_document_success(client): + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 + response = test_client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=42 + ) + +def test_delete_document_not_found(client): + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None + response = test_client.delete("/documents/999") + assert response.status_code == 404 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=999 + ) \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_general.py b/ai-hub/tests/api/routes/test_general.py new file mode 100644 index 0000000..cd4466e --- /dev/null +++ b/ai-hub/tests/api/routes/test_general.py @@ -0,0 +1,6 @@ +def test_read_root(client): + """Tests the root endpoint for a successful response.""" + test_client, _ = client + response = test_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_sessions.py b/ai-hub/tests/api/routes/test_sessions.py new file mode 100644 index 0000000..10334d2 --- /dev/null +++ b/ai-hub/tests/api/routes/test_sessions.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock, AsyncMock +from datetime import datetime +from app.db import models + +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" + test_client, mock_services = client + mock_session = MagicMock(spec=models.Session) + mock_session.id = 1 + mock_session.user_id = "test_user" + mock_session.model_name = "gemini" + mock_session.title = "New Chat" + mock_session.created_at = datetime.now() + mock_services.rag_service.create_session.return_value = mock_session + + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json()["id"] == 1 + mock_services.rag_service.create_session.assert_called_once() + +def test_chat_in_session_success(client): + """ + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", + load_faiss_retriever=True + ) + +def test_get_session_messages_success(client): + """Tests retrieving the message history for a session.""" + test_client, mock_services = client + mock_history = [ + MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()), + MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now()) + ] + mock_services.rag_service.get_message_history.return_value = mock_history + + response = test_client.get("/sessions/123/messages") + + assert response.status_code == 200 + response_data = response.json() + assert response_data["session_id"] == 123 + assert len(response_data["messages"]) == 2 + assert response_data["messages"][0]["sender"] == "user" + assert response_data["messages"][1]["content"] == "Hi there!" + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) + +def test_get_session_messages_not_found(client): + """Tests retrieving messages for a session that does not exist.""" + test_client, mock_services = client + mock_services.rag_service.get_message_history.return_value = None + + response = test_client.get("/sessions/999/messages") + + assert response.status_code == 404 + assert response.json()["detail"] == "Session with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_tts.py b/ai-hub/tests/api/routes/test_tts.py new file mode 100644 index 0000000..cd4f14e --- /dev/null +++ b/ai-hub/tests/api/routes/test_tts.py @@ -0,0 +1,46 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock + +@pytest.mark.asyncio +async def test_create_speech_response(async_client): + """Test the /speech endpoint returns audio bytes without streaming.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes = b"fake wav audio bytes" + + # Use AsyncMock for the async function create_speech_non_stream + mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) + + response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert response.content == mock_audio_bytes + + mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") + +@pytest.mark.asyncio +async def test_create_speech_stream_response(async_client): + """Test the /speech endpoint with stream=true returns a streaming response.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # This async generator mock correctly simulates the streaming service + async def mock_async_generator(): + for chunk in mock_audio_bytes_chunks: + yield chunk + + # We mock `create_speech_stream` with a MagicMock returning the async generator + mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) + + response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + # Read the streamed content and verify it matches the mocked chunks + streamed_content = b"" + async for chunk in response.aiter_bytes(): + streamed_content += chunk + + assert streamed_content == b"".join(mock_audio_bytes_chunks) + mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index ce8b2ac..fd05079 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -1,144 +1,144 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from fastapi.responses import Response, StreamingResponse -from sqlalchemy.orm import Session -from app.api.dependencies import ServiceContainer, get_db -from app.api import schemas -from typing import AsyncGenerator +# from fastapi import APIRouter, HTTPException, Depends, Query +# from fastapi.responses import Response, StreamingResponse +# from sqlalchemy.orm import Session +# from app.api.dependencies import ServiceContainer, get_db +# from app.api import schemas +# from typing import AsyncGenerator -def create_api_router(services: ServiceContainer) -> APIRouter: - """ - Creates and returns an APIRouter with all the application's endpoints. - """ - router = APIRouter() +# def create_api_router(services: ServiceContainer) -> APIRouter: +# """ +# Creates and returns an APIRouter with all the application's endpoints. +# """ +# router = APIRouter() - @router.get("/", summary="Check Service Status", tags=["General"]) - def read_root(): - return {"status": "AI Model Hub is running!"} +# @router.get("/", summary="Check Service Status", tags=["General"]) +# def read_root(): +# return {"status": "AI Model Hub is running!"} - # --- Session Management Endpoints --- +# # --- Session Management Endpoints --- - @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) - def create_session( - request: schemas.SessionCreate, - db: Session = Depends(get_db) - ): - try: - new_session = services.rag_service.create_session( - db=db, - user_id=request.user_id, - model=request.model - ) - return new_session - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") +# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) +# def create_session( +# request: schemas.SessionCreate, +# db: Session = Depends(get_db) +# ): +# try: +# new_session = services.rag_service.create_session( +# db=db, +# user_id=request.user_id, +# model=request.model +# ) +# return new_session +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) - async def chat_in_session( - session_id: int, - request: schemas.ChatRequest, - db: Session = Depends(get_db) - ): - try: - response_text, model_used = await services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=request.prompt, - model=request.model, - load_faiss_retriever=request.load_faiss_retriever - ) - return schemas.ChatResponse(answer=response_text, model_used=model_used) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") +# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) +# async def chat_in_session( +# session_id: int, +# request: schemas.ChatRequest, +# db: Session = Depends(get_db) +# ): +# try: +# response_text, model_used = await services.rag_service.chat_with_rag( +# db=db, +# session_id=session_id, +# prompt=request.prompt, +# model=request.model, +# load_faiss_retriever=request.load_faiss_retriever +# ) +# return schemas.ChatResponse(answer=response_text, model_used=model_used) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) - def get_session_messages(session_id: int, db: Session = Depends(get_db)): - try: - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - if messages is None: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") +# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) +# def get_session_messages(session_id: int, db: Session = Depends(get_db)): +# try: +# messages = services.rag_service.get_message_history(db=db, session_id=session_id) +# if messages is None: +# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Document Management Endpoints --- +# # --- Document Management Endpoints --- - @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): - try: - doc_data = doc.model_dump() - document_id = services.document_service.add_document(db=db, doc_data=doc_data) - return schemas.DocumentResponse( - message=f"Document '{doc.title}' added successfully with ID {document_id}" - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) +# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): +# try: +# doc_data = doc.model_dump() +# document_id = services.document_service.add_document(db=db, doc_data=doc_data) +# return schemas.DocumentResponse( +# message=f"Document '{doc.title}' added successfully with ID {document_id}" +# ) +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) - def get_documents(db: Session = Depends(get_db)): - try: - documents_from_db = services.document_service.get_all_documents(db=db) - return {"documents": documents_from_db} - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) +# def get_documents(db: Session = Depends(get_db)): +# try: +# documents_from_db = services.document_service.get_all_documents(db=db) +# return {"documents": documents_from_db} +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) - def delete_document(document_id: int, db: Session = Depends(get_db)): - try: - deleted_id = services.document_service.delete_document(db=db, document_id=document_id) - if deleted_id is None: - raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") +# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) +# def delete_document(document_id: int, db: Session = Depends(get_db)): +# try: +# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) +# if deleted_id is None: +# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - return schemas.DocumentDeleteResponse( - message="Document deleted successfully", - document_id=deleted_id - ) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") +# return schemas.DocumentDeleteResponse( +# message="Document deleted successfully", +# document_id=deleted_id +# ) +# except HTTPException: +# raise +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - # --- Consolidated Speech Endpoint --- +# # --- Consolidated Speech Endpoint --- - @router.post( - "/speech", - summary="Generate speech from text", - tags=["TTS"], - response_description="Audio bytes in WAV format, either as a complete file or a stream.", - ) - async def create_speech_response( - request: schemas.SpeechRequest, - stream: bool = Query( - False, - description="If true, returns a streamed audio response. Otherwise, returns a complete file." - ) - ): - """ - Generates an audio file or a streaming audio response from the provided text. - By default, it returns a complete audio file. - To get a streaming response, set the 'stream' query parameter to 'true'. - """ - try: - if stream: - # Use the streaming service method - audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( - text=request.text - ) - return StreamingResponse(audio_stream_generator, media_type="audio/wav") - else: - # Use the non-streaming service method - audio_bytes = await services.tts_service.create_speech_non_stream( - text=request.text - ) - return Response(content=audio_bytes, media_type="audio/wav") +# @router.post( +# "/speech", +# summary="Generate speech from text", +# tags=["TTS"], +# response_description="Audio bytes in WAV format, either as a complete file or a stream.", +# ) +# async def create_speech_response( +# request: schemas.SpeechRequest, +# stream: bool = Query( +# False, +# description="If true, returns a streamed audio response. Otherwise, returns a complete file." +# ) +# ): +# """ +# Generates an audio file or a streaming audio response from the provided text. +# By default, it returns a complete audio file. +# To get a streaming response, set the 'stream' query parameter to 'true'. +# """ +# try: +# if stream: +# # Use the streaming service method +# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( +# text=request.text +# ) +# return StreamingResponse(audio_stream_generator, media_type="audio/wav") +# else: +# # Use the non-streaming service method +# audio_bytes = await services.tts_service.create_speech_non_stream( +# text=request.text +# ) +# return Response(content=audio_bytes, media_type="audio/wav") - except HTTPException: - raise # Re-raise existing HTTPException - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to generate speech: {e}" - ) +# except HTTPException: +# raise # Re-raise existing HTTPException +# except Exception as e: +# raise HTTPException( +# status_code=500, detail=f"Failed to generate speech: {e}" +# ) - return router \ No newline at end of file +# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/__init__.py b/ai-hub/app/api/routes/__init__.py new file mode 100644 index 0000000..3fbb1fd --- /dev/null +++ b/ai-hub/app/api/routes/__init__.py @@ -0,0 +1 @@ +# This file can be left empty. diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py new file mode 100644 index 0000000..b052e7d --- /dev/null +++ b/ai-hub/app/api/routes/api.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from app.api.dependencies import ServiceContainer + +# Import routers +from .general import router as general_router +from .sessions import create_sessions_router +from .documents import create_documents_router +from .tts import create_tts_router + +def create_api_router(services: ServiceContainer) -> APIRouter: + """ + Creates and returns a main APIRouter that includes all sub-routers. + """ + router = APIRouter() + + # Include routers for different functionalities + router.include_router(general_router) + router.include_router(create_sessions_router(services)) + router.include_router(create_documents_router(services)) + router.include_router(create_tts_router(services)) + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/documents.py b/ai-hub/app/api/routes/documents.py new file mode 100644 index 0000000..403a712 --- /dev/null +++ b/ai-hub/app/api/routes/documents.py @@ -0,0 +1,44 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas + +def create_documents_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/documents", tags=["Documents"]) + + @router.post("/", response_model=schemas.DocumentResponse, summary="Add a New Document") + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): + try: + doc_data = doc.model_dump() + document_id = services.document_service.add_document(db=db, doc_data=doc_data) + return schemas.DocumentResponse( + message=f"Document '{doc.title}' added successfully with ID {document_id}" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.get("/", response_model=schemas.DocumentListResponse, summary="List All Documents") + def get_documents(db: Session = Depends(get_db)): + try: + documents_from_db = services.document_service.get_all_documents(db=db) + return {"documents": documents_from_db} + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + @router.delete("/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document") + def delete_document(document_id: int, db: Session = Depends(get_db)): + try: + deleted_id = services.document_service.delete_document(db=db, document_id=document_id) + if deleted_id is None: + raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") + + return schemas.DocumentDeleteResponse( + message="Document deleted successfully", + document_id=deleted_id + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/general.py b/ai-hub/app/api/routes/general.py new file mode 100644 index 0000000..c744fa3 --- /dev/null +++ b/ai-hub/app/api/routes/general.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["General"]) + +@router.get("/", summary="Check Service Status") +def read_root(): + return {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py new file mode 100644 index 0000000..2515147 --- /dev/null +++ b/ai-hub/app/api/routes/sessions.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from typing import AsyncGenerator + +def create_sessions_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/sessions", tags=["Sessions"]) + + @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session") + def create_session( + request: schemas.SessionCreate, + db: Session = Depends(get_db) + ): + try: + new_session = services.rag_service.create_session( + db=db, + user_id=request.user_id, + model=request.model + ) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session") + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, + db: Session = Depends(get_db) + ): + try: + response_text, model_used = await services.rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt, + model=request.model, + load_faiss_retriever=request.load_faiss_retriever + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") + + @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History") + def get_session_messages(session_id: int, db: Session = Depends(get_db)): + try: + messages = services.rag_service.get_message_history(db=db, session_id=session_id) + if messages is None: + raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + + return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py new file mode 100644 index 0000000..557924a --- /dev/null +++ b/ai-hub/app/api/routes/tts.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, HTTPException, Query, Response +from fastapi.responses import StreamingResponse +from app.api.dependencies import ServiceContainer +from app.api import schemas +from typing import AsyncGenerator + +def create_tts_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/speech", tags=["TTS"]) + + @router.post( + "", + summary="Generate speech from text", + response_description="Audio bytes in WAV format, either as a complete file or a stream.", + ) + async def create_speech_response( + request: schemas.SpeechRequest, + stream: bool = Query( + False, + description="If true, returns a streamed audio response. Otherwise, returns a complete file." + ) + ): + try: + if stream: + audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( + text=request.text + ) + return StreamingResponse(audio_stream_generator, media_type="audio/wav") + else: + audio_bytes = await services.tts_service.create_speech_non_stream( + text=request.text + ) + return Response(content=audio_bytes, media_type="audio/wav") + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to generate speech: {e}" + ) + + return router \ No newline at end of file diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a56fba6..3afc104 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -10,7 +10,7 @@ from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables -from app.api.routes import create_api_router +from app.api.routes.api import create_api_router from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index 5d51010..60a854d 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -8,7 +8,7 @@ """Fixture to provide the base URL for the tests.""" return BASE_URL -@pytest_asyncio.fixture(scope="function") # <-- Change scope to "function" +@pytest_asyncio.fixture(scope="function") async def http_client(): """ Fixture to provide an async HTTP client for all tests in the session. @@ -28,7 +28,8 @@ Returns the session ID. """ payload = {"user_id": "integration_tester", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + # The URL has been updated to include the trailing slash + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] yield session_id @@ -42,7 +43,8 @@ Returns the document ID. """ doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - response = await http_client.post("/documents", json=doc_data) + # The URL has been updated to include the trailing slash + response = await http_client.post("/documents/", json=doc_data) assert response.status_code == 200 try: message = response.json().get("message", "") @@ -54,4 +56,4 @@ # Teardown: Delete the document after the test delete_response = await http_client.delete(f"/documents/{document_id}") - assert delete_response.status_code == 200 \ No newline at end of file + assert delete_response.status_code == 200 diff --git a/ai-hub/integration_tests/test_documents.py b/ai-hub/integration_tests/test_documents.py index c6d4362..ea3ecb1 100644 --- a/ai-hub/integration_tests/test_documents.py +++ b/ai-hub/integration_tests/test_documents.py @@ -10,7 +10,8 @@ # 1. Add a new document doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - add_response = await http_client.post("/documents", json=doc_data) + # Correct the URL to include the trailing slash to avoid the 307 redirect + add_response = await http_client.post("/documents/", json=doc_data) assert add_response.status_code == 200 try: message = add_response.json().get("message", "") @@ -20,7 +21,8 @@ print(f"✅ Document for lifecycle test created with ID: {document_id}") # 2. List all documents and check if the new document is present - list_response = await http_client.get("/documents") + # Correct the URL to include the trailing slash + list_response = await http_client.get("/documents/") assert list_response.status_code == 200 ids_in_response = {doc["id"] for doc in list_response.json()["documents"]} assert document_id in ids_in_response @@ -30,4 +32,5 @@ delete_response = await http_client.delete(f"/documents/{document_id}") assert delete_response.status_code == 200 assert delete_response.json()["document_id"] == document_id - print("✅ Document delete test passed.") \ No newline at end of file + print("✅ Document delete test passed.") + diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py deleted file mode 100644 index 541881f..0000000 --- a/ai-hub/integration_tests/test_integration.py +++ /dev/null @@ -1,216 +0,0 @@ -import pytest -import httpx - -# The base URL for the local server started by the run_tests.sh script -BASE_URL = "http://127.0.0.1:8000" - -# A common prompt to be used for the tests -TEST_PROMPT = "Explain the theory of relativity in one sentence." -CONTEXT_PROMPT = "Who is the CEO of Microsoft?" -FOLLOW_UP_PROMPT = "When was he born?" - -# Document and prompt for the retrieval-augmented generation test -RAG_DOC_TITLE = "Fictional Company History" -RAG_DOC_TEXT = "The company AlphaCorp was founded in 2021 by Jane Doe. Their primary product is a smart home device called 'Nexus'." -RAG_PROMPT = "Who founded AlphaCorp and what is their main product?" - -# Global variables to pass state between sequential tests -created_document_id = None -created_session_id = None - -async def test_root_endpoint(): - """ - Tests if the root endpoint is alive and returns the correct status message. - """ - print("\n--- Running test_root_endpoint ---") - async with httpx.AsyncClient() as client: - response = await client.get(f"{BASE_URL}/") - - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} - print("✅ Root endpoint test passed.") - -# --- Session and Chat Lifecycle Tests --- - -async def test_create_session(): - """Tests creating a new chat session and saves the ID for the next test.""" - global created_session_id - print("\n--- Running test_create_session ---") - url = f"{BASE_URL}/sessions" - payload = {"user_id": "integration_tester", "model": "deepseek"} - - async with httpx.AsyncClient() as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Failed to create session. Response: {response.text}" - response_data = response.json() - assert "id" in response_data - created_session_id = response_data["id"] - print(f"✅ Session created successfully with ID: {created_session_id}") - -async def test_chat_in_session_turn_1(): - """Tests sending the first message to establish context using the default model.""" - print("\n--- Running test_chat_in_session (Turn 1) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": CONTEXT_PROMPT} # Model defaults to "deepseek" as per schema - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer mentions the CEO's name (assuming DeepSeek provides this) - assert "Satya Nadella" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 1 (context) test passed.") - -async def test_chat_in_session_turn_2_follow_up(): - """ - Tests sending a follow-up question to verify conversational memory using the default model. - """ - print("\n--- Running test_chat_in_session (Turn 2 - Follow-up) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - payload = {"prompt": FOLLOW_UP_PROMPT} # Model defaults to "deepseek" - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Follow-up chat request failed. Response: {response.text}" - response_data = response.json() - # Check that the answer contains the birth year, proving it understood "he" - assert "1967" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat Turn 2 (follow-up) test passed.") - -async def test_chat_in_session_with_model_switch(): - """ - Tests sending a message in the same session, explicitly switching to 'gemini'. - """ - print("\n--- Running test_chat_in_session (Model Switch to Gemini) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'gemini' model for this turn - payload = {"prompt": "What is the capital of France?", "model": "gemini"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"Gemini chat request failed. Response: {response.text}" - response_data = response.json() - assert "Paris" in response_data["answer"] - assert response_data["model_used"] == "gemini" - print("✅ Chat (Model Switch to Gemini) test passed.") - -async def test_chat_in_session_switch_back_to_deepseek(): - """ - Tests sending another message in the same session, explicitly switching back to 'deepseek'. - """ - print("\n--- Running test_chat_in_session (Model Switch back to DeepSeek) ---") - assert created_session_id is not None, "Session ID was not set." - - url = f"{BASE_URL}/sessions/{created_session_id}/chat" - # Explicitly request 'deepseek' model for this turn - payload = {"prompt": "What is the largest ocean?", "model": "deepseek"} - - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - - assert response.status_code == 200, f"DeepSeek chat request failed. Response: {response.text}" - response_data = response.json() - assert "Pacific Ocean" in response_data["answer"] - assert response_data["model_used"] == "deepseek" - print("✅ Chat (Model Switch back to DeepSeek) test passed.") - -async def test_chat_with_document_retrieval(): - """ - Tests injecting a document and using it for retrieval-augmented generation. - This simulates the 'load_faiss_retriever' functionality. - """ - print("\n--- Running test_chat_with_document_retrieval ---") - async with httpx.AsyncClient(timeout=60.0) as client: - # Create a new session for this test - session_response = await client.post(f"{BASE_URL}/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) - assert session_response.status_code == 200 - rag_session_id = session_response.json()["id"] - - # Add a new document with specific content for retrieval - doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await client.post(f"{BASE_URL}/documents", json=doc_data) - assert add_doc_response.status_code == 200 - try: - message = add_doc_response.json().get("message", "") - rag_document_id = int(message.split(" with ID ")[-1]) - print(f"Document for RAG created with ID: {rag_document_id}") - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - - try: - # Send a chat request with the document ID to enable retrieval - chat_payload = { - "prompt": RAG_PROMPT, - "document_id": rag_document_id, - "model": "deepseek", # or any other RAG-enabled model - "load_faiss_retriever": True - } - chat_response = await client.post(f"{BASE_URL}/sessions/{rag_session_id}/chat", json=chat_payload) - - assert chat_response.status_code == 200, f"RAG chat request failed. Response: {chat_response.text}" - chat_data = chat_response.json() - - # Verify the response contains information from the document - assert "Jane Doe" in chat_data["answer"] - assert "Nexus" in chat_data["answer"] - print("✅ Chat with document retrieval test passed.") - finally: - # Clean up the document after the test - delete_response = await client.delete(f"{BASE_URL}/documents/{rag_document_id}") - assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") - -# --- Document Management Lifecycle Tests --- -async def test_add_document_for_lifecycle(): - global created_document_id - print("\n--- Running test_add_document (for lifecycle) ---") - url = f"{BASE_URL}/documents" - doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(url, json=doc_data) - - assert response.status_code == 200 - try: - message = response.json().get("message", "") - created_document_id = int(message.split(" with ID ")[-1]) - except (ValueError, IndexError): - pytest.fail("Could not parse document ID from response message.") - print(f"✅ Document for lifecycle test created with ID: {created_document_id}") - -async def test_list_documents(): - print("\n--- Running test_list_documents ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents" - async with httpx.AsyncClient() as client: - response = await client.get(url) - - assert response.status_code == 200 - ids_in_response = {doc["id"] for doc in response.json()["documents"]} - assert created_document_id in ids_in_response - print("✅ Document list test passed.") - -async def test_delete_document(): - print("\n--- Running test_delete_document ---") - assert created_document_id is not None, "Document ID was not set." - - url = f"{BASE_URL}/documents/{created_document_id}" - async with httpx.AsyncClient() as client: - response = await client.delete(url) - - assert response.status_code == 200 - assert response.json()["document_id"] == created_document_id - print("✅ Document delete test passed.") diff --git a/ai-hub/integration_tests/test_sessions.py b/ai-hub/integration_tests/test_sessions.py index 44d63d7..435ce3d 100644 --- a/ai-hub/integration_tests/test_sessions.py +++ b/ai-hub/integration_tests/test_sessions.py @@ -15,9 +15,9 @@ """ print("\n--- Running test_chat_in_session_lifecycle ---") - # 1. Create a new session + # 1. Create a new session with a trailing slash payload = {"user_id": "integration_tester_lifecycle", "model": "deepseek"} - response = await http_client.post("/sessions", json=payload) + response = await http_client.post("/sessions/", json=payload) assert response.status_code == 200 session_id = response.json()["id"] print(f"✅ Session created successfully with ID: {session_id}") @@ -71,13 +71,15 @@ print("\n--- Running test_chat_with_document_retrieval ---") # Create a new session for this RAG test - session_response = await http_client.post("/sessions", json={"user_id": "rag_tester", "model": "deepseek"}) + # Corrected URL with a trailing slash + session_response = await http_client.post("/sessions/", json={"user_id": "rag_tester", "model": "deepseek"}) assert session_response.status_code == 200 rag_session_id = session_response.json()["id"] # Add a new document with specific content for retrieval doc_data = {"title": RAG_DOC_TITLE, "text": RAG_DOC_TEXT} - add_doc_response = await http_client.post("/documents", json=doc_data) + # Corrected URL with a trailing slash + add_doc_response = await http_client.post("/documents/", json=doc_data) assert add_doc_response.status_code == 200 try: message = add_doc_response.json().get("message", "") @@ -104,4 +106,4 @@ # Clean up the document after the test delete_response = await http_client.delete(f"/documents/{rag_document_id}") assert delete_response.status_code == 200 - print(f"Document {rag_document_id} deleted successfully.") \ No newline at end of file + print(f"Document {rag_document_id} deleted successfully.") diff --git a/ai-hub/run_chat.sh b/ai-hub/run_chat.sh index 5032780..5646e5a 100644 --- a/ai-hub/run_chat.sh +++ b/ai-hub/run_chat.sh @@ -36,7 +36,8 @@ # --- 3. Create a New Conversation Session --- echo "--- Starting a new conversation session... ---" -SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions" \ +# FIX: Added a trailing slash to the /sessions endpoint to avoid a 307 redirect +SESSION_DATA=$(curl -s -X POST "$BASE_URL/sessions/" \ -H "Content-Type: application/json" \ -d '{"user_id": "local_user", "model": "'"$DEFAULT_MODEL"'"}' \ -w '\n%{http_code}') # Add a new line and the status code diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py new file mode 100644 index 0000000..4ba708e --- /dev/null +++ b/ai-hub/tests/api/routes/conftest.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import AsyncClient, ASGITransport +from sqlalchemy.orm import Session +from app.api.dependencies import get_db, ServiceContainer +from app.core.services.rag import RAGService +from app.core.services.document import DocumentService +from app.core.services.tts import TTSService +from app.api.routes.api import create_api_router + +# Change the scope to "function" so the fixture is re-created for each test +@pytest.fixture(scope="function") +def client(): + """ + Pytest fixture to create a TestClient with a fully mocked environment + for synchronous endpoints, scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + test_client = TestClient(test_app) + + # Use a yield to ensure the teardown happens after each test + yield test_client, mock_services + + # You could also add a reset call here for an extra layer of safety, + # but with scope="function" it's not strictly necessary. + +# Change the scope to "function" for the async client as well +@pytest.fixture(scope="function") +async def async_client(): + """ + Pytest fixture to create an AsyncClient for testing async endpoints, + scoped to a single function. + """ + test_app = FastAPI() + + mock_rag_service = MagicMock(spec=RAGService) + mock_document_service = MagicMock(spec=DocumentService) + mock_tts_service = MagicMock(spec=TTSService) + + mock_services = MagicMock(spec=ServiceContainer) + mock_services.rag_service = mock_rag_service + mock_services.document_service = mock_document_service + mock_services.tts_service = mock_tts_service + + mock_db_session = MagicMock(spec=Session) + + def override_get_db(): + yield mock_db_session + + api_router = create_api_router(services=mock_services) + test_app.dependency_overrides[get_db] = override_get_db + test_app.include_router(api_router) + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: + yield client, mock_services diff --git a/ai-hub/tests/api/routes/test_documents.py b/ai-hub/tests/api/routes/test_documents.py new file mode 100644 index 0000000..670b375 --- /dev/null +++ b/ai-hub/tests/api/routes/test_documents.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock +from datetime import datetime +from app.db import models + +def test_add_document_success(client): + test_client, mock_services = client + mock_services.document_service.add_document.return_value = 123 + + # Update the payload to include the default values from the Pydantic model + doc_payload = { + "title": "Test Doc", + "text": "Content here", + "source_url": None, # Add these based on your schema's defaults + "author": None, + "user_id": "default_user", + } + + # The payload sent to the endpoint is just title and text + request_payload = {"title": "Test Doc", "text": "Content here"} + + response = test_client.post("/documents", json=request_payload) + + assert response.status_code == 200 + assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" + + # Assert against the full payload that the router will generate and pass to the service + mock_services.document_service.add_document.assert_called_once_with( + db=mock_services.document_service.add_document.call_args.kwargs['db'], + doc_data=doc_payload + ) + +def test_get_documents_success(client): + test_client, mock_services = client + + # Ensure mock attributes are valid types for the schema + mock_docs = [ + MagicMock( + spec=models.Document, + id=1, + title="Doc One", + status="ready", + created_at=datetime.now(), + source_url="http://example.com/doc1", # Explicitly set a string value + text="text one", + author="author one" + ), + MagicMock( + spec=models.Document, + id=2, + title="Doc Two", + status="processing", + created_at=datetime.now(), + source_url=None, # Or explicitly set to None if that's a valid case + text="text two", + author="author two" + ) + ] + + mock_services.document_service.get_all_documents.return_value = mock_docs + + response = test_client.get("/documents") + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["documents"]) == 2 + # You can also add more specific assertions + assert response_data["documents"][0]["title"] == "Doc One" + assert response_data["documents"][1]["status"] == "processing" + + mock_services.document_service.get_all_documents.assert_called_once() + +def test_delete_document_success(client): + """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = 42 + response = test_client.delete("/documents/42") + assert response.status_code == 200 + assert response.json()["document_id"] == 42 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=42 + ) + +def test_delete_document_not_found(client): + """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" + test_client, mock_services = client + mock_services.document_service.delete_document.return_value = None + response = test_client.delete("/documents/999") + assert response.status_code == 404 + mock_services.document_service.delete_document.assert_called_once_with( + db=mock_services.document_service.delete_document.call_args.kwargs['db'], + document_id=999 + ) \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_general.py b/ai-hub/tests/api/routes/test_general.py new file mode 100644 index 0000000..cd4466e --- /dev/null +++ b/ai-hub/tests/api/routes/test_general.py @@ -0,0 +1,6 @@ +def test_read_root(client): + """Tests the root endpoint for a successful response.""" + test_client, _ = client + response = test_client.get("/") + assert response.status_code == 200 + assert response.json() == {"status": "AI Model Hub is running!"} \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_sessions.py b/ai-hub/tests/api/routes/test_sessions.py new file mode 100644 index 0000000..10334d2 --- /dev/null +++ b/ai-hub/tests/api/routes/test_sessions.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock, AsyncMock +from datetime import datetime +from app.db import models + +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" + test_client, mock_services = client + mock_session = MagicMock(spec=models.Session) + mock_session.id = 1 + mock_session.user_id = "test_user" + mock_session.model_name = "gemini" + mock_session.title = "New Chat" + mock_session.created_at = datetime.now() + mock_services.rag_service.create_session.return_value = mock_session + + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json()["id"] == 1 + mock_services.rag_service.create_session.assert_called_once() + +def test_chat_in_session_success(client): + """ + Tests sending a message in an existing session without specifying a model + or retriever. It should default to 'deepseek' and 'False'. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there", + model="deepseek", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_model_switch(client): + """ + Tests sending a message in an existing session and explicitly switching the model. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) + + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) + + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="Hello there, Gemini!", + model="gemini", + load_faiss_retriever=False + ) + +def test_chat_in_session_with_faiss_retriever(client): + """ + Tests sending a message and explicitly enabling the FAISS retriever. + """ + test_client, mock_services = client + mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) + + response = test_client.post( + "/sessions/42/chat", + json={"prompt": "What is RAG?", "load_faiss_retriever": True} + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} + + mock_services.rag_service.chat_with_rag.assert_called_once_with( + db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], + session_id=42, + prompt="What is RAG?", + model="deepseek", + load_faiss_retriever=True + ) + +def test_get_session_messages_success(client): + """Tests retrieving the message history for a session.""" + test_client, mock_services = client + mock_history = [ + MagicMock(spec=models.Message, sender="user", content="Hello", created_at=datetime.now()), + MagicMock(spec=models.Message, sender="assistant", content="Hi there!", created_at=datetime.now()) + ] + mock_services.rag_service.get_message_history.return_value = mock_history + + response = test_client.get("/sessions/123/messages") + + assert response.status_code == 200 + response_data = response.json() + assert response_data["session_id"] == 123 + assert len(response_data["messages"]) == 2 + assert response_data["messages"][0]["sender"] == "user" + assert response_data["messages"][1]["content"] == "Hi there!" + mock_services.rag_service.get_message_history.assert_called_once_with( + db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], + session_id=123 + ) + +def test_get_session_messages_not_found(client): + """Tests retrieving messages for a session that does not exist.""" + test_client, mock_services = client + mock_services.rag_service.get_message_history.return_value = None + + response = test_client.get("/sessions/999/messages") + + assert response.status_code == 404 + assert response.json()["detail"] == "Session with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/tests/api/routes/test_tts.py b/ai-hub/tests/api/routes/test_tts.py new file mode 100644 index 0000000..cd4f14e --- /dev/null +++ b/ai-hub/tests/api/routes/test_tts.py @@ -0,0 +1,46 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock + +@pytest.mark.asyncio +async def test_create_speech_response(async_client): + """Test the /speech endpoint returns audio bytes without streaming.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes = b"fake wav audio bytes" + + # Use AsyncMock for the async function create_speech_non_stream + mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) + + response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert response.content == mock_audio_bytes + + mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") + +@pytest.mark.asyncio +async def test_create_speech_stream_response(async_client): + """Test the /speech endpoint with stream=true returns a streaming response.""" + test_client, mock_services = await anext(async_client) + mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # This async generator mock correctly simulates the streaming service + async def mock_async_generator(): + for chunk in mock_audio_bytes_chunks: + yield chunk + + # We mock `create_speech_stream` with a MagicMock returning the async generator + mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) + + response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + # Read the streamed content and verify it matches the mocked chunks + streamed_content = b"" + async for chunk in response.aiter_bytes(): + streamed_content += chunk + + assert streamed_content == b"".join(mock_audio_bytes_chunks) + mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index c4da460..d19bb7a 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -1,277 +1,277 @@ -import pytest -from unittest.mock import MagicMock, AsyncMock -from fastapi import FastAPI, Response -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session -from datetime import datetime -from httpx import AsyncClient, ASGITransport -import asyncio +# import pytest +# from unittest.mock import MagicMock, AsyncMock +# from fastapi import FastAPI, Response +# from fastapi.testclient import TestClient +# from sqlalchemy.orm import Session +# from datetime import datetime +# from httpx import AsyncClient, ASGITransport +# import asyncio -# Import the dependencies and router factory -from app.api.dependencies import get_db, ServiceContainer -from app.core.services.rag import RAGService -from app.core.services.document import DocumentService -from app.core.services.tts import TTSService -from app.api.routes import create_api_router -from app.db import models +# # Import the dependencies and router factory +# from app.api.dependencies import get_db, ServiceContainer +# from app.core.services.rag import RAGService +# from app.core.services.document import DocumentService +# from app.core.services.tts import TTSService +# from app.api.routes import create_api_router +# from app.db import models -@pytest.fixture -def client(): - """ - Pytest fixture to create a TestClient with a fully mocked environment - for synchronous endpoints. - """ - test_app = FastAPI() +# @pytest.fixture +# def client(): +# """ +# Pytest fixture to create a TestClient with a fully mocked environment +# for synchronous endpoints. +# """ +# test_app = FastAPI() - mock_rag_service = MagicMock(spec=RAGService) - mock_document_service = MagicMock(spec=DocumentService) - mock_tts_service = MagicMock(spec=TTSService) +# mock_rag_service = MagicMock(spec=RAGService) +# mock_document_service = MagicMock(spec=DocumentService) +# mock_tts_service = MagicMock(spec=TTSService) - mock_services = MagicMock(spec=ServiceContainer) - mock_services.rag_service = mock_rag_service - mock_services.document_service = mock_document_service - mock_services.tts_service = mock_tts_service +# mock_services = MagicMock(spec=ServiceContainer) +# mock_services.rag_service = mock_rag_service +# mock_services.document_service = mock_document_service +# mock_services.tts_service = mock_tts_service - mock_db_session = MagicMock(spec=Session) +# mock_db_session = MagicMock(spec=Session) - def override_get_db(): - yield mock_db_session +# def override_get_db(): +# yield mock_db_session - api_router = create_api_router(services=mock_services) - test_app.dependency_overrides[get_db] = override_get_db - test_app.include_router(api_router) +# api_router = create_api_router(services=mock_services) +# test_app.dependency_overrides[get_db] = override_get_db +# test_app.include_router(api_router) - test_client = TestClient(test_app) +# test_client = TestClient(test_app) - yield test_client, mock_services +# yield test_client, mock_services -@pytest.fixture -async def async_client(): - """ - Pytest fixture to create an AsyncClient for testing async endpoints. - """ - test_app = FastAPI() +# @pytest.fixture +# async def async_client(): +# """ +# Pytest fixture to create an AsyncClient for testing async endpoints. +# """ +# test_app = FastAPI() - mock_rag_service = MagicMock(spec=RAGService) - mock_document_service = MagicMock(spec=DocumentService) - mock_tts_service = MagicMock(spec=TTSService) +# mock_rag_service = MagicMock(spec=RAGService) +# mock_document_service = MagicMock(spec=DocumentService) +# mock_tts_service = MagicMock(spec=TTSService) - mock_services = MagicMock(spec=ServiceContainer) - mock_services.rag_service = mock_rag_service - mock_services.document_service = mock_document_service - mock_services.tts_service = mock_tts_service +# mock_services = MagicMock(spec=ServiceContainer) +# mock_services.rag_service = mock_rag_service +# mock_services.document_service = mock_document_service +# mock_services.tts_service = mock_tts_service - mock_db_session = MagicMock(spec=Session) +# mock_db_session = MagicMock(spec=Session) - def override_get_db(): - yield mock_db_session +# def override_get_db(): +# yield mock_db_session - api_router = create_api_router(services=mock_services) - test_app.dependency_overrides[get_db] = override_get_db - test_app.include_router(api_router) +# api_router = create_api_router(services=mock_services) +# test_app.dependency_overrides[get_db] = override_get_db +# test_app.include_router(api_router) - async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services +# async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: +# yield client, mock_services -# --- General Endpoint --- +# # --- General Endpoint --- -def test_read_root(client): - """Tests the root endpoint.""" - test_client, _ = client - response = test_client.get("/") - assert response.status_code == 200 - assert response.json() == {"status": "AI Model Hub is running!"} +# def test_read_root(client): +# """Tests the root endpoint.""" +# test_client, _ = client +# response = test_client.get("/") +# assert response.status_code == 200 +# assert response.json() == {"status": "AI Model Hub is running!"} -# --- Session and Chat Endpoints --- +# # --- Session and Chat Endpoints --- -def test_create_session_success(client): - """Tests successfully creating a new chat session.""" - test_client, mock_services = client - mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) - mock_services.rag_service.create_session.return_value = mock_session +# def test_create_session_success(client): +# """Tests successfully creating a new chat session.""" +# test_client, mock_services = client +# mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) +# mock_services.rag_service.create_session.return_value = mock_session - response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) +# response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - assert response.status_code == 200 - assert response.json()["id"] == 1 - mock_services.rag_service.create_session.assert_called_once() +# assert response.status_code == 200 +# assert response.json()["id"] == 1 +# mock_services.rag_service.create_session.assert_called_once() -def test_chat_in_session_success(client): - """ - Tests sending a message in an existing session without specifying a model - or retriever. It should default to 'deepseek' and 'False'. - """ - test_client, mock_services = client - mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) +# def test_chat_in_session_success(client): +# """ +# Tests sending a message in an existing session without specifying a model +# or retriever. It should default to 'deepseek' and 'False'. +# """ +# test_client, mock_services = client +# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) +# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} +# assert response.status_code == 200 +# assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - mock_services.rag_service.chat_with_rag.assert_called_once_with( - db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], - session_id=42, - prompt="Hello there", - model="deepseek", - load_faiss_retriever=False - ) +# mock_services.rag_service.chat_with_rag.assert_called_once_with( +# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], +# session_id=42, +# prompt="Hello there", +# model="deepseek", +# load_faiss_retriever=False +# ) -def test_chat_in_session_with_model_switch(client): - """ - Tests sending a message in an existing session and explicitly switching the model. - """ - test_client, mock_services = client - mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) +# def test_chat_in_session_with_model_switch(client): +# """ +# Tests sending a message in an existing session and explicitly switching the model. +# """ +# test_client, mock_services = client +# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) +# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - assert response.status_code == 200 - assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} +# assert response.status_code == 200 +# assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - mock_services.rag_service.chat_with_rag.assert_called_once_with( - db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], - session_id=42, - prompt="Hello there, Gemini!", - model="gemini", - load_faiss_retriever=False - ) +# mock_services.rag_service.chat_with_rag.assert_called_once_with( +# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], +# session_id=42, +# prompt="Hello there, Gemini!", +# model="gemini", +# load_faiss_retriever=False +# ) -def test_chat_in_session_with_faiss_retriever(client): - """ - Tests sending a message and explicitly enabling the FAISS retriever. - """ - test_client, mock_services = client - mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) +# def test_chat_in_session_with_faiss_retriever(client): +# """ +# Tests sending a message and explicitly enabling the FAISS retriever. +# """ +# test_client, mock_services = client +# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - response = test_client.post( - "/sessions/42/chat", - json={"prompt": "What is RAG?", "load_faiss_retriever": True} - ) +# response = test_client.post( +# "/sessions/42/chat", +# json={"prompt": "What is RAG?", "load_faiss_retriever": True} +# ) - assert response.status_code == 200 - assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} +# assert response.status_code == 200 +# assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - mock_services.rag_service.chat_with_rag.assert_called_once_with( - db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], - session_id=42, - prompt="What is RAG?", - model="deepseek", - load_faiss_retriever=True - ) +# mock_services.rag_service.chat_with_rag.assert_called_once_with( +# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], +# session_id=42, +# prompt="What is RAG?", +# model="deepseek", +# load_faiss_retriever=True +# ) -def test_get_session_messages_success(client): - """Tests retrieving the message history for a session.""" - test_client, mock_services = client - mock_history = [ - models.Message(sender="user", content="Hello", created_at=datetime.now()), - models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) - ] - mock_services.rag_service.get_message_history.return_value = mock_history +# def test_get_session_messages_success(client): +# """Tests retrieving the message history for a session.""" +# test_client, mock_services = client +# mock_history = [ +# models.Message(sender="user", content="Hello", created_at=datetime.now()), +# models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) +# ] +# mock_services.rag_service.get_message_history.return_value = mock_history - response = test_client.get("/sessions/123/messages") +# response = test_client.get("/sessions/123/messages") - assert response.status_code == 200 - response_data = response.json() - assert response_data["session_id"] == 123 - assert len(response_data["messages"]) == 2 - assert response_data["messages"][0]["sender"] == "user" - assert response_data["messages"][1]["content"] == "Hi there!" - mock_services.rag_service.get_message_history.assert_called_once_with( - db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], - session_id=123 - ) +# assert response.status_code == 200 +# response_data = response.json() +# assert response_data["session_id"] == 123 +# assert len(response_data["messages"]) == 2 +# assert response_data["messages"][0]["sender"] == "user" +# assert response_data["messages"][1]["content"] == "Hi there!" +# mock_services.rag_service.get_message_history.assert_called_once_with( +# db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], +# session_id=123 +# ) -def test_get_session_messages_not_found(client): - """Tests retrieving messages for a session that does not exist.""" - test_client, mock_services = client - mock_services.rag_service.get_message_history.return_value = None +# def test_get_session_messages_not_found(client): +# """Tests retrieving messages for a session that does not exist.""" +# test_client, mock_services = client +# mock_services.rag_service.get_message_history.return_value = None - response = test_client.get("/sessions/999/messages") +# response = test_client.get("/sessions/999/messages") - assert response.status_code == 404 - assert response.json()["detail"] == "Session with ID 999 not found." +# assert response.status_code == 404 +# assert response.json()["detail"] == "Session with ID 999 not found." -# --- Document Endpoints --- +# # --- Document Endpoints --- -def test_add_document_success(client): - """Tests the /documents endpoint for adding a new document.""" - test_client, mock_services = client - mock_services.document_service.add_document.return_value = 123 - doc_payload = {"title": "Test Doc", "text": "Content here"} - response = test_client.post("/documents", json=doc_payload) - assert response.status_code == 200 - assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" +# def test_add_document_success(client): +# """Tests the /documents endpoint for adding a new document.""" +# test_client, mock_services = client +# mock_services.document_service.add_document.return_value = 123 +# doc_payload = {"title": "Test Doc", "text": "Content here"} +# response = test_client.post("/documents", json=doc_payload) +# assert response.status_code == 200 +# assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" -def test_get_documents_success(client): - """Tests the /documents endpoint for retrieving all documents.""" - test_client, mock_services = client - mock_docs = [ - models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), - models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) - ] - mock_services.document_service.get_all_documents.return_value = mock_docs - response = test_client.get("/documents") - assert response.status_code == 200 - assert len(response.json()["documents"]) == 2 +# def test_get_documents_success(client): +# """Tests the /documents endpoint for retrieving all documents.""" +# test_client, mock_services = client +# mock_docs = [ +# models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), +# models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) +# ] +# mock_services.document_service.get_all_documents.return_value = mock_docs +# response = test_client.get("/documents") +# assert response.status_code == 200 +# assert len(response.json()["documents"]) == 2 -def test_delete_document_success(client): - """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" - test_client, mock_services = client - mock_services.document_service.delete_document.return_value = 42 - response = test_client.delete("/documents/42") - assert response.status_code == 200 - assert response.json()["document_id"] == 42 +# def test_delete_document_success(client): +# """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" +# test_client, mock_services = client +# mock_services.document_service.delete_document.return_value = 42 +# response = test_client.delete("/documents/42") +# assert response.status_code == 200 +# assert response.json()["document_id"] == 42 -def test_delete_document_not_found(client): - """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" - test_client, mock_services = client - mock_services.document_service.delete_document.return_value = None - response = test_client.delete("/documents/999") - assert response.status_code == 404 +# def test_delete_document_not_found(client): +# """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" +# test_client, mock_services = client +# mock_services.document_service.delete_document.return_value = None +# response = test_client.delete("/documents/999") +# assert response.status_code == 404 -@pytest.mark.asyncio -async def test_create_speech_response(async_client): - """Test the /speech endpoint returns audio bytes.""" - test_client, mock_services = await anext(async_client) - mock_audio_bytes = b"fake wav audio bytes" +# @pytest.mark.asyncio +# async def test_create_speech_response(async_client): +# """Test the /speech endpoint returns audio bytes.""" +# test_client, mock_services = await anext(async_client) +# mock_audio_bytes = b"fake wav audio bytes" - # The route handler calls `create_speech_non_stream`, not `create_speech_stream` - # It's an async function, so we must use AsyncMock - mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) +# # The route handler calls `create_speech_non_stream`, not `create_speech_stream` +# # It's an async function, so we must use AsyncMock +# mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) - response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) +# response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - assert response.content == mock_audio_bytes +# assert response.status_code == 200 +# assert response.headers["content-type"] == "audio/wav" +# assert response.content == mock_audio_bytes - mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") +# mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") -@pytest.mark.asyncio -async def test_create_speech_stream_response(async_client): - """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" - test_client, mock_services = await anext(async_client) - mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] +# @pytest.mark.asyncio +# async def test_create_speech_stream_response(async_client): +# """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" +# test_client, mock_services = await anext(async_client) +# mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] - # This async generator mock correctly simulates the streaming service - async def mock_async_generator(): - for chunk in mock_audio_bytes_chunks: - yield chunk +# # This async generator mock correctly simulates the streaming service +# async def mock_async_generator(): +# for chunk in mock_audio_bytes_chunks: +# yield chunk - # We mock `create_speech_stream` with a MagicMock returning the async generator - mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) +# # We mock `create_speech_stream` with a MagicMock returning the async generator +# mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) - # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter - response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) +# # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter +# response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" +# assert response.status_code == 200 +# assert response.headers["content-type"] == "audio/wav" - # Read the streamed content and verify it matches the mocked chunks - streamed_content = b"" - async for chunk in response.aiter_bytes(): - streamed_content += chunk +# # Read the streamed content and verify it matches the mocked chunks +# streamed_content = b"" +# async for chunk in response.aiter_bytes(): +# streamed_content += chunk - assert streamed_content == b"".join(mock_audio_bytes_chunks) - mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file +# assert streamed_content == b"".join(mock_audio_bytes_chunks) +# mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file