Newer
Older
cortex-hub / agent-node / src / agent_node / core / sync.py
import os
import hashlib
import time
import json
import zlib
from agent_node.config import SYNC_DIR
from protos import agent_pb2

class NodeSyncManager:
    """Handles local filesystem synchronization on the Agent Node."""
    def __init__(self, base_sync_dir=SYNC_DIR):
        self.base_sync_dir = base_sync_dir
        if not os.path.exists(self.base_sync_dir):
            os.makedirs(self.base_sync_dir, exist_ok=True)

    def get_session_dir(self, session_id: str, create: bool = False) -> str:
        """Returns the unique identifier directory for this session's sync."""
        path = os.path.join(self.base_sync_dir, session_id)
        if create:
            os.makedirs(path, exist_ok=True)
        return path

    def purge(self, session_id: str):
        """Completely removes a session's sync directory from the node."""
        path = os.path.join(self.base_sync_dir, session_id)
        if os.path.exists(path):
            import shutil
            shutil.rmtree(path)
            print(f"    [๐Ÿ“๐Ÿงน] Node sync directory deleted: {session_id}")

    def cleanup_unused_sessions(self, active_session_ids: list):
        """Removes any session directories that are no longer active on the server."""
        if not os.path.exists(self.base_sync_dir):
            return
            
        import shutil
        active_set = set(active_session_ids)
        for session_id in os.listdir(self.base_sync_dir):
            if session_id.startswith("session-") and session_id not in active_set:
                path = os.path.join(self.base_sync_dir, session_id)
                if os.path.isdir(path):
                    shutil.rmtree(path)
                    print(f"    [๐Ÿ“๐Ÿงน] Proactively purged unused session directory: {session_id}")

    def handle_manifest(self, session_id: str, manifest: agent_pb2.DirectoryManifest) -> list:
        """Compares local files with the server manifest and returns paths needing update."""
        session_dir = self.get_session_dir(session_id, create=True)
        print(f"[๐Ÿ“] Reconciling Sync Directory: {session_dir}")
        
        from shared_core.ignore import CortexIgnore
        ignore_filter = CortexIgnore(session_dir)
        expected_paths = {f.path for f in manifest.files}

        # 1. Purge extraneous local files and directories (handles Deletions)
        for root, dirs, files in os.walk(session_dir, topdown=False):
            for name in files:
                abs_path = os.path.join(root, name)
                rel_path = os.path.relpath(abs_path, session_dir)
                if rel_path in [".cortexignore", ".gitignore"] or ".cortex_browser" in rel_path: continue
                if rel_path not in expected_paths and not ignore_filter.is_ignored(rel_path):
                    try:
                        os.remove(abs_path)
                        print(f"    [๐Ÿ“๐Ÿ—‘๏ธ] Deleted extraneous local file: {rel_path}")
                    except Exception as e:
                        print(f"    [โš ๏ธ] Failed to delete file {rel_path}: {e}")
            
            for name in dirs:
                abs_path = os.path.join(root, name)
                rel_path = os.path.relpath(abs_path, session_dir)
                if rel_path not in expected_paths and not ignore_filter.is_ignored(rel_path) and ".cortex_browser" not in rel_path:
                    try:
                        if not os.listdir(abs_path):
                            os.rmdir(abs_path)
                    except Exception:
                        pass
        
        needs_update = []
        for file_info in manifest.files:
            target_path = os.path.join(session_dir, file_info.path.lstrip("/"))
            
            if file_info.is_dir:
                os.makedirs(target_path, exist_ok=True)
                continue
                
            # File Check
            if not os.path.exists(target_path):
                needs_update.append(file_info.path)
            else:
                # Memory-safe incremental hashing
                h = hashlib.sha256()
                with open(target_path, "rb") as f:
                    while True:
                        chunk = f.read(1024 * 1024)
                        if not chunk: break
                        h.update(chunk)
                actual_hash = h.hexdigest()
                if actual_hash != file_info.hash:
                    print(f"    [โš ๏ธ] Drift Detected: {file_info.path} (Local: {actual_hash[:8]} vs Remote: {file_info.hash[:8]})")
                    needs_update.append(file_info.path)
        
        return needs_update

    def write_chunk(self, session_id: str, payload: agent_pb2.FilePayload) -> bool:
        """Writes a file chunk to a shadow file and swaps to target on completion."""
        session_dir = self.get_session_dir(session_id, create=True)
        target_path = os.path.normpath(os.path.join(session_dir, payload.path.lstrip("/")))
        
        if not target_path.startswith(session_dir):
            return False # Path traversal guard
            
        os.makedirs(os.path.dirname(target_path), exist_ok=True)
        
        # We always write to a temporary "shadow" file during the sync
        tmp_path = target_path + ".cortex_tmp"
        lock_path = target_path + ".cortex_lock"

        if payload.chunk_index == 0:
            # 1. Handle Locks
            if os.path.exists(lock_path):
                try:
                    with open(lock_path, "r") as lf:
                        lock_data = json.loads(lf.read())
                        if time.time() - lock_data.get("ts", 0) < 30:
                             print(f"    [๐Ÿ“๐Ÿ”’] Lock active for {payload.path}. Proceeding with shadow write...")
                except: pass

            try:
                with open(lock_path, "w") as lf:
                    lf.write(json.dumps({"ts": time.time(), "owner": "node", "path": payload.path}))
            except: pass

            # 2. Initialize Shadow File (Truncate)
            data = payload.chunk
            if payload.compressed:
                try: data = zlib.decompress(data)
                except: pass

            with open(tmp_path, "wb") as f:
                f.write(data)
        else:
            # Random access write to shadow file
            if not os.path.exists(tmp_path):
                 with open(tmp_path, "wb") as f: pass
            
            data = payload.chunk
            if payload.compressed:
                try: data = zlib.decompress(data)
                except: pass

            with open(tmp_path, "r+b") as f:
                # Use offset directly. In proto3, it defaults to 0 if not set.
                f.seek(payload.offset)
                f.write(data)
            
        if payload.is_final:
            # 3. Finalization: Verify and Swap
            success = True
            if payload.hash:
                success = self._verify(tmp_path, payload.hash)
            
            if success:
                try:
                    # Atomic swap: The destination only changes once we are 100% sure the file is right.
                    import shutil
                    os.replace(tmp_path, target_path)
                except Exception as e:
                    print(f"    [๐Ÿ“โŒ] Atomic swap failed for {payload.path}: {e}")
                    success = False
            
            # 4. Cleanup
            if os.path.exists(lock_path):
                try: os.remove(lock_path)
                except: pass
            if os.path.exists(tmp_path) and not success:
                # If it failed verification or swap, we might want to keep it or delete it.
                # Let's delete it to allow a clean retry.
                try: os.remove(tmp_path)
                except: pass

            return success
        return True

    def _verify(self, path, expected_hash):
        # Memory-safe incremental hashing for verification
        h = hashlib.sha256()
        with open(path, "rb") as f:
            while True:
                chunk = f.read(1024 * 1024)
                if not chunk: break
                h.update(chunk)
        actual = h.hexdigest()
        if actual != expected_hash:
            print(f"[โš ๏ธ] Sync Hash Mismatch for {path}")
            return False
        return True