import os
import pty
import select
import threading
import time
import termios
import struct
import fcntl
import tempfile
from agent_node.skills.base import BaseSkill
from protos import agent_pb2

class ShellSkill(BaseSkill):
    """Admin Console Skill: Persistent stateful Bash via PTY."""
    def __init__(self, sync_mgr=None):
        self.sync_mgr = sync_mgr
        self.sessions = {} # session_id -> {fd, pid, thread, last_activity, ...}
        self.lock = threading.Lock()
        
        # Phase 3: Prompt Patterns for Edge Intelligence
        self.PROMPT_PATTERNS = [
            r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$",  # bash/zsh: user@host:~$
            r">>>\s*$",                        # python
            r"\.\.\.\s*$",                      # python multi-line
            r">\s*$",                           # node/js
        ]
        
        # --- M7: Idle Session Reaper ---
        # Automatically kills dormant bash processes to free up system resources.
        self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper")
        self.reaper_thread.start()

    def _session_reaper(self):
        """Background thread that cleans up unused PTY sessions."""
        while True:
            time.sleep(60)
            with self.lock:
                now = time.time()
                for sid, sess in list(self.sessions.items()):
                    # Avoid reaping currently active tasks
                    if sess.get("active_task"):
                        continue
                    
                    # 10 minute idle timeout
                    if now - sess.get("last_activity", 0) > 600:
                        print(f"    [🐚🧹] Reaping idle shell session: {sid}")
                        try:
                            os.close(sess["fd"])
                            os.kill(sess["pid"], 9)
                        except: pass
                        self.sessions.pop(sid, None)

    def _ensure_session(self, session_id, cwd, on_event):
        with self.lock:
            if session_id in self.sessions:
                self.sessions[session_id]["last_activity"] = time.time()
                return self.sessions[session_id]

            print(f"    [🐚] Initializing Persistent Shell Session: {session_id}")
            # Spawn bash in a pty
            pid, fd = pty.fork()
            if pid == 0: # Child
                # Environment prep
                os.environ["TERM"] = "xterm-256color"
                
                # Change to CWD
                if cwd and os.path.exists(cwd):
                    os.chdir(cwd)
                
                # Launch shell
                os.execv("/bin/bash", ["/bin/bash", "--login"])
            
            # Parent
            # Set non-blocking
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

            sess = {
                "fd": fd, 
                "pid": pid, 
                "last_activity": time.time(),
                "buffer_file": None,
                "tail_buffer": "",
                "active_task": None
            }

            def reader():
                while True:
                    try:
                        r, _, _ = select.select([fd], [], [], 0.1)
                        if fd in r:
                            data = os.read(fd, 4096)
                            if not data: break
                            
                            decoded = data.decode("utf-8", errors="replace")
                            
                            # Streaming/Sync logic (Detect completion marker)
                            with self.lock:
                                active_tid = sess.get("active_task")
                                marker = sess.get("marker")
                                if active_tid and marker and sess.get("buffer_file"):
                                    # Phase 2: Persistence Offloading
                                    # Write directly to disk instead of heap memory
                                    sess["buffer_file"].write(decoded)
                                    sess["buffer_file"].flush()
                                    
                                    # Keep a tiny 4KB tail in RAM for marker detection and prompt scanning
                                    sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-4096:]

                                    if marker in sess["tail_buffer"]:
                                        # Marker found! Extract exit code
                                        try:
                                            # The tail buffer has the marker
                                            after_marker = sess["tail_buffer"].split(marker)[1].strip().split()
                                            exit_code = int(after_marker[0]) if after_marker else 0

                                            # Formulate final stdout summary from the disk file
                                            bf = sess["buffer_file"]
                                            bf.seek(0, 2)
                                            file_len = bf.tell()

                                            HEAD, TAIL = 10_000, 30_000
                                            if file_len > HEAD + TAIL:
                                                bf.seek(0)
                                                head_str = bf.read(HEAD)
                                                bf.seek(file_len - TAIL)
                                                tail_str = bf.read()
                                                omitted = file_len - HEAD - TAIL
                                                pure_stdout = head_str + f"\n\n[... {omitted:,} bytes omitted (full output safely preserved at {bf.name}) ...]\n\n" + tail_str
                                            else:
                                                bf.seek(0)
                                                pure_stdout = bf.read()

                                            # Slice off the marker string and anything after it from the final result
                                            pure_stdout = pure_stdout.split(marker)[0]

                                            sess["result"]["stdout"] = pure_stdout
                                            sess["result"]["status"] = 0 if exit_code == 0 else 1
                                            
                                            # Close the file handle (leaves file on disk)
                                            sess["buffer_file"].close()
                                            sess["buffer_file"] = None
                                            
                                            sess["event"].set()
                                            decoded = pure_stdout.split(marker)[0][-4096:] if marker in pure_stdout else pure_stdout 
                                        except Exception as e:
                                            print(f"    [🐚⚠️] Marker parsing failed: {e}")
                                            sess["event"].set()

                            # Stream terminal output back (with stealth filtering)
                            if on_event:
                                stealth_out = decoded
                                if "__CORTEX_FIN_SH_" in decoded:
                                    import re
                                    # We remove any line that contains our internal marker to hide plumbing from user.
                                    # This covers both the initial command echo and the exit code output.
                                    stealth_out = re.sub(r'.*__CORTEX_FIN_SH_.*[\r\n]*', '', decoded)

                                if stealth_out:
                                    # Phase 3: Client-Side Truncation (Stream Rate Limiting)
                                    # Limit real-time stream to 15KB/sec per session to prevent flooding the Hub over gRPC.
                                    # The full output is still safely written to the tempfile on disk.
                                    with self.lock:
                                        now = time.time()
                                        if now - sess.get("stream_window_start", 0) > 1.0:
                                            sess["stream_window_start"] = now
                                            sess["stream_bytes_sent"] = 0
                                            dropped = sess.get("stream_dropped_bytes", 0)
                                            if dropped > 0:
                                                drop_msg = f"\n[... {dropped:,} bytes truncated from live stream ...]\n"
                                                event = agent_pb2.SkillEvent(
                                                    session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg
                                                )
                                                on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                                                sess["stream_dropped_bytes"] = 0

                                        if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 15_000:
                                            sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out)
                                        else:
                                            sess["stream_bytes_sent"] = sess.get("stream_bytes_sent", 0) + len(stealth_out)
                                            event = agent_pb2.SkillEvent(
                                                session_id=session_id,
                                                task_id=sess.get("active_task") or "",
                                                terminal_out=stealth_out
                                            )
                                            on_event(agent_pb2.ClientTaskMessage(skill_event=event))

                                    # EDGE INTELLIGENCE: Proactively signal prompt detection
                                    # We only check for prompts if we are actively running a task and haven't found the marker yet.
                                    if active_tid and not sess["event"].is_set():
                                        import re
                                        tail = sess["tail_buffer"][-100:] if len(sess["tail_buffer"]) > 100 else sess["tail_buffer"]
                                        for pattern in self.PROMPT_PATTERNS:
                                            if re.search(pattern, tail):
                                                # Send specific prompt signal
                                                # Use last 20 chars as the 'prompt' hint
                                                p_hint = tail[-20:].strip()
                                                prompt_event = agent_pb2.SkillEvent(
                                                    session_id=session_id,
                                                    task_id=active_tid,
                                                    prompt=p_hint
                                                )
                                                on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event))
                                                break
                    except (EOFError, OSError):
                        break
                
                # Thread Cleanup
                print(f"    [🐚] Shell Session Terminated: {session_id}")
                with self.lock:
                    self.sessions.pop(session_id, None)

            t = threading.Thread(target=reader, daemon=True, name=f"ShellReader-{session_id}")
            t.start()
            sess["thread"] = t
            
            self.sessions[session_id] = sess
            return sess


    def handle_transparent_tty(self, task, on_complete, on_event=None):
        """Processes raw TTY/Resize events synchronously (bypasses threadpool/sandbox)."""
        cmd = task.payload_json
        session_id = task.session_id or "default-session"
        try:
            import json
            if cmd.startswith('{') and cmd.endswith('}'):
                raw_payload = json.loads(cmd)
                
                # 1. Raw Keystroke forward
                if isinstance(raw_payload, dict) and "tty" in raw_payload:
                    raw_bytes = raw_payload["tty"]
                    sess = self._ensure_session(session_id, None, on_event)
                    os.write(sess["fd"], raw_bytes.encode("utf-8"))
                    on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
                    return True
                
                # 2. Window Resize
                if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
                    cols = raw_payload.get("cols", 80)
                    rows = raw_payload.get("rows", 24)
                    sess = self._ensure_session(session_id, None, on_event)
                    import termios, struct, fcntl
                    s = struct.pack('HHHH', rows, cols, 0, 0)
                    fcntl.ioctl(sess["fd"], termios.TIOCSWINSZ, s)
                    print(f"    [🐚] Terminal Resized to {cols}x{rows}")
                    on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
                    return True
        except Exception as pe:
            print(f"    [🐚] Transparent TTY Fail: {pe}")
        return False

    def execute(self, task, sandbox, on_complete, on_event=None):
        """Dispatches command string to the persistent PTY shell and WAITS for completion."""
        session_id = task.session_id or "default-session"
        tid = task.task_id
        try:
            cmd = task.payload_json
            
            # --- Legacy Full-Command Execution (Sandboxed) ---
            allowed, status_msg = sandbox.verify(cmd)
            if not allowed:
                err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
                if on_event:
                    event = agent_pb2.SkillEvent(
                        session_id=session_id, task_id=tid,
                        terminal_out=err_msg
                    )
                    on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                
                return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

            # Resolve CWD jail
            cwd = None
            if self.sync_mgr and task.session_id:
                cwd = self.sync_mgr.get_session_dir(task.session_id)
            elif sandbox.policy.get("WORKING_DIR_JAIL"):
                cwd = sandbox.policy["WORKING_DIR_JAIL"]
                if not os.path.exists(cwd):
                    try: os.makedirs(cwd, exist_ok=True)
                    except: pass

            # Handle Session Persistent Process
            sess = self._ensure_session(session_id, cwd, on_event)
            
            # Check for RAW mode first (bypasses busy check for interactive control)
            is_raw = cmd.startswith("!RAW:")
            if is_raw:
                input_str = cmd[5:] + "\n"
                print(f"    [🐚⌨️] RAW Input Injection: {input_str.strip()}")
                os.write(sess["fd"], input_str.encode("utf-8"))
                return on_complete(tid, {"stdout": "INJECTED", "status": 1}, task.trace_id)

            # --- 0. Busy Check: Serialize access to the PTY for standard commands ---
            with self.lock:
                if sess.get("active_task"):
                    curr_tid = sess.get("active_task")
                    return on_complete(tid, {"stderr": f"[BUSY] Session {session_id} is already running task {curr_tid}", "status": 2}, task.trace_id)

            # --- Blocking Wait Logic ---
            # --- Blocking Wait Logic ---
            marker_id = int(time.time())
            marker = f"__CORTEX_FIN_SH_{marker_id}__"
            event = threading.Event()
            result_container = {"stdout": "", "status": 1} # 1 = Success by default (node.py convention)
            
            # Register waiter in session state
            with self.lock:
                sess["active_task"] = tid
                sess["marker"] = marker
                sess["event"] = event
                # Create a persistent tempfile for stdout instead of RAM buffer
                sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False)
                sess["tail_buffer"] = ""
                sess["result"] = result_container
                sess["cancel_event"] = threading.Event()

            # Input injection: execute command then echo marker and exit code
            try:
                # 12-factor bash: ( cmd ) ; echo marker $?
                # We use "" concatenation in the echo command to ensure the marker literal
                # DOES NOT appear in the PTY input echo, preventing premature completion.
                full_input = f"({cmd}) ; echo \"__CORTEX_FIN_SH_\"\"{marker_id}__\" $?\n"
                os.write(sess["fd"], full_input.encode("utf-8"))

                # Wait for completion (triggered by reader) OR cancellation
                timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
                start_time = time.time()
                while time.time() - start_time < timeout:
                    # Check for completion (reader found marker)
                    if event.is_set():
                        return on_complete(tid, result_container, task.trace_id)
                    
                    # Check for cancellation (HUB sent cancel)
                    if sess["cancel_event"].is_set():
                        print(f"    [🐚🛑] Task {tid} cancelled on node.")
                        return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id)
                    
                    # Sleep slightly to avoid busy loop
                    time.sleep(0.1)

                # Timeout Case
                print(f"    [🐚⚠️] Task {tid} timed out on node.")
                with self.lock:
                    if sess.get("buffer_file"):
                        try:
                            sess["buffer_file"].seek(0, 2)
                            file_len = sess["buffer_file"].tell()
                            HEAD, TAIL = 10_000, 30_000
                            if file_len > HEAD + TAIL:
                                sess["buffer_file"].seek(0)
                                head_str = sess["buffer_file"].read(HEAD)
                                sess["buffer_file"].seek(file_len - TAIL)
                                tail_str = sess["buffer_file"].read()
                                omitted = file_len - HEAD - TAIL
                                partial_out = head_str + f"\n\n[... {omitted:,} bytes omitted (full timeout output saved to {sess['buffer_file'].name}) ...]\n\n" + tail_str
                            else:
                                sess["buffer_file"].seek(0)
                                partial_out = sess["buffer_file"].read()
                        except:
                            partial_out = ""
                    else:
                        partial_out = ""
                
                on_complete(tid, {"stdout": partial_out, "stderr": "TIMEOUT", "status": 2}, task.trace_id)
                
            finally:
                # Cleanup session task state
                with self.lock:
                    if sess.get("active_task") == tid:
                        if sess.get("buffer_file"):
                            try:
                                sess["buffer_file"].close()
                            except: pass
                            sess["buffer_file"] = None
                        sess["active_task"] = None
                        sess["marker"] = None
                        sess["event"] = None
                        sess["result"] = None
                        sess["cancel_event"] = None

        except Exception as e:
            print(f"    [🐚❌] Execute Error for {tid}: {e}")
            on_complete(tid, {"stderr": str(e), "status": 2}, task.trace_id)

    def cancel(self, task_id: str):
        """Cancels an active task — for persistent shell, this sends a SIGINT (Ctrl+C)."""
        with self.lock:
            for sid, sess in self.sessions.items():
                if sess.get("active_task") == task_id:
                    print(f"[🛑] Sending SIGINT (Ctrl+C) to shell session (Task {task_id}): {sid}")
                    # Write \x03 (Ctrl+C) to the master FD
                    os.write(sess["fd"], b"\x03")
                    # Break the wait loop in execute thread
                    if sess.get("cancel_event"):
                        sess["cancel_event"].set()
        return True


    def shutdown(self):
        """Cleanup: Terminates all persistent shells."""
        with self.lock:
            for sid, sess in list(self.sessions.items()):
                print(f"[🛑] Cleaning up persistent shell: {sid}")
                try: os.close(sess["fd"])
                except: pass
                # kill pid
                try: os.kill(sess["pid"], 9)
                except: pass
            self.sessions.clear()
