#!/usr/bin/env python3 """Docker integration: remote SSH reconnect after host restart.""" from __future__ import annotations import glob import hashlib import json import os import secrets import shutil import socket import struct import subprocess import sys import tempfile import time from base64 import b64encode from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) from cmux import cmux, cmuxError SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock") REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43173")) REMOTE_WS_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_WS_PORT", "43174")) DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1") DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1") def _must(cond: bool, msg: str) -> None: if not cond: raise cmuxError(msg) def _find_cli_binary() -> str: env_cli = os.environ.get("CMUXTERM_CLI") if env_cli and os.path.isfile(env_cli) and os.access(env_cli, os.X_OK): return env_cli fixed = os.path.expanduser("~/Library/Developer/Xcode/DerivedData/cmux-tests-v2/Build/Products/Debug/cmux") if os.path.isfile(fixed) and os.access(fixed, os.X_OK): return fixed candidates = glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/**/Build/Products/Debug/cmux"), recursive=True) candidates += glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux") candidates = [p for p in candidates if os.path.isfile(p) and os.access(p, os.X_OK)] if not candidates: raise cmuxError("Could not locate cmux CLI binary; set CMUXTERM_CLI") candidates.sort(key=lambda p: os.path.getmtime(p), reverse=True) return candidates[0] def _run(cmd: list[str], *, env: dict[str, str] | None = None, check: bool = True) -> subprocess.CompletedProcess[str]: proc = subprocess.run(cmd, capture_output=True, text=True, env=env, check=False) if check and proc.returncode != 0: merged = f"{proc.stdout}\n{proc.stderr}".strip() raise cmuxError(f"Command failed ({' '.join(cmd)}): {merged}") return proc def _run_cli_json(cli: str, args: list[str]) -> dict: env = dict(os.environ) env.pop("CMUX_WORKSPACE_ID", None) env.pop("CMUX_SURFACE_ID", None) env.pop("CMUX_TAB_ID", None) proc = _run([cli, "--socket", SOCKET_PATH, "--json", *args], env=env) try: return json.loads(proc.stdout or "{}") except Exception as exc: # noqa: BLE001 raise cmuxError(f"Invalid JSON output for {' '.join(args)}: {proc.stdout!r} ({exc})") def _docker_available() -> bool: if shutil.which("docker") is None: return False probe = _run(["docker", "info"], check=False) return probe.returncode == 0 def _curl_via_socks(proxy_port: int, target_url: str) -> str: if shutil.which("curl") is None: raise cmuxError("curl is required for SOCKS proxy verification") proc = _run( [ "curl", "--silent", "--show-error", "--max-time", "5", "--socks5-hostname", f"127.0.0.1:{proxy_port}", target_url, ], check=False, ) if proc.returncode != 0: merged = f"{proc.stdout}\n{proc.stderr}".strip() raise cmuxError(f"curl via SOCKS proxy failed: {merged}") return proc.stdout def _find_free_loopback_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("127.0.0.1", 0)) return int(sock.getsockname()[1]) def _recv_exact(sock: socket.socket, n: int) -> bytes: out = bytearray() while len(out) < n: chunk = sock.recv(n - len(out)) if not chunk: raise cmuxError("unexpected EOF while reading socket") out.extend(chunk) return bytes(out) def _recv_until(sock: socket.socket, marker: bytes, limit: int = 16384) -> bytes: out = bytearray() while marker not in out: chunk = sock.recv(1024) if not chunk: raise cmuxError("unexpected EOF while reading response headers") out.extend(chunk) if len(out) > limit: raise cmuxError("response headers too large") return bytes(out) def _read_socks5_connect_reply(sock: socket.socket) -> None: head = _recv_exact(sock, 4) if len(head) != 4 or head[0] != 0x05: raise cmuxError(f"invalid SOCKS5 reply: {head!r}") if head[1] != 0x00: raise cmuxError(f"SOCKS5 connect failed with status=0x{head[1]:02x}") reply_atyp = head[3] if reply_atyp == 0x01: _ = _recv_exact(sock, 4) elif reply_atyp == 0x03: ln = _recv_exact(sock, 1)[0] _ = _recv_exact(sock, ln) elif reply_atyp == 0x04: _ = _recv_exact(sock, 16) else: raise cmuxError(f"invalid SOCKS5 atyp in reply: 0x{reply_atyp:02x}") _ = _recv_exact(sock, 2) def _read_http_response_from_connected_socket(sock: socket.socket) -> str: response = _recv_until(sock, b"\r\n\r\n") header_end = response.index(b"\r\n\r\n") + 4 header_blob = response[:header_end] body = bytearray(response[header_end:]) header_text = header_blob.decode("utf-8", errors="replace") status_line = header_text.split("\r\n", 1)[0] if "200" not in status_line: raise cmuxError(f"HTTP over SOCKS tunnel failed: {status_line!r}") content_length: int | None = None for line in header_text.split("\r\n")[1:]: if line.lower().startswith("content-length:"): try: content_length = int(line.split(":", 1)[1].strip()) except Exception: # noqa: BLE001 content_length = None break if content_length is not None: while len(body) < content_length: chunk = sock.recv(4096) if not chunk: break body.extend(chunk) else: while True: try: chunk = sock.recv(4096) except socket.timeout: break if not chunk: break body.extend(chunk) return bytes(body).decode("utf-8", errors="replace") def _socks5_connect(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: sock = socket.create_connection((proxy_host, proxy_port), timeout=6) sock.settimeout(6) sock.sendall(b"\x05\x01\x00") greeting = _recv_exact(sock, 2) if greeting != b"\x05\x00": sock.close() raise cmuxError(f"SOCKS5 greeting failed: {greeting!r}") try: host_bytes = socket.inet_aton(target_host) atyp = b"\x01" addr = host_bytes except OSError: host_encoded = target_host.encode("utf-8") if len(host_encoded) > 255: sock.close() raise cmuxError("target host too long for SOCKS5 domain form") atyp = b"\x03" addr = bytes([len(host_encoded)]) + host_encoded req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) sock.sendall(req) try: _read_socks5_connect_reply(sock) except Exception: sock.close() raise return sock def _socks5_http_get_pipelined(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> str: sock = socket.create_connection((proxy_host, proxy_port), timeout=6) sock.settimeout(6) try: try: host_bytes = socket.inet_aton(target_host) atyp = b"\x01" addr = host_bytes except OSError: host_encoded = target_host.encode("utf-8") if len(host_encoded) > 255: raise cmuxError("target host too long for SOCKS5 domain form") atyp = b"\x03" addr = bytes([len(host_encoded)]) + host_encoded greeting = b"\x05\x01\x00" connect_req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) http_get = ( "GET / HTTP/1.1\r\n" f"Host: {target_host}:{target_port}\r\n" "Connection: close\r\n" "\r\n" ).encode("utf-8") sock.sendall(greeting + connect_req + http_get) greeting_reply = _recv_exact(sock, 2) if greeting_reply != b"\x05\x00": raise cmuxError(f"SOCKS5 greeting failed: {greeting_reply!r}") _read_socks5_connect_reply(sock) return _read_http_response_from_connected_socket(sock) finally: try: sock.close() except Exception: pass def _http_connect_tunnel(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: sock = socket.create_connection((proxy_host, proxy_port), timeout=6) sock.settimeout(6) request = ( f"CONNECT {target_host}:{target_port} HTTP/1.1\r\n" f"Host: {target_host}:{target_port}\r\n" "Proxy-Connection: Keep-Alive\r\n" "\r\n" ).encode("utf-8") sock.sendall(request) header_blob = _recv_until(sock, b"\r\n\r\n") header_text = header_blob.decode("utf-8", errors="replace") status_line = header_text.split("\r\n", 1)[0] if "200" not in status_line: sock.close() raise cmuxError(f"HTTP CONNECT tunnel failed: {status_line!r}") return sock def _encode_client_text_frame(payload: str) -> bytes: data = payload.encode("utf-8") first = 0x81 mask = secrets.token_bytes(4) length = len(data) if length < 126: header = bytes([first, 0x80 | length]) elif length <= 0xFFFF: header = bytes([first, 0x80 | 126]) + struct.pack("!H", length) else: header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length) masked = bytes(b ^ mask[i % 4] for i, b in enumerate(data)) return header + mask + masked def _read_server_text_frame(sock: socket.socket) -> str: first, second = _recv_exact(sock, 2) opcode = first & 0x0F masked = (second & 0x80) != 0 length = second & 0x7F if length == 126: length = struct.unpack("!H", _recv_exact(sock, 2))[0] elif length == 127: length = struct.unpack("!Q", _recv_exact(sock, 8))[0] mask = _recv_exact(sock, 4) if masked else b"" payload = _recv_exact(sock, length) if length else b"" if masked and payload: payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) if opcode != 0x1: raise cmuxError(f"Expected websocket text frame opcode=0x1, got opcode=0x{opcode:x}") try: return payload.decode("utf-8") except Exception as exc: # noqa: BLE001 raise cmuxError(f"WebSocket response payload is not valid UTF-8: {exc}") def _websocket_echo_on_connected_socket(sock: socket.socket, ws_host: str, ws_port: int, message: str, path_label: str) -> str: ws_key = b64encode(secrets.token_bytes(16)).decode("ascii") request = ( "GET /echo HTTP/1.1\r\n" f"Host: {ws_host}:{ws_port}\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {ws_key}\r\n" "Sec-WebSocket-Version: 13\r\n" "\r\n" ).encode("utf-8") sock.sendall(request) header_blob = _recv_until(sock, b"\r\n\r\n") header_text = header_blob.decode("utf-8", errors="replace") status_line = header_text.split("\r\n", 1)[0] if "101" not in status_line: raise cmuxError(f"WebSocket handshake failed over {path_label}: {status_line!r}") expected_accept = b64encode( hashlib.sha1((ws_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")).digest() ).decode("ascii") lowered_headers = { line.split(":", 1)[0].strip().lower(): line.split(":", 1)[1].strip() for line in header_text.split("\r\n")[1:] if ":" in line } if lowered_headers.get("sec-websocket-accept", "") != expected_accept: raise cmuxError(f"WebSocket handshake over {path_label} returned invalid Sec-WebSocket-Accept") sock.sendall(_encode_client_text_frame(message)) return _read_server_text_frame(sock) def _websocket_echo_via_socks(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: sock = _socks5_connect("127.0.0.1", proxy_port, ws_host, ws_port) try: return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "SOCKS proxy") finally: try: sock.close() except Exception: pass def _websocket_echo_via_connect(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: sock = _http_connect_tunnel("127.0.0.1", proxy_port, ws_host, ws_port) try: return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "HTTP CONNECT proxy") finally: try: sock.close() except Exception: pass def _start_container(image_tag: str, container_name: str, pubkey: str, host_ssh_port: int) -> None: for _ in range(20): proc = _run( [ "docker", "run", "-d", "--rm", "--name", container_name, "-e", f"AUTHORIZED_KEY={pubkey}", "-e", f"REMOTE_HTTP_PORT={REMOTE_HTTP_PORT}", "-e", f"REMOTE_WS_PORT={REMOTE_WS_PORT}", "-p", f"{DOCKER_PUBLISH_ADDR}:{host_ssh_port}:22", image_tag, ], check=False, ) if proc.returncode == 0: return time.sleep(0.5) merged = f"{proc.stdout}\n{proc.stderr}".strip() raise cmuxError(f"Failed to start ssh test container on fixed port {host_ssh_port}: {merged}") def _wait_remote_connected(client: cmux, workspace_id: str, timeout: float) -> dict: deadline = time.time() + timeout last_status = {} while time.time() < deadline: last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} remote = last_status.get("remote") or {} proxy = remote.get("proxy") or {} port_value = proxy.get("port") proxy_port: int | None if isinstance(port_value, int): proxy_port = port_value elif isinstance(port_value, str) and port_value.isdigit(): proxy_port = int(port_value) else: proxy_port = None if str(remote.get("state") or "") == "connected" and proxy_port is not None: return last_status time.sleep(0.5) raise cmuxError(f"Remote did not reach connected+proxy-ready state: {last_status}") def _wait_remote_degraded(client: cmux, workspace_id: str, timeout: float) -> dict: deadline = time.time() + timeout last_status = {} while time.time() < deadline: last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} remote = last_status.get("remote") or {} state = str(remote.get("state") or "") if state in {"error", "connecting", "disconnected"}: return last_status time.sleep(0.5) raise cmuxError(f"Remote did not enter reconnecting/degraded state: {last_status}") def main() -> int: if not _docker_available(): print("SKIP: docker is not available") return 0 cli = _find_cli_binary() repo_root = Path(__file__).resolve().parents[1] fixture_dir = repo_root / "tests" / "fixtures" / "ssh-remote" _must(fixture_dir.is_dir(), f"Missing docker fixture directory: {fixture_dir}") temp_dir = Path(tempfile.mkdtemp(prefix="cmux-ssh-reconnect-")) image_tag = f"cmux-ssh-test:{secrets.token_hex(4)}" container_name = f"cmux-ssh-reconnect-{secrets.token_hex(4)}" host_ssh_port = _find_free_loopback_port() workspace_id = "" container_running = False try: key_path = temp_dir / "id_ed25519" _run(["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_path)]) pubkey = (key_path.with_suffix(".pub")).read_text(encoding="utf-8").strip() _must(bool(pubkey), "Generated SSH public key was empty") _run(["docker", "build", "-t", image_tag, str(fixture_dir)]) _start_container(image_tag, container_name, pubkey, host_ssh_port) container_running = True with cmux(SOCKET_PATH) as client: payload = _run_cli_json( cli, [ "ssh", f"root@{DOCKER_SSH_HOST}", "--name", "docker-ssh-reconnect", "--port", str(host_ssh_port), "--identity", str(key_path), "--ssh-option", "UserKnownHostsFile=/dev/null", "--ssh-option", "StrictHostKeyChecking=no", ], ) workspace_id = str(payload.get("workspace_id") or "") workspace_ref = str(payload.get("workspace_ref") or "") if not workspace_id and workspace_ref.startswith("workspace:"): listed = client._call("workspace.list", {}) or {} for row in listed.get("workspaces") or []: if str(row.get("ref") or "") == workspace_ref: workspace_id = str(row.get("id") or "") break _must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}") first_status = _wait_remote_connected(client, workspace_id, timeout=45.0) first_daemon = ((first_status.get("remote") or {}).get("daemon") or {}) _must(str(first_daemon.get("state") or "") == "ready", f"daemon should be ready after first connect: {first_status}") first_capabilities = {str(item) for item in (first_daemon.get("capabilities") or [])} _must("proxy.stream" in first_capabilities, f"daemon should advertise proxy.stream: {first_status}") _must("proxy.socks5" in first_capabilities, f"daemon should advertise proxy.socks5: {first_status}") _must("proxy.http_connect" in first_capabilities, f"daemon should advertise proxy.http_connect: {first_status}") first_proxy = ((first_status.get("remote") or {}).get("proxy") or {}) first_proxy_port = first_proxy.get("port") if isinstance(first_proxy_port, str) and first_proxy_port.isdigit(): first_proxy_port = int(first_proxy_port) _must(isinstance(first_proxy_port, int), f"connected status should include proxy port: {first_status}") first_body = "" first_deadline_http = time.time() + 15.0 while time.time() < first_deadline_http: try: first_body = _curl_via_socks(int(first_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") except Exception: time.sleep(0.5) continue if "cmux-ssh-forward-ok" in first_body: break time.sleep(0.3) _must("cmux-ssh-forward-ok" in first_body, f"Forwarded HTTP endpoint failed before reconnect: {first_body[:120]!r}") first_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(first_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT) _must( "cmux-ssh-forward-ok" in first_pipelined_body, f"SOCKS pipelined greeting/connect+payload failed before reconnect: {first_pipelined_body[:120]!r}", ) first_ws_socks_message = "cmux-reconnect-before-over-socks" echoed_before_socks = _websocket_echo_via_socks(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_socks_message) _must( echoed_before_socks == first_ws_socks_message, f"WebSocket echo over SOCKS proxy failed before reconnect: {echoed_before_socks!r} != {first_ws_socks_message!r}", ) first_ws_connect_message = "cmux-reconnect-before-over-connect" echoed_before_connect = _websocket_echo_via_connect(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_connect_message) _must( echoed_before_connect == first_ws_connect_message, f"WebSocket echo over CONNECT proxy failed before reconnect: {echoed_before_connect!r} != {first_ws_connect_message!r}", ) _run(["docker", "rm", "-f", container_name], check=False) container_running = False _wait_remote_degraded(client, workspace_id, timeout=20.0) _start_container(image_tag, container_name, pubkey, host_ssh_port) container_running = True second_status = _wait_remote_connected(client, workspace_id, timeout=60.0) second_daemon = ((second_status.get("remote") or {}).get("daemon") or {}) _must(str(second_daemon.get("state") or "") == "ready", f"daemon should be ready after reconnect: {second_status}") second_capabilities = {str(item) for item in (second_daemon.get("capabilities") or [])} _must("proxy.stream" in second_capabilities, f"daemon should advertise proxy.stream after reconnect: {second_status}") _must("proxy.socks5" in second_capabilities, f"daemon should advertise proxy.socks5 after reconnect: {second_status}") _must("proxy.http_connect" in second_capabilities, f"daemon should advertise proxy.http_connect after reconnect: {second_status}") second_proxy = ((second_status.get("remote") or {}).get("proxy") or {}) second_proxy_port = second_proxy.get("port") if isinstance(second_proxy_port, str) and second_proxy_port.isdigit(): second_proxy_port = int(second_proxy_port) _must(isinstance(second_proxy_port, int), f"reconnected status should include proxy port: {second_status}") second_body = "" deadline_http = time.time() + 15.0 while time.time() < deadline_http: try: second_body = _curl_via_socks(int(second_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") except Exception: time.sleep(0.5) continue if "cmux-ssh-forward-ok" in second_body: break time.sleep(0.3) _must("cmux-ssh-forward-ok" in second_body, f"Forwarded HTTP endpoint failed after reconnect: {second_body[:120]!r}") second_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(second_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT) _must( "cmux-ssh-forward-ok" in second_pipelined_body, f"SOCKS pipelined greeting/connect+payload failed after reconnect: {second_pipelined_body[:120]!r}", ) second_ws_socks_message = "cmux-reconnect-after-over-socks" echoed_after_socks = _websocket_echo_via_socks(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_socks_message) _must( echoed_after_socks == second_ws_socks_message, f"WebSocket echo over SOCKS proxy failed after reconnect: {echoed_after_socks!r} != {second_ws_socks_message!r}", ) second_ws_connect_message = "cmux-reconnect-after-over-connect" echoed_after_connect = _websocket_echo_via_connect(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_connect_message) _must( echoed_after_connect == second_ws_connect_message, f"WebSocket echo over CONNECT proxy failed after reconnect: {echoed_after_connect!r} != {second_ws_connect_message!r}", ) try: client.close_workspace(workspace_id) except Exception: pass workspace_id = "" print("PASS: docker SSH remote reconnects and re-establishes HTTP + WebSocket egress over SOCKS and CONNECT") return 0 finally: if workspace_id: try: with cmux(SOCKET_PATH) as cleanup_client: cleanup_client.close_workspace(workspace_id) except Exception: pass if container_running: _run(["docker", "rm", "-f", container_name], check=False) _run(["docker", "rmi", "-f", image_tag], check=False) shutil.rmtree(temp_dir, ignore_errors=True) if __name__ == "__main__": raise SystemExit(main())