diff --git a/bin/dmshell_client.py b/bin/dmshell_client.py new file mode 100644 index 00000000000..816a60b0937 --- /dev/null +++ b/bin/dmshell_client.py @@ -0,0 +1,937 @@ +#!/usr/bin/env python3 + +import argparse +import os +import queue +import random +import shutil +import socket +import select +import subprocess +import sys +import tempfile +import termios +import threading +import time +import tty +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional, TextIO + + +START1 = 0x94 +START2 = 0xC3 +HEADER_LEN = 4 +DEFAULT_API_PORT = 4403 +DEFAULT_HOP_LIMIT = 0 +LOCAL_ESCAPE_BYTE = b"\x1d" # Ctrl+] +MISSING_SEQ_RETRY_INTERVAL_SEC = 1.0 +INPUT_BATCH_WINDOW_SEC = .5 +INPUT_BATCH_MAX_BYTES = 64 +HEARTBEAT_IDLE_DELAY_SEC = 5.0 +HEARTBEAT_REPEAT_SEC = 15.0 +HEARTBEAT_POLL_INTERVAL_SEC = 0.25 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Tiny DMShell client for Meshtastic native TCP API", + epilog=( + "Examples:\n" + " bin/dmshell_client.py --to !170896f7\n" + " bin/dmshell_client.py --to 0x170896f7 --command 'uname -a' --command 'id'" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--host", default="127.0.0.1", help="meshtasticd API host") + parser.add_argument("--port", type=int, default=DEFAULT_API_PORT, help="meshtasticd API port") + parser.add_argument( + "--serial", + nargs="?", + const="auto", + default=None, + help="use USB serial transport (optionally provide device path, default: auto-detect)", + ) + parser.add_argument("--baud", type=int, default=115200, help="serial baud rate when using --serial") + parser.add_argument("--to", required=True, help="destination node number, e.g. !170896f7 or 0x170896f7") + parser.add_argument("--channel", type=int, default=0, help="channel index to use") + parser.add_argument("--cols", type=int, default=None, help="initial terminal columns (default: detect local terminal)") + parser.add_argument("--rows", type=int, default=None, help="initial terminal rows (default: detect local terminal)") + parser.add_argument("--command", action="append", default=[], help="send a command line after opening") + parser.add_argument("--close-after", type=float, default=2.0, help="seconds to wait before closing in command mode") + parser.add_argument("--timeout", type=float, default=10.0, help="seconds to wait for API/session events") + parser.add_argument("--verbose", action="store_true", help="print extra protocol events") + return parser.parse_args() + + +def repo_root() -> Path: + return Path(__file__).resolve().parent.parent + + +def load_proto_modules() -> object: + try: + import google.protobuf # noqa: F401 + except ImportError as exc: + raise SystemExit("python package 'protobuf' is required to run this client") from exc + + protoc = shutil.which("protoc") + if not protoc: + raise SystemExit("'protoc' is required to generate temporary Python protobuf bindings") + + out_dir = Path(tempfile.mkdtemp(prefix="meshtastic_dmshell_proto_")) + proto_dir = repo_root() / "protobufs" + + # Compile all required protos for DMShell client (mesh and dependencies) + # Excludes nanopb.proto and other complex build artifacts + required_protos = [ + "mesh.proto", + "channel.proto", + "config.proto", + "device_ui.proto", + "module_config.proto", + "atak.proto", + "portnums.proto", + "telemetry.proto", + "xmodem.proto", + ] + proto_files = [proto_dir / "meshtastic" / name for name in required_protos] + for pf in proto_files: + if not pf.exists(): + raise SystemExit(f"could not find required proto file: {pf}") + + # Create __init__.py to make meshtastic a package + (out_dir / "meshtastic").mkdir(exist_ok=True) + (out_dir / "meshtastic" / "__init__.py").touch() + + # Build protoc command with just the meshtastic proto directory as include path + # protoc will use its built-in includes for standard google protobuf types + cmd = [protoc, f"-I{proto_dir}", f"--python_out={out_dir}", *[str(path) for path in proto_files]] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print(f"protoc stderr: {result.stderr}", file=sys.stderr) + print(f"protoc stdout: {result.stdout}", file=sys.stderr) + print(f"protoc command: {' '.join(cmd)}", file=sys.stderr) + raise SystemExit(f"protoc failed with return code {result.returncode}") + + # Create _pb2_grpc module stub if not present (protoc 3.20+) + mesh_pb2_file = out_dir / "meshtastic" / "mesh_pb2.py" + if not mesh_pb2_file.exists(): + raise SystemExit(f"protoc did not generate mesh_pb2.py in {out_dir / 'meshtastic'}") + + sys.path.insert(0, str(out_dir)) + try: + from meshtastic import mesh_pb2, portnums_pb2 # type: ignore + except ImportError as exc: + print(f"Failed to import protobuf modules. Output dir contents:", file=sys.stderr) + for item in (out_dir / "meshtastic").iterdir(): + print(f" {item.name}", file=sys.stderr) + raise SystemExit(f"could not import meshtastic proto modules: {exc}") from exc + + # Return an object that has both modules accessible + class ProtoModules: + pass + + pb2 = ProtoModules() + pb2.mesh = mesh_pb2 + pb2.portnums = portnums_pb2 + return pb2 + + +def parse_node_num(raw: str) -> int: + value = raw.strip() + if value.startswith("!"): + value = value[1:] + if value.lower().startswith("0x"): + return int(value, 16) + if any(ch in "abcdefABCDEF" for ch in value): + return int(value, 16) + return int(value, 10) + + +class SerialTransport: + def __init__(self, serial_obj): + self._serial = serial_obj + + def recv(self, length: int) -> bytes: + return self._serial.read(length) + + def sendall(self, data: bytes) -> None: + self._serial.write(data) + self._serial.flush() + + def close(self) -> None: + self._serial.close() + + +def detect_meshtastic_serial_port() -> str: + try: + from serial.tools import list_ports + except ImportError as exc: + raise SystemExit("python package 'pyserial' is required for --serial mode") from exc + + ports = list(list_ports.comports()) + if not ports: + raise SystemExit("no serial ports found for --serial mode") + + scored: list[tuple[int, str]] = [] + for port in ports: + text = " ".join( + filter( + None, + [port.device, port.description, port.manufacturer, port.product, port.hwid], + ) + ).lower() + score = 0 + if "meshtastic" in text: + score += 100 + if "lora" in text or "mesh" in text: + score += 10 + if "ttyacm" in (port.device or "").lower() or "ttyusb" in (port.device or "").lower(): + score += 1 + scored.append((score, port.device)) + + scored.sort(reverse=True) + best_score, best_device = scored[0] + if best_score <= 0 and len(scored) > 1: + raise SystemExit( + "could not confidently auto-detect a Meshtastic serial port; pass --serial /dev/ttyXXX explicitly" + ) + return best_device + + +def open_transport(args: argparse.Namespace): + if args.serial is None: + sock = socket.create_connection((args.host, args.port), timeout=args.timeout) + sock.settimeout(None) + return sock + + serial_path = args.serial + if serial_path == "auto": + serial_path = detect_meshtastic_serial_port() + print(f"[dmshell] using serial port {serial_path}", file=sys.stderr) + + try: + import serial + except ImportError as exc: + raise SystemExit("python package 'pyserial' is required for --serial mode") from exc + + try: + serial_obj = serial.Serial(serial_path, baudrate=args.baud, timeout=None, write_timeout=2) + except Exception as exc: + raise SystemExit(f"failed to open serial device {serial_path}: {exc}") from exc + + return SerialTransport(serial_obj) + + +def recv_exact(transport, length: int) -> bytes: + chunks = bytearray() + while len(chunks) < length: + piece = transport.recv(length - len(chunks)) + if not piece: + raise ConnectionError("connection closed by transport") + chunks.extend(piece) + return bytes(chunks) + + +def detect_local_terminal_size() -> tuple[int, int]: + size = shutil.get_terminal_size(fallback=(100, 40)) + cols = max(1, int(size.columns)) + rows = max(1, int(size.lines)) + return cols, rows + + +def resolve_initial_terminal_size(cols_override: Optional[int], rows_override: Optional[int]) -> tuple[int, int]: + detected_cols, detected_rows = detect_local_terminal_size() + cols = detected_cols if cols_override is None else max(1, cols_override) + rows = detected_rows if rows_override is None else max(1, rows_override) + return cols, rows + + +def recv_stream_frame(transport) -> bytes: + while True: + start = recv_exact(transport, 1)[0] + if start != START1: + continue + if recv_exact(transport, 1)[0] != START2: + continue + header = recv_exact(transport, 2) + length = (header[0] << 8) | header[1] + return recv_exact(transport, length) + + +def send_stream_frame(transport, payload: bytes) -> None: + if len(payload) > 0xFFFF: + raise ValueError("payload too large for stream API") + header = bytes((START1, START2, (len(payload) >> 8) & 0xFF, len(payload) & 0xFF)) + transport.sendall(header + payload) + + +@dataclass +class SentShellFrame: + op: int + session_id: int + seq: int + ack_seq: int + payload: bytes = b"" + cols: int = 0 + rows: int = 0 + flags: int = 0 + last_tx_seq: int = 0 + last_rx_seq: int = 0 + + +@dataclass +class SessionState: + pb2: object # ProtoModules with mesh and portnums attributes + target: int + channel: int + verbose: bool + session_id: int = field(default_factory=lambda: random.randint(1, 0x7FFFFFFF)) + next_seq: int = 1 + last_rx_seq: int = 0 + next_expected_rx_seq: int = 1 + highest_seen_rx_seq: int = 0 + active: bool = False + stopped: bool = False + opened_event: threading.Event = field(default_factory=threading.Event) + closed_event: threading.Event = field(default_factory=threading.Event) + event_queue: "queue.Queue[str]" = field(default_factory=queue.Queue) + tx_lock: threading.Lock = field(default_factory=threading.Lock) + socket_lock: threading.Lock = field(default_factory=threading.Lock) + tx_history: deque[SentShellFrame] = field(default_factory=lambda: deque(maxlen=50)) + pending_rx_frames: dict[int, object] = field(default_factory=dict) + last_requested_missing_seq: int = 0 + last_missing_request_time: float = 0.0 + requested_missing_seqs: set[int] = field(default_factory=set) + replay_log_lock: threading.Lock = field(default_factory=threading.Lock) + replay_log_file: Optional[TextIO] = None + replay_log_path: Optional[Path] = None + last_transport_activity_time: float = field(default_factory=time.monotonic) + last_heartbeat_sent_time: float = 0.0 + + def alloc_seq(self) -> int: + with self.tx_lock: + value = self.next_seq + self.next_seq += 1 + return value + + def current_ack_seq(self) -> int: + with self.tx_lock: + return self.last_rx_seq + + def highest_sent_seq(self) -> int: + with self.tx_lock: + return max(0, self.next_seq - 1) + + def note_outbound_packet(self, heartbeat: bool = False) -> None: + with self.tx_lock: + now = time.monotonic() + if heartbeat: + self.last_heartbeat_sent_time = now + else: + self.last_transport_activity_time = now + + def note_inbound_packet(self) -> None: + with self.tx_lock: + self.last_transport_activity_time = time.monotonic() + + def heartbeat_due(self) -> bool: + with self.tx_lock: + now = time.monotonic() + if (now - self.last_transport_activity_time) < HEARTBEAT_IDLE_DELAY_SEC: + return False + if self.last_heartbeat_sent_time <= self.last_transport_activity_time: + return True + return (now - self.last_heartbeat_sent_time) >= HEARTBEAT_REPEAT_SEC + + def note_peer_reported_tx_seq(self, seq: int) -> None: + with self.tx_lock: + if seq > self.highest_seen_rx_seq: + self.highest_seen_rx_seq = seq + + def note_received_seq(self, seq: int) -> tuple[str, Optional[int]]: + with self.tx_lock: + if seq == 0: + return ("process", None) + if seq < self.next_expected_rx_seq: + if self.highest_seen_rx_seq >= self.next_expected_rx_seq: + return ("gap", self.next_expected_rx_seq) + return ("duplicate", None) + if seq > self.next_expected_rx_seq: + if seq > self.highest_seen_rx_seq: + self.highest_seen_rx_seq = seq + return ("gap", self.next_expected_rx_seq) + self.last_rx_seq = seq + self.next_expected_rx_seq = seq + 1 + if self.last_requested_missing_seq != 0 and self.next_expected_rx_seq > self.last_requested_missing_seq: + self.last_requested_missing_seq = 0 + if seq > self.highest_seen_rx_seq: + self.highest_seen_rx_seq = seq + if self.highest_seen_rx_seq < self.next_expected_rx_seq: + self.highest_seen_rx_seq = 0 + return ("process", None) + + def remember_out_of_order_frame(self, shell) -> None: + with self.tx_lock: + if shell.seq <= self.next_expected_rx_seq: + return + if shell.seq not in self.pending_rx_frames: + self.pending_rx_frames[shell.seq] = shell + if shell.seq > self.highest_seen_rx_seq: + self.highest_seen_rx_seq = shell.seq + + def pop_next_buffered_frame(self): + with self.tx_lock: + return self.pending_rx_frames.pop(self.next_expected_rx_seq, None) + + def pending_missing_seq(self) -> Optional[int]: + with self.tx_lock: + if self.highest_seen_rx_seq >= self.next_expected_rx_seq: + return self.next_expected_rx_seq + return None + + def request_missing_seq_once(self) -> Optional[int]: + with self.tx_lock: + if self.highest_seen_rx_seq < self.next_expected_rx_seq: + return None + now = time.monotonic() + if ( + self.last_requested_missing_seq == self.next_expected_rx_seq + and (now - self.last_missing_request_time) < MISSING_SEQ_RETRY_INTERVAL_SEC + ): + return None + self.last_requested_missing_seq = self.next_expected_rx_seq + self.last_missing_request_time = now + return self.last_requested_missing_seq + + def set_receive_cursor(self, seq: int) -> None: + with self.tx_lock: + self.last_rx_seq = seq + self.next_expected_rx_seq = seq + 1 + self.highest_seen_rx_seq = seq + + def open_replay_log(self, session_id: int) -> None: + with self.replay_log_lock: + if self.replay_log_file is not None: + return + path = Path.cwd() / f"{session_id:08x}.log" + self.replay_log_file = path.open("a", encoding="utf-8") + self.replay_log_path = path + self.replay_log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} session_open session=0x{session_id:08x}\n") + self.replay_log_file.flush() + + def log_replay_event(self, event: str, seq: int, detail: str = "") -> None: + with self.replay_log_lock: + if self.replay_log_file is None: + return + extra = f" {detail}" if detail else "" + self.replay_log_file.write( + f"{time.strftime('%Y-%m-%d %H:%M:%S')} {event} seq={seq}{extra}\n" + ) + self.replay_log_file.flush() + + def note_missing_seq_requested(self, seq: int, reason: str) -> None: + with self.tx_lock: + self.requested_missing_seqs.add(seq) + self.log_replay_event("missing_requested", seq, f"reason={reason}") + + def note_replayed_seq_received(self, seq: int) -> None: + with self.tx_lock: + was_requested = seq in self.requested_missing_seqs + if was_requested: + self.requested_missing_seqs.remove(seq) + if was_requested: + self.log_replay_event("replay_received", seq) + + def close_replay_log(self) -> None: + with self.replay_log_lock: + if self.replay_log_file is None: + return + self.replay_log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} session_close\n") + self.replay_log_file.flush() + self.replay_log_file.close() + self.replay_log_file = None + + def remember_sent_frame(self, frame: SentShellFrame) -> None: + if frame.seq == 0 or frame.op == self.pb2.mesh.RemoteShell.ACK: + return + with self.tx_lock: + self.tx_history.append(frame) + + def prune_sent_frames(self, ack_seq: int) -> None: + if ack_seq <= 0: + return + with self.tx_lock: + self.tx_history = deque((frame for frame in self.tx_history if frame.seq > ack_seq), maxlen=50) + + def replay_frames_from(self, start_seq: int) -> list[SentShellFrame]: + with self.tx_lock: + return [frame for frame in self.tx_history if frame.seq >= start_seq] + + +def send_toradio(transport, toradio) -> None: + send_stream_frame(transport, toradio.SerializeToString()) + + +def make_toradio_packet(pb2, state: SessionState, shell_msg) -> object: + packet = pb2.mesh.MeshPacket() + packet.id = random.randint(1, 0x7FFFFFFF) + packet.to = state.target + # The 'from' field is a reserved keyword in Python, so use setattr + setattr(packet, "from", 0) + packet.channel = state.channel + packet.hop_limit = DEFAULT_HOP_LIMIT + packet.want_ack = False + packet.decoded.portnum = pb2.portnums.REMOTE_SHELL_APP + packet.decoded.payload = shell_msg.SerializeToString() + packet.decoded.want_response = False + packet.decoded.dest = state.target + packet.decoded.source = 0 + + toradio = pb2.mesh.ToRadio() + toradio.packet.CopyFrom(packet) + return toradio + + +def send_shell_frame( + transport, + state: SessionState, + op: int, + payload: bytes = b"", + cols: int = 0, + rows: int = 0, + session_id: Optional[int] = None, + ack_seq: Optional[int] = None, + seq: Optional[int] = None, + flags: int = 0, + last_tx_seq: int = 0, + last_rx_seq: int = 0, + remember: bool = True, + heartbeat: bool = False, +) -> int: + if seq is None: + seq = 0 if op == state.pb2.mesh.RemoteShell.ACK else state.alloc_seq() + if ack_seq is None: + ack_seq = state.current_ack_seq() + if session_id is None: + session_id = state.session_id + + shell = state.pb2.mesh.RemoteShell() + shell.op = op + shell.session_id = session_id + shell.seq = seq + shell.ack_seq = ack_seq + shell.cols = cols + shell.rows = rows + shell.flags = flags + shell.last_tx_seq = last_tx_seq + shell.last_rx_seq = last_rx_seq + if payload: + shell.payload = payload + with state.socket_lock: + send_toradio(transport, make_toradio_packet(state.pb2, state, shell)) + if remember: + state.remember_sent_frame( + SentShellFrame( + op=op, + session_id=session_id, + seq=seq, + ack_seq=ack_seq, + payload=payload, + cols=cols, + rows=rows, + flags=flags, + last_tx_seq=last_tx_seq, + last_rx_seq=last_rx_seq, + ) + ) + state.note_outbound_packet(heartbeat=heartbeat) + return seq + + +def send_ack_frame(transport, state: SessionState, replay_from: Optional[int] = None) -> None: + send_shell_frame( + transport, + state, + state.pb2.mesh.RemoteShell.ACK, + seq=0, + last_rx_seq=0 if replay_from is None else replay_from - 1, + remember=False, + ) + + +def replay_frames_from(transport, state: SessionState, start_seq: int) -> None: + frame = next((f for f in state.replay_frames_from(start_seq) if f.seq == start_seq), None) + if frame is None: + #state.event_queue.put(f"replay unavailable from seq={start_seq}") + state.log_replay_event("replay_unavailable", start_seq) + return + state.log_replay_event("replay_sent", start_seq) + #state.event_queue.put(f"replay frame seq={start_seq}") + send_shell_frame( + transport, + state, + frame.op, + payload=frame.payload, + cols=frame.cols, + rows=frame.rows, + session_id=frame.session_id, + ack_seq=frame.ack_seq, + seq=frame.seq, + flags=frame.flags, + last_tx_seq=frame.last_tx_seq, + last_rx_seq=frame.last_rx_seq, + remember=False, + ) + + +def wait_for_config_complete(transport, pb2, timeout: float, verbose: bool) -> None: + nonce = random.randint(1, 0x7FFFFFFF) + toradio = pb2.mesh.ToRadio() + toradio.want_config_id = nonce + send_toradio(transport, toradio) + + deadline = time.time() + timeout + while time.time() < deadline: + fromradio = pb2.mesh.FromRadio() + fromradio.ParseFromString(recv_stream_frame(transport)) + variant = fromradio.WhichOneof("payload_variant") + if verbose and variant: + print(f"[api] fromradio {variant}", file=sys.stderr) + if variant == "config_complete_id" and fromradio.config_complete_id == nonce: + return + raise TimeoutError("timed out waiting for config handshake to complete") + + +def decode_shell_packet(state: SessionState, packet) -> Optional[object]: + if packet.WhichOneof("payload_variant") != "decoded": + return None + if packet.decoded.portnum != state.pb2.portnums.REMOTE_SHELL_APP: + return None + shell = state.pb2.mesh.RemoteShell() + shell.ParseFromString(packet.decoded.payload) + return shell + + +def reader_loop(transport, state: SessionState) -> None: + def handle_in_order_shell(shell) -> bool: + state.note_replayed_seq_received(shell.seq) + if shell.op == state.pb2.mesh.RemoteShell.OPEN_OK: + state.session_id = shell.session_id + state.open_replay_log(state.session_id) + state.set_receive_cursor(shell.seq) + state.active = True + state.opened_event.set() + state.event_queue.put( + f"opened session=0x{shell.session_id:08x} cols={shell.cols} rows={shell.rows}" + ) + if state.replay_log_path is not None: + state.event_queue.put(f"replay log: {state.replay_log_path}") + elif shell.op == state.pb2.mesh.RemoteShell.OUTPUT: + if shell.payload: + sys.stdout.buffer.write(shell.payload) + sys.stdout.buffer.flush() + elif shell.op == state.pb2.mesh.RemoteShell.ERROR: + message = shell.payload.decode("utf-8", errors="replace") + state.event_queue.put(f"remote error: {message}") + elif shell.op == state.pb2.mesh.RemoteShell.CLOSED: + message = shell.payload.decode("utf-8", errors="replace") + state.event_queue.put(f"session closed: {message}") + state.closed_event.set() + state.active = False + return True + elif shell.op == state.pb2.mesh.RemoteShell.PONG: + remote_last_tx_seq = shell.last_tx_seq + remote_last_rx_seq = shell.last_rx_seq + local_latest_tx_seq = state.highest_sent_seq() + if remote_last_rx_seq < local_latest_tx_seq: + replay_frames_from(transport, state, remote_last_rx_seq + 1) + if remote_last_tx_seq > state.current_ack_seq(): + state.note_peer_reported_tx_seq(remote_last_tx_seq) + req = state.request_missing_seq_once() + if req is not None: + state.note_missing_seq_requested(req, "heartbeat_status") + send_ack_frame(transport, state, replay_from=req) + #state.event_queue.put("pong") + return False + + while not state.stopped: + try: + fromradio = state.pb2.mesh.FromRadio() + fromradio.ParseFromString(recv_stream_frame(transport)) + except Exception as exc: + if not state.stopped: + state.event_queue.put(f"connection error: {exc}") + state.closed_event.set() + return + + variant = fromradio.WhichOneof("payload_variant") + if variant == "packet": + shell = decode_shell_packet(state, fromradio.packet) + if not shell: + continue + state.note_inbound_packet() + #state.prune_sent_frames(shell.ack_seq) + if shell.op == state.pb2.mesh.RemoteShell.ACK: + #state.event_queue.put("peer requested replay") + replay_from = shell.last_rx_seq + 1 if shell.last_rx_seq > 0 else None + if replay_from is not None: + #state.event_queue.put(f"peer requested replay from seq={replay_from}") + replay_frames_from(transport, state, replay_from) + continue + + action, missing_from = state.note_received_seq(shell.seq) + if action == "duplicate": + req = state.request_missing_seq_once() + if req is not None: + state.note_missing_seq_requested(req, "duplicate") + send_ack_frame(transport, state, replay_from=req) + continue + if action == "gap": + state.remember_out_of_order_frame(shell) + req = state.request_missing_seq_once() + if req is not None: + state.note_missing_seq_requested(req, "gap") + send_ack_frame(transport, state, replay_from=req) + continue + + if handle_in_order_shell(shell): + return + + while True: + buffered_shell = state.pop_next_buffered_frame() + if buffered_shell is None: + break + buffered_action, _ = state.note_received_seq(buffered_shell.seq) + if buffered_action != "process": + state.remember_out_of_order_frame(buffered_shell) + break + if handle_in_order_shell(buffered_shell): + return + + req = state.request_missing_seq_once() + if req is not None: + state.note_missing_seq_requested(req, "post_process_gap") + send_ack_frame(transport, state, replay_from=req) + elif state.verbose and variant: + state.event_queue.put(f"fromradio {variant}") + + +def drain_events(state: SessionState) -> None: + while True: + try: + event = state.event_queue.get_nowait() + except queue.Empty: + return + print(f"[dmshell] {event}", file=sys.stderr) + + +def heartbeat_loop(transport, state: SessionState) -> None: + while not state.stopped and not state.closed_event.is_set(): + if not state.active: + time.sleep(HEARTBEAT_POLL_INTERVAL_SEC) + continue + if state.heartbeat_due(): + try: + send_shell_frame( + transport, + state, + state.pb2.mesh.RemoteShell.PING, + last_tx_seq=state.highest_sent_seq(), + last_rx_seq=state.current_ack_seq(), + remember=True, + heartbeat=True, + ) + except Exception as exc: + if not state.stopped: + state.event_queue.put(f"heartbeat error: {exc}") + state.closed_event.set() + return + time.sleep(HEARTBEAT_POLL_INTERVAL_SEC) + + +def run_command_mode(transport, state: SessionState, commands: list[str], close_after: float) -> None: + for command in commands: + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.INPUT, (command + "\n").encode("utf-8")) + time.sleep(close_after) + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.CLOSE) + state.closed_event.wait(timeout=close_after + 5.0) + + +def run_interactive_mode(transport, state: SessionState) -> None: + def read_local_command() -> str: + prompt = "\r\n[dmshell] local command (resume|close|ping|resize C R): " + sys.stderr.write(prompt) + sys.stderr.flush() + buf = bytearray() + + while True: + ch = os.read(sys.stdin.fileno(), 1) + if not ch: + sys.stderr.write("\r\n") + sys.stderr.flush() + return "close" + + b = ch[0] + if b in (10, 13): + sys.stderr.write("\r\n") + sys.stderr.flush() + return buf.decode("utf-8", errors="replace").strip() + + if b in (8, 127): + if buf: + buf.pop() + sys.stderr.write("\b \b") + sys.stderr.flush() + continue + + if b < 32: + continue + + buf.append(b) + sys.stderr.write(chr(b)) + sys.stderr.flush() + + def handle_local_command(cmd: str) -> bool: + if cmd in ("", "resume"): + return True + if cmd == "close": + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.CLOSE) + return False + if cmd == "ping": + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.PING) + return True + if cmd.startswith("resize "): + parts = cmd.split() + if len(parts) != 3: + state.event_queue.put("usage: resize COLS ROWS") + return True + try: + cols = int(parts[1]) + rows = int(parts[2]) + except ValueError: + state.event_queue.put("usage: resize COLS ROWS") + return True + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.RESIZE, cols=cols, rows=rows) + return True + + state.event_queue.put(f"unknown local command: {cmd}") + return True + + print( + "Raw input mode active. All keys (including Ctrl+C/Ctrl+X) are sent to remote. Ctrl+] for local commands.", + file=sys.stderr, + ) + + if not sys.stdin.isatty(): + # Fallback for non-TTY stdin: still send input as it arrives. + while not state.closed_event.is_set(): + drain_events(state) + data = sys.stdin.buffer.read(INPUT_BATCH_MAX_BYTES) + if not data: + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.CLOSE) + break + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.INPUT, data) + return + + fd = sys.stdin.fileno() + old_attrs = termios.tcgetattr(fd) + try: + tty.setraw(fd) + while not state.closed_event.is_set(): + drain_events(state) + ready, _, _ = select.select([sys.stdin], [], [], 0.05) + if not ready: + continue + + data = os.read(fd, 1) + if not data: + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.CLOSE) + break + + if data == LOCAL_ESCAPE_BYTE: + keep_running = handle_local_command(read_local_command()) + if not keep_running: + break + continue + + # Coalesce a short burst of bytes to reduce packet overhead for fast typing. + batched = bytearray(data) + enter_local_command = False + deadline = time.monotonic() + INPUT_BATCH_WINDOW_SEC + while len(batched) < INPUT_BATCH_MAX_BYTES: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + more_ready, _, _ = select.select([sys.stdin], [], [], remaining) + if not more_ready: + break + next_byte = os.read(fd, 1) + if not next_byte: + break + if next_byte == LOCAL_ESCAPE_BYTE: + enter_local_command = True + break + batched.extend(next_byte) + if next_byte == b'\r' or next_byte == b'\t': + break + deadline = time.monotonic() + INPUT_BATCH_WINDOW_SEC + + if batched: + send_shell_frame(transport, state, state.pb2.mesh.RemoteShell.INPUT, bytes(batched)) + + if enter_local_command: + keep_running = handle_local_command(read_local_command()) + if not keep_running: + break + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_attrs) + + +def main() -> int: + args = parse_args() + pb2 = load_proto_modules() + + state = SessionState( + pb2=pb2, + target=parse_node_num(args.to), + channel=args.channel, + verbose=args.verbose, + ) + + cols, rows = resolve_initial_terminal_size(args.cols, args.rows) + + transport = open_transport(args) + try: + wait_for_config_complete(transport, pb2, args.timeout, args.verbose) + + reader = threading.Thread(target=reader_loop, args=(transport, state), daemon=True) + reader.start() + + send_shell_frame(transport, state, pb2.mesh.RemoteShell.OPEN, cols=cols, rows=rows) + if not state.opened_event.wait(timeout=args.timeout): + raise SystemExit("timed out waiting for OPEN_OK from remote DMShell") + + heartbeat = threading.Thread(target=heartbeat_loop, args=(transport, state), daemon=True) + heartbeat.start() + + drain_events(state) + if args.command: + run_command_mode(transport, state, args.command, args.close_after) + else: + run_interactive_mode(transport, state) + + state.stopped = True + drain_events(state) + reader.join(timeout=1.0) + heartbeat.join(timeout=1.0) + state.close_replay_log() + finally: + transport.close() + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/src/mesh/RadioLibInterface.h b/src/mesh/RadioLibInterface.h index 310ca76bb24..fb5fe312f68 100644 --- a/src/mesh/RadioLibInterface.h +++ b/src/mesh/RadioLibInterface.h @@ -172,6 +172,8 @@ class RadioLibInterface : public RadioInterface, protected concurrency::Notified /** Attempt to find a packet in the TxQueue. Returns true if the packet was found. */ virtual bool findInTxQueue(NodeNum from, PacketId id) override; + uint8_t packetsInTxQueue() { return txQueue.getMaxLen() - txQueue.getFree(); } + /** * Request randomness sourced from the LoRa modem, if supported by the active RadioLib interface. * @return true if len bytes were produced, false otherwise. diff --git a/src/modules/DMShell.cpp b/src/modules/DMShell.cpp new file mode 100644 index 00000000000..b2cf6c0d933 --- /dev/null +++ b/src/modules/DMShell.cpp @@ -0,0 +1,602 @@ +#include "DMShell.h" + +#if defined(ARCH_PORTDUINO) + +#include "Channels.h" +#include "MeshService.h" +#include "NodeDB.h" +#include "Throttle.h" +#include "configuration.h" +#include "mesh/generated/meshtastic/mesh.pb.h" +#include "mesh/mesh-pb-constants.h" +#include "pb_decode.h" +#include "pb_encode.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DMShellModule *dmShellModule; + +namespace +{ +constexpr uint16_t PTY_COLS_DEFAULT = 120; +constexpr uint16_t PTY_ROWS_DEFAULT = 40; +constexpr size_t MAX_MESSAGE_SIZE = 200; +} // namespace + +DMShellModule::DMShellModule() + : SinglePortModule("DMShellModule", meshtastic_PortNum_REMOTE_SHELL_APP), concurrency::OSThread("DMShell", 100) +{ + LOG_WARN("DMShell enabled on Portduino: remote shell access is dangerous and intended for trusted debugging only"); +} + +ProcessMessage DMShellModule::handleReceived(const meshtastic_MeshPacket &mp) +{ + meshtastic_RemoteShell frame = meshtastic_RemoteShell_init_zero; + if (!mp.pki_encrypted) { + LOG_WARN("DMShell: ignoring packet without PKI from 0x%x", mp.from); + return ProcessMessage::STOP; + } + + if (!parseFrame(mp, frame)) { + LOG_WARN("DMShell: ignoring malformed frame"); + return ProcessMessage::STOP; + } + + if (frame.op == meshtastic_RemoteShell_OpCode_ACK) { + if (session.active && frame.session_id == session.sessionId && getFrom(&mp) == session.peer && frame.last_rx_seq > 0) { + resendFramesFrom(frame.last_rx_seq + 1); + } + return ProcessMessage::CONTINUE; + } + + if (frame.op >= 64) { + LOG_WARN("DMShell: ignoring frame with op code %d, seq %d", frame.op, frame.seq); + return ProcessMessage::CONTINUE; + } + + if (!isAuthorizedPacket(mp)) { + LOG_WARN("DMShell: unauthorized sender 0x%x, %u", mp.from, frame.op); + myReply = allocErrorResponse(meshtastic_Routing_Error_NOT_AUTHORIZED, &mp); + return ProcessMessage::STOP; + } + + if (frame.op == meshtastic_RemoteShell_OpCode_OPEN) { + LOG_WARN("DMShell: received OPEN from 0x%x sessionId=0x%x", mp.from, frame.session_id); + if (!openSession(mp, frame)) { + sendError("open_failed", getFrom(&mp)); + } + return ProcessMessage::STOP; + } + + if (!session.active || frame.session_id != session.sessionId || getFrom(&mp) != session.peer) { + if (!session.active) { + LOG_WARN("DMShell: no active session, rejecting op %d from 0x%x", frame.op, mp.from); + } else { + LOG_WARN("DMShell: session ID mismatch (got 0x%x expected 0x%x) or peer mismatch (got 0x%x expected 0x%x), rejecting " + "op %d", + frame.session_id, session.sessionId, mp.from, session.peer, frame.op); + } + sendError("invalid_session", getFrom(&mp)); + return ProcessMessage::STOP; + } + + if (!shouldProcessIncomingFrame(frame)) { + return ProcessMessage::STOP; + } + + session.lastActivityMs = millis(); + + switch (frame.op) { + case meshtastic_RemoteShell_OpCode_INPUT: + if (!writeSessionInput(frame)) { + sendError("input_write_failed"); + } else { + uint8_t outBuf[MAX_MESSAGE_SIZE]; + const ssize_t bytesRead = read(session.masterFd, outBuf, sizeof(outBuf)); + if (bytesRead > 0) { + LOG_WARN("DMShell: read %zd bytes from PTY", bytesRead); + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_OUTPUT, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + }; + assert(bytesRead <= sizeof(frame.payload.bytes)); + memcpy(frame.payload.bytes, outBuf, bytesRead); + frame.payload.size = bytesRead; + sendFrameToPeer(session.peer, frame, true); + session.lastActivityMs = millis(); + } + } + break; + case meshtastic_RemoteShell_OpCode_RESIZE: + if (frame.rows > 0 && frame.cols > 0) { + struct winsize ws = {}; + ws.ws_row = frame.rows; + ws.ws_col = frame.cols; + if (session.masterFd >= 0) { + ioctl(session.masterFd, TIOCSWINSZ, &ws); + } + } + break; + case meshtastic_RemoteShell_OpCode_PING: { + uint32_t peerLastRxSeq = frame.ack_seq; + if (frame.last_rx_seq > 0) { + peerLastRxSeq = frame.last_rx_seq; + } + + const uint32_t nextMissingForPeer = peerLastRxSeq + 1; + if (nextMissingForPeer > 0 && nextMissingForPeer < session.nextTxSeq) { + resendFramesFrom(nextMissingForPeer); + } + + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_PONG, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + .last_tx_seq = session.nextTxSeq > 0 ? session.nextTxSeq - 1 : 0, + .last_rx_seq = session.lastAckedRxSeq, + }; + frame.payload.size = 0; + sendFrameToPeer(session.peer, frame, true); + break; + } + case meshtastic_RemoteShell_OpCode_CLOSE: + closeSession("peer_close", true); + break; + default: + sendError("unsupported_op"); + break; + } + + return ProcessMessage::STOP; +} + +int32_t DMShellModule::runOnce() +{ + processPendingChildReap(); + + if (!session.active) { + return 100; + } + + reapChildIfExited(); + if (!session.active) { + return 100; + } + + if (Throttle::isWithinTimespanMs(session.lastActivityMs, SESSION_IDLE_TIMEOUT_MS) == false) { + closeSession("idle_timeout", true); + return 100; + } + + if (RadioLibInterface::instance->packetsInTxQueue() > 1) { + return 50; + } + + uint8_t outBuf[MAX_MESSAGE_SIZE]; + while (session.masterFd >= 0) { + const ssize_t bytesRead = read(session.masterFd, outBuf, sizeof(outBuf)); + if (bytesRead > 0) { + LOG_WARN("DMShell: read %zd bytes from PTY", bytesRead); + + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_OUTPUT, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + }; + assert(bytesRead <= sizeof(frame.payload.bytes)); + memcpy(frame.payload.bytes, outBuf, bytesRead); + frame.payload.size = bytesRead; + sendFrameToPeer(session.peer, frame, true); + + session.lastActivityMs = millis(); + // continue; + // do we want to ack every data message, and only send the next on ack? + // would require some retry logic. Maybe re-use the wantAck bit + return 50; + } + + if (bytesRead == 0) { + closeSession("pty_eof", true); + break; + } + + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + + LOG_WARN("DMShell: PTY read error errno=%d", errno); + closeSession("pty_read_error", true); + break; + } + + return 100; +} + +bool DMShellModule::parseFrame(const meshtastic_MeshPacket &mp, meshtastic_RemoteShell &outFrame) +{ + if (mp.which_payload_variant != meshtastic_MeshPacket_decoded_tag) { + return false; + } + + if (pb_decode_from_bytes(mp.decoded.payload.bytes, mp.decoded.payload.size, meshtastic_RemoteShell_fields, &outFrame)) { + LOG_INFO("Received a DMShell message"); + } else { + LOG_ERROR("Error decoding DMShell message!"); + return false; + } + + return true; +} + +bool DMShellModule::isAuthorizedPacket(const meshtastic_MeshPacket &mp) const +{ + if (mp.from == 0) { + return !config.security.is_managed; + } + + const meshtastic_Channel *ch = &channels.getByIndex(mp.channel); + if (strcasecmp(ch->settings.name, Channels::adminChannel) == 0) { + return config.security.admin_channel_enabled; + } + + if (mp.pki_encrypted) { + for (uint8_t i = 0; i < 3; ++i) { + if (config.security.admin_key[i].size == 32 && + memcmp(mp.public_key.bytes, config.security.admin_key[i].bytes, 32) == 0) { + return true; + } + } + } + + return false; +} + +bool DMShellModule::openSession(const meshtastic_MeshPacket &mp, const meshtastic_RemoteShell &frame) +{ + if (session.active) { + closeSession("preempted", false); + } + + int masterFd = -1; + struct winsize ws = {}; + if (frame.rows > 0) { + ws.ws_row = frame.rows; + } else { + ws.ws_row = PTY_ROWS_DEFAULT; + } + if (frame.cols > 0) { + ws.ws_col = frame.cols; + } else { + ws.ws_col = PTY_COLS_DEFAULT; + } + const pid_t childPid = forkpty(&masterFd, nullptr, nullptr, &ws); + if (childPid < 0) { + LOG_ERROR("DMShell: forkpty failed errno=%d", errno); + return false; + } + + if (childPid == 0) { + const char *shell = getenv("SHELL"); + if (!shell || !*shell) { + shell = "/bin/sh"; + } + execl(shell, shell, "-i", static_cast(nullptr)); + _exit(127); + } + + const int flags = fcntl(masterFd, F_GETFL, 0); + if (flags >= 0) { + fcntl(masterFd, F_SETFL, flags | O_NONBLOCK); + } + + session.active = true; + session.sessionId = (frame.session_id != 0) ? frame.session_id : static_cast(random(1, 0x7fffffff)); + session.peer = getFrom(&mp); + session.channel = mp.channel; + session.masterFd = masterFd; + session.childPid = childPid; + session.nextTxSeq = 1; + session.lastAckedRxSeq = frame.seq; + session.nextExpectedRxSeq = frame.seq + 1; + session.highestSeenRxSeq = frame.seq; + session.lastActivityMs = millis(); + + meshtastic_RemoteShell newFrame = { + .op = meshtastic_RemoteShell_OpCode_OPEN_OK, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = frame.seq, + .cols = ws.ws_col, + .rows = ws.ws_row, + .flags = 0, + }; + newFrame.payload.size = 0; + sendFrameToPeer(session.peer, newFrame, true); + + LOG_INFO("DMShell: opened session=0x%x peer=0x%x pid=%d", session.sessionId, session.peer, session.childPid); + return true; +} + +bool DMShellModule::writeSessionInput(const meshtastic_RemoteShell &frame) +{ + if (session.masterFd < 0) { + return false; + } + if (frame.payload.size == 0) { + return true; + } + + const ssize_t bytesWritten = write(session.masterFd, frame.payload.bytes, frame.payload.size); + return bytesWritten >= 0; +} + +void DMShellModule::closeSession(const char *reason, bool notifyPeer) +{ + if (!session.active) { + return; + } + + if (notifyPeer) { + const size_t reasonLen = strnlen(reason, 256); + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_CLOSED, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + }; + assert(reasonLen <= sizeof(frame.payload.bytes)); + memcpy(frame.payload.bytes, reason, reasonLen); + frame.payload.size = reasonLen; + sendFrameToPeer(session.peer, frame, true); + } + + if (session.masterFd >= 0) { + close(session.masterFd); + session.masterFd = -1; + } + + if (session.childPid > 0) { + // Run this to avoid forgetting a child + processPendingChildReap(); + + if (kill(session.childPid, SIGTERM) < 0 && errno != ESRCH) { + LOG_WARN("DMShell: failed to send SIGTERM to pid=%d errno=%d", session.childPid, errno); + } + + pendingChildPid = session.childPid; + session.childPid = -1; + } + + LOG_INFO("DMShell: closed session=0x%x reason=%s", session.sessionId, reason); + session = DMShellSession{}; +} + +void DMShellModule::reapChildIfExited() +{ + if (!session.active || session.childPid <= 0) { + return; + } + + int status = 0; + const pid_t result = waitpid(session.childPid, &status, WNOHANG); + if (result == session.childPid) { + closeSession("shell_exited", true); + } +} + +void DMShellModule::processPendingChildReap() +{ + if (pendingChildPid <= 0) { + return; + } + + int status = 0; + const pid_t result = waitpid(pendingChildPid, &status, WNOHANG); + + if (result == pendingChildPid || (result < 0 && errno == ECHILD)) { + pendingChildPid = -1; + return; + } + + if (result < 0) { + LOG_WARN("DMShell: waitpid failed for pid=%d errno=%d", pendingChildPid, errno); + pendingChildPid = -1; + return; + } + + if (pendingChildPid > 0) { + if (kill(pendingChildPid, SIGKILL) < 0 && errno != ESRCH) { + LOG_WARN("DMShell: failed to send SIGKILL to pid=%d errno=%d", pendingChildPid, errno); + } + pendingChildPid = -1; + } +} + +void DMShellModule::rememberSentFrame(meshtastic_RemoteShell frame) +{ + if (frame.seq == 0 || frame.op == meshtastic_RemoteShell_OpCode_ACK) { + return; + } + + auto &entry = session.txHistory[session.txHistoryNext]; + entry.valid = true; + entry.op = frame.op; + entry.sessionId = frame.session_id; + entry.seq = frame.seq; + entry.ackSeq = frame.ack_seq; + entry.cols = frame.cols; + entry.rows = frame.rows; + entry.flags = frame.flags; + entry.payloadLen = frame.payload.size; + if (frame.payload.size > 0) { + memcpy(entry.payload, frame.payload.bytes, frame.payload.size); + } + + session.txHistoryNext = (session.txHistoryNext + 1) % session.txHistory.size(); +} + +void DMShellModule::resendFramesFrom(uint32_t startSeq) +{ + if (startSeq == 0) { + return; + } + + DMShellSession::SentFrame *match = nullptr; + for (auto &entry : session.txHistory) { + if (!entry.valid || entry.seq != startSeq) { + continue; + } + match = &entry; + break; + } + + if (!match) { + LOG_WARN("DMShell: replay request for seq=%u not found in history", startSeq); + return; + } + + LOG_INFO("DMShell: replaying frame seq=%u op=%d", match->seq, match->op); + meshtastic_RemoteShell frame = { + .op = match->op, + .session_id = match->sessionId, + .seq = match->seq, + .ack_seq = match->ackSeq, + .cols = match->cols, + .rows = match->rows, + .flags = match->flags, + }; + assert(match->payloadLen <= sizeof(frame.payload.bytes)); + memcpy(frame.payload.bytes, match->payload, match->payloadLen); + frame.payload.size = match->payloadLen; + sendFrameToPeer(session.peer, frame, false); +} + +void DMShellModule::sendAck(uint32_t replayFromSeq) +{ + if (replayFromSeq > 0) { + LOG_WARN("DMShell: requesting replay from seq=%u", replayFromSeq); + } + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_ACK, + .session_id = session.sessionId, + .seq = 0, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + .last_rx_seq = replayFromSeq - 1, + }; + frame.payload.size = 0; + sendFrameToPeer(session.peer, frame, false); +} + +bool DMShellModule::shouldProcessIncomingFrame(const meshtastic_RemoteShell &frame) +{ + if (frame.seq == 0) { + return true; + } + + if (frame.seq < session.nextExpectedRxSeq) { + if (session.highestSeenRxSeq >= session.nextExpectedRxSeq) { + sendAck(session.nextExpectedRxSeq); + } else { + sendAck(); + } + return false; + } + + if (frame.seq > session.nextExpectedRxSeq) { + if (frame.seq > session.highestSeenRxSeq) { + session.highestSeenRxSeq = frame.seq; + } + sendAck(session.nextExpectedRxSeq); + return false; + } + + session.lastAckedRxSeq = frame.seq; + session.nextExpectedRxSeq = frame.seq + 1; + if (frame.seq > session.highestSeenRxSeq) { + session.highestSeenRxSeq = frame.seq; + } + if (session.highestSeenRxSeq >= session.nextExpectedRxSeq) { + sendAck(session.nextExpectedRxSeq); + } else { + session.highestSeenRxSeq = 0; + } + return true; +} + +void DMShellModule::sendFrameToPeer(NodeNum peer, meshtastic_RemoteShell frame, bool remember) +{ + meshtastic_MeshPacket *packet = allocDataPacket(); + if (!packet) { + return; + } + LOG_WARN("DMShell: building packet op=%u session=0x%x seq=%u payloadLen=%zu", frame.op, frame.session_id, frame.seq, + frame.payload.size); + const size_t encoded = pb_encode_to_bytes(packet->decoded.payload.bytes, sizeof(packet->decoded.payload.bytes), + meshtastic_RemoteShell_fields, &frame); + if (encoded == 0) { + return; + } + packet->decoded.payload.size = encoded; + + if (remember) { + rememberSentFrame(frame); + } + + packet->to = peer; + packet->channel = 0; + packet->want_ack = false; + packet->pki_encrypted = true; + packet->priority = meshtastic_MeshPacket_Priority_RELIABLE; + service->sendToMesh(packet); +} + +void DMShellModule::sendError(const char *message, NodeNum peer) +{ + const size_t len = strnlen(message, MAX_MESSAGE_SIZE); + meshtastic_RemoteShell frame = { + .op = meshtastic_RemoteShell_OpCode_ERROR, + .session_id = session.sessionId, + .seq = session.nextTxSeq++, + .ack_seq = session.lastAckedRxSeq, + .cols = 0, + .rows = 0, + .flags = 0, + }; + if (message && len > 0) { + assert(len <= sizeof(frame.payload.bytes)); + memcpy(frame.payload.bytes, message, len); + frame.payload.size = len; + } + if (peer == 0) { + peer = session.peer; + } + sendFrameToPeer(peer, frame, true); +} +#endif \ No newline at end of file diff --git a/src/modules/DMShell.h b/src/modules/DMShell.h new file mode 100644 index 00000000000..51e4e710edd --- /dev/null +++ b/src/modules/DMShell.h @@ -0,0 +1,77 @@ +#pragma once + +#include "MeshModule.h" +#include "Router.h" +#include "SinglePortModule.h" +#include "concurrency/OSThread.h" +#include "configuration.h" +#include "mesh/generated/meshtastic/mesh.pb.h" +#include +#include +#include + +#if defined(ARCH_PORTDUINO) + +struct DMShellSession { + bool active = false; + uint32_t sessionId = 0; + NodeNum peer = 0; + uint8_t channel = 0; + int masterFd = -1; + int childPid = -1; + uint32_t nextTxSeq = 1; + uint32_t lastAckedRxSeq = 0; + uint32_t nextExpectedRxSeq = 1; + uint32_t highestSeenRxSeq = 0; + uint32_t lastActivityMs = 0; + struct SentFrame { + bool valid = false; + meshtastic_RemoteShell_OpCode op = meshtastic_RemoteShell_OpCode_ERROR; + uint32_t sessionId = 0; + uint32_t seq = 0; + uint32_t ackSeq = 0; + uint32_t cols = 0; + uint32_t rows = 0; + uint32_t flags = 0; + uint8_t payload[meshtastic_Constants_DATA_PAYLOAD_LEN] = {0}; + size_t payloadLen = 0; + }; + std::array txHistory = {}; + size_t txHistoryNext = 0; +}; + +class DMShellModule : private concurrency::OSThread, public SinglePortModule +{ + + public: + DMShellModule(); + + protected: + virtual ProcessMessage handleReceived(const meshtastic_MeshPacket &mp) override; + virtual int32_t runOnce() override; + + private: + static constexpr uint32_t SESSION_IDLE_TIMEOUT_MS = 5 * 60 * 1000; + + DMShellSession session; + pid_t pendingChildPid = -1; + + bool parseFrame(const meshtastic_MeshPacket &mp, meshtastic_RemoteShell &outFrame); + bool isAuthorizedPacket(const meshtastic_MeshPacket &mp) const; + bool openSession(const meshtastic_MeshPacket &mp, const meshtastic_RemoteShell &frame); + bool shouldProcessIncomingFrame(const meshtastic_RemoteShell &frame); + bool writeSessionInput(const meshtastic_RemoteShell &frame); + void closeSession(const char *reason, bool notifyPeer); + void reapChildIfExited(); + void processPendingChildReap(); + + void rememberSentFrame(meshtastic_RemoteShell frame); + void resendFramesFrom(uint32_t startSeq); + void sendAck(uint32_t replayFromSeq = 0); + void sendFrameToPeer(NodeNum peer, meshtastic_RemoteShell frame, bool remember = true); + void sendError(const char *message, NodeNum peer = 0); +}; + +extern DMShellModule *dmShellModule; + +#endif \ No newline at end of file diff --git a/src/modules/Modules.cpp b/src/modules/Modules.cpp index d3ab9076d33..ea90c1d0348 100644 --- a/src/modules/Modules.cpp +++ b/src/modules/Modules.cpp @@ -49,6 +49,7 @@ #include "modules/WaypointModule.h" #endif #if ARCH_PORTDUINO +#include "modules/DMShell.h" #include "modules/Telemetry/HostMetrics.h" #if !MESHTASTIC_EXCLUDE_STOREFORWARD #include "modules/StoreForwardModule.h" @@ -195,6 +196,7 @@ void setupModules() #endif #if ARCH_PORTDUINO new HostMetricsModule(); + dmShellModule = new DMShellModule(); #endif #if HAS_TELEMETRY new DeviceTelemetryModule(); diff --git a/test/test_meshpacket_serializer/ports/test_dmshell.cpp b/test/test_meshpacket_serializer/ports/test_dmshell.cpp new file mode 100644 index 00000000000..29bc8bef69c --- /dev/null +++ b/test/test_meshpacket_serializer/ports/test_dmshell.cpp @@ -0,0 +1,97 @@ +#include "../test_helpers.h" +#include "mesh/mesh-pb-constants.h" + +namespace +{ +struct BytesDecodeState { + uint8_t *buffer; + size_t capacity; + size_t length; +}; + +struct BytesEncodeState { + const uint8_t *buffer; + size_t length; +}; + +bool decodeBytesField(pb_istream_t *stream, const pb_field_iter_t *field, void **arg) +{ + (void)field; + auto *state = static_cast(*arg); + if (!state) { + return false; + } + + const size_t fieldLen = stream->bytes_left; + if (fieldLen > state->capacity) { + return false; + } + + if (!pb_read(stream, state->buffer, fieldLen)) { + return false; + } + + state->length = fieldLen; + return true; +} + +bool encodeBytesField(pb_ostream_t *stream, const pb_field_iter_t *field, void *const *arg) +{ + auto *state = static_cast(*arg); + if (!state || !state->buffer || state->length == 0) { + return true; + } + + if (!pb_encode_tag_for_field(stream, field)) { + return false; + } + + return pb_encode_string(stream, state->buffer, state->length); +} + +void assert_dmshell_roundtrip(meshtastic_RemoteShell_OpCode op, uint32_t sessionId, uint32_t seq, const uint8_t *payload, + size_t payloadLen, uint32_t cols = 0, uint32_t rows = 0) +{ + meshtastic_RemoteShell tx = meshtastic_RemoteShell_init_zero; + tx.op = op; + tx.session_id = sessionId; + tx.seq = seq; + tx.cols = cols; + tx.rows = rows; + + uint8_t encoded[meshtastic_Constants_DATA_PAYLOAD_LEN] = {0}; + size_t encodedLen = pb_encode_to_bytes(encoded, sizeof(encoded), meshtastic_RemoteShell_fields, &tx); + TEST_ASSERT_GREATER_THAN_UINT32(0, encodedLen); + + meshtastic_RemoteShell rx = meshtastic_RemoteShell_init_zero; + + TEST_ASSERT_TRUE(pb_decode_from_bytes(encoded, encodedLen, meshtastic_RemoteShell_fields, &rx)); + TEST_ASSERT_EQUAL(op, rx.op); + TEST_ASSERT_EQUAL_UINT32(sessionId, rx.session_id); + TEST_ASSERT_EQUAL_UINT32(seq, rx.seq); + TEST_ASSERT_EQUAL_UINT32(cols, rx.cols); + TEST_ASSERT_EQUAL_UINT32(rows, rx.rows); +} +} // namespace + +void test_dmshell_open_roundtrip() +{ + assert_dmshell_roundtrip(meshtastic_RemoteShell_OpCode_OPEN, 0x101, 1, nullptr, 0, 120, 40); +} + +void test_dmshell_input_roundtrip() +{ + const uint8_t payload[] = {'l', 's', '\n'}; + assert_dmshell_roundtrip(meshtastic_RemoteShell_OpCode_INPUT, 0x202, 2, payload, sizeof(payload)); +} + +void test_dmshell_resize_roundtrip() +{ + assert_dmshell_roundtrip(meshtastic_RemoteShell_OpCode_RESIZE, 0x303, 3, nullptr, 0, 180, 55); +} + +void test_dmshell_close_roundtrip() +{ + const uint8_t reason[] = {'b', 'y', 'e'}; + assert_dmshell_roundtrip(meshtastic_RemoteShell_OpCode_CLOSE, 0x404, 4, reason, sizeof(reason)); +} \ No newline at end of file diff --git a/test/test_meshpacket_serializer/test_serializer.cpp b/test/test_meshpacket_serializer/test_serializer.cpp index 484db8d7486..3a19745d824 100644 --- a/test/test_meshpacket_serializer/test_serializer.cpp +++ b/test/test_meshpacket_serializer/test_serializer.cpp @@ -19,6 +19,10 @@ void test_telemetry_environment_metrics_complete_coverage(); void test_telemetry_environment_metrics_unset_fields(); void test_encrypted_packet_serialization(); void test_empty_encrypted_packet(); +void test_dmshell_open_roundtrip(); +void test_dmshell_input_roundtrip(); +void test_dmshell_resize_roundtrip(); +void test_dmshell_close_roundtrip(); void setup() { @@ -52,6 +56,12 @@ void setup() RUN_TEST(test_encrypted_packet_serialization); RUN_TEST(test_empty_encrypted_packet); + // DMShell protobuf transport tests + RUN_TEST(test_dmshell_open_roundtrip); + RUN_TEST(test_dmshell_input_roundtrip); + RUN_TEST(test_dmshell_resize_roundtrip); + RUN_TEST(test_dmshell_close_roundtrip); + UNITY_END(); }