
import time
import os
import hashlib
import zlib
try:
    from watchdog.observers import Observer
    from watchdog.events import FileSystemEventHandler
    HAS_WATCHDOG = True
except ImportError:
    # Optional dependency: Only required for live file sync/push-to-node features.
    Observer = object
    FileSystemEventHandler = object
    HAS_WATCHDOG = False
from shared_core.ignore import CortexIgnore
from protos import agent_pb2

class SyncEventHandler(FileSystemEventHandler):
    """Listens for FS events and triggers gRPC delta pushes."""
    def __init__(self, session_id, root_path, callback):
        self.session_id = session_id
        self.root_path = root_path
        self.callback = callback
        self.ignore_filter = CortexIgnore(root_path)
        self.last_sync = {} # path -> last_hash
        self.locked = False
        self.suppressed_paths = set() # Paths currently being modified by the system
        self.syncing_paths = set() # Paths currently being scanned/pushed

    def on_modified(self, event):
        if not event.is_directory:
            self._process_change(event.src_path)

    def on_created(self, event):
        if not event.is_directory:
            self._process_change(event.src_path)

    def on_closed(self, event):
        # critical for large writes like 'dd' or 'cp' that trigger many modified events
        if not event.is_directory:
            self._process_change(event.src_path, force=True)

    def on_deleted(self, event):
        if not event.is_directory:
            rel_path = os.path.normpath(os.path.relpath(event.src_path, self.root_path))
            if not self.ignore_filter.is_ignored(rel_path):
                self.callback(self.session_id, agent_pb2.FileSyncMessage(
                    session_id=self.session_id,
                    control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.DELETE, path=rel_path)
                ))

    def on_moved(self, event):
        # Treat as delete of src and create of dest
        self.on_deleted(event)
        self._process_change(event.dest_path, force=True)

    def _process_change(self, abs_path, force=False):
        if self.locked:
            return # Block all user edits when session is locked

        rel_path = os.path.normpath(os.path.relpath(abs_path, self.root_path))
        
        if rel_path in self.suppressed_paths:
            return # Ignore changes coming from the sync manager
        
        # Phase 3: Dynamic reload if .cortexignore / .gitignore changed
        if rel_path in [".cortexignore", ".gitignore"]:
            print(f"    [*] Reloading Ignore Filter for {self.session_id}")
            self.ignore_filter = CortexIgnore(self.root_path)

        if self.ignore_filter.is_ignored(rel_path):
            return

        if rel_path in self.syncing_paths:
            return
        
        self.syncing_paths.add(rel_path)
        
        try:
            # Step 0: Settle Check (Debounce)
            if not force:
                # Wait a moment to see if the file is still being written to.
                # This is critical for tools like 'dd' or 'cp' that write in bursts.
                try:
                    initial_mtime = os.path.getmtime(abs_path)
                    initial_size = os.path.getsize(abs_path)
                    time.sleep(1.0) 
                    if not os.path.exists(abs_path): return # File deleted during wait
                    
                    current_mtime = os.path.getmtime(abs_path)
                    current_size = os.path.getsize(abs_path)
                    
                    if current_mtime != initial_mtime or current_size != initial_size:
                        # Still being modified. We'll skip this event and let the next 
                        # 'modified' event trigger the actual sync.
                        return
                except (OSError, FileNotFoundError):
                    return

            if not os.path.exists(abs_path):
                return

            file_size = os.path.getsize(abs_path)
            chunk_size = 1024 * 1024 # 1MB buffer for hashing/stream
            total_chunks = (file_size + chunk_size - 1) // chunk_size if file_size > 0 else 1
            
            # Memory-safe incremental hashing
            hasher = hashlib.sha256()
            
            with open(abs_path, "rb") as f:
                offset = 0
                index = 0
                while True:
                    chunk = f.read(chunk_size)
                    if not chunk:
                        break
                    
                    hasher.update(chunk)
                    
                    # Compress Chunk for transit
                    compressed_chunk = zlib.compress(chunk)

                    payload_fields = {
                        "path": rel_path,
                        "chunk": compressed_chunk,
                        "chunk_index": index,
                        "is_final": False, 
                        "offset": offset,
                        "compressed": True,
                        "hash": "",
                    }
                    if hasattr(agent_pb2.FilePayload, "total_chunks"):
                        payload_fields["total_chunks"] = total_chunks
                        payload_fields["total_size"] = file_size

                    payload = agent_pb2.FilePayload(**payload_fields)
                    
                    # Callback pushes to gRPC queue
                    self.callback(self.session_id, payload)
                    
                    offset += len(chunk)
                    index += 1

                # Update internal tracking with the final hash
                file_hash = hasher.hexdigest()
                
                # Signal completion with the hash to the server
                # (The server uses this to verify integrity and perform the atomic swap)
                sentinel_fields = {
                    "path": rel_path,
                    "is_final": True,
                    "hash": file_hash,
                    "chunk_index": index,
                    "offset": offset,
                }
                if hasattr(agent_pb2.FilePayload, "total_chunks"):
                    sentinel_fields["total_chunks"] = total_chunks
                    sentinel_fields["total_size"] = file_size

                self.callback(self.session_id, agent_pb2.FilePayload(**sentinel_fields))

                if self.last_sync.get(rel_path) == file_hash:
                    # Chunks were already sent, so we must send sentinel above, 
                    # but we can skip the log message.
                    return 
                
                self.last_sync[rel_path] = file_hash
                print(f"    [📁📤] Streaming Sync Complete: {rel_path} ({file_size} bytes)")

        except Exception as e:
            print(f"    [!] Watcher Error for {rel_path}: {e}")
        finally:
            self.syncing_paths.discard(rel_path)

class WorkspaceWatcher:
    """Manages FS observers for active synchronization."""
    def __init__(self, callback):
        self.callback = callback
        self.observers = {} # session_id -> (observer, handler)

    def set_lock(self, session_id, locked=True):
        if session_id in self.observers:
            print(f"[*] Workspace LOCK for {session_id}: {locked}")
            self.observers[session_id][1].locked = locked

    def start_watching(self, session_id, root_path):
        if session_id in self.observers:
            self.stop_watching(session_id)
            
        if not HAS_WATCHDOG:
            print(f"[!] Warning: 'watchdog' not installed. File sync disabled for session {session_id}")
            return

        print(f"[*] Starting Watcher for Session {session_id} at {root_path}")
        os.makedirs(root_path, exist_ok=True)
        
        handler = SyncEventHandler(session_id, root_path, self.callback)
        observer = Observer()
        observer.schedule(handler, root_path, recursive=True)
        observer.start()
        self.observers[session_id] = (observer, handler)

    def stop_watching(self, session_id):
        if session_id in self.observers:
            print(f"[*] Stopping Watcher for Session {session_id}")
            obs, _ = self.observers[session_id]
            obs.stop()
            obs.join()
            del self.observers[session_id]

    def get_watch_path(self, session_id):
        if session_id in self.observers:
            return self.observers[session_id][1].root_path
        return None

    def acknowledge_remote_write(self, session_id, rel_path, file_hash):
        """Updates the internal hash record to match a remote write, preventing an echo-back."""
        if session_id in self.observers:
            _, handler = self.observers[session_id]
            handler.last_sync[rel_path] = file_hash

    def suppress_path(self, session_id, rel_path):
        """Tells the watcher to ignore events for a specific path (e.g. during sync)."""
        if session_id in self.observers:
            _, handler = self.observers[session_id]
            handler.suppressed_paths.add(rel_path)

    def unsuppress_path(self, session_id, rel_path):
        """Resumes watching a specific path."""
        if session_id in self.observers:
            _, handler = self.observers[session_id]
            # Use discard to avoid KeyError if it wasn't there
            handler.suppressed_paths.discard(rel_path)

    def shutdown(self):
        for sid in list(self.observers.keys()):
            self.stop_watching(sid)
