diff --git a/ai-hub/app/core/pipelines/dspy_rag.py b/ai-hub/app/core/pipelines/dspy_rag.py index 22ef056..9942888 100644 --- a/ai-hub/app/core/pipelines/dspy_rag.py +++ b/ai-hub/app/core/pipelines/dspy_rag.py @@ -25,27 +25,27 @@ def __init__( self, - retrievers: List[Retriever], + # retrievers: List[Retriever], signature_class: dspy.Signature = AnswerWithHistory, context_postprocessor: Optional[Callable[[List[str]], str]] = None, history_formatter: Optional[Callable[[List[models.Message]], str]] = None, response_postprocessor: Optional[Callable[[str], str]] = None, ): super().__init__() - self.retrievers = retrievers + # self.retrievers = retrievers self.generate_answer = dspy.Predict(signature_class) self.context_postprocessor = context_postprocessor or self._default_context_postprocessor self.history_formatter = history_formatter or self._default_history_formatter self.response_postprocessor = response_postprocessor - async def forward(self, question: str, history: List[models.Message], db: Session) -> str: + async def forward(self, question: str, history: List[models.Message], context_chunks :List[str]) -> str: logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") # Step 1: Retrieve all document contexts - context_chunks = [] - for retriever in self.retrievers: - context_chunks.extend(retriever.retrieve_context(question, db)) + # context_chunks = [] + # for retriever in self.retrievers: + # context_chunks.extend(retriever.retrieve_context(question, db)) context_text = self.context_postprocessor(context_chunks) diff --git a/ai-hub/app/core/retrievers/file_retriever.py b/ai-hub/app/core/retrievers/file_retriever.py new file mode 100644 index 0000000..8e316af --- /dev/null +++ b/ai-hub/app/core/retrievers/file_retriever.py @@ -0,0 +1,63 @@ +from typing import Dict, Any, Optional +from app.db import models +from sqlalchemy.orm import Session,joinedload +import uuid + +class FileRetriever: + """ + A retriever specifically for accessing file and directory content + based on a FileRetrievalRequest ID. + """ + + def retrieve_by_request_id(self, db: Session, request_id: str) -> Optional[Dict[str, Any]]: + """ + Retrieves a FileRetrievalRequest and all its associated files from the database, + returning the data in a well-formatted JSON-like dictionary. + + Args: + db: The SQLAlchemy database session. + request_id: The UUID of the FileRetrievalRequest. + + Returns: + A dictionary containing the request and file data, or None if the request is not found. + """ + try: + # Convert string request_id to UUID object for the query + request_uuid = uuid.UUID(request_id) + except ValueError: + print(f"Invalid UUID format for request_id: {request_id}") + return None + + # Fetch the request and its related files in a single query using join + request = db.query(models.FileRetrievalRequest).filter( + models.FileRetrievalRequest.id == request_uuid + ).options( + # Eagerly load the retrieved_files to avoid N+1 query problem + joinedload(models.FileRetrievalRequest.retrieved_files) + ).first() + + if not request: + return None + + # Build the dictionary to represent the JSON structure + retrieved_data = { + "request_id": str(request.id), + "question": request.question, + "directory_path": request.directory_path, + "session_id": request.session_id, + "created_at": request.created_at.isoformat() if request.created_at else None, + "retrieved_files": [ + { + "file_id": file.id, + "file_path": file.file_path, + "file_name": file.file_name, + "content": file.content, + "type": file.type, + "last_updated": file.last_updated.isoformat() if file.last_updated else None, + "created_at": file.created_at.isoformat() if file.created_at else None, + } + for file in request.retrieved_files + ] + } + + return retrieved_data \ No newline at end of file diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 1f4848c..56189ed 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -46,21 +46,20 @@ llm_provider = get_llm_provider(provider_name) # Configure retrievers for the pipeline - current_retrievers = [] + context_chunks = [] if load_faiss_retriever: if self.faiss_retriever: - current_retrievers.append(self.faiss_retriever) + context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db)) # Ensure FAISS index is loaded else: print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") - rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) + rag_pipeline = DspyRagPipeline() - # Run the RAG pipeline to get a response with dspy.context(lm=llm_provider): answer_text = await rag_pipeline.forward( question=prompt, history=session.messages, - db=db + context_chunks = context_chunks ) # Save assistant's response diff --git a/ai-hub/app/core/services/workspace.py b/ai-hub/app/core/services/workspace.py index 76f76a3..aa1e3c5 100644 --- a/ai-hub/app/core/services/workspace.py +++ b/ai-hub/app/core/services/workspace.py @@ -205,6 +205,12 @@ self.db.add(user_message) self.db.commit() self.db.refresh(user_message) + + # path = data.get("path", "") + # if path: + # # If the path is provided, list the directory contents first. + # await self.send_command(websocket, "list_directory", data={"path": path}) + # return llm_provider = get_llm_provider(provider_name) chat = DspyRagPipeline(retrievers=[]) with dspy.context(lm=llm_provider): diff --git a/ai-hub/app/db/file_retriever_models.py b/ai-hub/app/db/file_retriever_models.py new file mode 100644 index 0000000..9b562e1 --- /dev/null +++ b/ai-hub/app/db/file_retriever_models.py @@ -0,0 +1,74 @@ +from datetime import datetime +from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Boolean, JSON +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.orm import relationship +import uuid + +# Assuming Base is imported from your database.py +from .database import Base + +class FileRetrievalRequest(Base): + """ + SQLAlchemy model for the 'file_retrieval_requests' table. + + Each entry represents a single user request to process a directory + for the AI coding assistant. + """ + __tablename__ = 'file_retrieval_requests' + + # Primary key, a UUID for unique and secure identification. + id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + # The client's main question or instruction for the AI. + question = Column(Text, nullable=False) + # The path or name of the directory provided by the client. + directory_path = Column(String, nullable=False) + # Foreign key to link this request to a specific chat session. + session_id = Column(Integer, ForeignKey('sessions.id'), nullable=False) + # Timestamp for when the request was made. + created_at = Column(DateTime, default=datetime.now, nullable=False) + + # Defines a one-to-many relationship with the RetrievedFile table. + # 'cascade' ensures that all associated files are deleted when the request is deleted. + retrieved_files = relationship( + "RetrievedFile", + back_populates="retrieval_request", + cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"" + +class RetrievedFile(Base): + """ + SQLAlchemy model for the 'retrieved_files' table. + + This table stores the content and metadata for each file retrieved + during a file processing request. + """ + __tablename__ = 'retrieved_files' + + # Primary key for the file entry. + id = Column(Integer, primary_key=True, index=True) + # Foreign key linking this file back to its parent retrieval request. + request_id = Column(PG_UUID(as_uuid=True), ForeignKey('file_retrieval_requests.id'), nullable=False) + # The full path to the file. + file_path = Column(String, nullable=False) + # The name of the file. + file_name = Column(String, nullable=False) + # The actual content of the file. + content = Column(Text, nullable=False) + # The type of the file content (e.g., 'original', 'updated'). + type = Column(String, nullable=False, default='original') + # Timestamp for when the file was last modified. + last_updated = Column(DateTime, default=datetime.now) + # Timestamp for when this entry was created in the database. + created_at = Column(DateTime, default=datetime.now, nullable=False) + + # Defines a many-to-one relationship back to the FileRetrievalRequest. + retrieval_request = relationship( + "FileRetrievalRequest", + back_populates="retrieved_files" + ) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/ai-hub/tests/core/pipelines/test_dspy_rag.py b/ai-hub/tests/core/pipelines/test_dspy_rag.py index 4a5580b..4fc3028 100644 --- a/ai-hub/tests/core/pipelines/test_dspy_rag.py +++ b/ai-hub/tests/core/pipelines/test_dspy_rag.py @@ -47,28 +47,16 @@ # --- Test Cases --- -# def test_pipeline_initializes_with_defaults(mock_dspy_predict_instance): -# """Test that the pipeline initializes correctly with default processors.""" -# pipeline = DspyRagPipeline(retrievers=[MockRetriever("test", ["test context"])]) -# assert pipeline.retrievers is not None -# assert pipeline.context_postprocessor is pipeline._default_context_postprocessor -# assert pipeline.history_formatter is pipeline._default_history_formatter -# assert pipeline.response_postprocessor is None -# # Verify that dspy.Predict was instantiated once -# dspy.Predict.assert_called_once() -# # Verify the aforward method was not called yet -# mock_dspy_predict_instance.aforward.assert_not_called() - @pytest.mark.asyncio async def test_forward_pass_with_defaults(mock_db, mock_dspy_predict_instance): """Test a successful forward pass using default processors.""" - retriever = MockRetriever("test_retriever", ["Context 1.", "Context 2."]) - pipeline = DspyRagPipeline(retrievers=[retriever]) + pipeline = DspyRagPipeline() question = "What is the capital of France?" history = [models.Message(sender="user", content="Hello there."), models.Message(sender="assistant", content="Hi.")] + context_chunks = ["Context 1.", "Context 2."] - response = await pipeline(question, history, mock_db) + response = await pipeline.forward(question=question, history=history, context_chunks=context_chunks) expected_context = "Context 1.\n\nContext 2." expected_history = "Human: Hello there.\nAssistant: Hi." @@ -93,9 +81,7 @@ def custom_response_processor(response: str) -> str: return f"FINAL: {response.upper()}" - retriever = MockRetriever("test_retriever", ["Context A", "Context B"]) pipeline = DspyRagPipeline( - retrievers=[retriever], context_postprocessor=custom_context_processor, history_formatter=custom_history_formatter, response_postprocessor=custom_response_processor @@ -103,8 +89,9 @@ question = "Custom question?" history = [models.Message(sender="user", content="User message.")] + context_chunks = ["Context A", "Context B"] - response = await pipeline(question, history, mock_db) + response = await pipeline.forward(question=question, history=history, context_chunks=context_chunks) mock_dspy_predict_instance.aforward.assert_called_once_with( context="CUSTOM_CONTEXT: Context A | Context B", @@ -116,13 +103,13 @@ @pytest.mark.asyncio async def test_empty_context_and_history_handling(mock_db, mock_dspy_predict_instance): """Test behavior with empty context and chat history.""" - retriever = MockRetriever("empty_retriever", []) - pipeline = DspyRagPipeline(retrievers=[retriever]) + pipeline = DspyRagPipeline() question = "No context question." history = [] + context_chunks = [] - response = await pipeline(question, history, mock_db) + response = await pipeline.forward(question=question, history=history, context_chunks=context_chunks) mock_dspy_predict_instance.aforward.assert_called_once_with( context="No context provided.", diff --git a/ai-hub/tests/core/services/test_rag.py b/ai-hub/tests/core/services/test_rag.py index 7f87239..3ed53a3 100644 --- a/ai-hub/tests/core/services/test_rag.py +++ b/ai-hub/tests/core/services/test_rag.py @@ -82,12 +82,14 @@ assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("deepseek") - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + # Assert that DspyRagPipeline was called without any arguments + mock_dspy_pipeline.assert_called_once_with() + # Assert that the forward method received the correct arguments mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt", history=mock_session.messages, - db=mock_db + context_chunks=[] # It was called with an empty list ) assert answer == "Final RAG response" @@ -129,12 +131,14 @@ assert mock_db.add.call_count == 2 mock_get_llm_provider.assert_called_once_with("gemini") - mock_dspy_pipeline.assert_called_once_with(retrievers=[]) + # Assert that DspyRagPipeline was called without any arguments + mock_dspy_pipeline.assert_called_once_with() + # Assert that the forward method received the correct arguments mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt for Gemini", history=mock_session.messages, - db=mock_db + context_chunks=[] # It was called with an empty list ) assert answer == "Final RAG response from Gemini" @@ -150,6 +154,9 @@ mock_db = MagicMock(spec=Session) mock_session = models.Session(id=44, provider_name="deepseek", messages=[]) mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Mock FaissDBRetriever to return some chunks + rag_service.faiss_retriever.retrieve_context.return_value = ["faiss_chunk_1", "faiss_chunk_2"] with patch('app.core.services.rag.get_llm_provider') as mock_get_llm_provider, \ patch('app.core.services.rag.DspyRagPipeline') as mock_dspy_pipeline, \ @@ -173,13 +180,17 @@ ) # --- Assert --- - expected_retrievers = [rag_service.faiss_retriever] - mock_dspy_pipeline.assert_called_once_with(retrievers=expected_retrievers) + # The DspyRagPipeline is still called without arguments + mock_dspy_pipeline.assert_called_once_with() + # The retriever's context method is now called + rag_service.faiss_retriever.retrieve_context.assert_called_once_with(query="Test prompt with FAISS", db=mock_db) + + # The forward method receives the retrieved chunks mock_pipeline_instance.forward.assert_called_once_with( question="Test prompt with FAISS", history=mock_session.messages, - db=mock_db + context_chunks=["faiss_chunk_1", "faiss_chunk_2"] ) assert answer == "Response with FAISS context"