diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index dee8679..d405903 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -56,10 +56,17 @@ external_endpoint: Optional[str] = None grpc_endpoint: Optional[str] = None +class JournalSettings(BaseModel): + stream_head_chars: int = 10000 + stream_tail_chars: int = 30000 + thought_head_count: int = 5 + thought_tail_count: int = 15 + class AppConfig(BaseModel): """Top-level Pydantic model for application configuration.""" application: ApplicationSettings = Field(default_factory=ApplicationSettings) database: DatabaseSettings = Field(default_factory=DatabaseSettings) + journal: JournalSettings = Field(default_factory=JournalSettings) llm_providers: dict[str, dict] = Field(default_factory=dict) active_llm_provider: Optional[str] = None vector_store: VectorStoreSettings = Field(default_factory=VectorStoreSettings) @@ -165,6 +172,20 @@ self.CORS_ORIGINS: list[str] = (os.getenv("CORS_ORIGINS") or "http://localhost:8000,http://localhost:8080,http://localhost:3000").split(",") self.HUB_PUBLIC_URL: Optional[str] = os.getenv("HUB_PUBLIC_URL") + # --- Journal Settings --- + self.STREAM_HEAD_CHARS: int = int(os.getenv("STREAM_HEAD_CHARS") or \ + get_from_yaml(["journal", "stream_head_chars"]) or \ + config_from_pydantic.journal.stream_head_chars) + self.STREAM_TAIL_CHARS: int = int(os.getenv("STREAM_TAIL_CHARS") or \ + get_from_yaml(["journal", "stream_tail_chars"]) or \ + config_from_pydantic.journal.stream_tail_chars) + self.THOUGHT_HEAD_COUNT: int = int(os.getenv("THOUGHT_HEAD_COUNT") or \ + get_from_yaml(["journal", "thought_head_count"]) or \ + config_from_pydantic.journal.thought_head_count) + self.THOUGHT_TAIL_COUNT: int = int(os.getenv("THOUGHT_TAIL_COUNT") or \ + get_from_yaml(["journal", "thought_tail_count"]) or \ + config_from_pydantic.journal.thought_tail_count) + # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ get_from_yaml(["database", "mode"]) or \ @@ -338,6 +359,12 @@ "active_stt_provider": self.STT_PROVIDER, "swarm": { "external_endpoint": self.GRPC_EXTERNAL_ENDPOINT + }, + "journal": { + "stream_head_chars": self.STREAM_HEAD_CHARS, + "stream_tail_chars": self.STREAM_TAIL_CHARS, + "thought_head_count": self.THOUGHT_HEAD_COUNT, + "thought_tail_count": self.THOUGHT_TAIL_COUNT } } diff --git a/ai-hub/app/core/grpc/core/journal.py b/ai-hub/app/core/grpc/core/journal.py index 4de45cf..231d619 100644 --- a/ai-hub/app/core/grpc/core/journal.py +++ b/ai-hub/app/core/grpc/core/journal.py @@ -1,6 +1,6 @@ import threading import time - +from app.config import settings class TaskJournal: """ @@ -16,20 +16,19 @@ """ def __init__(self): - self.lock = threading.Lock() - self.tasks = {} # task_id -> { event, result, node_id, buffers... } - - # ---- Memory pressure limits ---- - # stream_buffer hard limit = HEAD + TAIL = 40KB - self.STREAM_HEAD_CHARS = 10_000 # ~10KB — preserve initial command/echo - self.STREAM_TAIL_CHARS = 30_000 # ~30KB — preserve most recent output + # ---- Memory pressure limits from config ---- + self.STREAM_HEAD_CHARS = settings.STREAM_HEAD_CHARS + self.STREAM_TAIL_CHARS = settings.STREAM_TAIL_CHARS self.STREAM_MAX_CHARS = self.STREAM_HEAD_CHARS + self.STREAM_TAIL_CHARS - # thought_history hard limit = HEAD + TAIL = 20 entries - self.THOUGHT_HEAD_COUNT = 5 # first 5 thoughts (task inception context) - self.THOUGHT_TAIL_COUNT = 15 # last 15 thoughts (most recent AI reasoning) + self.THOUGHT_HEAD_COUNT = settings.THOUGHT_HEAD_COUNT + self.THOUGHT_TAIL_COUNT = settings.THOUGHT_TAIL_COUNT self.THOUGHT_MAX_COUNT = self.THOUGHT_HEAD_COUNT + self.THOUGHT_TAIL_COUNT + # Sharded locking for high concurrency + self.NUM_SHARDS = 16 + self.shards = [{"tasks": {}, "lock": threading.Lock()} for _ in range(self.NUM_SHARDS)] + threading.Thread( target=self._cleanup_loop, daemon=True, name="JournalCleanup" ).start() @@ -38,14 +37,12 @@ # Internal helpers # ------------------------------------------------------------------ - def _trim_stream(self, buf: str, chunk: str) -> str: - """ - Append chunk to buf, enforcing the head+tail memory limit. + def _get_shard(self, task_id: str): + """Returns the shard (dict and lock) for a given task_id.""" + return self.shards[hash(task_id) % self.NUM_SHARDS] - When the total length exceeds STREAM_MAX_CHARS, the middle section is - replaced with one human-readable '[... N bytes omitted ...]' marker. - The head (STREAM_HEAD_CHARS) and tail (STREAM_TAIL_CHARS) are always kept. - """ + def _trim_stream(self, buf: str, chunk: str) -> str: + """Append chunk to buf, enforcing the head+tail memory limit.""" combined = buf + chunk if len(combined) <= self.STREAM_MAX_CHARS: return combined @@ -57,13 +54,8 @@ return head + marker + tail def _trim_thoughts(self, thoughts: list, new_entry: dict) -> list: - """ - Append new_entry to thoughts, enforcing the head+tail thought limit. - - When total entries exceed THOUGHT_MAX_COUNT, the middle entries are - collapsed into a single sentinel entry. - """ - thoughts = list(thoughts) # shallow copy to avoid mutating in-place + """Append new_entry to thoughts, enforcing the head+tail thought limit.""" + thoughts = list(thoughts) thoughts.append(new_entry) if len(thoughts) <= self.THOUGHT_MAX_COUNT: return thoughts @@ -103,14 +95,15 @@ """Initializes state for a new task and returns its completion event.""" event = threading.Event() prompt_event = threading.Event() - with self.lock: - self.tasks[task_id] = { + shard = self._get_shard(task_id) + with shard["lock"]: + shard["tasks"][task_id] = { "event": event, "prompt_event": prompt_event, "result": None, "node_id": node_id, - "stream_buffer": "", # head+tail bounded raw stdout - "thought_history": [], # head+tail bounded AI reasoning log + "stream_buffer": "", + "thought_history": [], "created_at": time.time(), "completed_at": None, } @@ -118,68 +111,72 @@ def add_thought(self, task_id: str, thought: str) -> bool: """Adds an AI reasoning entry to the task's history (head+tail bounded).""" - with self.lock: - if task_id not in self.tasks: + shard = self._get_shard(task_id) + with shard["lock"]: + if task_id not in shard["tasks"]: return False entry = {"time": time.time(), "thought": thought} - self.tasks[task_id]["thought_history"] = self._trim_thoughts( - self.tasks[task_id]["thought_history"], entry + shard["tasks"][task_id]["thought_history"] = self._trim_thoughts( + shard["tasks"][task_id]["thought_history"], entry ) return True def append_stream(self, task_id: str, chunk: str) -> bool: """Appends a real-time output chunk to the task's stream buffer (head+tail bounded).""" - with self.lock: - if task_id not in self.tasks: + shard = self._get_shard(task_id) + with shard["lock"]: + if task_id not in shard["tasks"]: return False - self.tasks[task_id]["stream_buffer"] = self._trim_stream( - self.tasks[task_id]["stream_buffer"], chunk + shard["tasks"][task_id]["stream_buffer"] = self._trim_stream( + shard["tasks"][task_id]["stream_buffer"], chunk ) return True def signal_prompt(self, task_id: str) -> bool: """Signals that an agent node detected an interactive prompt.""" - with self.lock: - if task_id in self.tasks: - self.tasks[task_id]["prompt_event"].set() + shard = self._get_shard(task_id) + with shard["lock"]: + if task_id in shard["tasks"]: + shard["tasks"][task_id]["prompt_event"].set() return True return False def fulfill(self, task_id: str, result: dict) -> bool: - """ - Stores the final result from a node and wakes all waiting threads. - Also trims the final stdout blob if it exceeds the memory limit. - """ - with self.lock: - if task_id not in self.tasks: + """Stores the final result from a node and wakes all waiting threads.""" + shard = self._get_shard(task_id) + with shard["lock"]: + if task_id not in shard["tasks"]: return False - # Trim oversized final stdout before storing if isinstance(result, dict) and result.get("stdout"): result["stdout"] = self._trim_final_stdout(result["stdout"]) - self.tasks[task_id]["result"] = result - self.tasks[task_id]["completed_at"] = time.time() - self.tasks[task_id]["event"].set() - self.tasks[task_id]["prompt_event"].set() # wake sub-agent edge sleeps + shard["tasks"][task_id]["result"] = result + shard["tasks"][task_id]["completed_at"] = time.time() + shard["tasks"][task_id]["event"].set() + shard["tasks"][task_id]["prompt_event"].set() return True def fail_node_tasks(self, node_id: str, error_msg: str = "Node disconnected") -> int: """Fulfills all pending tasks for a disconnected node with an error.""" - with self.lock: - to_fail = [ - tid for tid, t in self.tasks.items() - if t.get("node_id") == node_id and t.get("result") is None - ] - for tid in to_fail: - self.tasks[tid]["result"] = {"error": error_msg, "status": "ERROR"} - self.tasks[tid]["completed_at"] = time.time() - self.tasks[tid]["event"].set() - self.tasks[tid]["prompt_event"].set() - return len(to_fail) + total_failed = 0 + for shard in self.shards: + with shard["lock"]: + to_fail = [ + tid for tid, t in shard["tasks"].items() + if t.get("node_id") == node_id and t.get("result") is None + ] + for tid in to_fail: + shard["tasks"][tid]["result"] = {"error": error_msg, "status": "ERROR"} + shard["tasks"][tid]["completed_at"] = time.time() + shard["tasks"][tid]["event"].set() + shard["tasks"][tid]["prompt_event"].set() + total_failed += len(to_fail) + return total_failed def get_result(self, task_id: str): """Returns the enriched result for a task, or a partial snapshot if still running.""" - with self.lock: - data = self.tasks.get(task_id) + shard = self._get_shard(task_id) + with shard["lock"]: + data = shard["tasks"].get(task_id) if data is None: return None @@ -187,7 +184,6 @@ history = data.get("thought_history", []) if res is None: - # Task still running — return partial stream buffer return { "stdout": data["stream_buffer"], "stderr": "", @@ -196,25 +192,26 @@ "thought_history": history, } - # Enrich final result with reasoning history if isinstance(res, dict): - res = dict(res) # don't mutate stored result + res = dict(res) res["thought_history"] = history return res def pop(self, task_id: str): - """Removes the task's state from the journal (call after result is consumed).""" - with self.lock: - return self.tasks.pop(task_id, None) + """Removes the task's state from the journal.""" + shard = self._get_shard(task_id) + with shard["lock"]: + return shard["tasks"].pop(task_id, None) def cleanup(self, max_age_s: int = 900): """Purges stale tasks to prevent slow memory accumulation.""" now = time.time() - with self.lock: - to_remove = [ - tid for tid, t in self.tasks.items() - if (t["completed_at"] and (now - t["completed_at"]) > 120) # finished: keep 2m - or (now - t["created_at"]) > 900 # pending: keep 15m - ] - for tid in to_remove: - del self.tasks[tid] + for shard in self.shards: + with shard["lock"]: + to_remove = [ + tid for tid, t in shard["tasks"].items() + if (t["completed_at"] and (now - t["completed_at"]) > 120) + or (now - t["created_at"]) > max_age_s + ] + for tid in to_remove: + del shard["tasks"][tid] diff --git a/ai-hub/app/core/grpc/services/grpc_server.py b/ai-hub/app/core/grpc/services/grpc_server.py index 7bd42fd..265cbfc 100644 --- a/ai-hub/app/core/grpc/services/grpc_server.py +++ b/ai-hub/app/core/grpc/services/grpc_server.py @@ -5,6 +5,7 @@ import logging import json import uuid +from weakref import WeakValueDictionary # removed requests import as we now validate tokens directly @@ -28,7 +29,7 @@ self.journal = TaskJournal() self.pool = GlobalWorkPool() self.mirror = GhostMirrorManager(storage_root=os.path.join(settings.DATA_DIR, "mirrors")) - self.io_locks = {} # key -> threading.Lock + self.io_locks = WeakValueDictionary() # key -> threading.Lock (weakly referenced) self.io_locks_lock = threading.Lock() self.assistant = TaskAssistant(self.registry, self.journal, self.pool, self.mirror) self.pool.on_new_work = self._broadcast_work diff --git a/ai-hub/app/core/services/document.py b/ai-hub/app/core/services/document.py index bd4c60a..cca2aab 100644 --- a/ai-hub/app/core/services/document.py +++ b/ai-hub/app/core/services/document.py @@ -58,6 +58,14 @@ return None # CRITICAL SECURITY: Delete associated vector metadata to prevent RAG 'Ghost Results' + # 1. Get faiss IDs + faiss_ids = [r.id for r in db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).all()] + + # 2. Remove from FAISS index + if faiss_ids: + self.vector_store.remove_vectors(faiss_ids) + + # 3. Delete from DB db.query(models.VectorMetadata).filter(models.VectorMetadata.document_id == document_id).delete() db.delete(doc_to_delete) diff --git a/ai-hub/app/core/vector_store/faiss_store.py b/ai-hub/app/core/vector_store/faiss_store.py index 93f4fea..0a41108 100644 --- a/ai-hub/app/core/vector_store/faiss_store.py +++ b/ai-hub/app/core/vector_store/faiss_store.py @@ -110,6 +110,22 @@ logging.info(f"Added {len(new_faiss_ids)} documents to FAISS index.") return new_faiss_ids + def remove_vectors(self, faiss_ids: List[int]) -> int: + """Removes vectors with specified IDs from the FAISS index.""" + if self.index is None: + logging.warning("FAISS index is not initialized.") + return 0 + + logging.info(f"Removing {len(faiss_ids)} vectors from FAISS index...") + selector = faiss.IDSelectorBatch(faiss_ids) + num_removed = self.index.remove_ids(selector) + + self.doc_id_map = [i for i in self.doc_id_map if i not in faiss_ids] + + self.save_index() + logging.info(f"Successfully removed {num_removed} vectors from FAISS index.") + return num_removed + def search_similar_documents(self, query_text: str, k: int = 5, prefilter_tags: Optional[Dict[str, Any]] = None, db_session: Session = None) -> List[int]: diff --git a/ai-hub/tests/api/test_dependencies.py b/ai-hub/tests/api/test_dependencies.py index 0cf88d5..fae896a 100644 --- a/ai-hub/tests/api/test_dependencies.py +++ b/ai-hub/tests/api/test_dependencies.py @@ -86,7 +86,7 @@ assert excinfo.value.status_code == 401 assert "X-User-ID header is missing" in str(excinfo.value.detail) -@patch('app.api.dependencies.settings') +@patch('app.config.settings') def test_get_current_user_valid_secret_required_but_missing(mock_settings, mock_session): mock_settings.SECRET_KEY = "super_secret" with pytest.raises(HTTPException) as excinfo: @@ -94,7 +94,7 @@ assert excinfo.value.status_code == 403 assert "Invalid Proxy Secret" in str(excinfo.value.detail) -@patch('app.api.dependencies.settings') +@patch('app.config.settings') def test_get_current_user_invalid_secret(mock_settings, mock_session): mock_settings.SECRET_KEY = "super_secret" with pytest.raises(HTTPException) as excinfo: @@ -102,7 +102,7 @@ assert excinfo.value.status_code == 403 assert "Invalid Proxy Secret" in str(excinfo.value.detail) -@patch('app.api.dependencies.settings') +@patch('app.config.settings') def test_get_current_user_valid_secret(mock_settings, mock_session): mock_settings.SECRET_KEY = "super_secret" @@ -114,7 +114,7 @@ user = asyncio.run(get_current_user(db=mock_session, x_user_id="test_user", x_proxy_secret="super_secret")) assert user == mock_user -@patch('app.api.dependencies.settings') +@patch('app.config.settings') def test_get_current_user_secret_not_required(mock_settings, mock_session): mock_settings.SECRET_KEY = "dev" @@ -188,6 +188,28 @@ assert "object has no service named 'non_existent_service'" in str(excinfo.value) +def test_service_container_explicit_attributes(): + """ + Tests that ServiceContainer has explicit attributes initialized to None. + """ + container = ServiceContainer() + assert container.document_service is None + assert container.rag_service is None + assert container.orchestrator is None + assert container.settings is None + assert container.node_registry_service is None + assert container.browser_service is None + assert container.tool_service is None + assert container.stt_service is None + assert container.tts_service is None + assert container.prompt_service is None + assert container.session_service is None + assert container.user_service is None + assert container.mesh_service is None + assert container.auth_service is None + assert container.preference_service is None + assert container.agent_scheduler is None + def test_service_container_chaining(mock_faiss_vector_store): """ Tests that the with_* methods can be chained together. diff --git a/ai-hub/tests/core/grpc/test_grpc_server.py b/ai-hub/tests/core/grpc/test_grpc_server.py new file mode 100644 index 0000000..c852be5 --- /dev/null +++ b/ai-hub/tests/core/grpc/test_grpc_server.py @@ -0,0 +1,53 @@ +import pytest +import threading +import gc +from unittest.mock import MagicMock, patch +from weakref import WeakValueDictionary + +from app.core.grpc.services.grpc_server import AgentOrchestrator +from app.protos import agent_pb2 + +@pytest.fixture +def mock_registry(): + return MagicMock() + +@pytest.fixture +def orchestrator(mock_registry, tmp_path): + with patch('threading.Thread'), \ + patch('app.core.grpc.services.grpc_server.settings') as mock_settings: + mock_settings.DATA_DIR = str(tmp_path) + return AgentOrchestrator(registry=mock_registry) + +def test_io_locks_is_weak_value_dict(orchestrator): + assert isinstance(orchestrator.io_locks, WeakValueDictionary) + +def test_io_locks_garbage_collection(orchestrator): + # Simulate adding a lock + lock_key = "session1:file1" + + # We need to create a lock and put it in + lock = threading.Lock() + orchestrator.io_locks[lock_key] = lock + + assert lock_key in orchestrator.io_locks + + # Now remove the strong reference + del lock + + # Force garbage collection + gc.collect() + + # It should be gone! + assert lock_key not in orchestrator.io_locks + +def test_build_sandbox_policy_default_strict(orchestrator): + policy = orchestrator._build_sandbox_policy({}) + assert policy.mode == agent_pb2.SandboxPolicy.STRICT + +def test_build_sandbox_policy_explicit_strict(orchestrator): + policy = orchestrator._build_sandbox_policy({"shell": {"sandbox": {"mode": "STRICT"}}}) + assert policy.mode == agent_pb2.SandboxPolicy.STRICT + +def test_build_sandbox_policy_explicit_permissive(orchestrator): + policy = orchestrator._build_sandbox_policy({"shell": {"sandbox": {"mode": "PERMISSIVE"}}}) + assert policy.mode == agent_pb2.SandboxPolicy.PERMISSIVE diff --git a/ai-hub/tests/core/services/test_document.py b/ai-hub/tests/core/services/test_document.py index 850b0a9..cb08e7a 100644 --- a/ai-hub/tests/core/services/test_document.py +++ b/ai-hub/tests/core/services/test_document.py @@ -106,21 +106,49 @@ def test_delete_document_success(document_service: DocumentService): """ - Tests that delete_document correctly deletes a document. + Tests that delete_document correctly deletes a document and its vectors. """ # Arrange mock_db = MagicMock(spec=Session) doc_id_to_delete = 1 doc_to_delete = models.Document(id=doc_id_to_delete) - mock_db.query.return_value.filter.return_value.first.return_value = doc_to_delete + + # Mock queries using side_effect based on the model passed to query() + mock_query = MagicMock() + mock_db.query = mock_query + + # Mock for Document + mock_query_doc = MagicMock() + mock_query_doc.filter.return_value.first.return_value = doc_to_delete + + # Mock for VectorMetadata + mock_query_meta = MagicMock() + mock_meta = MagicMock() + mock_meta.id = 999 + mock_query_meta.filter.return_value.all.return_value = [mock_meta] + + def query_side_effect(model): + if model == models.Document: + return mock_query_doc + if model == models.VectorMetadata: + return mock_query_meta + return MagicMock() + + mock_query.side_effect = query_side_effect # Act deleted_id = document_service.delete_document(db=mock_db, document_id=doc_id_to_delete) # Assert assert deleted_id == doc_id_to_delete - mock_db.query.assert_called_once_with(models.Document) mock_db.delete.assert_called_once_with(doc_to_delete) + + # Verify FAISS cleanup + document_service.vector_store.remove_vectors.assert_called_once_with([999]) + + # Verify DB cleanup of metadata + mock_query_meta.filter.return_value.delete.assert_called_once() + mock_db.commit.assert_called_once() def test_delete_document_not_found(document_service: DocumentService):