Newer
Older
cortex-hub / agent-node / agent_node / skills / shell.py
import os
import pty
import select
import threading
import time
import termios
import struct
import fcntl
from .base import BaseSkill
from protos import agent_pb2

class ShellSkill(BaseSkill):
    """Admin Console Skill: Persistent stateful Bash via PTY."""
    def __init__(self, sync_mgr=None):
        self.sync_mgr = sync_mgr
        self.sessions = {} # session_id -> {fd, pid, thread}
        self.lock = threading.Lock()

    def _ensure_session(self, session_id, cwd, on_event):
        with self.lock:
            if session_id in self.sessions:
                return self.sessions[session_id]

            print(f"    [🐚] Initializing Persistent Shell Session: {session_id}")
            # Spawn bash in a pty
            pid, fd = pty.fork()
            if pid == 0: # Child
                # Environment prep
                os.environ["TERM"] = "xterm-256color"
                os.environ["PS1"] = "\\s-\\v\\$ " # Simple prompt for easier parsing maybe? No, let user have default.
                
                # Change to CWD
                if cwd and os.path.exists(cwd):
                    os.chdir(cwd)
                
                # Launch shell
                os.execv("/bin/bash", ["/bin/bash", "--login"])
            
            # Parent
            # Set non-blocking
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

            def reader():
                while True:
                    try:
                        r, _, _ = select.select([fd], [], [], 0.1)
                        if fd in r:
                            data = os.read(fd, 4096)
                            if not data: break
                            
                            decoded = data.decode("utf-8", errors="replace")
                            
                            # Blocking/Sync logic
                            with self.lock:
                                active_tid = sess.get("active_task")
                                marker = sess.get("marker")
                                if active_tid and marker:
                                    sess["buffer"] += decoded
                                    if marker in decoded:
                                        # Marker found! Extract exit code
                                        # Format: ...marker [exit_code]\n
                                        try:
                                            parts = sess["buffer"].split(marker)
                                            # The pure stdout is everything before the marker
                                            pure_stdout = parts[0]
                                            # The exit code is right after the marker
                                            after_marker = parts[1].strip().split()
                                            exit_code = int(after_marker[0]) if after_marker else 0
                                            
                                            sess["result"]["stdout"] = pure_stdout
                                            sess["result"]["status"] = 1 if exit_code == 0 else 2 # Success=1 for Skill mgr
                                            sess["event"].set()
                                            
                                            # We don't want the marker itself to spam the UI stream
                                            # So we only send the part before the marker
                                            decoded = pure_stdout
                                        except Exception as e:
                                            print(f"    [🐚⚠️] Marker parsing failed: {e}")
                                            sess["event"].set()

                            # Stream raw terminal output back
                            if on_event:
                                event = agent_pb2.SkillEvent(
                                    session_id=session_id,
                                    task_id=sess.get("active_task") or "",
                                    terminal_out=decoded
                                )
                                on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                    except (EOFError, OSError):
                        break
                print(f"    [🐚] Shell Session Terminated: {session_id}")
                with self.lock:
                    self.sessions.pop(session_id, None)

            t = threading.Thread(target=reader, daemon=True)
            t.start()
            
            self.sessions[session_id] = {"fd": fd, "pid": pid, "thread": t}
            return self.sessions[session_id]

    def handle_transparent_tty(self, task, on_complete, on_event=None):
        """Processes raw TTY/Resize events synchronously (bypasses threadpool/sandbox)."""
        cmd = task.payload_json
        session_id = task.session_id or "default-session"
        try:
            import json
            if cmd.startswith('{') and cmd.endswith('}'):
                raw_payload = json.loads(cmd)
                
                # 1. Raw Keystroke forward
                if isinstance(raw_payload, dict) and "tty" in raw_payload:
                    raw_bytes = raw_payload["tty"]
                    sess = self._ensure_session(session_id, None, on_event)
                    os.write(sess["fd"], raw_bytes.encode("utf-8"))
                    on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
                    return True
                
                # 2. Window Resize
                if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
                    cols = raw_payload.get("cols", 80)
                    rows = raw_payload.get("rows", 24)
                    sess = self._ensure_session(session_id, None, on_event)
                    import termios, struct, fcntl
                    s = struct.pack('HHHH', rows, cols, 0, 0)
                    fcntl.ioctl(sess["fd"], termios.TIOCSWINSZ, s)
                    print(f"    [🐚] Terminal Resized to {cols}x{rows}")
                    on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
                    return True
        except Exception as pe:
            print(f"    [🐚] Transparent TTY Fail: {pe}")
        return False

    def execute(self, task, sandbox, on_complete, on_event=None):
        """Dispatches command string to the persistent PTY shell and WAITS for completion."""
        session_id = task.session_id or "default-session"
        tid = task.task_id
        try:
            cmd = task.payload_json
            
            # --- Legacy Full-Command Execution (Sandboxed) ---
            allowed, status_msg = sandbox.verify(cmd)
            if not allowed:
                err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
                if on_event:
                    event = agent_pb2.SkillEvent(
                        session_id=session_id, task_id=tid,
                        terminal_out=err_msg
                    )
                    on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                
                return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

            # Resolve CWD jail
            cwd = None
            if self.sync_mgr and task.session_id:
                cwd = self.sync_mgr.get_session_dir(task.session_id)
            elif sandbox.policy.get("WORKING_DIR_JAIL"):
                cwd = sandbox.policy["WORKING_DIR_JAIL"]
                if not os.path.exists(cwd):
                    try: os.makedirs(cwd, exist_ok=True)
                    except: pass

            # Handle Session Persistent Process
            sess = self._ensure_session(session_id, cwd, on_event)
            
            # --- Blocking Wait Logic ---
            marker = f"__CORTEX_FIN_SH_{int(time.time())}__"
            event = threading.Event()
            result_container = {"stdout": "", "status": 1} # 1 = Error/Fail by default
            
            # Register waiter in session state
            with self.lock:
                sess["active_task"] = tid
                sess["marker"] = marker
                sess["event"] = event
                sess["buffer"] = ""
                sess["result"] = result_container

            # Input injection: execute command then echo marker and exit code
            print(f"    [🐚] Executing (Blocking): {cmd}")
            # We use a trick: execute command, then echo marker and return code.
            # We use ';' to chain even if first fails, unless it's a structural error.
            # 12-factor bash: ( cmd ) ; echo marker $?
            full_input = f"({cmd}) ; echo \"{marker} $?\"\n"
            os.write(sess["fd"], full_input.encode("utf-8"))

            # Wait for completion (triggered by reader)
            # Use a slightly longer timeout than the Hub's limit to avoid race, 
            # though the Hub will cancel us if it gets tired first.
            timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
            if event.wait(timeout):
                # Success! reader found the marker
                on_complete(tid, result_container, task.trace_id)
            else:
                # Timeout on node side
                print(f"    [🐚⚠️] Task {tid} timed out on node.")
                on_complete(tid, {"stdout": sess["buffer"], "stderr": "TIMEOUT", "status": 2}, task.trace_id)
            
            # Cleanup session task state
            with self.lock:
                if sess.get("active_task") == tid:
                    sess["active_task"] = None

        except Exception as e:
            print(f"    [🐚❌] Execute Error for {tid}: {e}")
            on_complete(tid, {"stderr": str(e), "status": 2}, task.trace_id)

    def cancel(self, task_id: str):
        """Cancels an active task — for persistent shell, this sends a SIGINT (Ctrl+C)."""
        # Note: We need a mapping from task_id to session_id to do this properly.
        # For now, let's assume we can broadcast a SIGINT to all shells if specific task is unknown.
        # Or better: track task-to-session mapping in the manager.
        # For Phase 3, we'll try to find the session.
        with self.lock:
            for sid, sess in self.sessions.items():
                print(f"[🛑] Sending SIGINT (Ctrl+C) to shell session: {sid}")
                # Write \x03 (Ctrl+C) to the master FD
                os.write(sess["fd"], b"\x03")
        return True

    def shutdown(self):
        """Cleanup: Terminates all persistent shells."""
        with self.lock:
            for sid, sess in list(self.sessions.items()):
                print(f"[🛑] Cleaning up persistent shell: {sid}")
                try: os.close(sess["fd"])
                except: pass
                # kill pid
                try: os.kill(sess["pid"], 9)
                except: pass
            self.sessions.clear()