Fix CLI socket autodiscovery for tagged cmux sockets (#832)

This commit is contained in:
Austin Wang 2026-03-04 18:03:25 -08:00 committed by GitHub
parent d72b014d6d
commit 76835662f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 393 additions and 36 deletions

View file

@ -451,9 +451,159 @@ private enum SocketPasswordResolver {
}
}
private enum CLISocketPathSource {
case explicitFlag
case environment
case implicitDefault
}
private enum CLISocketPathResolver {
static let defaultSocketPath = "/tmp/cmux.sock"
private static let fallbackSocketPath = "/tmp/cmux-debug.sock"
private static let stagingSocketPath = "/tmp/cmux-staging.sock"
private static let lastSocketPathFile = "/tmp/cmux-last-socket-path"
static func resolve(
requestedPath: String,
source: CLISocketPathSource,
environment: [String: String] = ProcessInfo.processInfo.environment
) -> String {
guard source == .implicitDefault else {
return requestedPath
}
let candidates = dedupe(candidatePaths(requestedPath: requestedPath, environment: environment))
// Prefer sockets that are currently accepting connections.
for path in candidates where canConnect(to: path) {
return path
}
// If the listener is still starting, prefer existing socket files.
for path in candidates where isSocketFile(path) {
return path
}
return requestedPath
}
private static func candidatePaths(requestedPath: String, environment: [String: String]) -> [String] {
var candidates: [String] = []
if let tag = normalized(environment["CMUX_TAG"]) {
let slug = sanitizeTagSlug(tag)
candidates.append("/tmp/cmux-debug-\(slug).sock")
candidates.append("/tmp/cmux-\(slug).sock")
}
candidates.append(requestedPath)
candidates.append(fallbackSocketPath)
candidates.append(stagingSocketPath)
candidates.append(contentsOf: discoverTaggedSockets(limit: 12))
if let last = readLastSocketPath() {
candidates.append(last)
}
return candidates
}
private static func readLastSocketPath() -> String? {
guard let data = try? String(contentsOfFile: lastSocketPathFile, encoding: .utf8) else {
return nil
}
return normalized(data)
}
private static func discoverTaggedSockets(limit: Int) -> [String] {
guard let entries = try? FileManager.default.contentsOfDirectory(atPath: "/tmp") else {
return []
}
var discovered: [(path: String, mtime: TimeInterval)] = []
discovered.reserveCapacity(min(limit, entries.count))
for name in entries where name.hasPrefix("cmux") && name.hasSuffix(".sock") {
let path = "/tmp/\(name)"
var st = stat()
guard lstat(path, &st) == 0 else { continue }
guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else { continue }
if path == defaultSocketPath || path == fallbackSocketPath || path == stagingSocketPath {
continue
}
let modified = TimeInterval(st.st_mtimespec.tv_sec) + TimeInterval(st.st_mtimespec.tv_nsec) / 1_000_000_000
discovered.append((path: path, mtime: modified))
}
discovered.sort { $0.mtime > $1.mtime }
return discovered.prefix(limit).map(\.path)
}
private static func isSocketFile(_ path: String) -> Bool {
var st = stat()
return lstat(path, &st) == 0 && (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK)
}
private static func canConnect(to path: String) -> Bool {
guard isSocketFile(path) else { return false }
let fd = socket(AF_UNIX, SOCK_STREAM, 0)
guard fd >= 0 else { return false }
defer { Darwin.close(fd) }
var addr = sockaddr_un()
addr.sun_family = sa_family_t(AF_UNIX)
let maxLength = MemoryLayout.size(ofValue: addr.sun_path)
path.withCString { ptr in
withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in
let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self)
strncpy(buf, ptr, maxLength - 1)
}
}
let result = withUnsafePointer(to: &addr) { ptr in
ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in
Darwin.connect(fd, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_un>.size))
}
}
return result == 0
}
private static func sanitizeTagSlug(_ raw: String) -> String {
let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
let slug = trimmed
.replacingOccurrences(of: "[^a-z0-9]+", with: "-", options: .regularExpression)
.replacingOccurrences(of: "-+", with: "-", options: .regularExpression)
.trimmingCharacters(in: CharacterSet(charactersIn: "-"))
return slug.isEmpty ? "agent" : slug
}
private static func normalized(_ value: String?) -> String? {
guard let value else { return nil }
let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines)
return trimmed.isEmpty ? nil : trimmed
}
private static func dedupe(_ paths: [String]) -> [String] {
var seen: Set<String> = []
var ordered: [String] = []
ordered.reserveCapacity(paths.count)
for path in paths where !path.isEmpty {
if seen.insert(path).inserted {
ordered.append(path)
}
}
return ordered
}
}
final class SocketClient {
private let path: String
private var socketFD: Int32 = -1
private static let connectRetryWindowSeconds: TimeInterval = 2.0
private static let connectRetryIntervalSeconds: TimeInterval = 0.1
private static let retriableConnectErrnos: Set<Int32> = [
ENOENT,
ECONNREFUSED,
EAGAIN,
EINTR
]
private static let defaultResponseTimeoutSeconds: TimeInterval = 15.0
private static let responseTimeoutSeconds: TimeInterval = {
let env = ProcessInfo.processInfo.environment
@ -472,40 +622,66 @@ final class SocketClient {
func connect() throws {
if socketFD >= 0 { return }
// Verify socket is owned by the current user to prevent fake-socket attacks
var st = stat()
guard stat(path, &st) == 0 else {
throw CLIError(message: "Socket not found at \(path)")
}
guard st.st_uid == getuid() else {
throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect")
}
let deadline = Date().addingTimeInterval(Self.connectRetryWindowSeconds)
var lastError: CLIError?
socketFD = socket(AF_UNIX, SOCK_STREAM, 0)
if socketFD < 0 {
throw CLIError(message: "Failed to create socket")
}
var addr = sockaddr_un()
addr.sun_family = sa_family_t(AF_UNIX)
let maxLength = MemoryLayout.size(ofValue: addr.sun_path)
path.withCString { ptr in
withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in
let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self)
strncpy(buf, ptr, maxLength - 1)
while true {
// Verify socket is owned by the current user to prevent fake-socket attacks.
var st = stat()
guard stat(path, &st) == 0 else {
let error = CLIError(message: "Socket not found at \(path)")
lastError = error
if errno == ENOENT, Date() < deadline {
Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds)
continue
}
throw error
}
}
let result = withUnsafePointer(to: &addr) { ptr in
ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in
Darwin.connect(socketFD, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_un>.size))
guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else {
throw CLIError(message: "Path exists at \(path) but is not a Unix socket")
}
}
if result != 0 {
guard st.st_uid == getuid() else {
throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect")
}
socketFD = socket(AF_UNIX, SOCK_STREAM, 0)
if socketFD < 0 {
throw CLIError(message: "Failed to create socket")
}
var addr = sockaddr_un()
addr.sun_family = sa_family_t(AF_UNIX)
let maxLength = MemoryLayout.size(ofValue: addr.sun_path)
path.withCString { ptr in
withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in
let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self)
strncpy(buf, ptr, maxLength - 1)
}
}
let result = withUnsafePointer(to: &addr) { ptr in
ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in
Darwin.connect(socketFD, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_un>.size))
}
}
if result == 0 {
return
}
let connectErrno = errno
Darwin.close(socketFD)
socketFD = -1
throw CLIError(message: "Failed to connect to socket at \(path)")
let error = CLIError(message: "Failed to connect to socket at \(path)")
lastError = error
if Self.retriableConnectErrnos.contains(connectErrno), Date() < deadline {
Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds)
continue
}
throw error
}
throw lastError ?? CLIError(message: "Failed to connect to socket at \(path)")
}
func close() {
@ -614,7 +790,19 @@ struct CMUXCLI {
let args: [String]
func run() throws {
var socketPath = ProcessInfo.processInfo.environment["CMUX_SOCKET_PATH"] ?? "/tmp/cmux.sock"
let processEnv = ProcessInfo.processInfo.environment
let envSocketPath: String? = {
guard let raw = processEnv["CMUX_SOCKET_PATH"] else { return nil }
let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines)
return trimmed.isEmpty ? nil : trimmed
}()
var socketPath = envSocketPath ?? CLISocketPathResolver.defaultSocketPath
var socketPathSource: CLISocketPathSource
if let envSocketPath {
socketPathSource = envSocketPath == CLISocketPathResolver.defaultSocketPath ? .implicitDefault : .environment
} else {
socketPathSource = .implicitDefault
}
var jsonOutput = false
var idFormatArg: String? = nil
var windowId: String? = nil
@ -628,6 +816,7 @@ struct CMUXCLI {
throw CLIError(message: "--socket requires a path")
}
socketPath = args[index + 1]
socketPathSource = .explicitFlag
index += 2
continue
}
@ -682,7 +871,12 @@ struct CMUXCLI {
command: command,
commandArgs: commandArgs,
socketPath: socketPath,
processEnv: ProcessInfo.processInfo.environment
processEnv: processEnv
)
let resolvedSocketPath = CLISocketPathResolver.resolve(
requestedPath: socketPath,
source: socketPathSource,
environment: processEnv
)
if command == "version" {
@ -692,7 +886,7 @@ struct CMUXCLI {
// If the argument looks like a path (not a known command), open a workspace there.
if looksLikePath(command) {
try openPath(command, socketPath: socketPath)
try openPath(command, socketPath: resolvedSocketPath)
return
}
@ -706,16 +900,28 @@ struct CMUXCLI {
return
}
let client = SocketClient(path: socketPath)
let client = SocketClient(path: resolvedSocketPath)
if resolvedSocketPath != socketPath {
cliTelemetry.breadcrumb(
"socket.path.autodiscovered",
data: [
"requested_path": socketPath,
"resolved_path": resolvedSocketPath
]
)
}
cliTelemetry.breadcrumb(
"socket.connect.attempt",
data: ["command": command]
data: [
"command": command,
"path": resolvedSocketPath
]
)
do {
try client.connect()
cliTelemetry.breadcrumb("socket.connect.success")
cliTelemetry.breadcrumb("socket.connect.success", data: ["path": resolvedSocketPath])
} catch {
cliTelemetry.breadcrumb("socket.connect.failure")
cliTelemetry.breadcrumb("socket.connect.failure", data: ["path": resolvedSocketPath])
cliTelemetry.captureError(stage: "socket_connect", error: error)
throw error
}
@ -6613,7 +6819,8 @@ struct CMUXCLI {
ALL commands (send, list-panels, new-split, notify, etc.).
CMUX_TAB_ID Optional alias used by `tab-action`/`rename-tab` as default --tab.
CMUX_SURFACE_ID Auto-set in cmux terminals. Used as default --surface.
CMUX_SOCKET_PATH Override the default Unix socket path (/tmp/cmux.sock).
CMUX_SOCKET_PATH Override the Unix socket path. Without this, the CLI defaults
to /tmp/cmux.sock and auto-discovers tagged/debug sockets.
CMUX_CLI_SENTRY_DISABLED
Set to 1 to disable CLI Sentry socket diagnostics.
"""

View file

@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""Regression test: CLI should auto-discover tagged debug sockets from CMUX_TAG."""
from __future__ import annotations
import glob
import os
import shutil
import socket
import subprocess
import threading
def resolve_cmux_cli() -> str:
explicit = os.environ.get("CMUX_CLI_BIN") or os.environ.get("CMUX_CLI")
if explicit and os.path.exists(explicit) and os.access(explicit, os.X_OK):
return explicit
candidates: list[str] = []
candidates.extend(glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/*/Build/Products/Debug/cmux")))
candidates.extend(glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux"))
candidates = [p for p in candidates if os.path.exists(p) and os.access(p, os.X_OK)]
if candidates:
candidates.sort(key=os.path.getmtime, reverse=True)
return candidates[0]
in_path = shutil.which("cmux")
if in_path:
return in_path
raise RuntimeError("Unable to find cmux CLI binary. Set CMUX_CLI_BIN.")
class PingServer:
def __init__(self, socket_path: str):
self.socket_path = socket_path
self.ready = threading.Event()
self.error: Exception | None = None
self._thread = threading.Thread(target=self._run, daemon=True)
def start(self) -> None:
self._thread.start()
def wait_ready(self, timeout: float) -> bool:
return self.ready.wait(timeout)
def join(self, timeout: float) -> None:
self._thread.join(timeout=timeout)
def _run(self) -> None:
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
if os.path.exists(self.socket_path):
os.remove(self.socket_path)
server.bind(self.socket_path)
server.listen(1)
server.settimeout(6.0)
self.ready.set()
# The CLI may probe candidate sockets with a connect-only check before
# issuing the actual command, so handle more than one connection.
for _ in range(4):
conn, _ = server.accept()
with conn:
conn.settimeout(2.0)
data = b""
while b"\n" not in data:
chunk = conn.recv(4096)
if not chunk:
break
data += chunk
if b"ping" in data:
conn.sendall(b"PONG\n")
return
raise RuntimeError("Did not receive ping command on test socket")
except Exception as exc: # pragma: no cover - explicit surface on failure
self.error = exc
self.ready.set()
finally:
server.close()
def main() -> int:
try:
cli_path = resolve_cmux_cli()
except Exception as exc:
print(f"FAIL: {exc}")
return 1
tag = f"cli-autodiscover-{os.getpid()}"
socket_path = f"/tmp/cmux-debug-{tag}.sock"
server = PingServer(socket_path)
server.start()
if not server.wait_ready(2.0):
print("FAIL: socket server did not become ready")
return 1
if server.error is not None:
print(f"FAIL: socket server failed to start: {server.error}")
return 1
env = os.environ.copy()
env["CMUX_SOCKET_PATH"] = "/tmp/cmux.sock"
env["CMUX_TAG"] = tag
env["CMUX_CLI_SENTRY_DISABLED"] = "1"
env["CMUX_CLAUDE_HOOK_SENTRY_DISABLED"] = "1"
try:
proc = subprocess.run(
[cli_path, "ping"],
text=True,
capture_output=True,
env=env,
timeout=8,
check=False,
)
except Exception as exc:
print(f"FAIL: invoking cmux ping failed: {exc}")
return 1
finally:
server.join(timeout=2.0)
try:
os.remove(socket_path)
except OSError:
pass
if server.error is not None:
print(f"FAIL: socket server error: {server.error}")
return 1
if proc.returncode != 0:
print("FAIL: cmux ping returned non-zero status")
print(f"stdout={proc.stdout!r}")
print(f"stderr={proc.stderr!r}")
return 1
if proc.stdout.strip() != "PONG":
print("FAIL: cmux ping did not use auto-discovered socket")
print(f"stdout={proc.stdout!r}")
print(f"stderr={proc.stderr!r}")
return 1
print("PASS: cmux ping auto-discovers tagged socket from CMUX_TAG")
return 0
if __name__ == "__main__":
raise SystemExit(main())