From 76835662f582869701ba82dd42b70d314ccd2ae9 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Wed, 4 Mar 2026 18:03:25 -0800 Subject: [PATCH] Fix CLI socket autodiscovery for tagged cmux sockets (#832) --- CLI/cmux.swift | 279 +++++++++++++++++++++---- tests/test_cli_socket_autodiscovery.py | 150 +++++++++++++ 2 files changed, 393 insertions(+), 36 deletions(-) create mode 100755 tests/test_cli_socket_autodiscovery.py diff --git a/CLI/cmux.swift b/CLI/cmux.swift index 17a13f91..8d218262 100644 --- a/CLI/cmux.swift +++ b/CLI/cmux.swift @@ -451,9 +451,159 @@ private enum SocketPasswordResolver { } } +private enum CLISocketPathSource { + case explicitFlag + case environment + case implicitDefault +} + +private enum CLISocketPathResolver { + static let defaultSocketPath = "/tmp/cmux.sock" + private static let fallbackSocketPath = "/tmp/cmux-debug.sock" + private static let stagingSocketPath = "/tmp/cmux-staging.sock" + private static let lastSocketPathFile = "/tmp/cmux-last-socket-path" + + static func resolve( + requestedPath: String, + source: CLISocketPathSource, + environment: [String: String] = ProcessInfo.processInfo.environment + ) -> String { + guard source == .implicitDefault else { + return requestedPath + } + + let candidates = dedupe(candidatePaths(requestedPath: requestedPath, environment: environment)) + + // Prefer sockets that are currently accepting connections. + for path in candidates where canConnect(to: path) { + return path + } + + // If the listener is still starting, prefer existing socket files. + for path in candidates where isSocketFile(path) { + return path + } + + return requestedPath + } + + private static func candidatePaths(requestedPath: String, environment: [String: String]) -> [String] { + var candidates: [String] = [] + + if let tag = normalized(environment["CMUX_TAG"]) { + let slug = sanitizeTagSlug(tag) + candidates.append("/tmp/cmux-debug-\(slug).sock") + candidates.append("/tmp/cmux-\(slug).sock") + } + + candidates.append(requestedPath) + candidates.append(fallbackSocketPath) + candidates.append(stagingSocketPath) + candidates.append(contentsOf: discoverTaggedSockets(limit: 12)) + if let last = readLastSocketPath() { + candidates.append(last) + } + return candidates + } + + private static func readLastSocketPath() -> String? { + guard let data = try? String(contentsOfFile: lastSocketPathFile, encoding: .utf8) else { + return nil + } + return normalized(data) + } + + private static func discoverTaggedSockets(limit: Int) -> [String] { + guard let entries = try? FileManager.default.contentsOfDirectory(atPath: "/tmp") else { + return [] + } + + var discovered: [(path: String, mtime: TimeInterval)] = [] + discovered.reserveCapacity(min(limit, entries.count)) + for name in entries where name.hasPrefix("cmux") && name.hasSuffix(".sock") { + let path = "/tmp/\(name)" + var st = stat() + guard lstat(path, &st) == 0 else { continue } + guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else { continue } + if path == defaultSocketPath || path == fallbackSocketPath || path == stagingSocketPath { + continue + } + let modified = TimeInterval(st.st_mtimespec.tv_sec) + TimeInterval(st.st_mtimespec.tv_nsec) / 1_000_000_000 + discovered.append((path: path, mtime: modified)) + } + + discovered.sort { $0.mtime > $1.mtime } + return discovered.prefix(limit).map(\.path) + } + + private static func isSocketFile(_ path: String) -> Bool { + var st = stat() + return lstat(path, &st) == 0 && (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) + } + + private static func canConnect(to path: String) -> Bool { + guard isSocketFile(path) else { return false } + let fd = socket(AF_UNIX, SOCK_STREAM, 0) + guard fd >= 0 else { return false } + defer { Darwin.close(fd) } + + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + let maxLength = MemoryLayout.size(ofValue: addr.sun_path) + path.withCString { ptr in + withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in + let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self) + strncpy(buf, ptr, maxLength - 1) + } + } + + let result = withUnsafePointer(to: &addr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + Darwin.connect(fd, sockaddrPtr, socklen_t(MemoryLayout.size)) + } + } + return result == 0 + } + + private static func sanitizeTagSlug(_ raw: String) -> String { + let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() + let slug = trimmed + .replacingOccurrences(of: "[^a-z0-9]+", with: "-", options: .regularExpression) + .replacingOccurrences(of: "-+", with: "-", options: .regularExpression) + .trimmingCharacters(in: CharacterSet(charactersIn: "-")) + return slug.isEmpty ? "agent" : slug + } + + private static func normalized(_ value: String?) -> String? { + guard let value else { return nil } + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + return trimmed.isEmpty ? nil : trimmed + } + + private static func dedupe(_ paths: [String]) -> [String] { + var seen: Set = [] + var ordered: [String] = [] + ordered.reserveCapacity(paths.count) + for path in paths where !path.isEmpty { + if seen.insert(path).inserted { + ordered.append(path) + } + } + return ordered + } +} + final class SocketClient { private let path: String private var socketFD: Int32 = -1 + private static let connectRetryWindowSeconds: TimeInterval = 2.0 + private static let connectRetryIntervalSeconds: TimeInterval = 0.1 + private static let retriableConnectErrnos: Set = [ + ENOENT, + ECONNREFUSED, + EAGAIN, + EINTR + ] private static let defaultResponseTimeoutSeconds: TimeInterval = 15.0 private static let responseTimeoutSeconds: TimeInterval = { let env = ProcessInfo.processInfo.environment @@ -472,40 +622,66 @@ final class SocketClient { func connect() throws { if socketFD >= 0 { return } - // Verify socket is owned by the current user to prevent fake-socket attacks - var st = stat() - guard stat(path, &st) == 0 else { - throw CLIError(message: "Socket not found at \(path)") - } - guard st.st_uid == getuid() else { - throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect") - } + let deadline = Date().addingTimeInterval(Self.connectRetryWindowSeconds) + var lastError: CLIError? - socketFD = socket(AF_UNIX, SOCK_STREAM, 0) - if socketFD < 0 { - throw CLIError(message: "Failed to create socket") - } - - var addr = sockaddr_un() - addr.sun_family = sa_family_t(AF_UNIX) - let maxLength = MemoryLayout.size(ofValue: addr.sun_path) - path.withCString { ptr in - withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in - let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self) - strncpy(buf, ptr, maxLength - 1) + while true { + // Verify socket is owned by the current user to prevent fake-socket attacks. + var st = stat() + guard stat(path, &st) == 0 else { + let error = CLIError(message: "Socket not found at \(path)") + lastError = error + if errno == ENOENT, Date() < deadline { + Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds) + continue + } + throw error } - } - - let result = withUnsafePointer(to: &addr) { ptr in - ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in - Darwin.connect(socketFD, sockaddrPtr, socklen_t(MemoryLayout.size)) + guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else { + throw CLIError(message: "Path exists at \(path) but is not a Unix socket") } - } - if result != 0 { + guard st.st_uid == getuid() else { + throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect") + } + + socketFD = socket(AF_UNIX, SOCK_STREAM, 0) + if socketFD < 0 { + throw CLIError(message: "Failed to create socket") + } + + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + let maxLength = MemoryLayout.size(ofValue: addr.sun_path) + path.withCString { ptr in + withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in + let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self) + strncpy(buf, ptr, maxLength - 1) + } + } + + let result = withUnsafePointer(to: &addr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + Darwin.connect(socketFD, sockaddrPtr, socklen_t(MemoryLayout.size)) + } + } + if result == 0 { + return + } + + let connectErrno = errno Darwin.close(socketFD) socketFD = -1 - throw CLIError(message: "Failed to connect to socket at \(path)") + + let error = CLIError(message: "Failed to connect to socket at \(path)") + lastError = error + if Self.retriableConnectErrnos.contains(connectErrno), Date() < deadline { + Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds) + continue + } + throw error } + + throw lastError ?? CLIError(message: "Failed to connect to socket at \(path)") } func close() { @@ -614,7 +790,19 @@ struct CMUXCLI { let args: [String] func run() throws { - var socketPath = ProcessInfo.processInfo.environment["CMUX_SOCKET_PATH"] ?? "/tmp/cmux.sock" + let processEnv = ProcessInfo.processInfo.environment + let envSocketPath: String? = { + guard let raw = processEnv["CMUX_SOCKET_PATH"] else { return nil } + let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines) + return trimmed.isEmpty ? nil : trimmed + }() + var socketPath = envSocketPath ?? CLISocketPathResolver.defaultSocketPath + var socketPathSource: CLISocketPathSource + if let envSocketPath { + socketPathSource = envSocketPath == CLISocketPathResolver.defaultSocketPath ? .implicitDefault : .environment + } else { + socketPathSource = .implicitDefault + } var jsonOutput = false var idFormatArg: String? = nil var windowId: String? = nil @@ -628,6 +816,7 @@ struct CMUXCLI { throw CLIError(message: "--socket requires a path") } socketPath = args[index + 1] + socketPathSource = .explicitFlag index += 2 continue } @@ -682,7 +871,12 @@ struct CMUXCLI { command: command, commandArgs: commandArgs, socketPath: socketPath, - processEnv: ProcessInfo.processInfo.environment + processEnv: processEnv + ) + let resolvedSocketPath = CLISocketPathResolver.resolve( + requestedPath: socketPath, + source: socketPathSource, + environment: processEnv ) if command == "version" { @@ -692,7 +886,7 @@ struct CMUXCLI { // If the argument looks like a path (not a known command), open a workspace there. if looksLikePath(command) { - try openPath(command, socketPath: socketPath) + try openPath(command, socketPath: resolvedSocketPath) return } @@ -706,16 +900,28 @@ struct CMUXCLI { return } - let client = SocketClient(path: socketPath) + let client = SocketClient(path: resolvedSocketPath) + if resolvedSocketPath != socketPath { + cliTelemetry.breadcrumb( + "socket.path.autodiscovered", + data: [ + "requested_path": socketPath, + "resolved_path": resolvedSocketPath + ] + ) + } cliTelemetry.breadcrumb( "socket.connect.attempt", - data: ["command": command] + data: [ + "command": command, + "path": resolvedSocketPath + ] ) do { try client.connect() - cliTelemetry.breadcrumb("socket.connect.success") + cliTelemetry.breadcrumb("socket.connect.success", data: ["path": resolvedSocketPath]) } catch { - cliTelemetry.breadcrumb("socket.connect.failure") + cliTelemetry.breadcrumb("socket.connect.failure", data: ["path": resolvedSocketPath]) cliTelemetry.captureError(stage: "socket_connect", error: error) throw error } @@ -6613,7 +6819,8 @@ struct CMUXCLI { ALL commands (send, list-panels, new-split, notify, etc.). CMUX_TAB_ID Optional alias used by `tab-action`/`rename-tab` as default --tab. CMUX_SURFACE_ID Auto-set in cmux terminals. Used as default --surface. - CMUX_SOCKET_PATH Override the default Unix socket path (/tmp/cmux.sock). + CMUX_SOCKET_PATH Override the Unix socket path. Without this, the CLI defaults + to /tmp/cmux.sock and auto-discovers tagged/debug sockets. CMUX_CLI_SENTRY_DISABLED Set to 1 to disable CLI Sentry socket diagnostics. """ diff --git a/tests/test_cli_socket_autodiscovery.py b/tests/test_cli_socket_autodiscovery.py new file mode 100755 index 00000000..6eaa205d --- /dev/null +++ b/tests/test_cli_socket_autodiscovery.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Regression test: CLI should auto-discover tagged debug sockets from CMUX_TAG.""" + +from __future__ import annotations + +import glob +import os +import shutil +import socket +import subprocess +import threading + + +def resolve_cmux_cli() -> str: + explicit = os.environ.get("CMUX_CLI_BIN") or os.environ.get("CMUX_CLI") + if explicit and os.path.exists(explicit) and os.access(explicit, os.X_OK): + return explicit + + candidates: list[str] = [] + candidates.extend(glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/*/Build/Products/Debug/cmux"))) + candidates.extend(glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux")) + candidates = [p for p in candidates if os.path.exists(p) and os.access(p, os.X_OK)] + if candidates: + candidates.sort(key=os.path.getmtime, reverse=True) + return candidates[0] + + in_path = shutil.which("cmux") + if in_path: + return in_path + + raise RuntimeError("Unable to find cmux CLI binary. Set CMUX_CLI_BIN.") + + +class PingServer: + def __init__(self, socket_path: str): + self.socket_path = socket_path + self.ready = threading.Event() + self.error: Exception | None = None + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def wait_ready(self, timeout: float) -> bool: + return self.ready.wait(timeout) + + def join(self, timeout: float) -> None: + self._thread.join(timeout=timeout) + + def _run(self) -> None: + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + if os.path.exists(self.socket_path): + os.remove(self.socket_path) + server.bind(self.socket_path) + server.listen(1) + server.settimeout(6.0) + self.ready.set() + + # The CLI may probe candidate sockets with a connect-only check before + # issuing the actual command, so handle more than one connection. + for _ in range(4): + conn, _ = server.accept() + with conn: + conn.settimeout(2.0) + data = b"" + while b"\n" not in data: + chunk = conn.recv(4096) + if not chunk: + break + data += chunk + + if b"ping" in data: + conn.sendall(b"PONG\n") + return + raise RuntimeError("Did not receive ping command on test socket") + except Exception as exc: # pragma: no cover - explicit surface on failure + self.error = exc + self.ready.set() + finally: + server.close() + + +def main() -> int: + try: + cli_path = resolve_cmux_cli() + except Exception as exc: + print(f"FAIL: {exc}") + return 1 + + tag = f"cli-autodiscover-{os.getpid()}" + socket_path = f"/tmp/cmux-debug-{tag}.sock" + server = PingServer(socket_path) + server.start() + + if not server.wait_ready(2.0): + print("FAIL: socket server did not become ready") + return 1 + + if server.error is not None: + print(f"FAIL: socket server failed to start: {server.error}") + return 1 + + env = os.environ.copy() + env["CMUX_SOCKET_PATH"] = "/tmp/cmux.sock" + env["CMUX_TAG"] = tag + env["CMUX_CLI_SENTRY_DISABLED"] = "1" + env["CMUX_CLAUDE_HOOK_SENTRY_DISABLED"] = "1" + + try: + proc = subprocess.run( + [cli_path, "ping"], + text=True, + capture_output=True, + env=env, + timeout=8, + check=False, + ) + except Exception as exc: + print(f"FAIL: invoking cmux ping failed: {exc}") + return 1 + finally: + server.join(timeout=2.0) + try: + os.remove(socket_path) + except OSError: + pass + + if server.error is not None: + print(f"FAIL: socket server error: {server.error}") + return 1 + + if proc.returncode != 0: + print("FAIL: cmux ping returned non-zero status") + print(f"stdout={proc.stdout!r}") + print(f"stderr={proc.stderr!r}") + return 1 + + if proc.stdout.strip() != "PONG": + print("FAIL: cmux ping did not use auto-discovered socket") + print(f"stdout={proc.stdout!r}") + print(f"stderr={proc.stderr!r}") + return 1 + + print("PASS: cmux ping auto-discovers tagged socket from CMUX_TAG") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())