Newer
Older
cortex-hub / agent-node / src / agent_node / skills / shell_bridge.py
from agent_node.skills.base import BaseSkill
from agent_node.skills.terminal_backends import get_terminal_backend
from protos import agent_pb2

class ShellSkill(BaseSkill):
    """Admin Console Skill: Persistent stateful Shell via Abstract Terminal Backend."""
    def __init__(self, sync_mgr=None):
        self.sync_mgr = sync_mgr
        self.sessions = {} # session_id -> {backend, thread, last_activity, ...}
        self.lock = threading.Lock()
        
        # Phase 3: Prompt Patterns for Edge Intelligence
        self.PROMPT_PATTERNS = [
            r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$",  # bash/zsh: user@host:~$
            r">>>\s*$",                        # python
            r"\.\.\.\s*$",                      # python multi-line
            r">\s*$",                           # node/js
            r"PS\s+.*>\s*$",                   # powershell
        ]
        
        # --- M7: Idle Session Reaper ---
        self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper")
        self.reaper_thread.start()

    def _session_reaper(self):
        """Background thread that cleans up unused Shell sessions."""
        while True:
            time.sleep(60)
            with self.lock:
                now = time.time()
                for sid, sess in list(self.sessions.items()):
                    if sess.get("active_task"):
                        continue
                    
                    if now - sess.get("last_activity", 0) > 600:
                        print(f"    [🐚🧹] Reaping idle shell session: {sid}")
                        try:
                            sess["backend"].kill()
                        except: pass
                        self.sessions.pop(sid, None)

    def _ensure_session(self, session_id, cwd, on_event):
        with self.lock:
            if session_id in self.sessions:
                self.sessions[session_id]["last_activity"] = time.time()
                return self.sessions[session_id]

            print(f"    [🐚] Initializing Persistent Shell Session: {session_id}")
            backend = get_terminal_backend()
            import os
            backend.spawn(cwd=cwd, env=os.environ.copy())
            print(f"    [🐚] Terminal Spawned (PID Check: {backend.is_alive()})")

            sess = {
                "backend": backend,
                "last_activity": time.time(),
                "buffer_file": None,
                "tail_buffer": "",
                "active_task": None
            }

            def reader():
                while True:
                    try:
                        data = backend.read(4096)
                        if not data:
                            if not backend.is_alive():
                                break
                            time.sleep(0.05)
                            continue
                        
                        if isinstance(data, str):
                            decoded = data
                        else:
                            decoded = data.decode("utf-8", errors="replace")
                        
                        # M7: Protocol-Aware Framing (OSC 1337)
                        # We use non-printable fences to accurately slice the command output
                        with self.lock:
                            active_tid = sess.get("active_task")
                            if active_tid and sess.get("buffer_file"):
                                start_fence = f"\x1b]1337;TaskStart;id={active_tid}\x07"
                                end_fence_prefix = f"\x1b]1337;TaskEnd;id={active_tid};exit="
                                
                                bracket_start_fence = f"[[1337;TaskStart;id={active_tid}]]"
                                bracket_end_fence_prefix = f"[[1337;TaskEnd;id={active_tid};exit="
                                
                                sess["buffer_file"].write(decoded)
                                sess["buffer_file"].flush()
                                
                                # Byte-accurate 16KB tail for fence detection
                                sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-16384:]

                                # Clean ANSI from the tail buffer to prevent ConPTY injecting random cursor positions inside our marker strings
                                import re
                                ansi_escape = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
                                clean_tail = ansi_escape.sub('', sess["tail_buffer"])
                                
                                if end_fence_prefix in clean_tail or bracket_end_fence_prefix in clean_tail:
                                    # Task completed via protocol fence!
                                    try:
                                        is_bracket = bracket_end_fence_prefix in clean_tail
                                        active_end_prefix = bracket_end_fence_prefix if is_bracket else end_fence_prefix
                                        active_start_fence = bracket_start_fence if is_bracket else start_fence
                                        
                                        # Extract exit code from the trailer: TaskEnd;id=...;exit=N
                                        after_end = clean_tail.split(active_end_prefix)[1]
                                        exit_match = re.search(r'(\d+)', after_end)
                                        exit_code = int(exit_match.group(1)) if exit_match else 0
                                        
                                        bf = sess["buffer_file"]
                                        bf.seek(0)
                                        full_raw = bf.read()
                                        clean_full_raw = ansi_escape.sub('', full_raw)
                                        
                                        print(f"    [🐚DEBUG] Fence Match! Buffer: {len(clean_full_raw)} bytes. Tail: {repr(clean_full_raw[-200:])}")
                                        
                                        # Clean extraction between fences (using ANSI stripped content)
                                        if active_start_fence in clean_full_raw:
                                            # We take the content AFTER the last start fence to avoid echo-back collision
                                            content = clean_full_raw.split(active_start_fence)[-1].split(active_end_prefix)[0]
                                        else:
                                            content = clean_full_raw.split(active_end_prefix)[0]
                                            
                                        # Minimal post-processing: remove the echo of the end command itself
                                        content = re.sub(r'echo \x1b]1337;TaskEnd;.*', '', content).strip()
                                        content = re.sub(r'echo \[\[1337;TaskEnd;.*', '', content).strip()
                                        
                                        sess["result"]["stdout"] = content
                                        sess["result"]["status"] = 0 if exit_code == 0 else 1
                                        
                                        sess["buffer_file"].close()
                                        sess["buffer_file"] = None
                                        sess["event"].set()
                                        
                                        # Strip the protocol fences from the live UI stream to keep it clean (ANSI and Bracket)
                                        decoded = re.sub(r'\x1b]1337;Task(Start|End);id=.*?\x07', '', decoded)
                                        decoded = re.sub(r'\[\[1337;Task(Start|End);id=.*?\]\]', '', decoded)
                                    except Exception as e:
                                        print(f"    [🐚⚠️] Protocol parsing failed: {e}")
                                        sess["event"].set()

                        # Stream terminal output back to UI
                        if on_event:
                            import re
                            
                            # M9: Filter Native Escaped cmd Echo framing from bouncing back to the UI
                            # e.g., "echo [[1337;Task^Start;id=xyz]] & "
                            decoded = re.sub(r'echo \s*\[\[1337;Task\^Start;id=[a-zA-Z0-9-]*\]\]\s*&\s*', '', decoded)
                            decoded = re.sub(r'\s*&\s*echo \s*\[\[1337;Task\^End;id=[a-zA-Z0-9-]*;exit=%errorlevel%\]\]', '', decoded)
                            
                            # M7: Line-Aware Hyper-Aggressive Stealthing
                            # Instead of complex regex on the whole buffer, we nuke any lines
                            # that carry our internal protocol baggage.
                            lines = decoded.splitlines(keepends=True)
                            clean_lines = []
                            ansi_escape_ui = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
                            for line in lines:
                                stripped_line = ansi_escape_ui.sub('', line)
                                # If the line contains our protocol marker, it's plumbing - drop it.
                                if "1337;Task" in stripped_line or "`e]" in line or "\\033]" in line:
                                    continue
                                clean_lines.append(line)
                            
                            stealth_out = "".join(clean_lines)

                            if stealth_out.strip():
                                with self.lock:
                                    now = time.time()
                                    if now - sess.get("stream_window_start", 0) > 1.0:
                                        sess["stream_window_start"] = now
                                        sess["stream_bytes_sent"] = 0
                                        dropped = sess.get("stream_dropped_bytes", 0)
                                        if dropped > 0:
                                            drop_msg = f"\n[... {dropped:,} bytes truncated from live stream ...]\n"
                                            event = agent_pb2.SkillEvent(
                                                session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg
                                            )
                                            on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                                            sess["stream_dropped_bytes"] = 0

                                    if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 100_000:
                                        sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out)
                                    else:
                                        sess["stream_bytes_sent"] = sess.get("stream_bytes_sent", 0) + len(stealth_out)
                                        event = agent_pb2.SkillEvent(
                                            session_id=session_id,
                                            task_id=sess.get("active_task") or "",
                                            terminal_out=stealth_out
                                        )
                                        on_event(agent_pb2.ClientTaskMessage(skill_event=event))

                                # EDGE INTELLIGENCE: Proactively signal prompt detection
                                current_event = sess.get("event")
                                if active_tid and current_event and not current_event.is_set():
                                    import re
                                    tail = sess["tail_buffer"][-100:] if len(sess["tail_buffer"]) > 100 else sess["tail_buffer"]
                                    for pattern in self.PROMPT_PATTERNS:
                                        if re.search(pattern, tail):
                                            p_hint = tail[-20:].strip()
                                            prompt_event = agent_pb2.SkillEvent(
                                                session_id=session_id,
                                                task_id=active_tid,
                                                prompt=p_hint
                                            )
                                            on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event))
                                            break
                    except (EOFError, OSError):
                        break
                    except Exception as catch_all:
                        print(f"    [🐚❌] Reader thread FATAL exception: {catch_all}")
                        break

                print(f"    [🐚] Shell Session Terminated: {session_id}")
                with self.lock:
                    self.sessions.pop(session_id, None)

            t = threading.Thread(target=reader, daemon=True, name=f"ShellReader-{session_id}")
            t.start()
            sess["thread"] = t
            
            self.sessions[session_id] = sess
            return sess

    def handle_transparent_tty(self, task, on_complete, on_event=None):
        """Processes raw TTY/Resize events synchronously."""
        cmd = task.payload_json
        session_id = task.session_id or "default-session"
        try:
            import json
            if cmd.startswith('{') and cmd.endswith('}'):
                raw_payload = json.loads(cmd)
                
                # 1. Raw Keystroke forward
                if isinstance(raw_payload, dict) and "tty" in raw_payload:
                    raw_bytes = raw_payload["tty"]
                    sess = self._ensure_session(session_id, None, on_event)
                    sess["backend"].write(raw_bytes.encode("utf-8"))
                    on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
                    return True
                
                # 2. Window Resize
                if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
                    cols = raw_payload.get("cols", 80)
                    rows = raw_payload.get("rows", 24)
                    sess = self._ensure_session(session_id, None, on_event)
                    sess["backend"].resize(cols, rows)
                    print(f"    [🐚] Terminal Resized to {cols}x{rows}")
                    on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
                    return True
        except Exception as pe:
            print(f"    [🐚] Transparent TTY Fail: {pe}")
        return False

    def execute(self, task, sandbox, on_complete, on_event=None):
        """Dispatches command string to the abstract terminal backend and WAITS for completion."""
        session_id = task.session_id or "default-session"
        tid = task.task_id
        try:
            cmd = task.payload_json
            
            allowed, status_msg = sandbox.verify(cmd)
            if not allowed:
                err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
                if on_event:
                    event = agent_pb2.SkillEvent(
                        session_id=session_id, task_id=tid,
                        terminal_out=err_msg
                    )
                    on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                
                return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

            cwd = None
            if self.sync_mgr and task.session_id:
                cwd = self.sync_mgr.get_session_dir(task.session_id, create=True)
            elif sandbox.policy.get("WORKING_DIR_JAIL"):
                cwd = sandbox.policy["WORKING_DIR_JAIL"]
                if not os.path.exists(cwd):
                    try: os.makedirs(cwd, exist_ok=True)
                    except: pass

            sess = self._ensure_session(session_id, cwd, on_event)
            
            is_raw = cmd.startswith("!RAW:")
            if is_raw:
                # M7 Fix: Agentic tasks (starting with 'task-') MUST use framing
                # to ensure results are captured. Forced bypass is only allowed for manual UI typing.
                if tid.startswith("task-"):
                    cmd = cmd[5:]
                    is_raw = False
                else:
                    input_str = cmd[5:] + "\n"
                    print(f"    [🐚⌨️] RAW Input Injection: {input_str.strip()}")
                    sess["backend"].write(input_str.encode("utf-8"))
                    return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id)

            marker_id = int(time.time())
            marker = f"__CORTEX_FIN_SH_{marker_id}__"
            event = threading.Event()
            result_container = {"stdout": "", "status": 0} 
            
            with self.lock:
                sess["active_task"] = tid
                sess["event"] = event
                sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False)
                sess["tail_buffer"] = ""
                sess["result"] = result_container
                sess["cancel_event"] = threading.Event()

            try:
                # M7: Protocol-Aware Command Framing (OSC 1337)
                # We wrap the command in non-printable control sequences.
                # Format: ESC ] 1337 ; <Metadata> ST (\x07)
                start_marker = f"1337;TaskStart;id={tid}"
                end_marker = f"1337;TaskEnd;id={tid}"
                
                import platform
                if platform.system() == "Windows":
                    # M7: EncodedCommand for Windows (Bypasses Quote Hell)
                    # This ensures byte-accurate delivery of ESC ([char]27) and BEL ([char]7)
                    import base64
                    
                    # M8: Ultimate Windows Shell Boundary Method (File Spooling)
                    # Bypasses Conhost VTP Redraw byte swallowing caused by line wrapping in PTY
                    # Bypasses powershell encoded limits.
                    import os
                    import tempfile as tf
                    spool_dir = os.path.join(tf.gettempdir(), "cortex_pty_tasks")
                    os.makedirs(spool_dir, exist_ok=True)
                    task_path = os.path.join(spool_dir, f"{tid}.bat")
                    
                    # We write the logic to a native shell file so the PTY simply executes a short path
                    with open(task_path, "w", encoding="utf-8") as f:
                        f.write(f"@echo off\r\n")
                        f.write(f"echo [[1337;TaskStart;id={tid}]]\r\n")
                        f.write(f"{cmd}\r\n")
                        f.write(f"echo [[1337;TaskEnd;id={tid};exit=%errorlevel%]]\r\n")
                        # optionally clean up itself
                        f.write(f"del \"%~f0\"\r\n")
                    
                    full_input = f"\"{task_path}\"\r\n"
                else:
                    # On Linux, we use echo -e with octal escapes
                    s_m = f"\\033]{start_marker}\\007"
                    e_m = f"\\033]{end_marker};exit=$__ctx_exit\\007"
                    full_input = f"echo -e -n \"{s_m}\"; {cmd}; __ctx_exit=$?; echo -e -n \"{e_m}\"\n"
                    
                sess["backend"].write(full_input.encode("utf-8"))

                timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
                start_time = time.time()
                while time.time() - start_time < timeout:
                    if event.is_set():
                        return on_complete(tid, result_container, task.trace_id)
                    if sess["cancel_event"].is_set():
                        print(f"    [πŸšπŸ›‘] Task {tid} cancelled on node.")
                        return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id)
                    time.sleep(0.1)

                print(f"    [🐚⚠️] Task {tid} timed out on node.")
                with self.lock:
                    if sess.get("buffer_file"):
                        try:
                            sess["buffer_file"].seek(0, 2)
                            file_len = sess["buffer_file"].tell()
                            HEAD, TAIL = 10_000, 30_000
                            if file_len > HEAD + TAIL:
                                sess["buffer_file"].seek(0)
                                head_str = sess["buffer_file"].read(HEAD)
                                sess["buffer_file"].seek(file_len - TAIL)
                                tail_str = sess["buffer_file"].read()
                                omitted = file_len - HEAD - TAIL
                                partial_out = head_str + f"\n\n[... {omitted:,} bytes omitted (full timeout output saved to {sess['buffer_file'].name}) ...]\n\n" + tail_str
                            else:
                                sess["buffer_file"].seek(0)
                                partial_out = sess["buffer_file"].read()
                        except:
                            partial_out = ""
                    else:
                        partial_out = ""
                
                on_complete(tid, {"stdout": partial_out, "stderr": "TIMEOUT", "status": 2}, task.trace_id)
                
            finally:
                with self.lock:
                    if sess.get("active_task") == tid:
                        if sess.get("buffer_file"):
                            try:
                                sess["buffer_file"].close()
                            except: pass
                            sess["buffer_file"] = None
                        sess["active_task"] = None
                        sess["marker"] = None
                        sess["event"] = None
                        sess["result"] = None
                        sess["cancel_event"] = None

        except Exception as e:
            print(f"    [🐚❌] Execute Error for {tid}: {e}")
            on_complete(tid, {"stderr": str(e), "status": 2}, task.trace_id)

    def cancel(self, task_id: str):
        """Cancels an active task β€” for persistent shell, this sends a SIGINT (Ctrl+C)."""
        with self.lock:
            for sid, sess in self.sessions.items():
                if sess.get("active_task") == task_id:
                    print(f"[πŸ›‘] Sending SIGINT (Ctrl+C) to shell session (Task {task_id}): {sid}")
                    sess["backend"].write(b"\x03")
                    if sess.get("cancel_event"):
                        sess["cancel_event"].set()
        return True

    def shutdown(self):
        """Cleanup: Terminates all persistent shells via backends."""
        with self.lock:
            for sid, sess in list(self.sessions.items()):
                print(f"[πŸ›‘] Cleaning up persistent shell: {sid}")
                try: sess["backend"].kill()
                except: pass
            self.sessions.clear()

import os
import threading
import time
import tempfile