from agent_node.skills.base import BaseSkill
from agent_node.skills.terminal_backends import get_terminal_backend
from protos import agent_pb2
import re
import threading
import time
import tempfile
import os
from agent_node.core.regex_patterns import (
COMPILED_PROMPT_PATTERNS,
ANSI_ESCAPE,
EXIT_CODE_PATTERN,
ECHO_CLEANUP_ANSI,
ECHO_CLEANUP_BRACKET,
STRIP_START_FENCE,
STRIP_BRACKET_FENCE,
ECHO_START_PATTERN,
ECHO_END_PATTERN,
PROTOCOL_HINT_PATTERN
)
class ShellSkill(BaseSkill):
"""Admin Console Skill: Persistent stateful Shell via Abstract Terminal Backend."""
def __init__(self, sync_mgr=None):
self.sync_mgr = sync_mgr
self.sessions = {} # session_id -> {backend, thread, last_activity, ...}
self.lock = threading.Lock()
# Patterns moved to core/regex_patterns.py
self.PROMPT_PATTERNS = COMPILED_PROMPT_PATTERNS
# --- M7: Idle Session Reaper ---
self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper")
self.reaper_thread.start()
def _session_reaper(self):
"""Background thread that cleans up unused Shell sessions."""
while True:
time.sleep(60)
with self.lock:
now = time.time()
for sid, sess in list(self.sessions.items()):
if sess.get("active_task"):
continue
if now - sess.get("last_activity", 0) > 600:
print(f" [๐๐งน] Reaping idle shell session: {sid}")
try:
sess["backend"].kill()
except: pass
self.sessions.pop(sid, None)
def _ensure_session(self, session_id, cwd, on_event):
"""Retrieves or initializes a persistent terminal session."""
with self.lock:
if session_id in self.sessions:
self.sessions[session_id]["last_activity"] = time.time()
return self.sessions[session_id]
print(f" [๐] Initializing Persistent Shell Session: {session_id}")
backend = get_terminal_backend()
backend.spawn(cwd=cwd, env=os.environ.copy())
print(f" [๐] Terminal Spawned (PID Check: {backend.is_alive()})")
sess = {
"backend": backend,
"last_activity": time.time(),
"buffer_file": None,
"tail_buffer": "",
"active_task": None,
"write_lock": threading.Lock()
}
# Start refactored reader thread
t = threading.Thread(
target=self._reader_loop,
args=(session_id, on_event),
daemon=True,
name=f"ShellReader-{session_id}"
)
t.start()
sess["thread"] = t
self.sessions[session_id] = sess
return sess
def _reader_loop(self, session_id, on_event):
"""Internal method to handle terminal reading and protocol extraction."""
with self.lock:
sess = self.sessions.get(session_id)
if not sess: return
backend = sess["backend"]
while True:
try:
data = backend.read(4096)
if not data:
if not backend.is_alive(): break
time.sleep(0.05)
continue
decoded = data if isinstance(data, str) else data.decode("utf-8", errors="replace")
with self.lock:
active_tid = sess.get("active_task")
if active_tid and sess.get("buffer_file"):
self._process_protocol_fences(sess, active_tid, decoded)
# Stream and Edge Intelligence
if on_event:
self._handle_ui_streaming(sess, session_id, active_tid, decoded, on_event)
except (EOFError, OSError): break
except Exception as e:
print(f" [๐โ] Reader thread FATAL exception: {e}")
break
print(f" [๐] Shell Session Terminated: {session_id}")
with self.lock: self.sessions.pop(session_id, None)
def _process_protocol_fences(self, sess, active_tid, decoded):
"""Internal helper to handle OSC 1337 / Bracketed Task framing."""
start_fence = f"\x1b]1337;TaskStart;id={active_tid}\x07"
end_fence_prefix = f"\x1b]1337;TaskEnd;id={active_tid};exit="
bracket_start_fence = f"[[1337;TaskStart;id={active_tid}]]"
bracket_end_fence_prefix = f"[[1337;TaskEnd;id={active_tid};exit="
sess["buffer_file"].write(decoded)
sess["buffer_file"].flush()
sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-16384:]
clean_tail = ANSI_ESCAPE.sub('', sess["tail_buffer"])
if end_fence_prefix in clean_tail or bracket_end_fence_prefix in clean_tail:
try:
is_bracket = bracket_end_fence_prefix in clean_tail
active_end_prefix = bracket_end_fence_prefix if is_bracket else end_fence_prefix
active_start_fence = bracket_start_fence if is_bracket else start_fence
after_end = clean_tail.split(active_end_prefix)[1]
exit_match = EXIT_CODE_PATTERN.search(after_end)
exit_code = int(exit_match.group(1)) if exit_match else 0
bf = sess["buffer_file"]
bf.seek(0)
clean_full_raw = ANSI_ESCAPE.sub('', bf.read())
# Extract content between fences
if active_start_fence in clean_full_raw:
content = clean_full_raw.split(active_start_fence)[-1].split(active_end_prefix)[0]
else:
content = clean_full_raw.split(active_end_prefix)[0]
# Cleanup internal echo echo
content = ECHO_CLEANUP_ANSI.sub('', content)
content = ECHO_CLEANUP_BRACKET.sub('', content).strip()
sess["result"]["stdout"] = content
sess["result"]["status"] = 0 if exit_code == 0 else 1
sess["buffer_file"].close()
sess["buffer_file"] = None
if sess.get("event"): sess["event"].set()
except Exception as e:
print(f" [๐โ ๏ธ] Protocol parsing failed: {e}")
if sess.get("event"): sess["event"].set()
def _handle_ui_streaming(self, sess, session_id, active_tid, decoded, on_event):
"""Internal helper to filter plumbing and stream terminal output to the client."""
# Clean framing echoes from the live stream
decoded = ECHO_START_PATTERN.sub('', decoded)
decoded = ECHO_END_PATTERN.sub('', decoded)
decoded = STRIP_START_FENCE.sub('', decoded)
decoded = STRIP_BRACKET_FENCE.sub('', decoded)
# Line-Aware Stealthing for extra safety
lines = decoded.splitlines(keepends=True)
clean_lines = [line for line in lines if not PROTOCOL_HINT_PATTERN.search(ANSI_ESCAPE.sub('', line))]
stealth_out = "".join(clean_lines)
if stealth_out.strip():
with self.lock:
self._apply_stream_throttling(sess, session_id, stealth_out, on_event)
self._detect_edge_prompts(sess, session_id, active_tid, on_event)
def _apply_stream_throttling(self, sess, session_id, stealth_out, on_event):
"""Protects the bridge from output flooding."""
now = time.time()
if now - sess.get("stream_window_start", 0) > 1.0:
sess["stream_window_start"], sess["stream_bytes_sent"] = now, 0
if sess.get("stream_dropped_bytes", 0) > 0:
drop_msg = f"\n[... {sess['stream_dropped_bytes']:,} bytes truncated from live stream ...]\n"
event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg)
on_event(agent_pb2.ClientTaskMessage(skill_event=event))
sess["stream_dropped_bytes"] = 0
if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 100_000:
sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out)
else:
sess["stream_bytes_sent"] += len(stealth_out)
event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=stealth_out)
on_event(agent_pb2.ClientTaskMessage(skill_event=event))
def _detect_edge_prompts(self, sess, session_id, active_tid, on_event):
"""Signals prompt detection (e.g. login: or password:) back to the client."""
current_event = sess.get("event")
if active_tid and current_event and not current_event.is_set():
tail = sess["tail_buffer"][-100:]
for pattern in self.PROMPT_PATTERNS:
if pattern.search(tail):
p_hint = tail[-20:].strip()
prompt_event = agent_pb2.SkillEvent(session_id=session_id, task_id=active_tid, prompt=p_hint)
on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event))
break
def handle_transparent_tty(self, task, on_complete, on_event=None):
"""Processes raw TTY/Resize events synchronously."""
cmd = task.payload_json
session_id = task.session_id or "default-session"
try:
import json
if cmd.startswith('{') and cmd.endswith('}'):
raw_payload = json.loads(cmd)
# 1. Raw Keystroke forward
if isinstance(raw_payload, dict) and "tty" in raw_payload:
sess = self._ensure_session(session_id, None, on_event)
sess["backend"].write(raw_payload["tty"].encode("utf-8"))
on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id)
return True
# 2. Window Resize
if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize":
cols, rows = raw_payload.get("cols", 80), raw_payload.get("rows", 24)
sess = self._ensure_session(session_id, None, on_event)
sess["backend"].resize(cols, rows)
on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id)
return True
except Exception as pe:
print(f" [๐] Transparent TTY Fail: {pe}")
return False
def execute(self, task, sandbox, on_complete, on_event=None):
"""Dispatches command string to the terminal and waits for framed response."""
session_id = task.session_id or "default-session"
tid = task.task_id
cmd = task.payload_json
allowed, status_msg = sandbox.verify(cmd)
if not allowed:
err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n"
if on_event:
on_event(agent_pb2.ClientTaskMessage(skill_event=agent_pb2.SkillEvent(session_id=session_id, task_id=tid, terminal_out=err_msg)))
return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)
cwd = self.sync_mgr.get_session_dir(task.session_id, create=True) if self.sync_mgr and task.session_id else None
sess = self._ensure_session(session_id, cwd, on_event)
with sess["write_lock"]:
if cmd.startswith("!RAW:"):
if not tid.startswith("task-"):
sess["backend"].write((cmd[5:] + "\n").encode("utf-8"))
return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id)
cmd = cmd[5:]
event, cancel_event = threading.Event(), threading.Event()
result_container = {"stdout": "", "status": 0}
with self.lock:
sess["active_task"] = tid
sess["event"] = event
sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False)
sess["tail_buffer"] = ""
sess["result"] = result_container
sess["cancel_event"] = cancel_event
try:
full_input = self._build_framed_command(tid, cmd)
sess["backend"].write(full_input.encode("utf-8"))
timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0
start_time = time.time()
while time.time() - start_time < timeout:
if event.is_set(): return on_complete(tid, result_container, task.trace_id)
if cancel_event.is_set(): return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id)
time.sleep(0.1)
on_complete(tid, {"stdout": self._get_timeout_output(sess), "stderr": "TIMEOUT", "status": 2}, task.trace_id)
finally:
self._cleanup_task_state(sess, tid, event, cancel_event)
def _build_framed_command(self, tid, cmd):
"""Constructs the shell command with protocol framing."""
import platform
if platform.system() == "Windows":
# M7: Enhanced PowerShell stability - auto-inject NoProfile and NonInteractive
ps_pattern = re.compile(r'^(powershell|pwsh)(\.exe)?', re.IGNORECASE)
if ps_pattern.match(cmd):
if "-noprofile" not in cmd.lower():
cmd = ps_pattern.sub(r'\1 -NoProfile -NonInteractive', cmd)
spool_dir = os.path.join(tempfile.gettempdir(), "cortex_pty_tasks")
os.makedirs(spool_dir, exist_ok=True)
task_path = os.path.join(spool_dir, f"{tid}.bat")
# Use a robust wrapper that ensures the TaskEnd fence is ALWAYS printed
with open(task_path, "w", encoding="utf-8") as f:
f.write(f"@echo off\r\n"
f"echo [[1337;TaskStart;id={tid}]]\r\n"
f"rem Execute command and capture exit code\r\n"
f"cmd /c \"{cmd}\"\r\n"
f"set __ctx_err=%errorlevel%\r\n"
f"echo [[1337;TaskEnd;id={tid};exit=%__ctx_err%]]\r\n"
f"exit /b %__ctx_err%\r\n")
return f"call \"{task_path}\"\r\n"
else:
return f"echo -e -n \"\\033]1337;TaskStart;id={tid}\\007\"; {cmd}; __ctx_exit=$?; echo -e -n \"\\033]1337;TaskEnd;id={tid};exit=$__ctx_exit\\007\"\n"
def _get_timeout_output(self, sess):
"""Extracts Head/Tail output from the buffer file upon task timeout."""
try:
with self.lock:
if not sess.get("buffer_file"): return ""
sess["buffer_file"].seek(0, 2)
f_len = sess["buffer_file"].tell()
HEAD, TAIL = 10_000, 30_000
sess["buffer_file"].seek(0)
if f_len > HEAD + TAIL:
head = sess["buffer_file"].read(HEAD)
sess["buffer_file"].seek(f_len - TAIL)
return head + f"\n\n[... {f_len - HEAD - TAIL:,} bytes omitted ...] \n\n" + sess["buffer_file"].read()
return sess["buffer_file"].read()
except: return ""
def _cleanup_task_state(self, sess, tid, event, cancel_event):
"""Normalizes session state after task completion or error."""
with self.lock:
if sess.get("active_task") == tid:
if sess.get("buffer_file"):
try: sess["buffer_file"].close()
except: pass
sess["buffer_file"] = None
sess["active_task"] = None
if sess.get("event") == event: sess["event"] = None
if sess.get("cancel_event") == cancel_event: sess["cancel_event"] = None
def cancel(self, task_id: str):
"""Cancels an active task โ for persistent shell, this sends a SIGINT (Ctrl+C)."""
with self.lock:
for sid, sess in self.sessions.items():
if sess.get("active_task") == task_id:
sess["backend"].write(b"\x03")
if sess.get("cancel_event"): sess["cancel_event"].set()
return True
def shutdown(self):
"""Cleanup: Terminates all persistent shells via backends."""
with self.lock:
for sid, sess in list(self.sessions.items()):
try: sess["backend"].kill()
except: pass
self.sessions.clear()