Newer
Older
cortex-hub / agent-node / src / agent_node / skills / shell_bridge.py
@Antigravity AI Antigravity AI 4 hours ago 21 KB Fix terminal PTY hang from thread crash
import os
import pty
import select
import threading
import time
import termios
import struct
import fcntl
import tempfile
from agent_node.skills.base import BaseSkill
from protos import agent_pb2

class ShellSkill(BaseSkill):
    """Admin Console Skill: Persistent stateful Bash via PTY."""
    def __init__(self, sync_mgr=None):
        self.sync_mgr = sync_mgr
        self.sessions = {} # session_id -> {fd, pid, thread, last_activity, ...}
        self.lock = threading.Lock()
        
        # Phase 3: Prompt Patterns for Edge Intelligence
        self.PROMPT_PATTERNS = [
            r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$",  # bash/zsh: user@host:~$
            r">>>\s*$",                        # python
            r"\.\.\.\s*$",                      # python multi-line
            r">\s*$",                           # node/js
        ]
        
        # --- M7: Idle Session Reaper ---
        # Automatically kills dormant bash processes to free up system resources.
        self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper")
        self.reaper_thread.start()

    def _session_reaper(self):
        """Background thread that cleans up unused PTY sessions."""
        while True:
            time.sleep(60)
            with self.lock:
                now = time.time()
                for sid, sess in list(self.sessions.items()):
                    # Avoid reaping currently active tasks
                    if sess.get("active_task"):
                        continue
                    
                    # 10 minute idle timeout
                    if now - sess.get("last_activity", 0) > 600:
                        print(f"    [๐Ÿš๐Ÿงน] Reaping idle shell session: {sid}")
                        try:
                            os.close(sess["fd"])
                            os.kill(sess["pid"], 9)
                        except: pass
                        self.sessions.pop(sid, None)

    def _ensure_session(self, session_id, cwd, on_event):
        with self.lock:
            if session_id in self.sessions:
                self.sessions[session_id]["last_activity"] = time.time()
                return self.sessions[session_id]

            print(f"    [๐Ÿš] Initializing Persistent Shell Session: {session_id}")
            # Spawn bash in a pty
            pid, fd = pty.fork()
            if pid == 0: # Child
                # Environment prep
                os.environ["TERM"] = "xterm-256color"
                
                # Change to CWD
                if cwd and os.path.exists(cwd):
                    os.chdir(cwd)
                
                # Launch shell
                shell_path = "/bin/bash"
                if not os.path.exists(shell_path):
                    shell_path = "/bin/sh"
                os.execv(shell_path, [shell_path, "--login"])
            
            # Parent
            # Set non-blocking
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

            sess = {
                "fd": fd, 
                "pid": pid, 
                "last_activity": time.time(),
                "buffer_file": None,
                "tail_buffer": "",
                "active_task": None
            }

            def reader():
                while True:
                    try:
                        r, _, _ = select.select([fd], [], [], 0.1)
                        if fd in r:
                            data = os.read(fd, 4096)
                            if not data: break
                            
                            decoded = data.decode("utf-8", errors="replace")
                            
                            # Streaming/Sync logic (Detect completion marker)
                            with self.lock:
                                active_tid = sess.get("active_task")
                                marker = sess.get("marker")
                                if active_tid and marker and sess.get("buffer_file"):
                                    # Phase 2: Persistence Offloading
                                    # Write directly to disk instead of heap memory
                                    sess["buffer_file"].write(decoded)
                                    sess["buffer_file"].flush()
                                    
                                    # Keep a tiny 4KB tail in RAM for marker detection and prompt scanning
                                    sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-4096:]

                                    if marker in sess["tail_buffer"]:
                                        # Marker found! Extract exit code
                                        try:
                                            # The tail buffer has the marker
                                            after_marker = sess["tail_buffer"].split(marker)[1].strip().split()
                                            exit_code = int(after_marker[0]) if after_marker else 0

                                            # Formulate final stdout summary from the disk file
                                            bf = sess["buffer_file"]
                                            bf.seek(0, 2)
                                            file_len = bf.tell()

                                            HEAD, TAIL = 10_000, 30_000
                                            if file_len > HEAD + TAIL:
                                                bf.seek(0)
                                                head_str = bf.read(HEAD)
                                                bf.seek(file_len - TAIL)
                                                tail_str = bf.read()
                                                omitted = file_len - HEAD - TAIL
                                                pure_stdout = head_str + f"\n\n[... {omitted:,} bytes omitted (full output safely preserved at {bf.name}) ...]\n\n" + tail_str
                                            else:
                                                bf.seek(0)
                                                pure_stdout = bf.read()

                                            # Slice off the marker string and anything after it from the final result
                                            pure_stdout = pure_stdout.split(marker)[0]

                                            sess["result"]["stdout"] = pure_stdout
                                            sess["result"]["status"] = 0 if exit_code == 0 else 1
                                            
                                            # Close the file handle (leaves file on disk)
                                            sess["buffer_file"].close()
                                            sess["buffer_file"] = None
                                            
                                            sess["event"].set()
                                            decoded = pure_stdout.split(marker)[0][-4096:] if marker in pure_stdout else pure_stdout 
                                        except Exception as e:
                                            print(f"    [๐Ÿšโš ๏ธ] Marker parsing failed: {e}")
                                            sess["event"].set()

                            # Stream terminal output back (with stealth filtering)
                            if on_event:
                                stealth_out = decoded
                                if "__CORTEX_FIN_SH_" in decoded:
                                    import re
                                    # We remove any line that contains our internal marker to hide plumbing from user.
                                    # This covers both the initial command echo and the exit code output.
                                    stealth_out = re.sub(r'.*__CORTEX_FIN_SH_.*[\r\n]*', '', decoded)

                                if stealth_out:
                                    # Phase 3: Client-Side Truncation (Stream Rate Limiting)
                                    # Limit real-time stream to 15KB/sec per session to prevent flooding the Hub over gRPC.
                                    # The full output is still safely written to the tempfile on disk.
                                    with self.lock:
                                        now = time.time()
                                        if now - sess.get("stream_window_start", 0) > 1.0:
                                            sess["stream_window_start"] = now
                                            sess["stream_bytes_sent"] = 0
                                            dropped = sess.get("stream_dropped_bytes", 0)
                                            if dropped > 0:
                                                drop_msg = f"\n[... {dropped:,} bytes truncated from live stream ...]\n"
                                                event = agent_pb2.SkillEvent(
                                                    session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg
                                                )
                                                on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                                                sess["stream_dropped_bytes"] = 0

                                        if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 15_000:
                                            sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out)
                                        else:
                                            sess["stream_bytes_sent"] = sess.get("stream_bytes_sent", 0) + len(stealth_out)
                                            event = agent_pb2.SkillEvent(
                                                session_id=session_id,
                                                task_id=sess.get("active_task") or "",
                                                terminal_out=stealth_out
                                            )
                                            on_event(agent_pb2.ClientTaskMessage(skill_event=event))

                                    # EDGE INTELLIGENCE: Proactively signal prompt detection
                                    # We only check for prompts if we are actively running a task and haven't found the marker yet.
                                    current_event = sess.get("event")
                                    if active_tid and current_event and not current_event.is_set():
                                        import re
                                        tail = sess["tail_buffer"][-100:] if len(sess["tail_buffer"]) > 100 else sess["tail_buffer"]
                                        for pattern in self.PROMPT_PATTERNS:
                                            if re.search(pattern, tail):
                                                # Send specific prompt signal
                                                # Use last 20 chars as the 'prompt' hint
                                                p_hint = tail[-20:].strip()
                                                prompt_event = agent_pb2.SkillEvent(
                                                    session_id=session_id,
                                                    task_id=active_tid,
                                                    prompt=p_hint
                                                )
                                                on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event))
                                                break
                    except (EOFError, OSError):
                        break
                    except Exception as catch_all:
                        print(f"    [๐ŸšโŒ] Reader thread FATAL exception: {catch_all}")
                        break

                
                # Thread Cleanup
                print(f"    [๐Ÿš] Shell Session Terminated: {session_id}")
                with self.lock:
                    self.sessions.pop(session_id, None)

            t = threading.Thread(target=reader, daemon=True, name=f"ShellReader-{session_id}")
            t.start()
            sess["thread"] = t
            
            self.sessions[session_id] = sess
            return sess


    def handle_transparent_tty(self, task, on_complete, on_event=None):
        """Processes raw TTY/Resize events synchronously (bypasses threadpool/sandbox)."""
        cmd = task.payload_json
        session_id = task.session_id or "default-session"
        try:
            import json
            if cmd.startswith('{') and cmd.endswith('}'):
                raw_payload = json.loads(cmd)
                
                # 1. Raw Keystroke forward
                if isinstance(raw_payload, dict) and "tty" in raw_payload:
                    raw_bytes = raw_payload["tty"]
                    sess = self._ensure_session(session_id, None, on_event)
                    os.write(sess["fd"], raw_bytes.encode("utf-8"))
                    on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
                    return True
                
                # 2. Window Resize
                if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
                    cols = raw_payload.get("cols", 80)
                    rows = raw_payload.get("rows", 24)
                    sess = self._ensure_session(session_id, None, on_event)
                    import termios, struct, fcntl
                    s = struct.pack('HHHH', rows, cols, 0, 0)
                    fcntl.ioctl(sess["fd"], termios.TIOCSWINSZ, s)
                    print(f"    [๐Ÿš] Terminal Resized to {cols}x{rows}")
                    on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
                    return True
        except Exception as pe:
            print(f"    [๐Ÿš] Transparent TTY Fail: {pe}")
        return False

    def execute(self, task, sandbox, on_complete, on_event=None):
        """Dispatches command string to the persistent PTY shell and WAITS for completion."""
        session_id = task.session_id or "default-session"
        tid = task.task_id
        try:
            cmd = task.payload_json
            
            # --- Legacy Full-Command Execution (Sandboxed) ---
            allowed, status_msg = sandbox.verify(cmd)
            if not allowed:
                err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
                if on_event:
                    event = agent_pb2.SkillEvent(
                        session_id=session_id, task_id=tid,
                        terminal_out=err_msg
                    )
                    on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                
                return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

            # Resolve CWD jail
            cwd = None
            if self.sync_mgr and task.session_id:
                cwd = self.sync_mgr.get_session_dir(task.session_id)
            elif sandbox.policy.get("WORKING_DIR_JAIL"):
                cwd = sandbox.policy["WORKING_DIR_JAIL"]
                if not os.path.exists(cwd):
                    try: os.makedirs(cwd, exist_ok=True)
                    except: pass

            # Handle Session Persistent Process
            sess = self._ensure_session(session_id, cwd, on_event)
            
            # Check for RAW mode first (bypasses busy check for interactive control)
            is_raw = cmd.startswith("!RAW:")
            if is_raw:
                input_str = cmd[5:] + "\n"
                print(f"    [๐ŸšโŒจ๏ธ] RAW Input Injection: {input_str.strip()}")
                os.write(sess["fd"], input_str.encode("utf-8"))
                return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id)

            # --- 0. Busy Check: Serialize access to the PTY for standard commands ---
            with self.lock:
                if sess.get("active_task"):
                    curr_tid = sess.get("active_task")
                    return on_complete(tid, {"stderr": f"[BUSY] Session {session_id} is already running task {curr_tid}", "status": 1}, task.trace_id)

            # --- Blocking Wait Logic ---
            # --- Blocking Wait Logic ---
            marker_id = int(time.time())
            marker = f"__CORTEX_FIN_SH_{marker_id}__"
            event = threading.Event()
            result_container = {"stdout": "", "status": 0} # 0 = Success (Protobuf Convention)
            
            # Register waiter in session state
            with self.lock:
                sess["active_task"] = tid
                sess["marker"] = marker
                sess["event"] = event
                # Create a persistent tempfile for stdout instead of RAM buffer
                sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False)
                sess["tail_buffer"] = ""
                sess["result"] = result_container
                sess["cancel_event"] = threading.Event()

            # Input injection: execute command then echo marker and exit code
            try:
                # 12-factor bash: ( cmd ) ; echo marker $?
                # We use "" concatenation in the echo command to ensure the marker literal
                # DOES NOT appear in the PTY input echo, preventing premature completion.
                full_input = f"({cmd}) ; echo \"__CORTEX_FIN_SH_\"\"{marker_id}__\" $?\n"
                os.write(sess["fd"], full_input.encode("utf-8"))

                # Wait for completion (triggered by reader) OR cancellation
                timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
                start_time = time.time()
                while time.time() - start_time < timeout:
                    # Check for completion (reader found marker)
                    if event.is_set():
                        return on_complete(tid, result_container, task.trace_id)
                    
                    # Check for cancellation (HUB sent cancel)
                    if sess["cancel_event"].is_set():
                        print(f"    [๐Ÿš๐Ÿ›‘] Task {tid} cancelled on node.")
                        return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id)
                    
                    # Sleep slightly to avoid busy loop
                    time.sleep(0.1)

                # Timeout Case
                print(f"    [๐Ÿšโš ๏ธ] Task {tid} timed out on node.")
                with self.lock:
                    if sess.get("buffer_file"):
                        try:
                            sess["buffer_file"].seek(0, 2)
                            file_len = sess["buffer_file"].tell()
                            HEAD, TAIL = 10_000, 30_000
                            if file_len > HEAD + TAIL:
                                sess["buffer_file"].seek(0)
                                head_str = sess["buffer_file"].read(HEAD)
                                sess["buffer_file"].seek(file_len - TAIL)
                                tail_str = sess["buffer_file"].read()
                                omitted = file_len - HEAD - TAIL
                                partial_out = head_str + f"\n\n[... {omitted:,} bytes omitted (full timeout output saved to {sess['buffer_file'].name}) ...]\n\n" + tail_str
                            else:
                                sess["buffer_file"].seek(0)
                                partial_out = sess["buffer_file"].read()
                        except:
                            partial_out = ""
                    else:
                        partial_out = ""
                
                on_complete(tid, {"stdout": partial_out, "stderr": "TIMEOUT", "status": 2}, task.trace_id)
                
            finally:
                # Cleanup session task state
                with self.lock:
                    if sess.get("active_task") == tid:
                        if sess.get("buffer_file"):
                            try:
                                sess["buffer_file"].close()
                            except: pass
                            sess["buffer_file"] = None
                        sess["active_task"] = None
                        sess["marker"] = None
                        sess["event"] = None
                        sess["result"] = None
                        sess["cancel_event"] = None

        except Exception as e:
            print(f"    [๐ŸšโŒ] Execute Error for {tid}: {e}")
            on_complete(tid, {"stderr": str(e), "status": 2}, task.trace_id)

    def cancel(self, task_id: str):
        """Cancels an active task โ€” for persistent shell, this sends a SIGINT (Ctrl+C)."""
        with self.lock:
            for sid, sess in self.sessions.items():
                if sess.get("active_task") == task_id:
                    print(f"[๐Ÿ›‘] Sending SIGINT (Ctrl+C) to shell session (Task {task_id}): {sid}")
                    # Write \x03 (Ctrl+C) to the master FD
                    os.write(sess["fd"], b"\x03")
                    # Break the wait loop in execute thread
                    if sess.get("cancel_event"):
                        sess["cancel_event"].set()
        return True


    def shutdown(self):
        """Cleanup: Terminates all persistent shells."""
        with self.lock:
            for sid, sess in list(self.sessions.items()):
                print(f"[๐Ÿ›‘] Cleaning up persistent shell: {sid}")
                try: os.close(sess["fd"])
                except: pass
                # kill pid
                try: os.kill(sess["pid"], 9)
                except: pass
            self.sessions.clear()