Newer
Older
cortex-hub / agent-node / src / agent_node / skills / shell_bridge.py
from agent_node.skills.base import BaseSkill
from agent_node.skills.terminal_backends import get_terminal_backend
from protos import agent_pb2
import re
import threading
import time
import tempfile
import os
from agent_node.core.regex_patterns import (
    COMPILED_PROMPT_PATTERNS,
    ANSI_ESCAPE,
    EXIT_CODE_PATTERN,
    ECHO_CLEANUP_ANSI,
    ECHO_CLEANUP_BRACKET,
    STRIP_START_FENCE,
    STRIP_BRACKET_FENCE,
    ECHO_START_PATTERN,
    ECHO_END_PATTERN,
    PROTOCOL_HINT_PATTERN
)

class ShellSkill(BaseSkill):
    """Admin Console Skill: Persistent stateful Shell via Abstract Terminal Backend."""
    def __init__(self, sync_mgr=None):
        self.sync_mgr = sync_mgr
        self.sessions = {} # session_id -> {backend, thread, last_activity, ...}
        self.lock = threading.Lock()
        
        # Patterns moved to core/regex_patterns.py
        self.PROMPT_PATTERNS = COMPILED_PROMPT_PATTERNS
        
        # --- M7: Idle Session Reaper ---
        self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper")
        self.reaper_thread.start()

    def _session_reaper(self):
        """Background thread that cleans up unused Shell sessions."""
        while True:
            time.sleep(60)
            with self.lock:
                now = time.time()
                for sid, sess in list(self.sessions.items()):
                    if sess.get("active_task"):
                        continue
                    
                    if now - sess.get("last_activity", 0) > 600:
                        print(f"    [๐Ÿš๐Ÿงน] Reaping idle shell session: {sid}")
                        try:
                            sess["backend"].kill()
                        except: pass
                        self.sessions.pop(sid, None)

    def _ensure_session(self, session_id, cwd, on_event):
        """Retrieves or initializes a persistent terminal session."""
        with self.lock:
            if session_id in self.sessions:
                self.sessions[session_id]["last_activity"] = time.time()
                return self.sessions[session_id]

            print(f"    [๐Ÿš] Initializing Persistent Shell Session: {session_id}")
            backend = get_terminal_backend()
            backend.spawn(cwd=cwd, env=os.environ.copy())
            print(f"    [๐Ÿš] Terminal Spawned (PID Check: {backend.is_alive()})")

            sess = {
                "backend": backend,
                "last_activity": time.time(),
                "buffer_file": None,
                "tail_buffer": "",
                "active_task": None,
                "write_lock": threading.Lock()
            }

            # Start refactored reader thread
            t = threading.Thread(
                target=self._reader_loop, 
                args=(session_id, on_event), 
                daemon=True, 
                name=f"ShellReader-{session_id}"
            )
            t.start()
            sess["thread"] = t
            
            self.sessions[session_id] = sess
            return sess

    def _reader_loop(self, session_id, on_event):
        """Internal method to handle terminal reading and protocol extraction."""
        with self.lock:
            sess = self.sessions.get(session_id)
        if not sess: return

        backend = sess["backend"]
        while True:
            try:
                data = backend.read(4096)
                if not data:
                    if not backend.is_alive(): break
                    time.sleep(0.05)
                    continue
                
                decoded = data if isinstance(data, str) else data.decode("utf-8", errors="replace")
                
                with self.lock:
                    active_tid = sess.get("active_task")
                    if active_tid and sess.get("buffer_file"):
                        self._process_protocol_fences(sess, active_tid, decoded)

                # Stream and Edge Intelligence
                if on_event:
                    self._handle_ui_streaming(sess, session_id, active_tid, decoded, on_event)

            except (EOFError, OSError): break
            except Exception as e:
                print(f"    [๐ŸšโŒ] Reader thread FATAL exception: {e}")
                break

        print(f"    [๐Ÿš] Shell Session Terminated: {session_id}")
        with self.lock: self.sessions.pop(session_id, None)

    def _process_protocol_fences(self, sess, active_tid, decoded):
        """Internal helper to handle OSC 1337 / Bracketed Task framing."""
        start_fence = f"\x1b]1337;TaskStart;id={active_tid}\x07"
        end_fence_prefix = f"\x1b]1337;TaskEnd;id={active_tid};exit="
        bracket_start_fence = f"[[1337;TaskStart;id={active_tid}]]"
        bracket_end_fence_prefix = f"[[1337;TaskEnd;id={active_tid};exit="

        sess["buffer_file"].write(decoded)
        sess["buffer_file"].flush()
        sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-16384:]
        
        clean_tail = ANSI_ESCAPE.sub('', sess["tail_buffer"])
        
        if end_fence_prefix in clean_tail or bracket_end_fence_prefix in clean_tail:
            try:
                is_bracket = bracket_end_fence_prefix in clean_tail
                active_end_prefix = bracket_end_fence_prefix if is_bracket else end_fence_prefix
                active_start_fence = bracket_start_fence if is_bracket else start_fence
                
                after_end = clean_tail.split(active_end_prefix)[1]
                exit_match = EXIT_CODE_PATTERN.search(after_end)
                exit_code = int(exit_match.group(1)) if exit_match else 0
                
                bf = sess["buffer_file"]
                bf.seek(0)
                clean_full_raw = ANSI_ESCAPE.sub('', bf.read())
                
                # Extract content between fences
                if active_start_fence in clean_full_raw:
                    content = clean_full_raw.split(active_start_fence)[-1].split(active_end_prefix)[0]
                else:
                    content = clean_full_raw.split(active_end_prefix)[0]
                
                # Cleanup internal echo echo
                content = ECHO_CLEANUP_ANSI.sub('', content)
                content = ECHO_CLEANUP_BRACKET.sub('', content).strip()
                
                sess["result"]["stdout"] = content
                sess["result"]["status"] = 0 if exit_code == 0 else 1
                
                sess["buffer_file"].close()
                sess["buffer_file"] = None
                
                if sess.get("event"): sess["event"].set()
            except Exception as e:
                print(f"    [๐Ÿšโš ๏ธ] Protocol parsing failed: {e}")
                if sess.get("event"): sess["event"].set()

    def _handle_ui_streaming(self, sess, session_id, active_tid, decoded, on_event):
        """Internal helper to filter plumbing and stream terminal output to the client."""
        # Clean framing echoes from the live stream
        decoded = ECHO_START_PATTERN.sub('', decoded)
        decoded = ECHO_END_PATTERN.sub('', decoded)
        decoded = STRIP_START_FENCE.sub('', decoded)
        decoded = STRIP_BRACKET_FENCE.sub('', decoded)
        
        # Line-Aware Stealthing for extra safety
        lines = decoded.splitlines(keepends=True)
        clean_lines = [line for line in lines if not PROTOCOL_HINT_PATTERN.search(ANSI_ESCAPE.sub('', line))]
        stealth_out = "".join(clean_lines)

        if stealth_out.strip():
            with self.lock:
                self._apply_stream_throttling(sess, session_id, stealth_out, on_event)
                self._detect_edge_prompts(sess, session_id, active_tid, on_event)

    def _apply_stream_throttling(self, sess, session_id, stealth_out, on_event):
        """Protects the bridge from output flooding."""
        now = time.time()
        if now - sess.get("stream_window_start", 0) > 1.0:
            sess["stream_window_start"], sess["stream_bytes_sent"] = now, 0
            if sess.get("stream_dropped_bytes", 0) > 0:
                drop_msg = f"\n[... {sess['stream_dropped_bytes']:,} bytes truncated from live stream ...]\n"
                event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg)
                on_event(agent_pb2.ClientTaskMessage(skill_event=event))
                sess["stream_dropped_bytes"] = 0

        if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 100_000:
            sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out)
        else:
            sess["stream_bytes_sent"] += len(stealth_out)
            event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=stealth_out)
            on_event(agent_pb2.ClientTaskMessage(skill_event=event))

    def _detect_edge_prompts(self, sess, session_id, active_tid, on_event):
        """Signals prompt detection (e.g. login: or password:) back to the client."""
        current_event = sess.get("event")
        if active_tid and current_event and not current_event.is_set():
            tail = sess["tail_buffer"][-100:]
            for pattern in self.PROMPT_PATTERNS:
                if pattern.search(tail):
                    p_hint = tail[-20:].strip()
                    prompt_event = agent_pb2.SkillEvent(session_id=session_id, task_id=active_tid, prompt=p_hint)
                    on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event))
                    break

    def handle_transparent_tty(self, task, on_complete, on_event=None):
        """Processes raw TTY/Resize events synchronously."""
        cmd = task.payload_json
        session_id = task.session_id or "default-session"
        try:
            import json
            if cmd.startswith('{') and cmd.endswith('}'):
                raw_payload = json.loads(cmd)
                
                # 1. Raw Keystroke forward
                if isinstance(raw_payload, dict) and "tty" in raw_payload:
                    sess = self._ensure_session(session_id, None, on_event)
                    sess["backend"].write(raw_payload["tty"].encode("utf-8"))
                    on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
                    return True
                
                # 2. Window Resize
                if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
                    cols, rows = raw_payload.get("cols", 80), raw_payload.get("rows", 24)
                    sess = self._ensure_session(session_id, None, on_event)
                    sess["backend"].resize(cols, rows)
                    on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
                    return True
        except Exception as pe:
            print(f"    [๐Ÿš] Transparent TTY Fail: {pe}")
        return False

    def execute(self, task, sandbox, on_complete, on_event=None):
        """Dispatches command string to the terminal and waits for framed response."""
        session_id = task.session_id or "default-session"
        tid = task.task_id
        cmd = task.payload_json
        
        allowed, status_msg = sandbox.verify(cmd)
        if not allowed:
            err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
            if on_event:
                on_event(agent_pb2.ClientTaskMessage(skill_event=agent_pb2.SkillEvent(session_id=session_id, task_id=tid, terminal_out=err_msg)))
            return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

        cwd = self.sync_mgr.get_session_dir(task.session_id, create=True) if self.sync_mgr and task.session_id else None
        sess = self._ensure_session(session_id, cwd, on_event)
        
        with sess["write_lock"]:
            if cmd.startswith("!RAW:"):
                if not tid.startswith("task-"):
                    sess["backend"].write((cmd[5:] + "\n").encode("utf-8"))
                    return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id)
                cmd = cmd[5:]

            event, cancel_event = threading.Event(), threading.Event()
            result_container = {"stdout": "", "status": 0} 
            
            with self.lock:
                sess["active_task"] = tid
                sess["event"] = event
                sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False)
                sess["tail_buffer"] = ""
                sess["result"] = result_container
                sess["cancel_event"] = cancel_event

            try:
                full_input = self._build_framed_command(tid, cmd)
                sess["backend"].write(full_input.encode("utf-8"))

                timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
                start_time = time.time()
                while time.time() - start_time < timeout:
                    if event.is_set(): return on_complete(tid, result_container, task.trace_id)
                    if cancel_event.is_set(): return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id)
                    time.sleep(0.1)

                on_complete(tid, {"stdout": self._get_timeout_output(sess), "stderr": "TIMEOUT", "status": 2}, task.trace_id)
            finally:
                self._cleanup_task_state(sess, tid, event, cancel_event)

    def _build_framed_command(self, tid, cmd):
        """Constructs the shell command with protocol framing."""
        import platform
        if platform.system() == "Windows":
            # M7: Enhanced PowerShell stability - auto-inject NoProfile and NonInteractive
            ps_pattern = re.compile(r'^(powershell|pwsh)(\.exe)?', re.IGNORECASE)
            if ps_pattern.match(cmd):
                if "-noprofile" not in cmd.lower():
                    cmd = ps_pattern.sub(r'\1 -NoProfile -NonInteractive', cmd)
            
            spool_dir = os.path.join(tempfile.gettempdir(), "cortex_pty_tasks")
            os.makedirs(spool_dir, exist_ok=True)
            task_path = os.path.join(spool_dir, f"{tid}.bat")
            
            # Use a robust wrapper that ensures the TaskEnd fence is ALWAYS printed
            with open(task_path, "w", encoding="utf-8") as f:
                f.write(f"@echo off\r\n"
                        f"echo [[1337;TaskStart;id={tid}]]\r\n"
                        f"rem Execute command and capture exit code\r\n"
                        f"cmd /c \"{cmd}\"\r\n"
                        f"set __ctx_err=%errorlevel%\r\n"
                        f"echo [[1337;TaskEnd;id={tid};exit=%__ctx_err%]]\r\n"
                        f"exit /b %__ctx_err%\r\n")
            return f"call \"{task_path}\"\r\n"
        else:
            return f"echo -e -n \"\\033]1337;TaskStart;id={tid}\\007\"; {cmd}; __ctx_exit=$?; echo -e -n \"\\033]1337;TaskEnd;id={tid};exit=$__ctx_exit\\007\"\n"

    def _get_timeout_output(self, sess):
        """Extracts Head/Tail output from the buffer file upon task timeout."""
        try:
            with self.lock:
                if not sess.get("buffer_file"): return ""
                sess["buffer_file"].seek(0, 2)
                f_len = sess["buffer_file"].tell()
                HEAD, TAIL = 10_000, 30_000
                sess["buffer_file"].seek(0)
                if f_len > HEAD + TAIL:
                    head = sess["buffer_file"].read(HEAD)
                    sess["buffer_file"].seek(f_len - TAIL)
                    return head + f"\n\n[... {f_len - HEAD - TAIL:,} bytes omitted ...] \n\n" + sess["buffer_file"].read()
                return sess["buffer_file"].read()
        except: return ""

    def _cleanup_task_state(self, sess, tid, event, cancel_event):
        """Normalizes session state after task completion or error."""
        with self.lock:
            if sess.get("active_task") == tid:
                if sess.get("buffer_file"):
                    try: sess["buffer_file"].close()
                    except: pass
                    sess["buffer_file"] = None
                sess["active_task"] = None
                if sess.get("event") == event: sess["event"] = None
                if sess.get("cancel_event") == cancel_event: sess["cancel_event"] = None

    def cancel(self, task_id: str):
        """Cancels an active task โ€” for persistent shell, this sends a SIGINT (Ctrl+C)."""
        with self.lock:
            for sid, sess in self.sessions.items():
                if sess.get("active_task") == task_id:
                    sess["backend"].write(b"\x03")
                    if sess.get("cancel_event"): sess["cancel_event"].set()
        return True

    def shutdown(self):
        """Cleanup: Terminates all persistent shells via backends."""
        with self.lock:
            for sid, sess in list(self.sessions.items()):
                try: sess["backend"].kill()
                except: pass
            self.sessions.clear()