diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index a6c782c..72e11ce 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -38,7 +38,7 @@ print("Application shutdown...") # Access the vector_store from the application state to save it if hasattr(app.state, 'vector_store'): - app.state.vector_store.save_index() + app.state.vector_store.save_index_and_metadata() def create_app() -> FastAPI: """ diff --git a/ai-hub/app/core/services/workspace.py b/ai-hub/app/core/services/workspace.py index 8ca08b5..76f76a3 100644 --- a/ai-hub/app/core/services/workspace.py +++ b/ai-hub/app/core/services/workspace.py @@ -1,14 +1,20 @@ import dspy import json import uuid +import logging import ast # Import the Abstract Syntax Trees module from typing import Dict, Any, Callable, Awaitable, List -from fastapi import WebSocket +from fastapi import WebSocket,Depends +from sqlalchemy.orm import Session,joinedload +from app.db import models +from app.db.session import SessionLocal from app.core.providers.factory import get_llm_provider from app.core.pipelines.file_selector import CodeRagFileSelector - +from app.core.pipelines.dspy_rag import DspyRagPipeline # A type hint for our handler functions MessageHandler = Callable[[WebSocket, Dict[str, Any]], Awaitable[None]] +# Configure logging +logger = logging.getLogger(__name__) class WorkspaceService: """ @@ -22,6 +28,7 @@ "list_directory_response": self.handle_list_directory_response, "file_content_response": self.handle_files_content_response, "execute_command_response": self.handle_command_output, + "chat_message": self.handle_chat_message, # Add more message types here as needed } # Centralized map of commands that can be sent to the client @@ -33,6 +40,7 @@ } # Per-websocket session state management self.sessions: Dict[str, Dict[str, Any]] = {} + self.db = SessionLocal() def generate_request_id(self) -> str: """Generates a unique request ID.""" @@ -65,12 +73,13 @@ Retrieves session state to maintain context. """ message_type = message.get("type") - request_id = message.get("request_id") - round_num = message.get("round") + # request_id = message.get("request_id") + # round_num = message.get("round") # In a real-world app, you'd retrieve historical data based on request_id or session_id # For this example, we'll just print it. - print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})") + # print(f"Received message of type '{message_type}' (request_id: {request_id}, round: {round_num})") + logger.info(f"Received message: {message}") handler = self.message_handlers.get(message_type) if handler: @@ -102,7 +111,6 @@ raw_answer_text = await cfs( question="Please help to refactor my code", # The history will be retrieved from a database in a real application - # For this example, we'll pass an empty history history="", file_list=files ) @@ -169,4 +177,48 @@ await websocket.send_text(json.dumps({ "type": "thinking_log", "content": f"Command '{command}' completed. Analyzing output." - })) \ No newline at end of file + })) + + async def handle_chat_message(self, websocket: WebSocket, data: Dict[str, Any]): + """Handles incoming chat messages from the client.""" + # TODO: Enhance this function to process the chat message and determine the next action. + prompt = data.get("content") + provider_name = data.get("provider_name", "gemini") + session_id = data.get("session_id") + if session_id is None: + await websocket.send_text(json.dumps({ + "type": "error", + "content": "Error: session_id is required for chat messages." + })) + return + session = self.db.query(models.Session).options( + joinedload(models.Session.messages) + ).filter(models.Session.id == session_id).first() + + if not session: + await websocket.send_text(json.dumps({ + "type": "error", + "content": f"Error: Session with ID {session_id} not found." + })) + return + user_message = models.Message(session_id=session_id, sender="user", content=prompt) + self.db.add(user_message) + self.db.commit() + self.db.refresh(user_message) + llm_provider = get_llm_provider(provider_name) + chat = DspyRagPipeline(retrievers=[]) + with dspy.context(lm=llm_provider): + answer_text = await chat(question=prompt, history=session.messages, db=self.db) + # Save assistant's response + assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) + self.db.add(assistant_message) + self.db.commit() + self.db.refresh(assistant_message) + + # 📝 Add this section to send the response back to the client + # The client-side `handleChatMessage` handler will process this message + await websocket.send_text(json.dumps({ + "type": "chat_message", + "content": answer_text + })) + logger.info(f"Sent chat response to client: {answer_text}") \ No newline at end of file diff --git a/ai-hub/app/core/vector_store/faiss_store.py b/ai-hub/app/core/vector_store/faiss_store.py index 5573c98..5846e7c 100644 --- a/ai-hub/app/core/vector_store/faiss_store.py +++ b/ai-hub/app/core/vector_store/faiss_store.py @@ -2,101 +2,156 @@ import logging import faiss import numpy as np +import pickle +from typing import List, Optional, Dict, Any + from .base import VectorStore from .utils import save_faiss_index, load_faiss_index -from typing import List, Optional, Dict, Any + class FaissVectorStore(VectorStore): """ - An in-memory vector store using the FAISS library for efficient similarity search. - This implementation handles the persistence of the FAISS index to a file. + An in-memory vector store using the FAISS library with support for filtering + by metadata tags and persistence of both the index and the tags. """ def __init__(self, index_file_path: str, dimension: int, embedder): - """ - Initializes the FaissVectorStore. - """ self.index_file_path = index_file_path self.dimension = dimension self.embedder = embedder - + + self.doc_tags = {} # Metadata per document + self.doc_vectors = {} # Store vectors for filtered search + if os.path.exists(self.index_file_path): logging.info(f"Loading FAISS index from {self.index_file_path}") self.index = faiss.read_index(self.index_file_path) - self.doc_id_map = list(range(self.index.ntotal)) + self.load_metadata() + self.load_vectors() + self.doc_id_map = list(self.doc_tags.keys()) else: logging.info("Creating a new FAISS index.") - self.index = faiss.IndexFlatL2(dimension) + quantizer = faiss.IndexFlatL2(dimension) + self.index = faiss.IndexIDMap(quantizer) self.doc_id_map = [] - - def add_document(self, text: str) -> int: - """ - Embeds a document's text and adds the vector to the FAISS index. - This is now a synchronous method. - """ + + def add_document(self, text: str, tags: Optional[Dict[str, Any]] = None) -> int: logging.debug("Embedding document text for FAISS index...") - vector = self.embedder.embed_text(text) - vector = vector.reshape(1, -1) - self.index.add(vector) - - new_doc_id = self.index.ntotal - 1 + vector = self.embedder.embed_text(text).reshape(1, -1).astype('float32') + + new_doc_id = self.index.ntotal + self.index.add_with_ids(vector, np.array([new_doc_id], dtype='int64')) + self.doc_id_map.append(new_doc_id) - - self.save_index() + self.doc_tags[new_doc_id] = tags if tags else {} + self.doc_vectors[new_doc_id] = vector.flatten() + + self.save_index_and_metadata() logging.info(f"Document added to FAISS index with ID: {new_doc_id}") - + return new_doc_id - def add_multiple_documents(self, texts: List[str]) -> List[int]: - """ - Embeds multiple documents' texts and adds the vectors to the FAISS index. - This is now a synchronous method. - """ + def add_multiple_documents(self, texts: List[str], tags: Optional[List[Dict[str, Any]]] = None) -> List[int]: logging.debug("Embedding multiple document texts for FAISS index...") - # Embed each text synchronously - vectors = [self.embedder.embed_text(text) for text in texts] - - # Reshape the vectors to be suitable for FAISS - vectors = np.vstack([v.reshape(1, -1) for v in vectors]).astype('float32') - self.index.add(vectors) - - new_doc_ids = list(range(self.index.ntotal - len(texts), self.index.ntotal)) + vectors = np.vstack([ + self.embedder.embed_text(text).reshape(1, -1) for text in texts + ]).astype('float32') + + start_id = self.index.ntotal + new_doc_ids = list(range(start_id, start_id + len(texts))) + + self.index.add_with_ids(vectors, np.array(new_doc_ids, dtype='int64')) + self.doc_id_map.extend(new_doc_ids) - self.save_index() - + for i, doc_id in enumerate(new_doc_ids): + self.doc_tags[doc_id] = tags[i] if tags and len(tags) == len(texts) else {} + self.doc_vectors[doc_id] = vectors[i] + + self.save_index_and_metadata() logging.info(f"Added {len(new_doc_ids)} documents to FAISS index.") + return new_doc_ids - - def search_similar_documents(self, query_text: str, k: int = 5) -> List[int]: - """ - Embeds a query string and performs a similarity search in the FAISS index. - This is now a synchronous method. - """ + + def search_similar_documents(self, query_text: str, k: int = 5, prefilter_tags: Optional[Dict[str, Any]] = None) -> List[int]: logging.debug(f"Searching FAISS index for similar documents to query: '{query_text[:50]}...'") if self.index.ntotal == 0: logging.warning("FAISS index is empty, no documents to search.") return [] - - query_vector = self.embedder.embed_text(query_text) - query_vector = query_vector.reshape(1, -1) - - D, I = self.index.search(query_vector, k) - - result_ids = [self.doc_id_map[int(i)] for i in I.flatten() if i >= 0] + + query_vector = self.embedder.embed_text(query_text).reshape(1, -1).astype('float32') + + if prefilter_tags: + valid_ids = [ + doc_id for doc_id, tags in self.doc_tags.items() + if all(tags.get(key) == value for key, value in prefilter_tags.items()) + ] + + if not valid_ids: + logging.warning("No documents match the filter criteria.") + return [] + + try: + filtered_vectors = np.vstack([ + self.doc_vectors[doc_id].reshape(1, -1) + for doc_id in valid_ids + ]).astype('float32') + + temp_index = faiss.IndexFlatL2(self.dimension) + temp_index.add(filtered_vectors) + + D, I = temp_index.search(query_vector, min(k, len(valid_ids))) + result_ids = [int(valid_ids[i]) for i in I.flatten() if i >= 0] + + except Exception as e: + logging.error(f"Error during filtered search: {e}") + return [] + else: + D, I = self.index.search(query_vector, k) + result_ids = [int(i) for i in I.flatten() if i >= 0] + logging.info(f"Search complete, found {len(result_ids)} similar documents.") return result_ids - def save_index(self): - """ - Saves the FAISS index to the specified file path. - """ + def save_index_and_metadata(self): if self.index: logging.info(f"Saving FAISS index to {self.index_file_path}") faiss.write_index(self.index, self.index_file_path) - - def load_index(self): - """ - Loads a FAISS index from the specified file path. - """ - if os.path.exists(self.index_file_path): - logging.info(f"Loading FAISS index from {self.index_file_path}") - self.index = faiss.read_index(self.index_file_path) + + # Save metadata + tags_file_path = self.index_file_path + ".tags" + with open(tags_file_path, 'wb') as f: + pickle.dump(self.doc_tags, f) + logging.info(f"Saved metadata to {tags_file_path}") + + # Save vectors + vectors_file_path = self.index_file_path + ".vecs" + with open(vectors_file_path, 'wb') as f: + pickle.dump(self.doc_vectors, f) + logging.info(f"Saved document vectors to {vectors_file_path}") + + def load_metadata(self): + tags_file_path = self.index_file_path + ".tags" + if os.path.exists(tags_file_path): + try: + with open(tags_file_path, 'rb') as f: + self.doc_tags = pickle.load(f) + logging.info(f"Loaded metadata from {tags_file_path}") + except Exception as e: + logging.error(f"Failed to load metadata file: {e}") + self.doc_tags = {} + else: + logging.warning("Metadata file not found, initializing empty tags dictionary.") + self.doc_tags = {} + + def load_vectors(self): + vectors_file_path = self.index_file_path + ".vecs" + if os.path.exists(vectors_file_path): + try: + with open(vectors_file_path, 'rb') as f: + self.doc_vectors = pickle.load(f) + logging.info(f"Loaded vectors from {vectors_file_path}") + except Exception as e: + logging.error(f"Failed to load vectors file: {e}") + self.doc_vectors = {} + else: + logging.warning("Vectors file not found, initializing empty vector dictionary.") + self.doc_vectors = {} diff --git a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py index 695d5ae..00a3735 100644 --- a/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py +++ b/ai-hub/tests/core/retrievers/test_faiss_db_retriever.py @@ -44,7 +44,7 @@ return [0] # --- E2E test setup and fixtures --- -SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" +SQLALCHEMY_DATABASE_URL = "sqlite:///./data/test.db" engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) diff --git a/ai-hub/tests/core/services/test_workspace.py b/ai-hub/tests/core/services/test_workspace.py new file mode 100644 index 0000000..2177e11 --- /dev/null +++ b/ai-hub/tests/core/services/test_workspace.py @@ -0,0 +1,267 @@ +import unittest +import json +from unittest.mock import MagicMock, AsyncMock, patch +from app.core.services.workspace import WorkspaceService +from fastapi import WebSocket +import asyncio +from app.db import models +from dspy.teleprompt import LabeledFewShot +from app.core.pipelines.dspy_rag import DspyRagPipeline +from app.core.pipelines.file_selector import CodeRagFileSelector +from app.core.providers.factory import get_llm_provider + +# Use IsolatedAsyncioTestCase for cleaner async testing +class TestWorkspaceService(unittest.IsolatedAsyncioTestCase): + """ + Unit tests for the WorkspaceService class. + """ + + def setUp(self): + """ + Set up a new WorkspaceService instance for each test. + """ + # Patch the database session to prevent real database connections + self.db_patcher = patch('app.core.services.workspace.SessionLocal') + self.mock_session_local = self.db_patcher.start() + self.mock_db = MagicMock() + self.mock_session_local.return_value = self.mock_db + + # Create a WebSocket mock object + self.mock_websocket = AsyncMock() + self.mock_websocket.scope = {"client": ("127.0.0.1", 8000)} + + self.service = WorkspaceService() + + def tearDown(self): + """ + Clean up patches after each test. + """ + self.db_patcher.stop() + + async def test_generate_request_id(self): + """ + Test that generate_request_id returns a valid UUID string. + """ + request_id = self.service.generate_request_id() + self.assertIsInstance(request_id, str) + self.assertTrue(len(request_id) > 0) + + @patch('app.core.services.workspace.WorkspaceService.generate_request_id', return_value="test-uuid") + async def test_send_command_success(self, mock_generate_id): + """ + Test that send_command sends a correctly formatted message. + """ + data = {"path": "/test/path"} + await self.service.send_command(self.mock_websocket, "list_directory", data) + + # Verify the correct message was sent + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "list_directory") + self.assertEqual(sent_message["request_id"], "test-uuid") + self.assertEqual(sent_message["round"], 1) + self.assertEqual(sent_message["path"], "/test/path") + + # Verify session state was updated + self.assertEqual(self.service.sessions[self.mock_websocket.scope["client"]]["round"], 1) + + async def test_send_command_unknown_command(self): + """ + Test that send_command raises an error for an unknown command. + """ + with self.assertRaises(ValueError) as context: + await self.service.send_command(self.mock_websocket, "unknown_command") + self.assertIn("Unknown command: unknown_command", str(context.exception)) + + @patch('app.core.services.workspace.WorkspaceService.handle_select_folder_response', new_callable=AsyncMock) + async def test_dispatch_message_valid_type(self, mock_handler): + """ + Test that dispatch_message calls the correct handler for a known message type. + """ + message = {"type": "select_folder_response", "path": "/test/path"} + + # Recreate the service after patching so the patched method is used + self.service = WorkspaceService() + + await self.service.dispatch_message(self.mock_websocket, message) + mock_handler.assert_called_once_with(self.mock_websocket, message) + + async def test_dispatch_message_unknown_type(self): + """ + Test that dispatch_message handles an unknown message type gracefully. + """ + message = {"type": "unknown_type"} + await self.service.dispatch_message(self.mock_websocket, message) + + # Verify an error message was sent back to the client + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "error") + self.assertIn("Unknown message type", sent_message["content"]) + + @patch('app.core.services.workspace.WorkspaceService.send_command', new_callable=AsyncMock) + async def test_handle_select_folder_response(self, mock_send_command): + """ + Test that the folder response handler correctly triggers a `list_directory` command. + """ + data = {"path": "/user/project", "request_id": "req-123"} + await self.service.handle_select_folder_response(self.mock_websocket, data) + + mock_send_command.assert_called_once_with(self.mock_websocket, "list_directory", data={"path": "/user/project"}) + + @patch('app.core.services.workspace.ast.literal_eval') + @patch('app.core.services.workspace.CodeRagFileSelector') + @patch('app.core.services.workspace.get_llm_provider') + async def test_handle_list_directory_response_success(self, mock_get_llm, mock_cfs, mock_literal_eval): + """ + Test the list directory handler's success path, ensuring it requests file content. + """ + # Mock LLM and dspy pipeline behavior + mock_llm_provider = MagicMock() + mock_get_llm.return_value = mock_llm_provider + mock_pipeline_instance = AsyncMock() + mock_cfs.return_value = mock_pipeline_instance + mock_pipeline_instance.return_value = '["file1.py", "file2.js"]' # LLM output + mock_literal_eval.return_value = ["file1.py", "file2.js"] # Parsed list + + # Mock send_command to verify it's called + self.service.send_command = AsyncMock() + + data = {"files": ["file1.py", "file2.js", "README.md"], "provider_name": "gemini"} + await self.service.handle_list_directory_response(self.mock_websocket, data) + + # Verify the thinking log message was sent + self.mock_websocket.send_text.assert_called() + log_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(log_message["type"], "thinking_log") + self.assertIn("AI selected files: ['file1.py', 'file2.js']", log_message["content"]) + + # Verify the get_file_content command was sent + self.service.send_command.assert_called_once_with( + self.mock_websocket, "get_file_content", data={"filenames": ["file1.py", "file2.js"]} + ) + + @patch('app.core.services.workspace.ast.literal_eval', side_effect=ValueError("Syntax error")) + @patch('app.core.services.workspace.CodeRagFileSelector') + @patch('app.core.services.workspace.get_llm_provider') + async def test_handle_list_directory_response_eval_error(self, mock_get_llm, mock_cfs, mock_literal_eval): + """ + Test that the handler gracefully handles an error from `ast.literal_eval`. + """ + mock_llm_provider = MagicMock() + mock_get_llm.return_value = mock_llm_provider + mock_pipeline_instance = AsyncMock() + mock_cfs.return_value = mock_pipeline_instance + mock_pipeline_instance.return_value = '{"not a list": "!"}' + + data = {"files": ["file1.py"], "provider_name": "gemini"} + await self.service.handle_list_directory_response(self.mock_websocket, data) + + # Verify that an error log was sent to the client + self.mock_websocket.send_text.assert_called() + error_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(error_message["type"], "thinking_log") + self.assertIn("Warning: AI's file list could not be parsed.", error_message["content"]) + + async def test_handle_files_content_response_success(self): + """ + Test that the file content handler processes files correctly. + """ + data = { + "request_id": "req-123", + "files": [ + {"filename": "file1.py", "content": "print('hello world')"}, + {"filename": "file2.js", "content": "console.log('hi')"} + ] + } + await self.service.handle_files_content_response(self.mock_websocket, data) + + # Verify that thinking logs were sent for each file + self.assertEqual(self.mock_websocket.send_text.call_count, 2) + sent_message = json.loads(self.mock_websocket.send_text.call_args_list[0][0][0]) + self.assertEqual(sent_message["content"], "Analyzing the content of file: file1.py") + + sent_message = json.loads(self.mock_websocket.send_text.call_args_list[1][0][0]) + self.assertEqual(sent_message["content"], "Analyzing the content of file: file2.js") + + async def test_handle_files_content_response_empty_data(self): + """ + Test that the handler does nothing if no files are provided. + """ + data = {"request_id": "req-123", "files": []} + await self.service.handle_files_content_response(self.mock_websocket, data) + self.mock_websocket.send_text.assert_not_called() + + async def test_handle_command_output(self): + """ + Test that the command output handler sends a thinking log. + """ + data = {"request_id": "req-123", "command": "ls -l", "output": "file.txt"} + await self.service.handle_command_output(self.mock_websocket, data) + + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "thinking_log") + self.assertIn("Command 'ls -l' completed. Analyzing output.", sent_message["content"]) + + @patch('app.core.services.workspace.DspyRagPipeline') + @patch('app.core.services.workspace.get_llm_provider') + async def test_handle_chat_message_success(self, mock_get_llm, mock_dspy_rag): + """ + Test that a chat message is handled, saved, and a response is sent. + """ + # Mock database session and query behavior + mock_session = MagicMock() + mock_message_instance = models.Message(session_id="test-session-id", sender="user", content="hello") + mock_session.messages = [mock_message_instance] + self.mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = mock_session + + # Mock LLM and dspy pipeline + mock_llm_provider = MagicMock() + mock_get_llm.return_value = mock_llm_provider + mock_dspy_pipeline = AsyncMock() + mock_dspy_rag.return_value = mock_dspy_pipeline + mock_dspy_pipeline.return_value = "Hello! How can I help you?" + + data = {"content": "Hello", "session_id": "test-session-id", "provider_name": "gemini"} + await self.service.handle_chat_message(self.mock_websocket, data) + + # Verify database calls + self.mock_db.add.assert_called() + self.mock_db.commit.assert_called() + self.mock_db.refresh.assert_called() + + # Verify the response was sent to the client + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "chat_message") + self.assertEqual(sent_message["content"], "Hello! How can I help you?") + + async def test_handle_chat_message_no_session_id(self): + """ + Test that the handler returns an error if session_id is missing. + """ + data = {"content": "Hello"} + await self.service.handle_chat_message(self.mock_websocket, data) + + self.mock_db.query.assert_not_called() + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "error") + self.assertIn("session_id is required", sent_message["content"]) + + async def test_handle_chat_message_session_not_found(self): + """ + Test that the handler returns an error if the session is not found in the DB. + """ + # Mock database query to return None + self.mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = None + + data = {"content": "Hello", "session_id": "non-existent-id"} + await self.service.handle_chat_message(self.mock_websocket, data) + + self.mock_db.add.assert_not_called() + self.mock_websocket.send_text.assert_called_once() + sent_message = json.loads(self.mock_websocket.send_text.call_args[0][0]) + self.assertEqual(sent_message["type"], "error") + self.assertIn("Session with ID non-existent-id not found", sent_message["content"]) diff --git a/ai-hub/tests/core/vector_store/test_faiss_store.py b/ai-hub/tests/core/vector_store/test_faiss_store.py index 73c440e..ee43b22 100644 --- a/ai-hub/tests/core/vector_store/test_faiss_store.py +++ b/ai-hub/tests/core/vector_store/test_faiss_store.py @@ -1,35 +1,119 @@ import os +import shutil +import tempfile import pytest +import pickle +import numpy as np + from app.core.vector_store.faiss_store import FaissVectorStore -def test_add_document(faiss_store: FaissVectorStore): + +# ----------------------------- +# Fixtures +# ----------------------------- + +@pytest.fixture +def mock_embedder(): + class MockEmbedder: + def embed_text(self, text): + # Return a deterministic fake vector based on hash of text + np.random.seed(abs(hash(text)) % 2**32) + return np.random.rand(768).astype('float32') + return MockEmbedder() + +@pytest.fixture +def temp_faiss_file(): + tmp_dir = tempfile.mkdtemp() + index_file = os.path.join(tmp_dir, "test.index") + yield index_file + shutil.rmtree(tmp_dir) + +@pytest.fixture +def faiss_store(temp_faiss_file, mock_embedder): + return FaissVectorStore(index_file_path=temp_faiss_file, dimension=768, embedder=mock_embedder) + + +# ----------------------------- +# Tests +# ----------------------------- + +def test_add_document(faiss_store): test_text = "This is a test document." + test_tags = {"author": "John Doe", "year": 2023} + assert faiss_store.index.ntotal == 0 - faiss_id = faiss_store.add_document(test_text) + assert faiss_store.doc_tags == {} + + doc_id = faiss_store.add_document(test_text, tags=test_tags) + assert faiss_store.index.ntotal == 1 - assert faiss_id == 0 + assert doc_id == 0 assert os.path.exists(faiss_store.index_file_path) + assert os.path.exists(faiss_store.index_file_path + ".tags") + assert os.path.exists(faiss_store.index_file_path + ".vecs") + assert faiss_store.doc_tags[doc_id] == test_tags + assert isinstance(faiss_store.doc_vectors[doc_id], np.ndarray) -def test_add_multiple_documents(faiss_store: FaissVectorStore): + +def test_add_multiple_documents(faiss_store): docs = ["Doc 1", "Doc 2", "Doc 3"] + tags = [{"type": "a"}, {"type": "b"}, {"type": "a"}] + assert faiss_store.index.ntotal == 0 - faiss_ids = faiss_store.add_multiple_documents(docs) + + doc_ids = faiss_store.add_multiple_documents(docs, tags=tags) + assert faiss_store.index.ntotal == 3 - assert faiss_ids == [0, 1, 2] + assert doc_ids == [0, 1, 2] + assert faiss_store.doc_tags[0] == {"type": "a"} + assert faiss_store.doc_tags[1] == {"type": "b"} + assert faiss_store.doc_tags[2] == {"type": "a"} + assert all(isinstance(faiss_store.doc_vectors[i], np.ndarray) for i in doc_ids) -def test_load_existing_index(temp_faiss_file, mock_embedder): + +def test_load_existing_index_with_metadata(temp_faiss_file, mock_embedder): store1 = FaissVectorStore(temp_faiss_file, 768, mock_embedder) - store1.add_document("Persistence test.") + store1.add_document("Persistence test with tags.", tags={"status": "complete"}) + # Reload store2 = FaissVectorStore(temp_faiss_file, 768, mock_embedder) + assert store2.index.ntotal == 1 assert store2.doc_id_map == [0] + assert store2.doc_tags[0] == {"status": "complete"} + assert isinstance(store2.doc_vectors[0], np.ndarray) -def test_search_similar_documents(faiss_store: FaissVectorStore): - faiss_store.add_document("The sun is a star.") - faiss_store.add_document("Mars is a planet.") - faiss_store.add_document("The moon orbits the Earth.") + +def test_search_similar_documents_without_filter(faiss_store): + faiss_store.add_document("The sun is a star.", tags={"category": "astronomy"}) + faiss_store.add_document("Mars is a planet.", tags={"category": "astronomy"}) + faiss_store.add_document("The moon orbits the Earth.", tags={"category": "astronomy"}) results = faiss_store.search_similar_documents("What is a star?", k=2) + assert len(results) == 2 - assert isinstance(results[0], int) + assert all(isinstance(doc_id, int) for doc_id in results) + + +def test_search_similar_documents_with_filter(faiss_store): + faiss_store.add_document("Python is a programming language.", tags={"type": "programming"}) + faiss_store.add_document("A dog is a loyal pet.", tags={"type": "animal"}) + faiss_store.add_document("Java is another programming language.", tags={"type": "programming"}) + + results = faiss_store.search_similar_documents( + "Which is a programming language?", k=2, prefilter_tags={"type": "programming"} + ) + + assert len(results) == 2 + for doc_id in results: + assert faiss_store.doc_tags[doc_id]["type"] == "programming" + + +def test_search_with_no_matching_filter(faiss_store): + faiss_store.add_document("A document about cats.", tags={"species": "feline"}) + + results = faiss_store.search_similar_documents( + "What is a dog?", k=5, prefilter_tags={"species": "canine"} + ) + + assert len(results) == 0 diff --git a/ai-hub/tests/test_app.py b/ai-hub/tests/test_app.py index 96e9973..9d7e916 100644 --- a/ai-hub/tests/test_app.py +++ b/ai-hub/tests/test_app.py @@ -333,26 +333,43 @@ assert response.json()["detail"] == "Document with ID 999 not found." mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999) -# FIX: Add a new test to explicitly check the application shutdown behavior -@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index') +@patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index_and_metadata') @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.app.print_config') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') +@patch('os.path.exists', return_value=True) @patch('faiss.read_index') -def test_shutdown_saves_index(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index): +def test_shutdown_saves_index_and_metadata(mock_read_index, mock_os_exists, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index_and_metadata): """ - Tests that the FAISS index is saved on application shutdown. + Tests that the FAISS index and its associated metadata are saved on application shutdown. """ # Arrange - mock_read_index.return_value = MagicMock() + # Mock FAISS components + mock_index = MagicMock() + mock_index.ntotal = 10 + mock_read_index.return_value = mock_index + + # Mock the embedder mock_get_embedder.return_value = MagicMock() - - # Create the app and let the lifespan events run + + # We need to simulate the FaissVectorStore instance being created and having data + # to be saved. We can't just mock the method, we must also ensure the instance is + # properly created and accessible by the lifespan event. + mock_faiss_vector_store = MagicMock(spec=FaissVectorStore) + mock_faiss_vector_store.doc_tags = {i: {"tag": "value"} for i in range(10)} + mock_faiss_vector_store.index = mock_index + mock_faiss_vector_store.dimension = 768 + + # We need to mock the ServiceContainer to return our mocked FaissVectorStore instance. + mock_service_container.return_value.vector_store = mock_faiss_vector_store + + # Create the app and let the lifespan events run. app = create_app() with TestClient(app) as client: - # Act: The lifespan shutdown event will run when the 'with' block is exited + # Act: The lifespan shutdown event will run when the 'with' block is exited. pass - + # Assert - mock_save_index.assert_called_once() + # Check that the new save_index_and_metadata method was called exactly once. + mock_save_index_and_metadata.assert_called_once() diff --git a/ui/client-app/src/hooks/useCodeAssistant.js b/ui/client-app/src/hooks/useCodeAssistant.js index b878e83..1dc69e8 100644 --- a/ui/client-app/src/hooks/useCodeAssistant.js +++ b/ui/client-app/src/hooks/useCodeAssistant.js @@ -1,4 +1,3 @@ -// src/hooks/useCodeAssistant.js import { useState, useEffect, useRef, useCallback } from "react"; import { connectToWebSocket } from "../services/websocket"; import { v4 as uuidv4 } from 'uuid'; @@ -56,8 +55,8 @@ } // setThinkingProcess((prev) => [ - // ...prev, - // { type: "system", message: `Scanning directory...`, round }, + // ...prev, + // { type: "system", message: `Scanning directory...`, round }, // ]); try { @@ -89,6 +88,7 @@ files, request_id, round, + session_id: sessionId, }) ); @@ -108,17 +108,18 @@ content: "Failed to access folder contents.", request_id, round, + session_id: sessionId, }) ); } - }, []); + }, [sessionId]); const handleReadFilesRequest = useCallback(async (message) => { console.log(message); const { filenames, request_id, round } = message; const dirHandle = dirHandleRef.current; if (!dirHandle) { - ws.current.send(JSON.stringify({ type: "error", content: "No folder selected.", request_id, round })); + ws.current.send(JSON.stringify({ type: "error", content: "No folder selected.", request_id, round, session_id: sessionId })); return; } @@ -145,6 +146,7 @@ content: `Could not read file: ${filename}`, request_id, round, + session_id: sessionId, })); } } @@ -154,13 +156,14 @@ files: filesData, request_id, round, + session_id: sessionId, })); setThinkingProcess((prev) => [ ...prev, { type: "local", message: `Sent content for ${filesData.length} files to server.`, round }, ]); - }, []); + }, [sessionId]); const handleExecuteCommandRequest = useCallback((message) => { const { command, request_id, round } = message; @@ -172,13 +175,14 @@ output, request_id, round, + session_id: sessionId, })); setThinkingProcess((prev) => [ ...prev, { type: "system", message: `Simulated execution of command: '${command}'`, round }, ]); - }, []); + }, [sessionId]); // Main message handler that routes messages to the correct function const handleIncomingMessage = useCallback((message) => { @@ -246,13 +250,14 @@ }, [handleIncomingMessage]); // Send chat message to server - const handleSendChat = useCallback((text) => { + const handleSendChat = useCallback(async (text) => { if (ws.current && ws.current.readyState === WebSocket.OPEN) { setChatHistory((prev) => [...prev, { isUser: true, text }]); setIsProcessing(true); - ws.current.send(JSON.stringify({ type: "chat_message", content: text })); + // Removed the extra call to getSessionId as it is already in state + ws.current.send(JSON.stringify({ type: "chat_message",content: text,round:0, session_id: sessionId })); } - }, []); + }, [sessionId]); // Open folder picker and store handle const handleSelectFolder = useCallback(async (directoryHandle) => { @@ -266,7 +271,7 @@ // Send the initial message to the server with a unique request_id const request_id = uuidv4(); - ws.current.send(JSON.stringify({ type: "select_folder_response", path: directoryHandle.name, request_id })); + ws.current.send(JSON.stringify({ type: "select_folder_response", path: directoryHandle.name, request_id, session_id: sessionId })); setThinkingProcess((prev) => [ ...prev, @@ -275,20 +280,20 @@ } catch (error) { console.error("Folder selection canceled or failed:", error); } - }, []); + }, [sessionId]); // Control functions const handlePause = useCallback(() => { if (ws.current && ws.current.readyState === WebSocket.OPEN) { - ws.current.send(JSON.stringify({ type: "control", command: "pause" })); + ws.current.send(JSON.stringify({ type: "control", command: "pause", session_id: sessionId })); } - }, []); + }, [sessionId]); const handleStop = useCallback(() => { if (ws.current && ws.current.readyState === WebSocket.OPEN) { - ws.current.send(JSON.stringify({ type: "control", command: "stop" })); + ws.current.send(JSON.stringify({ type: "control", command: "stop", session_id: sessionId })); } - }, []); + }, [sessionId]); return { chatHistory, @@ -307,4 +312,4 @@ }; }; -export default useCodeAssistant; \ No newline at end of file +export default useCodeAssistant; diff --git a/ui/client-app/src/services/websocket.js b/ui/client-app/src/services/websocket.js index 1fe4458..2976ef7 100644 --- a/ui/client-app/src/services/websocket.js +++ b/ui/client-app/src/services/websocket.js @@ -3,6 +3,26 @@ import { createSession } from "./apiService"; /** + * Gets a session ID from localStorage or creates a new one via the API. + * @returns {Promise} The session ID. + */ +export const getSessionId = async () => { + let sessionId = localStorage.getItem("sessionId"); + + if (!sessionId) { + // No existing session, so create one via API + const session = await createSession(); + sessionId = session.id; + + // Store it in localStorage for reuse + localStorage.setItem("sessionId", sessionId); + } + + console.log("Using session ID:", sessionId); + return sessionId; +}; + +/** * Connects to the WebSocket server, establishing a new session first. * @param {function(Object): void} onMessageCallback - Callback for incoming messages. * @param {function(): void} onOpenCallback - Callback when the connection is opened.