diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index ae9243b..0a2cc57 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any -from sqlalchemy.orm import Session +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy @@ -12,25 +12,77 @@ class RAGService: """ - Service class for managing the RAG (Retrieval-Augmented Generation) pipeline. - This class acts as a high-level orchestrator. + Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """ + Creates a new chat session in the database. + """ + try: + # Create a default title; this could be updated later by the AI + new_session = models.Session( + user_id=user_id, + model_name=model, + title=f"New Chat Session" + ) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + if not prompt or not prompt.strip(): + raise ValueError("Prompt cannot be empty.") + + # 1. Find the session and its history + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + # 2. Save the user's new message to the database + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + # 3. Configure DSPy with the session's model and execute the pipeline + llm_provider = get_llm_provider(session.model_name) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + dspy.configure(lm=dspy_llm) + + rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # (Optional) You could pass `session.messages` to the pipeline for context + answer_text = await rag_pipeline.forward(question=prompt, db=db) + + # 4. Save the assistant's response to the database + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, session.model_name + + # --- Document Management (Unchanged) --- + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - """ - Adds a document to both the database and the vector store. - """ + """Adds a document to the database and vector store.""" + # ... (implementation is unchanged) try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, @@ -38,61 +90,27 @@ ) db.add(vector_metadata) db.commit() - print(f"Document with ID {document_db.id} successfully added.") return document_db.id except SQLAlchemyError as e: db.rollback() - # **FIXED LINE**: Added the missing '})' - print(f"Database error while adding document: {e}") raise - except Exception as e: - db.rollback() - print(f"An unexpected error occurred: {e}") - raise - - async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str: - """ - Generates a response to a user prompt by orchestrating the RAG pipeline. - """ - print(f"Received Prompt: {prompt}") - if not prompt or not prompt.strip(): - raise ValueError("The prompt cannot be null, empty, or contain only whitespace.") - - llm_provider_instance = get_llm_provider(model) - dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model) - dspy.configure(lm=dspy_llm_provider) - - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - answer = await rag_pipeline.forward(question=prompt, db=db) - - return answer def get_all_documents(self, db: Session) -> List[models.Document]: - """ - Retrieves all documents from the database. - """ - try: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - except SQLAlchemyError as e: - print(f"Database error while retrieving documents: {e}") - raise + """Retrieves all documents from the database.""" + # ... (implementation is unchanged) + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + def delete_document(self, db: Session, document_id: int) -> int: - """ - Deletes a document and its associated vector metadata from the database. - Returns the ID of the deleted document, or None if not found. - """ + """Deletes a document from the database.""" + # ... (implementation is unchanged) try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: return None - db.delete(doc_to_delete) db.commit() - return document_id except SQLAlchemyError as e: db.rollback() - print(f"Database error while deleting document: {e}") raise \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index ae9243b..0a2cc57 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any -from sqlalchemy.orm import Session +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy @@ -12,25 +12,77 @@ class RAGService: """ - Service class for managing the RAG (Retrieval-Augmented Generation) pipeline. - This class acts as a high-level orchestrator. + Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """ + Creates a new chat session in the database. + """ + try: + # Create a default title; this could be updated later by the AI + new_session = models.Session( + user_id=user_id, + model_name=model, + title=f"New Chat Session" + ) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + if not prompt or not prompt.strip(): + raise ValueError("Prompt cannot be empty.") + + # 1. Find the session and its history + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + # 2. Save the user's new message to the database + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + # 3. Configure DSPy with the session's model and execute the pipeline + llm_provider = get_llm_provider(session.model_name) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + dspy.configure(lm=dspy_llm) + + rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # (Optional) You could pass `session.messages` to the pipeline for context + answer_text = await rag_pipeline.forward(question=prompt, db=db) + + # 4. Save the assistant's response to the database + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, session.model_name + + # --- Document Management (Unchanged) --- + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - """ - Adds a document to both the database and the vector store. - """ + """Adds a document to the database and vector store.""" + # ... (implementation is unchanged) try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, @@ -38,61 +90,27 @@ ) db.add(vector_metadata) db.commit() - print(f"Document with ID {document_db.id} successfully added.") return document_db.id except SQLAlchemyError as e: db.rollback() - # **FIXED LINE**: Added the missing '})' - print(f"Database error while adding document: {e}") raise - except Exception as e: - db.rollback() - print(f"An unexpected error occurred: {e}") - raise - - async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str: - """ - Generates a response to a user prompt by orchestrating the RAG pipeline. - """ - print(f"Received Prompt: {prompt}") - if not prompt or not prompt.strip(): - raise ValueError("The prompt cannot be null, empty, or contain only whitespace.") - - llm_provider_instance = get_llm_provider(model) - dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model) - dspy.configure(lm=dspy_llm_provider) - - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - answer = await rag_pipeline.forward(question=prompt, db=db) - - return answer def get_all_documents(self, db: Session) -> List[models.Document]: - """ - Retrieves all documents from the database. - """ - try: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - except SQLAlchemyError as e: - print(f"Database error while retrieving documents: {e}") - raise + """Retrieves all documents from the database.""" + # ... (implementation is unchanged) + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + def delete_document(self, db: Session, document_id: int) -> int: - """ - Deletes a document and its associated vector metadata from the database. - Returns the ID of the deleted document, or None if not found. - """ + """Deletes a document from the database.""" + # ... (implementation is unchanged) try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: return None - db.delete(doc_to_delete) db.commit() - return document_id except SQLAlchemyError as e: db.rollback() - print(f"Database error while deleting document: {e}") raise \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 0b84764..dd20190 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -5,9 +5,9 @@ BASE_URL = "http://127.0.0.1:8000" TEST_PROMPT = "Explain the theory of relativity in one sentence." -# This global variable will store the ID of a document created by one test -# so that it can be used by subsequent tests for listing and deletion. +# Global variables to pass state between sequential tests created_document_id = None +created_session_id = None async def test_root_endpoint(): """Tests if the root endpoint is alive.""" @@ -18,33 +18,48 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# --- Chat Endpoint Tests --- +# --- Session and Chat Lifecycle Tests --- -async def test_chat_endpoint_deepseek(): - """Tests the /chat endpoint using the 'deepseek' model.""" - print("\n--- Running test_chat_endpoint_deepseek ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "deepseek"} +async def test_create_session(): + """ + Tests creating a new chat session and saves the ID for the next test. + """ + global created_session_id + print("\n--- Running test_create_session ---") + url = f"{BASE_URL}/sessions" + payload = {"user_id": "integration_tester", "model": "deepseek"} + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Failed to create session. Response: {response.text}" + response_data = response.json() + assert "id" in response_data + assert response_data["user_id"] == "integration_tester" + assert response_data["model_name"] == "deepseek" + + created_session_id = response_data["id"] + print(f"✅ Session created successfully with ID: {created_session_id}") + +async def test_chat_in_session(): + """ + Tests sending a message within the session created by the previous test. + """ + print("\n--- Running test_chat_in_session ---") + assert created_session_id is not None, "Session ID was not set by the create_session test." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + payload = {"prompt": TEST_PROMPT} + async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "deepseek" - print("✅ DeepSeek chat test passed.") - -async def test_chat_endpoint_gemini(): - """Tests the /chat endpoint using the 'gemini' model.""" - print("\n--- Running test_chat_endpoint_gemini ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "gemini"} - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "gemini" - print("✅ Gemini chat test passed.") + + assert response.status_code == 200, f"Chat request failed. Response: {response.text}" + response_data = response.json() + assert "answer" in response_data + assert len(response_data["answer"]) > 0 + assert response_data["model_used"] == "deepseek" + print("✅ Chat in session test passed.") # --- Document Management Lifecycle Tests --- @@ -56,6 +71,7 @@ print("\n--- Running test_add_document (for lifecycle) ---") url = f"{BASE_URL}/documents" doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=doc_data) @@ -64,7 +80,6 @@ message = response_data.get("message", "") assert "added successfully with ID" in message - # Extract the ID from the success message for the next tests try: created_document_id = int(message.split(" with ID ")[-1]) except (ValueError, IndexError): @@ -74,8 +89,7 @@ async def test_list_documents(): """ - Tests the GET /documents endpoint to ensure it returns a list - that includes the document created in the previous test. + Tests listing documents to ensure the previously created one appears. """ print("\n--- Running test_list_documents ---") assert created_document_id is not None, "Document ID was not set by the add test." @@ -88,15 +102,13 @@ response_data = response.json() assert "documents" in response_data - # Check if the list of documents contains the one we just created ids_in_response = {doc["id"] for doc in response_data["documents"]} - assert created_document_id in ids_in_response, f"Document ID {created_document_id} not found in list." + assert created_document_id in ids_in_response print("✅ Document list test passed.") async def test_delete_document(): """ - Tests the DELETE /documents/{id} endpoint to remove the document - created at the start of the lifecycle tests. + Tests deleting the document created at the start of the lifecycle. """ print("\n--- Running test_delete_document ---") assert created_document_id is not None, "Document ID was not set by the add test." diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index ae9243b..0a2cc57 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any -from sqlalchemy.orm import Session +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy @@ -12,25 +12,77 @@ class RAGService: """ - Service class for managing the RAG (Retrieval-Augmented Generation) pipeline. - This class acts as a high-level orchestrator. + Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """ + Creates a new chat session in the database. + """ + try: + # Create a default title; this could be updated later by the AI + new_session = models.Session( + user_id=user_id, + model_name=model, + title=f"New Chat Session" + ) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + if not prompt or not prompt.strip(): + raise ValueError("Prompt cannot be empty.") + + # 1. Find the session and its history + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + # 2. Save the user's new message to the database + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + # 3. Configure DSPy with the session's model and execute the pipeline + llm_provider = get_llm_provider(session.model_name) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + dspy.configure(lm=dspy_llm) + + rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # (Optional) You could pass `session.messages` to the pipeline for context + answer_text = await rag_pipeline.forward(question=prompt, db=db) + + # 4. Save the assistant's response to the database + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, session.model_name + + # --- Document Management (Unchanged) --- + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - """ - Adds a document to both the database and the vector store. - """ + """Adds a document to the database and vector store.""" + # ... (implementation is unchanged) try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, @@ -38,61 +90,27 @@ ) db.add(vector_metadata) db.commit() - print(f"Document with ID {document_db.id} successfully added.") return document_db.id except SQLAlchemyError as e: db.rollback() - # **FIXED LINE**: Added the missing '})' - print(f"Database error while adding document: {e}") raise - except Exception as e: - db.rollback() - print(f"An unexpected error occurred: {e}") - raise - - async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str: - """ - Generates a response to a user prompt by orchestrating the RAG pipeline. - """ - print(f"Received Prompt: {prompt}") - if not prompt or not prompt.strip(): - raise ValueError("The prompt cannot be null, empty, or contain only whitespace.") - - llm_provider_instance = get_llm_provider(model) - dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model) - dspy.configure(lm=dspy_llm_provider) - - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - answer = await rag_pipeline.forward(question=prompt, db=db) - - return answer def get_all_documents(self, db: Session) -> List[models.Document]: - """ - Retrieves all documents from the database. - """ - try: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - except SQLAlchemyError as e: - print(f"Database error while retrieving documents: {e}") - raise + """Retrieves all documents from the database.""" + # ... (implementation is unchanged) + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + def delete_document(self, db: Session, document_id: int) -> int: - """ - Deletes a document and its associated vector metadata from the database. - Returns the ID of the deleted document, or None if not found. - """ + """Deletes a document from the database.""" + # ... (implementation is unchanged) try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: return None - db.delete(doc_to_delete) db.commit() - return document_id except SQLAlchemyError as e: db.rollback() - print(f"Database error while deleting document: {e}") raise \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 0b84764..dd20190 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -5,9 +5,9 @@ BASE_URL = "http://127.0.0.1:8000" TEST_PROMPT = "Explain the theory of relativity in one sentence." -# This global variable will store the ID of a document created by one test -# so that it can be used by subsequent tests for listing and deletion. +# Global variables to pass state between sequential tests created_document_id = None +created_session_id = None async def test_root_endpoint(): """Tests if the root endpoint is alive.""" @@ -18,33 +18,48 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# --- Chat Endpoint Tests --- +# --- Session and Chat Lifecycle Tests --- -async def test_chat_endpoint_deepseek(): - """Tests the /chat endpoint using the 'deepseek' model.""" - print("\n--- Running test_chat_endpoint_deepseek ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "deepseek"} +async def test_create_session(): + """ + Tests creating a new chat session and saves the ID for the next test. + """ + global created_session_id + print("\n--- Running test_create_session ---") + url = f"{BASE_URL}/sessions" + payload = {"user_id": "integration_tester", "model": "deepseek"} + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Failed to create session. Response: {response.text}" + response_data = response.json() + assert "id" in response_data + assert response_data["user_id"] == "integration_tester" + assert response_data["model_name"] == "deepseek" + + created_session_id = response_data["id"] + print(f"✅ Session created successfully with ID: {created_session_id}") + +async def test_chat_in_session(): + """ + Tests sending a message within the session created by the previous test. + """ + print("\n--- Running test_chat_in_session ---") + assert created_session_id is not None, "Session ID was not set by the create_session test." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + payload = {"prompt": TEST_PROMPT} + async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "deepseek" - print("✅ DeepSeek chat test passed.") - -async def test_chat_endpoint_gemini(): - """Tests the /chat endpoint using the 'gemini' model.""" - print("\n--- Running test_chat_endpoint_gemini ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "gemini"} - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "gemini" - print("✅ Gemini chat test passed.") + + assert response.status_code == 200, f"Chat request failed. Response: {response.text}" + response_data = response.json() + assert "answer" in response_data + assert len(response_data["answer"]) > 0 + assert response_data["model_used"] == "deepseek" + print("✅ Chat in session test passed.") # --- Document Management Lifecycle Tests --- @@ -56,6 +71,7 @@ print("\n--- Running test_add_document (for lifecycle) ---") url = f"{BASE_URL}/documents" doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=doc_data) @@ -64,7 +80,6 @@ message = response_data.get("message", "") assert "added successfully with ID" in message - # Extract the ID from the success message for the next tests try: created_document_id = int(message.split(" with ID ")[-1]) except (ValueError, IndexError): @@ -74,8 +89,7 @@ async def test_list_documents(): """ - Tests the GET /documents endpoint to ensure it returns a list - that includes the document created in the previous test. + Tests listing documents to ensure the previously created one appears. """ print("\n--- Running test_list_documents ---") assert created_document_id is not None, "Document ID was not set by the add test." @@ -88,15 +102,13 @@ response_data = response.json() assert "documents" in response_data - # Check if the list of documents contains the one we just created ids_in_response = {doc["id"] for doc in response_data["documents"]} - assert created_document_id in ids_in_response, f"Document ID {created_document_id} not found in list." + assert created_document_id in ids_in_response print("✅ Document list test passed.") async def test_delete_document(): """ - Tests the DELETE /documents/{id} endpoint to remove the document - created at the start of the lifecycle tests. + Tests deleting the document created at the start of the lifecycle. """ print("\n--- Running test_delete_document ---") assert created_document_id is not None, "Document ID was not set by the add test." diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 9a3fda9..e650342 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -9,13 +9,11 @@ from app.core.services import RAGService from app.api.dependencies import get_db from app.api.routes import create_api_router +from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """ - Pytest fixture to create a TestClient with a fully mocked environment. - This creates an isolated FastAPI app for each test. - """ + """Pytest fixture to create a TestClient with a fully mocked environment.""" test_app = FastAPI() mock_rag_service = MagicMock(spec=RAGService) mock_db_session = MagicMock(spec=Session) @@ -29,44 +27,48 @@ yield TestClient(test_app), mock_rag_service -# --- Root Endpoint --- +# --- General Endpoint --- def test_read_root(client): - """Tests the root endpoint to ensure the API is running.""" + """Tests the root endpoint.""" test_client, _ = client response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} -# --- Chat Endpoints --- +# --- Session and Chat Endpoints --- -def test_chat_handler_success(client): - """Tests a successful chat request.""" +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = "This is a mocked RAG response." + # Arrange: Mock the service to return a new session object + mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) + mock_rag_service.create_session.return_value = mock_session - response = test_client.post("/chat", json={"prompt": "Hello!", "model": "gemini"}) + # Act + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + # Assert assert response.status_code == 200 - assert response.json()["answer"] == "This is a mocked RAG response." - mock_rag_service.chat_with_rag.assert_called_once() + response_data = response.json() + assert response_data["id"] == 1 + assert response_data["user_id"] == "test_user" + assert response_data["model_name"] == "gemini" + mock_rag_service.create_session.assert_called_once() -def test_chat_handler_validation_error(client): - """Tests the chat endpoint with invalid data (an empty prompt).""" - test_client, _ = client - response = test_client.post("/chat", json={"prompt": "", "model": "deepseek"}) - assert response.status_code == 422 - -def test_chat_handler_internal_error(client): - """Tests the chat endpoint when the RAG service raises an exception.""" +def test_chat_in_session_success(client): + """Tests sending a message in an existing session.""" test_client, mock_rag_service = client - error_message = "LLM provider is down" - mock_rag_service.chat_with_rag.side_effect = Exception(error_message) + # Arrange: Mock the chat service to return a tuple (answer, model_name) + mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") - response = test_client.post("/chat", json={"prompt": "A valid question", "model": "deepseek"}) + # Act: Send a chat message to a hypothetical session 42 + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - assert response.status_code == 500 - assert error_message in response.json()["detail"] + # Assert + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + mock_rag_service.chat_with_rag.assert_called_once() # --- Document Endpoints --- @@ -76,7 +78,6 @@ mock_rag_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} - # FIX: Use the correct plural URL '/documents' response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 @@ -86,10 +87,10 @@ def test_get_documents_success(client): """Tests successfully retrieving a list of all documents.""" test_client, mock_rag_service = client - # Arrange: Create mock document data that the service will return + # Arrange: Your mock service should return objects that match the schema attributes mock_docs = [ - {"id": 1, "title": "Doc One", "status": "ready", "created_at": datetime.now()}, - {"id": 2, "title": "Doc Two", "status": "processing", "created_at": datetime.now()} + models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), + models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_rag_service.get_all_documents.return_value = mock_docs @@ -106,13 +107,10 @@ def test_delete_document_success(client): """Tests successfully deleting a document.""" test_client, mock_rag_service = client - # Arrange: Mock the service to confirm deletion of document ID 42 mock_rag_service.delete_document.return_value = 42 - # Act response = test_client.delete("/documents/42") - # Assert assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 @@ -121,12 +119,9 @@ def test_delete_document_not_found(client): """Tests attempting to delete a document that does not exist.""" test_client, mock_rag_service = client - # Arrange: Mock the service to indicate the document was not found mock_rag_service.delete_document.return_value = None - # Act response = test_client.delete("/documents/999") - # Assert assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index ae9243b..0a2cc57 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any -from sqlalchemy.orm import Session +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy @@ -12,25 +12,77 @@ class RAGService: """ - Service class for managing the RAG (Retrieval-Augmented Generation) pipeline. - This class acts as a high-level orchestrator. + Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """ + Creates a new chat session in the database. + """ + try: + # Create a default title; this could be updated later by the AI + new_session = models.Session( + user_id=user_id, + model_name=model, + title=f"New Chat Session" + ) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + if not prompt or not prompt.strip(): + raise ValueError("Prompt cannot be empty.") + + # 1. Find the session and its history + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + # 2. Save the user's new message to the database + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + # 3. Configure DSPy with the session's model and execute the pipeline + llm_provider = get_llm_provider(session.model_name) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + dspy.configure(lm=dspy_llm) + + rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # (Optional) You could pass `session.messages` to the pipeline for context + answer_text = await rag_pipeline.forward(question=prompt, db=db) + + # 4. Save the assistant's response to the database + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, session.model_name + + # --- Document Management (Unchanged) --- + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - """ - Adds a document to both the database and the vector store. - """ + """Adds a document to the database and vector store.""" + # ... (implementation is unchanged) try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, @@ -38,61 +90,27 @@ ) db.add(vector_metadata) db.commit() - print(f"Document with ID {document_db.id} successfully added.") return document_db.id except SQLAlchemyError as e: db.rollback() - # **FIXED LINE**: Added the missing '})' - print(f"Database error while adding document: {e}") raise - except Exception as e: - db.rollback() - print(f"An unexpected error occurred: {e}") - raise - - async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str: - """ - Generates a response to a user prompt by orchestrating the RAG pipeline. - """ - print(f"Received Prompt: {prompt}") - if not prompt or not prompt.strip(): - raise ValueError("The prompt cannot be null, empty, or contain only whitespace.") - - llm_provider_instance = get_llm_provider(model) - dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model) - dspy.configure(lm=dspy_llm_provider) - - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - answer = await rag_pipeline.forward(question=prompt, db=db) - - return answer def get_all_documents(self, db: Session) -> List[models.Document]: - """ - Retrieves all documents from the database. - """ - try: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - except SQLAlchemyError as e: - print(f"Database error while retrieving documents: {e}") - raise + """Retrieves all documents from the database.""" + # ... (implementation is unchanged) + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + def delete_document(self, db: Session, document_id: int) -> int: - """ - Deletes a document and its associated vector metadata from the database. - Returns the ID of the deleted document, or None if not found. - """ + """Deletes a document from the database.""" + # ... (implementation is unchanged) try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: return None - db.delete(doc_to_delete) db.commit() - return document_id except SQLAlchemyError as e: db.rollback() - print(f"Database error while deleting document: {e}") raise \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 0b84764..dd20190 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -5,9 +5,9 @@ BASE_URL = "http://127.0.0.1:8000" TEST_PROMPT = "Explain the theory of relativity in one sentence." -# This global variable will store the ID of a document created by one test -# so that it can be used by subsequent tests for listing and deletion. +# Global variables to pass state between sequential tests created_document_id = None +created_session_id = None async def test_root_endpoint(): """Tests if the root endpoint is alive.""" @@ -18,33 +18,48 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# --- Chat Endpoint Tests --- +# --- Session and Chat Lifecycle Tests --- -async def test_chat_endpoint_deepseek(): - """Tests the /chat endpoint using the 'deepseek' model.""" - print("\n--- Running test_chat_endpoint_deepseek ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "deepseek"} +async def test_create_session(): + """ + Tests creating a new chat session and saves the ID for the next test. + """ + global created_session_id + print("\n--- Running test_create_session ---") + url = f"{BASE_URL}/sessions" + payload = {"user_id": "integration_tester", "model": "deepseek"} + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Failed to create session. Response: {response.text}" + response_data = response.json() + assert "id" in response_data + assert response_data["user_id"] == "integration_tester" + assert response_data["model_name"] == "deepseek" + + created_session_id = response_data["id"] + print(f"✅ Session created successfully with ID: {created_session_id}") + +async def test_chat_in_session(): + """ + Tests sending a message within the session created by the previous test. + """ + print("\n--- Running test_chat_in_session ---") + assert created_session_id is not None, "Session ID was not set by the create_session test." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + payload = {"prompt": TEST_PROMPT} + async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "deepseek" - print("✅ DeepSeek chat test passed.") - -async def test_chat_endpoint_gemini(): - """Tests the /chat endpoint using the 'gemini' model.""" - print("\n--- Running test_chat_endpoint_gemini ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "gemini"} - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "gemini" - print("✅ Gemini chat test passed.") + + assert response.status_code == 200, f"Chat request failed. Response: {response.text}" + response_data = response.json() + assert "answer" in response_data + assert len(response_data["answer"]) > 0 + assert response_data["model_used"] == "deepseek" + print("✅ Chat in session test passed.") # --- Document Management Lifecycle Tests --- @@ -56,6 +71,7 @@ print("\n--- Running test_add_document (for lifecycle) ---") url = f"{BASE_URL}/documents" doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=doc_data) @@ -64,7 +80,6 @@ message = response_data.get("message", "") assert "added successfully with ID" in message - # Extract the ID from the success message for the next tests try: created_document_id = int(message.split(" with ID ")[-1]) except (ValueError, IndexError): @@ -74,8 +89,7 @@ async def test_list_documents(): """ - Tests the GET /documents endpoint to ensure it returns a list - that includes the document created in the previous test. + Tests listing documents to ensure the previously created one appears. """ print("\n--- Running test_list_documents ---") assert created_document_id is not None, "Document ID was not set by the add test." @@ -88,15 +102,13 @@ response_data = response.json() assert "documents" in response_data - # Check if the list of documents contains the one we just created ids_in_response = {doc["id"] for doc in response_data["documents"]} - assert created_document_id in ids_in_response, f"Document ID {created_document_id} not found in list." + assert created_document_id in ids_in_response print("✅ Document list test passed.") async def test_delete_document(): """ - Tests the DELETE /documents/{id} endpoint to remove the document - created at the start of the lifecycle tests. + Tests deleting the document created at the start of the lifecycle. """ print("\n--- Running test_delete_document ---") assert created_document_id is not None, "Document ID was not set by the add test." diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 9a3fda9..e650342 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -9,13 +9,11 @@ from app.core.services import RAGService from app.api.dependencies import get_db from app.api.routes import create_api_router +from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """ - Pytest fixture to create a TestClient with a fully mocked environment. - This creates an isolated FastAPI app for each test. - """ + """Pytest fixture to create a TestClient with a fully mocked environment.""" test_app = FastAPI() mock_rag_service = MagicMock(spec=RAGService) mock_db_session = MagicMock(spec=Session) @@ -29,44 +27,48 @@ yield TestClient(test_app), mock_rag_service -# --- Root Endpoint --- +# --- General Endpoint --- def test_read_root(client): - """Tests the root endpoint to ensure the API is running.""" + """Tests the root endpoint.""" test_client, _ = client response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} -# --- Chat Endpoints --- +# --- Session and Chat Endpoints --- -def test_chat_handler_success(client): - """Tests a successful chat request.""" +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = "This is a mocked RAG response." + # Arrange: Mock the service to return a new session object + mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) + mock_rag_service.create_session.return_value = mock_session - response = test_client.post("/chat", json={"prompt": "Hello!", "model": "gemini"}) + # Act + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + # Assert assert response.status_code == 200 - assert response.json()["answer"] == "This is a mocked RAG response." - mock_rag_service.chat_with_rag.assert_called_once() + response_data = response.json() + assert response_data["id"] == 1 + assert response_data["user_id"] == "test_user" + assert response_data["model_name"] == "gemini" + mock_rag_service.create_session.assert_called_once() -def test_chat_handler_validation_error(client): - """Tests the chat endpoint with invalid data (an empty prompt).""" - test_client, _ = client - response = test_client.post("/chat", json={"prompt": "", "model": "deepseek"}) - assert response.status_code == 422 - -def test_chat_handler_internal_error(client): - """Tests the chat endpoint when the RAG service raises an exception.""" +def test_chat_in_session_success(client): + """Tests sending a message in an existing session.""" test_client, mock_rag_service = client - error_message = "LLM provider is down" - mock_rag_service.chat_with_rag.side_effect = Exception(error_message) + # Arrange: Mock the chat service to return a tuple (answer, model_name) + mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") - response = test_client.post("/chat", json={"prompt": "A valid question", "model": "deepseek"}) + # Act: Send a chat message to a hypothetical session 42 + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - assert response.status_code == 500 - assert error_message in response.json()["detail"] + # Assert + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + mock_rag_service.chat_with_rag.assert_called_once() # --- Document Endpoints --- @@ -76,7 +78,6 @@ mock_rag_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} - # FIX: Use the correct plural URL '/documents' response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 @@ -86,10 +87,10 @@ def test_get_documents_success(client): """Tests successfully retrieving a list of all documents.""" test_client, mock_rag_service = client - # Arrange: Create mock document data that the service will return + # Arrange: Your mock service should return objects that match the schema attributes mock_docs = [ - {"id": 1, "title": "Doc One", "status": "ready", "created_at": datetime.now()}, - {"id": 2, "title": "Doc Two", "status": "processing", "created_at": datetime.now()} + models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), + models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_rag_service.get_all_documents.return_value = mock_docs @@ -106,13 +107,10 @@ def test_delete_document_success(client): """Tests successfully deleting a document.""" test_client, mock_rag_service = client - # Arrange: Mock the service to confirm deletion of document ID 42 mock_rag_service.delete_document.return_value = 42 - # Act response = test_client.delete("/documents/42") - # Assert assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 @@ -121,12 +119,9 @@ def test_delete_document_not_found(client): """Tests attempting to delete a document that does not exist.""" test_client, mock_rag_service = client - # Arrange: Mock the service to indicate the document was not found mock_rag_service.delete_document.return_value = None - # Act response = test_client.delete("/documents/999") - # Assert assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index 44b567a..dbfb3e5 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -1,60 +1,80 @@ +import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session -# Import the service being tested +# Import the service and its dependencies from app.core.services import RAGService - -# Import dependencies that need to be referenced in mocks +from app.db import models +from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider +@pytest.fixture +def rag_service(): + """Pytest fixture to create a RAGService instance with mocked dependencies.""" + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_retriever = MagicMock(spec=Retriever) + return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) + +# --- Session Management Tests --- + +def test_create_session(rag_service: RAGService): + """Tests that the create_session method correctly creates a new session.""" + # Arrange + mock_db = MagicMock(spec=Session) + + # Act + session = rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + + # Assert + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + mock_db.refresh.assert_called_once() + + # Check that the object passed to db.add was a Session instance + added_object = mock_db.add.call_args[0][0] + assert isinstance(added_object, models.Session) + assert added_object.user_id == "test_user" + assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') # Patched the new class name +@patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') -def test_rag_service_orchestration(mock_configure, mock_dspy_pipeline, mock_get_llm_provider): +def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests that RAGService.chat_with_rag correctly orchestrates its dependencies. - It should: - 1. Get the correct LLM provider. - 2. Configure DSPy with a wrapped provider. - 3. Instantiate and call the pipeline with the correct arguments. + Tests the full orchestration of a chat message within a session. """ # --- Arrange --- - # Mock the dependencies that RAGService uses + # Mock the database to return a session when queried + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=42, model_name="deepseek") + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Mock the LLM provider and the DSPy pipeline mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider - mock_db = MagicMock(spec=Session) - mock_retriever = MagicMock(spec=Retriever) - - # Mock the pipeline instance and its return value mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance - # Instantiate the service class we are testing - rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) - prompt = "Test prompt." - model = "deepseek" - # --- Act --- - response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model=model)) + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) # --- Assert --- - # 1. Assert that the correct LLM provider was requested - mock_get_llm_provider.assert_called_once_with(model) + # 1. Assert the session was fetched correctly + mock_db.query.assert_called_once_with(models.Session) - # 2. Assert that dspy was configured with a correctly wrapped provider - mock_configure.assert_called_once() - lm_instance = mock_configure.call_args.kwargs['lm'] - assert isinstance(lm_instance, DSPyLLMProvider) - assert lm_instance.provider == mock_llm_provider + # 2. Assert the user and assistant messages were saved + assert mock_db.add.call_count == 2 + assert mock_db.commit.call_count == 2 - # 3. Assert that the pipeline was instantiated and called correctly - mock_dspy_pipeline.assert_called_once_with(retrievers=[mock_retriever]) - mock_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db) - - # 4. Assert the final response is returned - assert response_text == "Final RAG response" \ No newline at end of file + # 3. Assert the RAG pipeline was orchestrated correctly + mock_get_llm_provider.assert_called_once_with("deepseek") + mock_dspy_pipeline.assert_called_once() + mock_pipeline_instance.forward.assert_called_once() + + # 4. Assert the correct response was returned + assert answer == "Final RAG response" + assert model_name == "deepseek" \ No newline at end of file diff --git a/ai-hub/app/api/routes.py b/ai-hub/app/api/routes.py index 9b64f8f..a632398 100644 --- a/ai-hub/app/api/routes.py +++ b/ai-hub/app/api/routes.py @@ -14,37 +14,58 @@ def read_root(): return {"status": "AI Model Hub is running!"} - # --- Chat Endpoint --- - @router.post("/chat", response_model=schemas.ChatResponse, summary="Get AI-Generated Response", tags=["Chat"]) - async def chat_handler( - request: schemas.ChatRequest, + # --- Session Management Endpoints --- + + @router.post("/sessions", response_model=schemas.Session, summary="Create a New Chat Session", tags=["Sessions"]) + def create_session( + request: schemas.SessionCreate, db: Session = Depends(get_db) ): """ - Handles a chat request using the prompt and model from the request body. + Starts a new conversation session and returns its details. + The returned session_id should be used for subsequent chat messages. """ try: - response_text = await rag_service.chat_with_rag( - db=db, - prompt=request.prompt, + # Note: You'll need to add a `create_session` method to your RAGService. + new_session = rag_service.create_session( + db=db, + user_id=request.user_id, model=request.model ) - return schemas.ChatResponse(answer=response_text, model_used=request.model) + return new_session + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create session: {e}") + + @router.post("/sessions/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session", tags=["Sessions"]) + async def chat_in_session( + session_id: int, + request: schemas.ChatRequest, # We can reuse ChatRequest + db: Session = Depends(get_db) + ): + """ + Sends a message within an existing session and gets a contextual response. + The model used is determined by the session, not the request. + """ + try: + # Note: You'll need to update `chat_with_rag` to accept a session_id + # and use it to retrieve chat history for context. + response_text, model_used = await rag_service.chat_with_rag( + db=db, + session_id=session_id, + prompt=request.prompt + ) + return schemas.ChatResponse(answer=response_text, model_used=model_used) except Exception as e: raise HTTPException( status_code=500, - detail=f"An unexpected error occurred with the {request.model} API: {e}" + detail=f"An error occurred during chat: {e}" ) # --- Document Management Endpoints --- + # (These endpoints remain unchanged) + @router.post("/documents", response_model=schemas.DocumentResponse, summary="Add a New Document", tags=["Documents"]) - def add_document( - doc: schemas.DocumentCreate, - db: Session = Depends(get_db) - ): - """ - Adds a new document to the database and vector store. - """ + def add_document(doc: schemas.DocumentCreate, db: Session = Depends(get_db)): try: doc_data = doc.model_dump() document_id = rag_service.add_document(db=db, doc_data=doc_data) @@ -56,25 +77,15 @@ @router.get("/documents", response_model=schemas.DocumentListResponse, summary="List All Documents", tags=["Documents"]) def get_documents(db: Session = Depends(get_db)): - """ - Retrieves a list of all documents in the knowledge base. - """ try: documents_from_db = rag_service.get_all_documents(db=db) - # **SIMPLIFICATION**: Just return the list of ORM objects. - # FastAPI will use your Pydantic schema's new ORM mode to convert it. return {"documents": documents_from_db} except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/documents/{document_id}", response_model=schemas.DocumentDeleteResponse, summary="Delete a Document", tags=["Documents"]) def delete_document(document_id: int, db: Session = Depends(get_db)): - """ - Deletes a document from the database and vector store by its ID. - """ try: - # Note: You'll need to implement the `delete_document` method in your RAGService. - # This method should return the ID of the deleted doc or raise an error if not found. deleted_id = rag_service.delete_document(db=db, document_id=document_id) if deleted_id is None: raise HTTPException(status_code=404, detail=f"Document with ID {document_id} not found.") @@ -84,7 +95,6 @@ document_id=deleted_id ) except HTTPException: - # Re-raise HTTPException to preserve the 404 status raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 5e30e68..d76cd7f 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -1,12 +1,13 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # <-- Add ConfigDict here from typing import List, Literal, Optional -from datetime import datetime # <-- Add this import +from datetime import datetime # --- Chat Schemas --- - class ChatRequest(BaseModel): """Defines the shape of a request to the /chat endpoint.""" prompt: str = Field(..., min_length=1) + # The 'model' is now part of the Session, but we can keep it here for stateless requests if needed. + # For session-based chat, this field might be ignored. model: Literal["deepseek", "gemini"] = Field("deepseek") class ChatResponse(BaseModel): @@ -15,9 +16,7 @@ model_used: str # --- Document Schemas --- - class DocumentCreate(BaseModel): - """Defines the shape for creating a new document.""" title: str text: str source_url: Optional[str] = None @@ -25,7 +24,6 @@ user_id: str = "default_user" class DocumentResponse(BaseModel): - """Defines the response after creating a document.""" message: str class DocumentInfo(BaseModel): @@ -34,10 +32,26 @@ source_url: Optional[str] = None status: str created_at: datetime + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): documents: List[DocumentInfo] class DocumentDeleteResponse(BaseModel): message: str - document_id: int \ No newline at end of file + document_id: int + +# --- Session Schemas --- +class SessionCreate(BaseModel): + """Defines the shape for starting a new conversation session.""" + user_id: str + model: Literal["deepseek", "gemini"] = "deepseek" + +class Session(BaseModel): + """Defines the shape of a session object returned by the API.""" + id: int + user_id: str + title: str + model_name: str + created_at: datetime + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/ai-hub/app/core/services.py b/ai-hub/app/core/services.py index ae9243b..0a2cc57 100644 --- a/ai-hub/app/core/services.py +++ b/ai-hub/app/core/services.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any -from sqlalchemy.orm import Session +from typing import List, Dict, Any, Tuple +from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy @@ -12,25 +12,77 @@ class RAGService: """ - Service class for managing the RAG (Retrieval-Augmented Generation) pipeline. - This class acts as a high-level orchestrator. + Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers + # --- Session Management --- + + def create_session(self, db: Session, user_id: str, model: str) -> models.Session: + """ + Creates a new chat session in the database. + """ + try: + # Create a default title; this could be updated later by the AI + new_session = models.Session( + user_id=user_id, + model_name=model, + title=f"New Chat Session" + ) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + except SQLAlchemyError as e: + db.rollback() + raise + + async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]: + """ + Handles a message within a session, including saving history and getting a response. + """ + if not prompt or not prompt.strip(): + raise ValueError("Prompt cannot be empty.") + + # 1. Find the session and its history + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: + raise ValueError(f"Session with ID {session_id} not found.") + + # 2. Save the user's new message to the database + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + db.add(user_message) + db.commit() + + # 3. Configure DSPy with the session's model and execute the pipeline + llm_provider = get_llm_provider(session.model_name) + dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name) + dspy.configure(lm=dspy_llm) + + rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) + # (Optional) You could pass `session.messages` to the pipeline for context + answer_text = await rag_pipeline.forward(question=prompt, db=db) + + # 4. Save the assistant's response to the database + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + db.add(assistant_message) + db.commit() + + return answer_text, session.model_name + + # --- Document Management (Unchanged) --- + def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: - """ - Adds a document to both the database and the vector store. - """ + """Adds a document to the database and vector store.""" + # ... (implementation is unchanged) try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) - faiss_index = self.vector_store.add_document(document_db.text) - vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, @@ -38,61 +90,27 @@ ) db.add(vector_metadata) db.commit() - print(f"Document with ID {document_db.id} successfully added.") return document_db.id except SQLAlchemyError as e: db.rollback() - # **FIXED LINE**: Added the missing '})' - print(f"Database error while adding document: {e}") raise - except Exception as e: - db.rollback() - print(f"An unexpected error occurred: {e}") - raise - - async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str: - """ - Generates a response to a user prompt by orchestrating the RAG pipeline. - """ - print(f"Received Prompt: {prompt}") - if not prompt or not prompt.strip(): - raise ValueError("The prompt cannot be null, empty, or contain only whitespace.") - - llm_provider_instance = get_llm_provider(model) - dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model) - dspy.configure(lm=dspy_llm_provider) - - rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) - answer = await rag_pipeline.forward(question=prompt, db=db) - - return answer def get_all_documents(self, db: Session) -> List[models.Document]: - """ - Retrieves all documents from the database. - """ - try: - return db.query(models.Document).order_by(models.Document.created_at.desc()).all() - except SQLAlchemyError as e: - print(f"Database error while retrieving documents: {e}") - raise + """Retrieves all documents from the database.""" + # ... (implementation is unchanged) + return db.query(models.Document).order_by(models.Document.created_at.desc()).all() + def delete_document(self, db: Session, document_id: int) -> int: - """ - Deletes a document and its associated vector metadata from the database. - Returns the ID of the deleted document, or None if not found. - """ + """Deletes a document from the database.""" + # ... (implementation is unchanged) try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() - if not doc_to_delete: return None - db.delete(doc_to_delete) db.commit() - return document_id except SQLAlchemyError as e: db.rollback() - print(f"Database error while deleting document: {e}") raise \ No newline at end of file diff --git a/ai-hub/integration_tests/test_integration.py b/ai-hub/integration_tests/test_integration.py index 0b84764..dd20190 100644 --- a/ai-hub/integration_tests/test_integration.py +++ b/ai-hub/integration_tests/test_integration.py @@ -5,9 +5,9 @@ BASE_URL = "http://127.0.0.1:8000" TEST_PROMPT = "Explain the theory of relativity in one sentence." -# This global variable will store the ID of a document created by one test -# so that it can be used by subsequent tests for listing and deletion. +# Global variables to pass state between sequential tests created_document_id = None +created_session_id = None async def test_root_endpoint(): """Tests if the root endpoint is alive.""" @@ -18,33 +18,48 @@ assert response.json() == {"status": "AI Model Hub is running!"} print("✅ Root endpoint test passed.") -# --- Chat Endpoint Tests --- +# --- Session and Chat Lifecycle Tests --- -async def test_chat_endpoint_deepseek(): - """Tests the /chat endpoint using the 'deepseek' model.""" - print("\n--- Running test_chat_endpoint_deepseek ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "deepseek"} +async def test_create_session(): + """ + Tests creating a new chat session and saves the ID for the next test. + """ + global created_session_id + print("\n--- Running test_create_session ---") + url = f"{BASE_URL}/sessions" + payload = {"user_id": "integration_tester", "model": "deepseek"} + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload) + + assert response.status_code == 200, f"Failed to create session. Response: {response.text}" + response_data = response.json() + assert "id" in response_data + assert response_data["user_id"] == "integration_tester" + assert response_data["model_name"] == "deepseek" + + created_session_id = response_data["id"] + print(f"✅ Session created successfully with ID: {created_session_id}") + +async def test_chat_in_session(): + """ + Tests sending a message within the session created by the previous test. + """ + print("\n--- Running test_chat_in_session ---") + assert created_session_id is not None, "Session ID was not set by the create_session test." + + url = f"{BASE_URL}/sessions/{created_session_id}/chat" + payload = {"prompt": TEST_PROMPT} + async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "deepseek" - print("✅ DeepSeek chat test passed.") - -async def test_chat_endpoint_gemini(): - """Tests the /chat endpoint using the 'gemini' model.""" - print("\n--- Running test_chat_endpoint_gemini ---") - url = f"{BASE_URL}/chat" - payload = {"prompt": TEST_PROMPT, "model": "gemini"} - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post(url, json=payload) - assert response.status_code == 200 - data = response.json() - assert "answer" in data - assert data["model_used"] == "gemini" - print("✅ Gemini chat test passed.") + + assert response.status_code == 200, f"Chat request failed. Response: {response.text}" + response_data = response.json() + assert "answer" in response_data + assert len(response_data["answer"]) > 0 + assert response_data["model_used"] == "deepseek" + print("✅ Chat in session test passed.") # --- Document Management Lifecycle Tests --- @@ -56,6 +71,7 @@ print("\n--- Running test_add_document (for lifecycle) ---") url = f"{BASE_URL}/documents" doc_data = {"title": "Lifecycle Test Doc", "text": "This doc will be listed and deleted."} + async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=doc_data) @@ -64,7 +80,6 @@ message = response_data.get("message", "") assert "added successfully with ID" in message - # Extract the ID from the success message for the next tests try: created_document_id = int(message.split(" with ID ")[-1]) except (ValueError, IndexError): @@ -74,8 +89,7 @@ async def test_list_documents(): """ - Tests the GET /documents endpoint to ensure it returns a list - that includes the document created in the previous test. + Tests listing documents to ensure the previously created one appears. """ print("\n--- Running test_list_documents ---") assert created_document_id is not None, "Document ID was not set by the add test." @@ -88,15 +102,13 @@ response_data = response.json() assert "documents" in response_data - # Check if the list of documents contains the one we just created ids_in_response = {doc["id"] for doc in response_data["documents"]} - assert created_document_id in ids_in_response, f"Document ID {created_document_id} not found in list." + assert created_document_id in ids_in_response print("✅ Document list test passed.") async def test_delete_document(): """ - Tests the DELETE /documents/{id} endpoint to remove the document - created at the start of the lifecycle tests. + Tests deleting the document created at the start of the lifecycle. """ print("\n--- Running test_delete_document ---") assert created_document_id is not None, "Document ID was not set by the add test." diff --git a/ai-hub/tests/api/test_routes.py b/ai-hub/tests/api/test_routes.py index 9a3fda9..e650342 100644 --- a/ai-hub/tests/api/test_routes.py +++ b/ai-hub/tests/api/test_routes.py @@ -9,13 +9,11 @@ from app.core.services import RAGService from app.api.dependencies import get_db from app.api.routes import create_api_router +from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): - """ - Pytest fixture to create a TestClient with a fully mocked environment. - This creates an isolated FastAPI app for each test. - """ + """Pytest fixture to create a TestClient with a fully mocked environment.""" test_app = FastAPI() mock_rag_service = MagicMock(spec=RAGService) mock_db_session = MagicMock(spec=Session) @@ -29,44 +27,48 @@ yield TestClient(test_app), mock_rag_service -# --- Root Endpoint --- +# --- General Endpoint --- def test_read_root(client): - """Tests the root endpoint to ensure the API is running.""" + """Tests the root endpoint.""" test_client, _ = client response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} -# --- Chat Endpoints --- +# --- Session and Chat Endpoints --- -def test_chat_handler_success(client): - """Tests a successful chat request.""" +def test_create_session_success(client): + """Tests successfully creating a new chat session.""" test_client, mock_rag_service = client - mock_rag_service.chat_with_rag.return_value = "This is a mocked RAG response." + # Arrange: Mock the service to return a new session object + mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) + mock_rag_service.create_session.return_value = mock_session - response = test_client.post("/chat", json={"prompt": "Hello!", "model": "gemini"}) + # Act + response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + # Assert assert response.status_code == 200 - assert response.json()["answer"] == "This is a mocked RAG response." - mock_rag_service.chat_with_rag.assert_called_once() + response_data = response.json() + assert response_data["id"] == 1 + assert response_data["user_id"] == "test_user" + assert response_data["model_name"] == "gemini" + mock_rag_service.create_session.assert_called_once() -def test_chat_handler_validation_error(client): - """Tests the chat endpoint with invalid data (an empty prompt).""" - test_client, _ = client - response = test_client.post("/chat", json={"prompt": "", "model": "deepseek"}) - assert response.status_code == 422 - -def test_chat_handler_internal_error(client): - """Tests the chat endpoint when the RAG service raises an exception.""" +def test_chat_in_session_success(client): + """Tests sending a message in an existing session.""" test_client, mock_rag_service = client - error_message = "LLM provider is down" - mock_rag_service.chat_with_rag.side_effect = Exception(error_message) + # Arrange: Mock the chat service to return a tuple (answer, model_name) + mock_rag_service.chat_with_rag.return_value = ("Mocked response", "deepseek") - response = test_client.post("/chat", json={"prompt": "A valid question", "model": "deepseek"}) + # Act: Send a chat message to a hypothetical session 42 + response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) - assert response.status_code == 500 - assert error_message in response.json()["detail"] + # Assert + assert response.status_code == 200 + assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} + mock_rag_service.chat_with_rag.assert_called_once() # --- Document Endpoints --- @@ -76,7 +78,6 @@ mock_rag_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} - # FIX: Use the correct plural URL '/documents' response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 @@ -86,10 +87,10 @@ def test_get_documents_success(client): """Tests successfully retrieving a list of all documents.""" test_client, mock_rag_service = client - # Arrange: Create mock document data that the service will return + # Arrange: Your mock service should return objects that match the schema attributes mock_docs = [ - {"id": 1, "title": "Doc One", "status": "ready", "created_at": datetime.now()}, - {"id": 2, "title": "Doc Two", "status": "processing", "created_at": datetime.now()} + models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), + models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_rag_service.get_all_documents.return_value = mock_docs @@ -106,13 +107,10 @@ def test_delete_document_success(client): """Tests successfully deleting a document.""" test_client, mock_rag_service = client - # Arrange: Mock the service to confirm deletion of document ID 42 mock_rag_service.delete_document.return_value = 42 - # Act response = test_client.delete("/documents/42") - # Assert assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 @@ -121,12 +119,9 @@ def test_delete_document_not_found(client): """Tests attempting to delete a document that does not exist.""" test_client, mock_rag_service = client - # Arrange: Mock the service to indicate the document was not found mock_rag_service.delete_document.return_value = None - # Act response = test_client.delete("/documents/999") - # Assert assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." \ No newline at end of file diff --git a/ai-hub/tests/core/test_services.py b/ai-hub/tests/core/test_services.py index 44b567a..dbfb3e5 100644 --- a/ai-hub/tests/core/test_services.py +++ b/ai-hub/tests/core/test_services.py @@ -1,60 +1,80 @@ +import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session -# Import the service being tested +# Import the service and its dependencies from app.core.services import RAGService - -# Import dependencies that need to be referenced in mocks +from app.db import models +from app.core.vector_store import FaissVectorStore from app.core.retrievers import Retriever from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider from app.core.llm_providers import LLMProvider +@pytest.fixture +def rag_service(): + """Pytest fixture to create a RAGService instance with mocked dependencies.""" + mock_vector_store = MagicMock(spec=FaissVectorStore) + mock_retriever = MagicMock(spec=Retriever) + return RAGService(vector_store=mock_vector_store, retrievers=[mock_retriever]) + +# --- Session Management Tests --- + +def test_create_session(rag_service: RAGService): + """Tests that the create_session method correctly creates a new session.""" + # Arrange + mock_db = MagicMock(spec=Session) + + # Act + session = rag_service.create_session(db=mock_db, user_id="test_user", model="gemini") + + # Assert + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + mock_db.refresh.assert_called_once() + + # Check that the object passed to db.add was a Session instance + added_object = mock_db.add.call_args[0][0] + assert isinstance(added_object, models.Session) + assert added_object.user_id == "test_user" + assert added_object.model_name == "gemini" @patch('app.core.services.get_llm_provider') -@patch('app.core.services.DspyRagPipeline') # Patched the new class name +@patch('app.core.services.DspyRagPipeline') @patch('dspy.configure') -def test_rag_service_orchestration(mock_configure, mock_dspy_pipeline, mock_get_llm_provider): +def test_chat_with_rag_success(mock_configure, mock_dspy_pipeline, mock_get_llm_provider, rag_service: RAGService): """ - Tests that RAGService.chat_with_rag correctly orchestrates its dependencies. - It should: - 1. Get the correct LLM provider. - 2. Configure DSPy with a wrapped provider. - 3. Instantiate and call the pipeline with the correct arguments. + Tests the full orchestration of a chat message within a session. """ # --- Arrange --- - # Mock the dependencies that RAGService uses + # Mock the database to return a session when queried + mock_db = MagicMock(spec=Session) + mock_session = models.Session(id=42, model_name="deepseek") + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Mock the LLM provider and the DSPy pipeline mock_llm_provider = MagicMock(spec=LLMProvider) mock_get_llm_provider.return_value = mock_llm_provider - mock_db = MagicMock(spec=Session) - mock_retriever = MagicMock(spec=Retriever) - - # Mock the pipeline instance and its return value mock_pipeline_instance = MagicMock(spec=DspyRagPipeline) mock_pipeline_instance.forward = AsyncMock(return_value="Final RAG response") mock_dspy_pipeline.return_value = mock_pipeline_instance - # Instantiate the service class we are testing - rag_service = RAGService(vector_store=MagicMock(), retrievers=[mock_retriever]) - prompt = "Test prompt." - model = "deepseek" - # --- Act --- - response_text = asyncio.run(rag_service.chat_with_rag(db=mock_db, prompt=prompt, model=model)) + answer, model_name = asyncio.run(rag_service.chat_with_rag(db=mock_db, session_id=42, prompt="Test prompt")) # --- Assert --- - # 1. Assert that the correct LLM provider was requested - mock_get_llm_provider.assert_called_once_with(model) + # 1. Assert the session was fetched correctly + mock_db.query.assert_called_once_with(models.Session) - # 2. Assert that dspy was configured with a correctly wrapped provider - mock_configure.assert_called_once() - lm_instance = mock_configure.call_args.kwargs['lm'] - assert isinstance(lm_instance, DSPyLLMProvider) - assert lm_instance.provider == mock_llm_provider + # 2. Assert the user and assistant messages were saved + assert mock_db.add.call_count == 2 + assert mock_db.commit.call_count == 2 - # 3. Assert that the pipeline was instantiated and called correctly - mock_dspy_pipeline.assert_called_once_with(retrievers=[mock_retriever]) - mock_pipeline_instance.forward.assert_called_once_with(question=prompt, db=mock_db) - - # 4. Assert the final response is returned - assert response_text == "Final RAG response" \ No newline at end of file + # 3. Assert the RAG pipeline was orchestrated correctly + mock_get_llm_provider.assert_called_once_with("deepseek") + mock_dspy_pipeline.assert_called_once() + mock_pipeline_instance.forward.assert_called_once() + + # 4. Assert the correct response was returned + assert answer == "Final RAG response" + assert model_name == "deepseek" \ No newline at end of file diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 5851da9..ae6226b 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -1,13 +1,18 @@ from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session +from datetime import datetime from app.app import create_app from app.api.dependencies import get_db +from app.db import models # Import your SQLAlchemy models -# --- Dependency Override for Testing --- +# --- Test Setup --- + +# A mock DB session that can be used across tests mock_db = MagicMock(spec=Session) def override_get_db(): + """Dependency override to replace the real database with a mock.""" try: yield mock_db finally: @@ -24,50 +29,59 @@ assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.RAGService') -def test_chat_handler_success(mock_rag_service_class): +def test_create_session_success(mock_rag_service_class): """ - Test the /chat endpoint with a successful, mocked RAG service response. + Tests successfully creating a new chat session via the POST /sessions endpoint. """ # Arrange mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.chat_with_rag = AsyncMock(return_value="This is a mock response.") + # The service should return a SQLAlchemy Session object + mock_session_obj = models.Session( + id=1, + user_id="test_user", + model_name="gemini", + title="New Chat Session", + created_at=datetime.now() + ) + mock_rag_service_instance.create_session.return_value = mock_session_obj + app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) - # This payload is now valid according to the ChatRequest Pydantic model - payload = {"prompt": "Hello there", "model": "deepseek"} - + # Act - response = client.post("/chat", json=payload) + response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) + + # Assert + assert response.status_code == 200 + response_data = response.json() + assert response_data["id"] == 1 + assert response_data["user_id"] == "test_user" + mock_rag_service_instance.create_session.assert_called_once_with( + db=mock_db, user_id="test_user", model="gemini" + ) + +@patch('app.app.RAGService') +def test_chat_in_session_success(mock_rag_service_class): + """ + Test the session-based chat endpoint with a successful, mocked response. + """ + # Arrange + mock_rag_service_instance = mock_rag_service_class.return_value + # The service now returns a tuple: (answer_text, model_used) + mock_rag_service_instance.chat_with_rag = AsyncMock(return_value=("This is a mock response.", "gemini")) + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + + # Act + response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." - assert response.json()["model_used"] == "deepseek" + assert response.json()["model_used"] == "gemini" mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, prompt="Hello there", model="deepseek" - ) - -@patch('app.app.RAGService') -def test_chat_handler_api_failure(mock_rag_service_class): - """ - Test the /chat endpoint when the RAG service encounters an error. - """ - # Arrange - mock_rag_service_instance = mock_rag_service_class.return_value - mock_rag_service_instance.chat_with_rag = AsyncMock(side_effect=Exception("API connection error")) - app = create_app() - app.dependency_overrides[get_db] = override_get_db - client = TestClient(app) - # This payload is now valid according to the ChatRequest Pydantic model - payload = {"prompt": "This request will fail", "model": "deepseek"} - - # Act - response = client.post("/chat", json=payload) - - # Assert - assert response.status_code == 500 - assert "An unexpected error occurred with the deepseek API" in response.json()["detail"] - mock_rag_service_instance.chat_with_rag.assert_called_once_with( - db=mock_db, prompt="This request will fail", model="deepseek" + db=mock_db, session_id=123, prompt="Hello there" ) \ No newline at end of file