diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py index 4ba708e..e76f8bf 100644 --- a/ai-hub/tests/api/routes/conftest.py +++ b/ai-hub/tests/api/routes/conftest.py @@ -8,6 +8,7 @@ from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.api.routes.api import create_api_router # Change the scope to "function" so the fixture is re-created for each test @@ -19,33 +20,35 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) - # Use a yield to ensure the teardown happens after each test yield test_client, mock_services - - # You could also add a reset call here for an extra layer of safety, - # but with scope="function" it's not strictly necessary. -# Change the scope to "function" for the async client as well @pytest.fixture(scope="function") async def async_client(): """ @@ -54,23 +57,31 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) + # Use ASGITransport for testing async code async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services + yield client, mock_services \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py index 4ba708e..e76f8bf 100644 --- a/ai-hub/tests/api/routes/conftest.py +++ b/ai-hub/tests/api/routes/conftest.py @@ -8,6 +8,7 @@ from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.api.routes.api import create_api_router # Change the scope to "function" so the fixture is re-created for each test @@ -19,33 +20,35 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) - # Use a yield to ensure the teardown happens after each test yield test_client, mock_services - - # You could also add a reset call here for an extra layer of safety, - # but with scope="function" it's not strictly necessary. -# Change the scope to "function" for the async client as well @pytest.fixture(scope="function") async def async_client(): """ @@ -54,23 +57,31 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) + # Use ASGITransport for testing async code async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services + yield client, mock_services \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index d2e0472..a371623 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,7 +9,8 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService -from app.core.services.tts import TTSService # Added this import +from app.core.services.tts import TTSService +from app.core.services.stt import STTService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -99,12 +100,16 @@ mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) + + # NEW: Create a mock for STTService + mock_stt_service = MagicMock(spec=STTService) - # Act: Instantiate the ServiceContainer + # Act: Instantiate the ServiceContainer, now with all required arguments container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers, - tts_service=mock_tts_service # Passing the mock TTS service + tts_service=mock_tts_service, + stt_service=mock_stt_service # Pass the new mock here ) # Assert: Check if the services were created and configured correctly @@ -114,6 +119,9 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers - # Assert for the tts_service as well + # Assert for the tts_service and stt_service as well assert isinstance(container.tts_service, TTSService) assert container.tts_service == mock_tts_service + assert isinstance(container.stt_service, STTService) + assert container.stt_service == mock_stt_service + diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py index 4ba708e..e76f8bf 100644 --- a/ai-hub/tests/api/routes/conftest.py +++ b/ai-hub/tests/api/routes/conftest.py @@ -8,6 +8,7 @@ from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.api.routes.api import create_api_router # Change the scope to "function" so the fixture is re-created for each test @@ -19,33 +20,35 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) - # Use a yield to ensure the teardown happens after each test yield test_client, mock_services - - # You could also add a reset call here for an extra layer of safety, - # but with scope="function" it's not strictly necessary. -# Change the scope to "function" for the async client as well @pytest.fixture(scope="function") async def async_client(): """ @@ -54,23 +57,31 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) + # Use ASGITransport for testing async code async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services + yield client, mock_services \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index d2e0472..a371623 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,7 +9,8 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService -from app.core.services.tts import TTSService # Added this import +from app.core.services.tts import TTSService +from app.core.services.stt import STTService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -99,12 +100,16 @@ mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) + + # NEW: Create a mock for STTService + mock_stt_service = MagicMock(spec=STTService) - # Act: Instantiate the ServiceContainer + # Act: Instantiate the ServiceContainer, now with all required arguments container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers, - tts_service=mock_tts_service # Passing the mock TTS service + tts_service=mock_tts_service, + stt_service=mock_stt_service # Pass the new mock here ) # Assert: Check if the services were created and configured correctly @@ -114,6 +119,9 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers - # Assert for the tts_service as well + # Assert for the tts_service and stt_service as well assert isinstance(container.tts_service, TTSService) assert container.tts_service == mock_tts_service + assert isinstance(container.stt_service, STTService) + assert container.stt_service == mock_stt_service + diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py deleted file mode 100644 index d19bb7a..0000000 --- a/ai-hub/tests/api/test_routes.py +++ /dev/null @@ -1,277 +0,0 @@ -# import pytest -# from unittest.mock import MagicMock, AsyncMock -# from fastapi import FastAPI, Response -# from fastapi.testclient import TestClient -# from sqlalchemy.orm import Session -# from datetime import datetime -# from httpx import AsyncClient, ASGITransport -# import asyncio - -# # Import the dependencies and router factory -# from app.api.dependencies import get_db, ServiceContainer -# from app.core.services.rag import RAGService -# from app.core.services.document import DocumentService -# from app.core.services.tts import TTSService -# from app.api.routes import create_api_router -# from app.db import models - -# @pytest.fixture -# def client(): -# """ -# Pytest fixture to create a TestClient with a fully mocked environment -# for synchronous endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# test_client = TestClient(test_app) - -# yield test_client, mock_services - -# @pytest.fixture -# async def async_client(): -# """ -# Pytest fixture to create an AsyncClient for testing async endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: -# yield client, mock_services - -# # --- General Endpoint --- - -# def test_read_root(client): -# """Tests the root endpoint.""" -# test_client, _ = client -# response = test_client.get("/") -# assert response.status_code == 200 -# assert response.json() == {"status": "AI Model Hub is running!"} - -# # --- Session and Chat Endpoints --- - -# def test_create_session_success(client): -# """Tests successfully creating a new chat session.""" -# test_client, mock_services = client -# mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) -# mock_services.rag_service.create_session.return_value = mock_session - -# response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json()["id"] == 1 -# mock_services.rag_service.create_session.assert_called_once() - -# def test_chat_in_session_success(client): -# """ -# Tests sending a message in an existing session without specifying a model -# or retriever. It should default to 'deepseek' and 'False'. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there", -# model="deepseek", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_model_switch(client): -# """ -# Tests sending a message in an existing session and explicitly switching the model. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there, Gemini!", -# model="gemini", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_faiss_retriever(client): -# """ -# Tests sending a message and explicitly enabling the FAISS retriever. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - -# response = test_client.post( -# "/sessions/42/chat", -# json={"prompt": "What is RAG?", "load_faiss_retriever": True} -# ) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="What is RAG?", -# model="deepseek", -# load_faiss_retriever=True -# ) - -# def test_get_session_messages_success(client): -# """Tests retrieving the message history for a session.""" -# test_client, mock_services = client -# mock_history = [ -# models.Message(sender="user", content="Hello", created_at=datetime.now()), -# models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) -# ] -# mock_services.rag_service.get_message_history.return_value = mock_history - -# response = test_client.get("/sessions/123/messages") - -# assert response.status_code == 200 -# response_data = response.json() -# assert response_data["session_id"] == 123 -# assert len(response_data["messages"]) == 2 -# assert response_data["messages"][0]["sender"] == "user" -# assert response_data["messages"][1]["content"] == "Hi there!" -# mock_services.rag_service.get_message_history.assert_called_once_with( -# db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], -# session_id=123 -# ) - -# def test_get_session_messages_not_found(client): -# """Tests retrieving messages for a session that does not exist.""" -# test_client, mock_services = client -# mock_services.rag_service.get_message_history.return_value = None - -# response = test_client.get("/sessions/999/messages") - -# assert response.status_code == 404 -# assert response.json()["detail"] == "Session with ID 999 not found." - -# # --- Document Endpoints --- - -# def test_add_document_success(client): -# """Tests the /documents endpoint for adding a new document.""" -# test_client, mock_services = client -# mock_services.document_service.add_document.return_value = 123 -# doc_payload = {"title": "Test Doc", "text": "Content here"} -# response = test_client.post("/documents", json=doc_payload) -# assert response.status_code == 200 -# assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" - -# def test_get_documents_success(client): -# """Tests the /documents endpoint for retrieving all documents.""" -# test_client, mock_services = client -# mock_docs = [ -# models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), -# models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) -# ] -# mock_services.document_service.get_all_documents.return_value = mock_docs -# response = test_client.get("/documents") -# assert response.status_code == 200 -# assert len(response.json()["documents"]) == 2 - -# def test_delete_document_success(client): -# """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = 42 -# response = test_client.delete("/documents/42") -# assert response.status_code == 200 -# assert response.json()["document_id"] == 42 - -# def test_delete_document_not_found(client): -# """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = None -# response = test_client.delete("/documents/999") -# assert response.status_code == 404 - -# @pytest.mark.asyncio -# async def test_create_speech_response(async_client): -# """Test the /speech endpoint returns audio bytes.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes = b"fake wav audio bytes" - -# # The route handler calls `create_speech_non_stream`, not `create_speech_stream` -# # It's an async function, so we must use AsyncMock -# mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) - -# response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" -# assert response.content == mock_audio_bytes - -# mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") - -# @pytest.mark.asyncio -# async def test_create_speech_stream_response(async_client): -# """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] - -# # This async generator mock correctly simulates the streaming service -# async def mock_async_generator(): -# for chunk in mock_audio_bytes_chunks: -# yield chunk - -# # We mock `create_speech_stream` with a MagicMock returning the async generator -# mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) - -# # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter -# response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" - -# # Read the streamed content and verify it matches the mocked chunks -# streamed_content = b"" -# async for chunk in response.aiter_bytes(): -# streamed_content += chunk - -# assert streamed_content == b"".join(mock_audio_bytes_chunks) -# mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py index 4ba708e..e76f8bf 100644 --- a/ai-hub/tests/api/routes/conftest.py +++ b/ai-hub/tests/api/routes/conftest.py @@ -8,6 +8,7 @@ from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.api.routes.api import create_api_router # Change the scope to "function" so the fixture is re-created for each test @@ -19,33 +20,35 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) - # Use a yield to ensure the teardown happens after each test yield test_client, mock_services - - # You could also add a reset call here for an extra layer of safety, - # but with scope="function" it's not strictly necessary. -# Change the scope to "function" for the async client as well @pytest.fixture(scope="function") async def async_client(): """ @@ -54,23 +57,31 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) + # Use ASGITransport for testing async code async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services + yield client, mock_services \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index d2e0472..a371623 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,7 +9,8 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService -from app.core.services.tts import TTSService # Added this import +from app.core.services.tts import TTSService +from app.core.services.stt import STTService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -99,12 +100,16 @@ mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) + + # NEW: Create a mock for STTService + mock_stt_service = MagicMock(spec=STTService) - # Act: Instantiate the ServiceContainer + # Act: Instantiate the ServiceContainer, now with all required arguments container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers, - tts_service=mock_tts_service # Passing the mock TTS service + tts_service=mock_tts_service, + stt_service=mock_stt_service # Pass the new mock here ) # Assert: Check if the services were created and configured correctly @@ -114,6 +119,9 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers - # Assert for the tts_service as well + # Assert for the tts_service and stt_service as well assert isinstance(container.tts_service, TTSService) assert container.tts_service == mock_tts_service + assert isinstance(container.stt_service, STTService) + assert container.stt_service == mock_stt_service + diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py deleted file mode 100644 index d19bb7a..0000000 --- a/ai-hub/tests/api/test_routes.py +++ /dev/null @@ -1,277 +0,0 @@ -# import pytest -# from unittest.mock import MagicMock, AsyncMock -# from fastapi import FastAPI, Response -# from fastapi.testclient import TestClient -# from sqlalchemy.orm import Session -# from datetime import datetime -# from httpx import AsyncClient, ASGITransport -# import asyncio - -# # Import the dependencies and router factory -# from app.api.dependencies import get_db, ServiceContainer -# from app.core.services.rag import RAGService -# from app.core.services.document import DocumentService -# from app.core.services.tts import TTSService -# from app.api.routes import create_api_router -# from app.db import models - -# @pytest.fixture -# def client(): -# """ -# Pytest fixture to create a TestClient with a fully mocked environment -# for synchronous endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# test_client = TestClient(test_app) - -# yield test_client, mock_services - -# @pytest.fixture -# async def async_client(): -# """ -# Pytest fixture to create an AsyncClient for testing async endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: -# yield client, mock_services - -# # --- General Endpoint --- - -# def test_read_root(client): -# """Tests the root endpoint.""" -# test_client, _ = client -# response = test_client.get("/") -# assert response.status_code == 200 -# assert response.json() == {"status": "AI Model Hub is running!"} - -# # --- Session and Chat Endpoints --- - -# def test_create_session_success(client): -# """Tests successfully creating a new chat session.""" -# test_client, mock_services = client -# mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) -# mock_services.rag_service.create_session.return_value = mock_session - -# response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json()["id"] == 1 -# mock_services.rag_service.create_session.assert_called_once() - -# def test_chat_in_session_success(client): -# """ -# Tests sending a message in an existing session without specifying a model -# or retriever. It should default to 'deepseek' and 'False'. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there", -# model="deepseek", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_model_switch(client): -# """ -# Tests sending a message in an existing session and explicitly switching the model. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there, Gemini!", -# model="gemini", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_faiss_retriever(client): -# """ -# Tests sending a message and explicitly enabling the FAISS retriever. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - -# response = test_client.post( -# "/sessions/42/chat", -# json={"prompt": "What is RAG?", "load_faiss_retriever": True} -# ) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="What is RAG?", -# model="deepseek", -# load_faiss_retriever=True -# ) - -# def test_get_session_messages_success(client): -# """Tests retrieving the message history for a session.""" -# test_client, mock_services = client -# mock_history = [ -# models.Message(sender="user", content="Hello", created_at=datetime.now()), -# models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) -# ] -# mock_services.rag_service.get_message_history.return_value = mock_history - -# response = test_client.get("/sessions/123/messages") - -# assert response.status_code == 200 -# response_data = response.json() -# assert response_data["session_id"] == 123 -# assert len(response_data["messages"]) == 2 -# assert response_data["messages"][0]["sender"] == "user" -# assert response_data["messages"][1]["content"] == "Hi there!" -# mock_services.rag_service.get_message_history.assert_called_once_with( -# db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], -# session_id=123 -# ) - -# def test_get_session_messages_not_found(client): -# """Tests retrieving messages for a session that does not exist.""" -# test_client, mock_services = client -# mock_services.rag_service.get_message_history.return_value = None - -# response = test_client.get("/sessions/999/messages") - -# assert response.status_code == 404 -# assert response.json()["detail"] == "Session with ID 999 not found." - -# # --- Document Endpoints --- - -# def test_add_document_success(client): -# """Tests the /documents endpoint for adding a new document.""" -# test_client, mock_services = client -# mock_services.document_service.add_document.return_value = 123 -# doc_payload = {"title": "Test Doc", "text": "Content here"} -# response = test_client.post("/documents", json=doc_payload) -# assert response.status_code == 200 -# assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" - -# def test_get_documents_success(client): -# """Tests the /documents endpoint for retrieving all documents.""" -# test_client, mock_services = client -# mock_docs = [ -# models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), -# models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) -# ] -# mock_services.document_service.get_all_documents.return_value = mock_docs -# response = test_client.get("/documents") -# assert response.status_code == 200 -# assert len(response.json()["documents"]) == 2 - -# def test_delete_document_success(client): -# """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = 42 -# response = test_client.delete("/documents/42") -# assert response.status_code == 200 -# assert response.json()["document_id"] == 42 - -# def test_delete_document_not_found(client): -# """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = None -# response = test_client.delete("/documents/999") -# assert response.status_code == 404 - -# @pytest.mark.asyncio -# async def test_create_speech_response(async_client): -# """Test the /speech endpoint returns audio bytes.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes = b"fake wav audio bytes" - -# # The route handler calls `create_speech_non_stream`, not `create_speech_stream` -# # It's an async function, so we must use AsyncMock -# mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) - -# response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" -# assert response.content == mock_audio_bytes - -# mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") - -# @pytest.mark.asyncio -# async def test_create_speech_stream_response(async_client): -# """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] - -# # This async generator mock correctly simulates the streaming service -# async def mock_async_generator(): -# for chunk in mock_audio_bytes_chunks: -# yield chunk - -# # We mock `create_speech_stream` with a MagicMock returning the async generator -# mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) - -# # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter -# response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" - -# # Read the streamed content and verify it matches the mocked chunks -# streamed_content = b"" -# async for chunk in response.aiter_bytes(): -# streamed_content += chunk - -# assert streamed_content == b"".join(mock_audio_bytes_chunks) -# mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index 58d4c85..f3f25ac 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -1,7 +1,11 @@ import pytest -from app.core.providers.factory import get_llm_provider +from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider from app.core.providers.llm.deepseek import DeepSeekProvider from app.core.providers.llm.gemini import GeminiProvider +from app.core.providers.tts.gemini import GeminiTTSProvider +from app.core.providers.stt.gemini import GoogleSTTProvider + +# --- Existing Tests for LLM Provider --- def test_get_llm_provider_returns_deepseek_provider(): """Tests that the factory returns a DeepSeekProvider instance.""" @@ -16,4 +20,34 @@ def test_get_llm_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported provider name.""" with pytest.raises(ValueError, match="Unsupported model provider: 'unknown'"): - get_llm_provider("unknown") \ No newline at end of file + get_llm_provider("unknown") + + +# --- NEW Tests for TTS Provider --- + +def test_get_tts_provider_returns_gemini_tts_provider(): + """Tests that the factory returns a GeminiTTSProvider instance for 'google_genai'.""" + # Use a dummy key for testing + provider = get_tts_provider("google_genai", api_key="dummy_key") + assert isinstance(provider, GeminiTTSProvider) + assert provider.api_key == "dummy_key" + +def test_get_tts_provider_raises_error_for_unsupported_provider(): + """Tests that the factory raises an error for an unsupported TTS provider name.""" + with pytest.raises(ValueError, match="Unsupported TTS provider: 'unknown'"): + get_tts_provider("unknown", api_key="dummy_key") + + +# --- NEW Tests for STT Provider --- + +def test_get_stt_provider_returns_google_stt_provider(): + """Tests that the factory returns a GoogleSTTProvider instance for 'google_gemini'.""" + provider = get_stt_provider("google_gemini", api_key="dummy_key", model_name="dummy-model") + assert isinstance(provider, GoogleSTTProvider) + assert provider.api_key == "dummy_key" + assert provider.model_name == "dummy-model" + +def test_get_stt_provider_raises_error_for_unsupported_provider(): + """Tests that the factory raises an error for an unsupported STT provider name.""" + with pytest.raises(ValueError, match="Unsupported STT provider: 'unknown'"): + get_stt_provider("unknown", api_key="dummy_key", model_name="dummy-model") \ No newline at end of file diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index e1aecde..a1e497f 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -7,6 +7,7 @@ from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.core.vector_store.faiss_store import FaissVectorStore @@ -27,10 +28,11 @@ class ServiceContainer: - def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService): + def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever], tts_service: TTSService, stt_service: STTService): # Initialize all services within the container self.document_service = DocumentService(vector_store=vector_store) self.rag_service = RAGService( retrievers=retrievers ) self.tts_service = tts_service + self.stt_service = stt_service diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py deleted file mode 100644 index fd05079..0000000 --- a/ai-hub/app/api/routes.py +++ /dev/null @@ -1,144 +0,0 @@ -# from fastapi import APIRouter, HTTPException, Depends, Query -# from fastapi.responses import Response, StreamingResponse -# from sqlalchemy.orm import Session -# from app.api.dependencies import ServiceContainer, get_db -# from app.api import schemas -# from typing import AsyncGenerator - -# def create_api_router(services: ServiceContainer) -> APIRouter: -# """ -# Creates and returns an APIRouter with all the application's endpoints. -# """ -# router = APIRouter() - -# @router.get("/", summary="Check Service Status", tags=["General"]) -# def read_root(): -# return {"status": "AI Model Hub is running!"} - -# # --- Session Management Endpoints --- - -# @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) -# def create_session( -# request: schemas.SessionCreate, -# db: Session = Depends(get_db) -# ): -# try: -# new_session = services.rag_service.create_session( -# db=db, -# user_id=request.user_id, -# model=request.model -# ) -# return new_session -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") - -# @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) -# async def chat_in_session( -# session_id: int, -# request: schemas.ChatRequest, -# db: Session = Depends(get_db) -# ): -# try: -# response_text, model_used = await services.rag_service.chat_with_rag( -# db=db, -# session_id=session_id, -# prompt=request.prompt, -# model=request.model, -# load_faiss_retriever=request.load_faiss_retriever -# ) -# return schemas.ChatResponse(answer=response_text, model_used=model_used) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}") - -# @router.get("/sessions/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History", tags=["Sessions"]) -# def get_session_messages(session_id: int, db: Session = Depends(get_db)): -# try: -# messages = services.rag_service.get_message_history(db=db, session_id=session_id) -# if messages is None: -# raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") - -# return schemas.MessageHistoryResponse(session_id=session_id, messages=messages) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Document Management Endpoints --- - -# @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) -# def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): -# try: -# doc_data = doc.model_dump() -# document_id = services.document_service.add_document(db=db, doc_data=doc_data) -# return schemas.DocumentResponse( -# message=f"Document '{doc.title}' added successfully with ID {document_id}" -# ) -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) -# def get_documents(db: Session = Depends(get_db)): -# try: -# documents_from_db = services.document_service.get_all_documents(db=db) -# return {"documents": documents_from_db} -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) -# def delete_document(document_id: int, db: Session = Depends(get_db)): -# try: -# deleted_id = services.document_service.delete_document(db=db, document_id=document_id) -# if deleted_id is None: -# raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") - -# return schemas.DocumentDeleteResponse( -# message="Document deleted successfully", -# document_id=deleted_id -# ) -# except HTTPException: -# raise -# except Exception as e: -# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") - -# # --- Consolidated Speech Endpoint --- - -# @router.post( -# "/speech", -# summary="Generate speech from text", -# tags=["TTS"], -# response_description="Audio bytes in WAV format, either as a complete file or a stream.", -# ) -# async def create_speech_response( -# request: schemas.SpeechRequest, -# stream: bool = Query( -# False, -# description="If true, returns a streamed audio response. Otherwise, returns a complete file." -# ) -# ): -# """ -# Generates an audio file or a streaming audio response from the provided text. -# By default, it returns a complete audio file. -# To get a streaming response, set the 'stream' query parameter to 'true'. -# """ -# try: -# if stream: -# # Use the streaming service method -# audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream( -# text=request.text -# ) -# return StreamingResponse(audio_stream_generator, media_type="audio/wav") -# else: -# # Use the non-streaming service method -# audio_bytes = await services.tts_service.create_speech_non_stream( -# text=request.text -# ) -# return Response(content=audio_bytes, media_type="audio/wav") - -# except HTTPException: -# raise # Re-raise existing HTTPException -# except Exception as e: -# raise HTTPException( -# status_code=500, detail=f"Failed to generate speech: {e}" -# ) - -# return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index b052e7d..69862b0 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -6,6 +6,7 @@ from .sessions import create_sessions_router from .documents import create_documents_router from .tts import create_tts_router +from .stt import create_stt_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -18,5 +19,6 @@ router.include_router(create_sessions_router(services)) router.include_router(create_documents_router(services)) router.include_router(create_tts_router(services)) + router.include_router(create_stt_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/stt.py b/ai-hub/app/api/routes/stt.py new file mode 100644 index 0000000..d99220e --- /dev/null +++ b/ai-hub/app/api/routes/stt.py @@ -0,0 +1,59 @@ +import logging +from fastapi import APIRouter, HTTPException, UploadFile, File, Depends +from app.api.dependencies import ServiceContainer +from app.api import schemas +from app.core.services.stt import STTService + +# Configure logging +logger = logging.getLogger(__name__) + +def create_stt_router(services: ServiceContainer) -> APIRouter: + """ + Creates and configures the API router for Speech-to-Text (STT) functionality. + """ + router = APIRouter(prefix="/stt", tags=["STT"]) + + @router.post( + "/transcribe", + summary="Transcribe audio to text", + response_description="The transcribed text from the audio file.", + response_model=schemas.STTResponse + ) + async def transcribe_audio_to_text( + audio_file: UploadFile = File(...) + ): + """ + Transcribes an uploaded audio file into text using the configured STT service. + + The audio file is expected to be a common audio format like WAV or MP3, + though the specific provider implementation will determine supported formats. + """ + logger.info(f"Received transcription request for file: {audio_file.filename}") + + if not audio_file.content_type.startswith("audio/"): + logger.warning(f"Invalid file type uploaded: {audio_file.content_type}") + raise HTTPException( + status_code=415, + detail="Unsupported media type. Please upload an audio file." + ) + + try: + # Read the audio bytes from the uploaded file + audio_bytes = await audio_file.read() + + # Use the STT service to get the transcript + transcript = await services.stt_service.transcribe(audio_bytes) + + # Return the transcript in a simple JSON response + return schemas.STTResponse(transcript=transcript) + + except HTTPException: + # Re-raise Fast API exceptions so they're handled correctly + raise + except Exception as e: + logger.error(f"Failed to transcribe audio file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to transcribe audio: {e}" + ) from e + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 31fb342..a1eff32 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here +from pydantic import BaseModel, Field, ConfigDict from typing import List, Literal, Optional from datetime import datetime @@ -76,4 +76,9 @@ messages: List[Message] class SpeechRequest(BaseModel): - text: str \ No newline at end of file + text: str + +# --- STT Schemas --- +class STTResponse(BaseModel): + """Defines the shape of a successful response from the /stt endpoint.""" + transcript: str diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 3afc104..a4ea604 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -1,3 +1,4 @@ +# app/app.py from fastapi import FastAPI from contextlib import asynccontextmanager from typing import List @@ -6,7 +7,7 @@ from app.config import settings from app.core.vector_store.faiss_store import FaissVectorStore from app.core.vector_store.embedder.factory import get_embedder_from_config -from app.core.providers.factory import get_tts_provider +from app.core.providers.factory import get_tts_provider, get_stt_provider from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.db.session import create_db_and_tables @@ -14,6 +15,7 @@ from app.utils import print_config from app.api.dependencies import ServiceContainer from app.core.services.tts import TTSService +from app.core.services.stt import STTService # NEW: Added the missing import for STTService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -83,12 +85,23 @@ # 5. Initialize the TTSService tts_service = TTSService(tts_provider=tts_provider) + + # 6. Get the concrete STT provider from the factory + stt_provider = get_stt_provider( + provider_name=settings.STT_PROVIDER, + api_key=settings.STT_API_KEY, + model_name=settings.STT_MODEL_NAME + ) + # 7. Initialize the STTService + stt_service = STTService(stt_provider=stt_provider) - # 6. Initialize the Service Container with all services + # 8. Initialize the Service Container with all services + # This replaces the previous, redundant initialization services = ServiceContainer( - vector_store=vector_store, + vector_store=vector_store, retrievers=retrievers, - tts_service=tts_service # Pass the new TTS service instance + tts_service=tts_service, + stt_service=stt_service # NEW: Pass the new STT service instance ) # Create and include the API router, injecting the service diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index 65c69c9..f2de94e 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -19,6 +19,11 @@ """An enum for supported Text-to-Speech (TTS) providers.""" GOOGLE_GENAI = "google_genai" +class STTProvider(str, Enum): + """An enum for supported Speech-to-Text (STT) providers.""" + GOOGLE_GEMINI = "google_gemini" + OPENAI = "openai" # NEW: Add OpenAI as a supported provider + class ApplicationSettings(BaseModel): project_name: str = "Cortex Hub" version: str = "1.0.0" @@ -44,6 +49,14 @@ model_name: str = "gemini-2.5-flash-preview-tts" api_key: Optional[SecretStr] = None +class STTProviderSettings(BaseModel): + provider: STTProvider = Field(default=STTProvider.GOOGLE_GEMINI) + model_name: str = "gemini-2.5-flash" + api_key: Optional[SecretStr] = None + # NOTE: OpenAI provider requires a different model name (e.g., 'whisper-1') + # but we will handle this dynamically or through configuration. + # The BaseModel is for schema validation, not for provider-specific logic. + class VectorStoreSettings(BaseModel): index_path: str = "data/faiss_index.bin" embedding_dimension: int = 768 @@ -56,6 +69,7 @@ vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: TTSProviderSettings = Field(default_factory=TTSProviderSettings) + stt_provider: STTProviderSettings = Field(default_factory=STTProviderSettings) # --- 2. Create the Final Settings Object --- @@ -96,8 +110,8 @@ config_from_pydantic.database.mode local_db_path = os.getenv("LOCAL_DB_PATH") or \ - get_from_yaml(["database", "local_path"]) or \ - config_from_pydantic.database.local_path + get_from_yaml(["database", "local_path"]) or \ + config_from_pydantic.database.local_path external_db_url = os.getenv("DATABASE_URL") or \ get_from_yaml(["database", "url"]) or \ config_from_pydantic.database.url @@ -111,21 +125,22 @@ # --- API Keys & Models --- self.DEEPSEEK_API_KEY: Optional[str] = os.getenv("DEEPSEEK_API_KEY") self.GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY") + self.OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # NEW: Add dedicated OpenAI API key self.DEEPSEEK_MODEL_NAME: str = os.getenv("DEEPSEEK_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ - config_from_pydantic.llm_providers.deepseek_model_name + get_from_yaml(["llm_providers", "deepseek_model_name"]) or \ + config_from_pydantic.llm_providers.deepseek_model_name self.GEMINI_MODEL_NAME: str = os.getenv("GEMINI_MODEL_NAME") or \ - get_from_yaml(["llm_providers", "gemini_model_name"]) or \ - config_from_pydantic.llm_providers.gemini_model_name + get_from_yaml(["llm_providers", "gemini_model_name"]) or \ + config_from_pydantic.llm_providers.gemini_model_name # --- Vector Store Settings --- self.FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH") or \ - get_from_yaml(["vector_store", "index_path"]) or \ - config_from_pydantic.vector_store.index_path + get_from_yaml(["vector_store", "index_path"]) or \ + config_from_pydantic.vector_store.index_path dimension_str = os.getenv("EMBEDDING_DIMENSION") or \ - get_from_yaml(["vector_store", "embedding_dimension"]) or \ - config_from_pydantic.vector_store.embedding_dimension + get_from_yaml(["vector_store", "embedding_dimension"]) or \ + config_from_pydantic.vector_store.embedding_dimension self.EMBEDDING_DIMENSION: int = int(dimension_str) # --- Embedding Provider Settings --- @@ -137,13 +152,12 @@ get_from_yaml(["embedding_provider", "provider"]) or \ config_from_pydantic.embedding_provider.provider) self.EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME") or \ - get_from_yaml(["embedding_provider", "model_name"]) or \ - config_from_pydantic.embedding_provider.model_name + get_from_yaml(["embedding_provider", "model_name"]) or \ + config_from_pydantic.embedding_provider.model_name - # Fixed logic: Prioritize EMBEDDING_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.EMBEDDING_API_KEY: Optional[str] = os.getenv("EMBEDDING_API_KEY") or \ - get_from_yaml(["embedding_provider", "api_key"]) or \ - self.GEMINI_API_KEY + get_from_yaml(["embedding_provider", "api_key"]) or \ + self.GEMINI_API_KEY # --- TTS Provider Settings --- tts_provider_env = os.getenv("TTS_PROVIDER") @@ -151,8 +165,8 @@ tts_provider_env = tts_provider_env.lower() self.TTS_PROVIDER: TTSProvider = TTSProvider(tts_provider_env or \ - get_from_yaml(["tts_provider", "provider"]) or \ - config_from_pydantic.tts_provider.provider) + get_from_yaml(["tts_provider", "provider"]) or \ + config_from_pydantic.tts_provider.provider) self.TTS_VOICE_NAME: str = os.getenv("TTS_VOICE_NAME") or \ get_from_yaml(["tts_provider", "voice_name"]) or \ config_from_pydantic.tts_provider.voice_name @@ -160,10 +174,33 @@ get_from_yaml(["tts_provider", "model_name"]) or \ config_from_pydantic.tts_provider.model_name - # Fixed logic: Prioritize TTS_API_KEY from env, then yaml, then fallback to GEMINI_API_KEY self.TTS_API_KEY: Optional[str] = os.getenv("TTS_API_KEY") or \ get_from_yaml(["tts_provider", "api_key"]) or \ self.GEMINI_API_KEY + # --- NEW STT Provider Settings --- + stt_provider_env = os.getenv("STT_PROVIDER") + if stt_provider_env: + stt_provider_env = stt_provider_env.lower() + + self.STT_PROVIDER: STTProvider = STTProvider(stt_provider_env or \ + get_from_yaml(["stt_provider", "provider"]) or \ + config_from_pydantic.stt_provider.provider) + self.STT_MODEL_NAME: str = os.getenv("STT_MODEL_NAME") or \ + get_from_yaml(["stt_provider", "model_name"]) or \ + config_from_pydantic.stt_provider.model_name + + # Logic for STT_API_KEY: Prioritize a dedicated STT_API_KEY. + # Fallback to OPENAI_API_KEY if the provider is OpenAI, otherwise use GEMINI_API_KEY. + explicit_stt_api_key = os.getenv("STT_API_KEY") or get_from_yaml(["stt_provider", "api_key"]) + + if explicit_stt_api_key: + self.STT_API_KEY: Optional[str] = explicit_stt_api_key + elif self.STT_PROVIDER == STTProvider.OPENAI: + self.STT_API_KEY: Optional[str] = self.OPENAI_API_KEY + else: + # Fallback for Google Gemini or other providers + self.STT_API_KEY: Optional[str] = self.GEMINI_API_KEY + # Instantiate the single settings object for the application -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai-hub/app/config.yaml b/ai-hub/app/config.yaml index 49b30f7..de9d8dd 100644 --- a/ai-hub/app/config.yaml +++ b/ai-hub/app/config.yaml @@ -39,3 +39,12 @@ voice_name: "Kore" # The model name for the TTS service. model_name: "gemini-2.5-flash-preview-tts" + +# The provider for the Speech-to-Text (STT) service. +stt_provider: + # The provider can be "google_gemini" or "openai". + provider: "google_gemini" + # The model name for the STT service. + # For "google_gemini" this would be a Gemini model (e.g., "gemini-2.5-flash"). + # For "openai" this would be a Whisper model (e.g., "whisper-1"). + model_name: "gemini-2.5-flash" \ No newline at end of file diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 53f3651..e3aec9b 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -1,8 +1,9 @@ from app.config import settings -from .base import LLMProvider,TTSProvider +from .base import LLMProvider, TTSProvider, STTProvider from .llm.deepseek import DeepSeekProvider from .llm.gemini import GeminiProvider from .tts.gemini import GeminiTTSProvider +from .stt.gemini import GoogleSTTProvider from openai import AsyncOpenAI # --- 1. Initialize API Clients from Central Config --- @@ -15,7 +16,7 @@ "gemini": GeminiProvider(api_url=GEMINI_URL) } -# --- 3. The Factory Function --- +# --- 3. The Factory Functions --- def get_llm_provider(model_name: str) -> LLMProvider: """Factory function to get the appropriate, pre-configured LLM provider.""" provider = _llm_providers.get(model_name) @@ -26,5 +27,9 @@ def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider: if provider_name == "google_genai": return GeminiTTSProvider(api_key=api_key) - # Add other TTS providers here if needed - raise ValueError(f"Unknown TTS provider: {provider_name}") \ No newline at end of file + raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']") + +def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: + if provider_name == "google_gemini": + return GoogleSTTProvider(api_key=api_key, model_name=model_name) + raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']") \ No newline at end of file diff --git a/ai-hub/requirements.txt b/ai-hub/requirements.txt index 276e3a2..c455aa1 100644 --- a/ai-hub/requirements.txt +++ b/ai-hub/requirements.txt @@ -15,4 +15,5 @@ numpy faiss-cpu dspy -aioresponses \ No newline at end of file +aioresponses +python-multipart \ No newline at end of file diff --git a/ai-hub/run_integration_tests.sh b/ai-hub/run_integration_tests.sh index ef45dc0..46489d2 100644 --- a/ai-hub/run_integration_tests.sh +++ b/ai-hub/run_integration_tests.sh @@ -80,7 +80,6 @@ # Wait a few seconds to ensure the server is fully up and running sleep 5 - echo "--- Running tests in: $TEST_PATH ---" # Execute the Python tests using pytest on the specified path diff --git a/ai-hub/tests/api/routes/conftest.py b/ai-hub/tests/api/routes/conftest.py index 4ba708e..e76f8bf 100644 --- a/ai-hub/tests/api/routes/conftest.py +++ b/ai-hub/tests/api/routes/conftest.py @@ -8,6 +8,7 @@ from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService +from app.core.services.stt import STTService from app.api.routes.api import create_api_router # Change the scope to "function" so the fixture is re-created for each test @@ -19,33 +20,35 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) - # Use a yield to ensure the teardown happens after each test yield test_client, mock_services - - # You could also add a reset call here for an extra layer of safety, - # but with scope="function" it's not strictly necessary. -# Change the scope to "function" for the async client as well @pytest.fixture(scope="function") async def async_client(): """ @@ -54,23 +57,31 @@ """ test_app = FastAPI() + # Create mocks for all individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) + mock_stt_service = MagicMock(spec=STTService) + # Create a mock for the ServiceContainer and attach all the individual service mocks mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service + mock_services.stt_service = mock_stt_service + # Mock the database session mock_db_session = MagicMock(spec=Session) + # Dependency override for the database session def override_get_db(): yield mock_db_session + # Create the API router and include it in the test app api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) + # Use ASGITransport for testing async code async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - yield client, mock_services + yield client, mock_services \ No newline at end of file diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index d2e0472..a371623 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -9,7 +9,8 @@ from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService -from app.core.services.tts import TTSService # Added this import +from app.core.services.tts import TTSService +from app.core.services.stt import STTService # Added this import from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever @@ -99,12 +100,16 @@ mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) + + # NEW: Create a mock for STTService + mock_stt_service = MagicMock(spec=STTService) - # Act: Instantiate the ServiceContainer + # Act: Instantiate the ServiceContainer, now with all required arguments container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers, - tts_service=mock_tts_service # Passing the mock TTS service + tts_service=mock_tts_service, + stt_service=mock_stt_service # Pass the new mock here ) # Assert: Check if the services were created and configured correctly @@ -114,6 +119,9 @@ assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers - # Assert for the tts_service as well + # Assert for the tts_service and stt_service as well assert isinstance(container.tts_service, TTSService) assert container.tts_service == mock_tts_service + assert isinstance(container.stt_service, STTService) + assert container.stt_service == mock_stt_service + diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py deleted file mode 100644 index d19bb7a..0000000 --- a/ai-hub/tests/api/test_routes.py +++ /dev/null @@ -1,277 +0,0 @@ -# import pytest -# from unittest.mock import MagicMock, AsyncMock -# from fastapi import FastAPI, Response -# from fastapi.testclient import TestClient -# from sqlalchemy.orm import Session -# from datetime import datetime -# from httpx import AsyncClient, ASGITransport -# import asyncio - -# # Import the dependencies and router factory -# from app.api.dependencies import get_db, ServiceContainer -# from app.core.services.rag import RAGService -# from app.core.services.document import DocumentService -# from app.core.services.tts import TTSService -# from app.api.routes import create_api_router -# from app.db import models - -# @pytest.fixture -# def client(): -# """ -# Pytest fixture to create a TestClient with a fully mocked environment -# for synchronous endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# test_client = TestClient(test_app) - -# yield test_client, mock_services - -# @pytest.fixture -# async def async_client(): -# """ -# Pytest fixture to create an AsyncClient for testing async endpoints. -# """ -# test_app = FastAPI() - -# mock_rag_service = MagicMock(spec=RAGService) -# mock_document_service = MagicMock(spec=DocumentService) -# mock_tts_service = MagicMock(spec=TTSService) - -# mock_services = MagicMock(spec=ServiceContainer) -# mock_services.rag_service = mock_rag_service -# mock_services.document_service = mock_document_service -# mock_services.tts_service = mock_tts_service - -# mock_db_session = MagicMock(spec=Session) - -# def override_get_db(): -# yield mock_db_session - -# api_router = create_api_router(services=mock_services) -# test_app.dependency_overrides[get_db] = override_get_db -# test_app.include_router(api_router) - -# async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: -# yield client, mock_services - -# # --- General Endpoint --- - -# def test_read_root(client): -# """Tests the root endpoint.""" -# test_client, _ = client -# response = test_client.get("/") -# assert response.status_code == 200 -# assert response.json() == {"status": "AI Model Hub is running!"} - -# # --- Session and Chat Endpoints --- - -# def test_create_session_success(client): -# """Tests successfully creating a new chat session.""" -# test_client, mock_services = client -# mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) -# mock_services.rag_service.create_session.return_value = mock_session - -# response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json()["id"] == 1 -# mock_services.rag_service.create_session.assert_called_once() - -# def test_chat_in_session_success(client): -# """ -# Tests sending a message in an existing session without specifying a model -# or retriever. It should default to 'deepseek' and 'False'. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there", -# model="deepseek", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_model_switch(client): -# """ -# Tests sending a message in an existing session and explicitly switching the model. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) - -# response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="Hello there, Gemini!", -# model="gemini", -# load_faiss_retriever=False -# ) - -# def test_chat_in_session_with_faiss_retriever(client): -# """ -# Tests sending a message and explicitly enabling the FAISS retriever. -# """ -# test_client, mock_services = client -# mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) - -# response = test_client.post( -# "/sessions/42/chat", -# json={"prompt": "What is RAG?", "load_faiss_retriever": True} -# ) - -# assert response.status_code == 200 -# assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} - -# mock_services.rag_service.chat_with_rag.assert_called_once_with( -# db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], -# session_id=42, -# prompt="What is RAG?", -# model="deepseek", -# load_faiss_retriever=True -# ) - -# def test_get_session_messages_success(client): -# """Tests retrieving the message history for a session.""" -# test_client, mock_services = client -# mock_history = [ -# models.Message(sender="user", content="Hello", created_at=datetime.now()), -# models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) -# ] -# mock_services.rag_service.get_message_history.return_value = mock_history - -# response = test_client.get("/sessions/123/messages") - -# assert response.status_code == 200 -# response_data = response.json() -# assert response_data["session_id"] == 123 -# assert len(response_data["messages"]) == 2 -# assert response_data["messages"][0]["sender"] == "user" -# assert response_data["messages"][1]["content"] == "Hi there!" -# mock_services.rag_service.get_message_history.assert_called_once_with( -# db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], -# session_id=123 -# ) - -# def test_get_session_messages_not_found(client): -# """Tests retrieving messages for a session that does not exist.""" -# test_client, mock_services = client -# mock_services.rag_service.get_message_history.return_value = None - -# response = test_client.get("/sessions/999/messages") - -# assert response.status_code == 404 -# assert response.json()["detail"] == "Session with ID 999 not found." - -# # --- Document Endpoints --- - -# def test_add_document_success(client): -# """Tests the /documents endpoint for adding a new document.""" -# test_client, mock_services = client -# mock_services.document_service.add_document.return_value = 123 -# doc_payload = {"title": "Test Doc", "text": "Content here"} -# response = test_client.post("/documents", json=doc_payload) -# assert response.status_code == 200 -# assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" - -# def test_get_documents_success(client): -# """Tests the /documents endpoint for retrieving all documents.""" -# test_client, mock_services = client -# mock_docs = [ -# models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), -# models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) -# ] -# mock_services.document_service.get_all_documents.return_value = mock_docs -# response = test_client.get("/documents") -# assert response.status_code == 200 -# assert len(response.json()["documents"]) == 2 - -# def test_delete_document_success(client): -# """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = 42 -# response = test_client.delete("/documents/42") -# assert response.status_code == 200 -# assert response.json()["document_id"] == 42 - -# def test_delete_document_not_found(client): -# """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" -# test_client, mock_services = client -# mock_services.document_service.delete_document.return_value = None -# response = test_client.delete("/documents/999") -# assert response.status_code == 404 - -# @pytest.mark.asyncio -# async def test_create_speech_response(async_client): -# """Test the /speech endpoint returns audio bytes.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes = b"fake wav audio bytes" - -# # The route handler calls `create_speech_non_stream`, not `create_speech_stream` -# # It's an async function, so we must use AsyncMock -# mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) - -# response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" -# assert response.content == mock_audio_bytes - -# mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") - -# @pytest.mark.asyncio -# async def test_create_speech_stream_response(async_client): -# """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" -# test_client, mock_services = await anext(async_client) -# mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] - -# # This async generator mock correctly simulates the streaming service -# async def mock_async_generator(): -# for chunk in mock_audio_bytes_chunks: -# yield chunk - -# # We mock `create_speech_stream` with a MagicMock returning the async generator -# mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) - -# # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter -# response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) - -# assert response.status_code == 200 -# assert response.headers["content-type"] == "audio/wav" - -# # Read the streamed content and verify it matches the mocked chunks -# streamed_content = b"" -# async for chunk in response.aiter_bytes(): -# streamed_content += chunk - -# assert streamed_content == b"".join(mock_audio_bytes_chunks) -# mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test") \ No newline at end of file diff --git a/ai-hub/tests/core/providers/test_factory.py b/ai-hub/tests/core/providers/test_factory.py index 58d4c85..f3f25ac 100644 --- a/ai-hub/tests/core/providers/test_factory.py +++ b/ai-hub/tests/core/providers/test_factory.py @@ -1,7 +1,11 @@ import pytest -from app.core.providers.factory import get_llm_provider +from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider from app.core.providers.llm.deepseek import DeepSeekProvider from app.core.providers.llm.gemini import GeminiProvider +from app.core.providers.tts.gemini import GeminiTTSProvider +from app.core.providers.stt.gemini import GoogleSTTProvider + +# --- Existing Tests for LLM Provider --- def test_get_llm_provider_returns_deepseek_provider(): """Tests that the factory returns a DeepSeekProvider instance.""" @@ -16,4 +20,34 @@ def test_get_llm_provider_raises_error_for_unsupported_provider(): """Tests that the factory raises an error for an unsupported provider name.""" with pytest.raises(ValueError, match="Unsupported model provider: 'unknown'"): - get_llm_provider("unknown") \ No newline at end of file + get_llm_provider("unknown") + + +# --- NEW Tests for TTS Provider --- + +def test_get_tts_provider_returns_gemini_tts_provider(): + """Tests that the factory returns a GeminiTTSProvider instance for 'google_genai'.""" + # Use a dummy key for testing + provider = get_tts_provider("google_genai", api_key="dummy_key") + assert isinstance(provider, GeminiTTSProvider) + assert provider.api_key == "dummy_key" + +def test_get_tts_provider_raises_error_for_unsupported_provider(): + """Tests that the factory raises an error for an unsupported TTS provider name.""" + with pytest.raises(ValueError, match="Unsupported TTS provider: 'unknown'"): + get_tts_provider("unknown", api_key="dummy_key") + + +# --- NEW Tests for STT Provider --- + +def test_get_stt_provider_returns_google_stt_provider(): + """Tests that the factory returns a GoogleSTTProvider instance for 'google_gemini'.""" + provider = get_stt_provider("google_gemini", api_key="dummy_key", model_name="dummy-model") + assert isinstance(provider, GoogleSTTProvider) + assert provider.api_key == "dummy_key" + assert provider.model_name == "dummy-model" + +def test_get_stt_provider_raises_error_for_unsupported_provider(): + """Tests that the factory raises an error for an unsupported STT provider name.""" + with pytest.raises(ValueError, match="Unsupported STT provider: 'unknown'"): + get_stt_provider("unknown", api_key="dummy_key", model_name="dummy-model") \ No newline at end of file diff --git a/ai-hub/tests/test_config.py b/ai-hub/tests/test_config.py index f3835f4..2eca11c 100644 --- a/ai-hub/tests/test_config.py +++ b/ai-hub/tests/test_config.py @@ -1,15 +1,14 @@ import pytest import importlib import yaml -from app.config import EmbeddingProvider, TTSProvider, Settings +from app.config import EmbeddingProvider, TTSProvider, STTProvider, Settings @pytest.fixture def tmp_config_file(tmp_path): """ Creates a temporary config.yaml file and returns its path. - Corrected the 'provider' value to be lowercase 'mock' to match the Enum. - Added database and TTS settings for testing. + Adds STT provider settings for testing. """ config_content = { "application": { @@ -18,7 +17,6 @@ }, "llm_providers": {"deepseek_model_name": "deepseek-from-yaml"}, "embedding_provider": { - # This value must be lowercase to match the Pydantic Enum member "provider": "mock", "model_name": "embedding-model-from-yaml", "api_key": "embedding-api-from-yaml" @@ -26,13 +24,19 @@ "database": { "mode": "sqlite", "local_path": "custom_folder/test_ai_hub.db", - "url": "postgresql://user:pass@host/dbname" # Should be ignored for sqlite mode + "url": "postgresql://user:pass@host/dbname" }, "tts_provider": { "provider": "google_genai", "voice_name": "Laomedeia", - "model_name": "tts-model-from-yaml", # Added configurable model name + "model_name": "tts-model-from-yaml", "api_key": "tts-api-from-yaml" + }, + # NEW: Add STT Provider settings to the YAML fixture + "stt_provider": { + "provider": "openai", + "model_name": "stt-model-from-yaml", + "api_key": "stt-api-from-yaml" } } config_path = tmp_path / "test_config.yaml" @@ -45,15 +49,20 @@ def clear_all_env(monkeypatch): """ A fixture to clear all relevant environment variables for test isolation. + Added new STT environment variables. """ monkeypatch.delenv("CONFIG_PATH", raising=False) monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) # NEW monkeypatch.delenv("EMBEDDING_API_KEY", raising=False) monkeypatch.delenv("TTS_PROVIDER", raising=False) monkeypatch.delenv("TTS_VOICE_NAME", raising=False) - monkeypatch.delenv("TTS_MODEL_NAME", raising=False) # Added for the new setting + monkeypatch.delenv("TTS_MODEL_NAME", raising=False) monkeypatch.delenv("TTS_API_KEY", raising=False) + monkeypatch.delenv("STT_PROVIDER", raising=False) # NEW + monkeypatch.delenv("STT_MODEL_NAME", raising=False) # NEW + monkeypatch.delenv("STT_API_KEY", raising=False) # NEW monkeypatch.delenv("DB_MODE", raising=False) monkeypatch.delenv("LOCAL_DB_PATH", raising=False) monkeypatch.delenv("DATABASE_URL", raising=False) @@ -66,8 +75,10 @@ """ monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_deepseek_key") monkeypatch.setenv("GEMINI_API_KEY", "mock_gemini_key") + monkeypatch.setenv("OPENAI_API_KEY", "mock_openai_key") # NEW monkeypatch.setenv("EMBEDDING_API_KEY", "mock_embedding_key") monkeypatch.setenv("TTS_API_KEY", "mock_tts_key") + monkeypatch.setenv("STT_API_KEY", "mock_stt_key") # NEW def test_sqlite_db_url_from_yaml(monkeypatch, tmp_config_file, clear_all_env): @@ -110,7 +121,6 @@ def test_external_db_url_from_yaml_when_not_sqlite(monkeypatch, tmp_path, clear_all_env): """Tests DATABASE_URL uses YAML url when DB_MODE != sqlite and no env DATABASE_URL.""" - # Write YAML with postgresql mode and url config_content = { "database": { "mode": "postgresql", @@ -153,17 +163,16 @@ settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Laomedeia" - assert settings.TTS_MODEL_NAME == "tts-model-from-yaml" # Test for new model name + assert settings.TTS_MODEL_NAME == "tts-model-from-yaml" assert settings.TTS_API_KEY == "tts-api-from-yaml" def test_tts_settings_from_env(monkeypatch, tmp_config_file, clear_all_env): """Tests that TTS environment variables override the YAML file.""" monkeypatch.setenv("CONFIG_PATH", tmp_config_file) - # Explicitly set all TTS env vars for this test monkeypatch.setenv("TTS_PROVIDER", "google_genai") monkeypatch.setenv("TTS_VOICE_NAME", "Zephyr") - monkeypatch.setenv("TTS_MODEL_NAME", "env-tts-model") # Added for the new setting + monkeypatch.setenv("TTS_MODEL_NAME", "env-tts-model") monkeypatch.setenv("TTS_API_KEY", "env_tts_key") monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") monkeypatch.setenv("GEMINI_API_KEY", "mock_key") @@ -171,20 +180,71 @@ settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Zephyr" - assert settings.TTS_MODEL_NAME == "env-tts-model" # Assert new setting is loaded + assert settings.TTS_MODEL_NAME == "env-tts-model" assert settings.TTS_API_KEY == "env_tts_key" def test_tts_settings_defaults(monkeypatch, clear_all_env): - """Tests that TTS settings fall back to Pydantic defaults if no env or YAML are present.""" + """Tests that TTS settings fall back to Pydantic defaults.""" monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") - # We remove the line below that sets GEMINI_API_KEY. - # The clear_all_env fixture already ensures no env vars are set initially. - # settings = Settings() will be able to fall back to None for the API key. - + # Setting GEMINI_API_KEY to test its fallback behavior + monkeypatch.setenv("GEMINI_API_KEY", "fallback_gemini_key") + settings = Settings() assert settings.TTS_PROVIDER == TTSProvider.GOOGLE_GENAI assert settings.TTS_VOICE_NAME == "Kore" - assert settings.TTS_MODEL_NAME == "gemini-2.5-flash-preview-tts" # Assert default value - assert settings.TTS_API_KEY is None + assert settings.TTS_MODEL_NAME == "gemini-2.5-flash-preview-tts" + assert settings.TTS_API_KEY == "fallback_gemini_key" + + +# --- NEW Tests for STT Configuration --- +def test_stt_settings_from_yaml(monkeypatch, tmp_config_file, clear_all_env): + """Tests that STT settings are loaded correctly from a YAML file.""" + monkeypatch.setenv("CONFIG_PATH", tmp_config_file) + monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") + monkeypatch.setenv("GEMINI_API_KEY", "mock_gemini_key") + + settings = Settings() + assert settings.STT_PROVIDER == STTProvider.OPENAI + assert settings.STT_MODEL_NAME == "stt-model-from-yaml" + assert settings.STT_API_KEY == "stt-api-from-yaml" + + +def test_stt_settings_from_env_openai(monkeypatch, tmp_config_file, clear_all_env): + """Tests that STT environment variables override the YAML file for OpenAI provider.""" + monkeypatch.setenv("CONFIG_PATH", tmp_config_file) + monkeypatch.setenv("STT_PROVIDER", "openai") + monkeypatch.setenv("STT_MODEL_NAME", "env-stt-model") + monkeypatch.setenv("STT_API_KEY", "env_stt_key") + monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") + monkeypatch.setenv("GEMINI_API_KEY", "mock_gemini_key") + + settings = Settings() + assert settings.STT_PROVIDER == STTProvider.OPENAI + assert settings.STT_MODEL_NAME == "env-stt-model" + assert settings.STT_API_KEY == "env_stt_key" + + +def test_stt_api_key_fallback_to_openai_api_key(monkeypatch, clear_all_env): + """Tests that STT_API_KEY falls back to OPENAI_API_KEY if the provider is OpenAI and no STT_API_KEY is set.""" + monkeypatch.setenv("STT_PROVIDER", "openai") + monkeypatch.setenv("OPENAI_API_KEY", "fallback_openai_key") + monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") + + settings = Settings() + + assert settings.STT_PROVIDER == STTProvider.OPENAI + assert settings.STT_API_KEY == "fallback_openai_key" + + +def test_stt_api_key_fallback_to_gemini_api_key(monkeypatch, clear_all_env): + """Tests that STT_API_KEY falls back to GEMINI_API_KEY if the provider is not OpenAI and no STT_API_KEY is set.""" + monkeypatch.setenv("STT_PROVIDER", "google_gemini") + monkeypatch.setenv("GEMINI_API_KEY", "fallback_gemini_key") + monkeypatch.setenv("DEEPSEEK_API_KEY", "mock_key") + + settings = Settings() + + assert settings.STT_PROVIDER == STTProvider.GOOGLE_GEMINI + assert settings.STT_API_KEY == "fallback_gemini_key" \ No newline at end of file