Fix CLI socket autodiscovery for tagged cmux sockets (#832)
This commit is contained in:
parent
d72b014d6d
commit
76835662f5
2 changed files with 393 additions and 36 deletions
231
CLI/cmux.swift
231
CLI/cmux.swift
|
|
@ -451,9 +451,159 @@ private enum SocketPasswordResolver {
|
|||
}
|
||||
}
|
||||
|
||||
private enum CLISocketPathSource {
|
||||
case explicitFlag
|
||||
case environment
|
||||
case implicitDefault
|
||||
}
|
||||
|
||||
private enum CLISocketPathResolver {
|
||||
static let defaultSocketPath = "/tmp/cmux.sock"
|
||||
private static let fallbackSocketPath = "/tmp/cmux-debug.sock"
|
||||
private static let stagingSocketPath = "/tmp/cmux-staging.sock"
|
||||
private static let lastSocketPathFile = "/tmp/cmux-last-socket-path"
|
||||
|
||||
static func resolve(
|
||||
requestedPath: String,
|
||||
source: CLISocketPathSource,
|
||||
environment: [String: String] = ProcessInfo.processInfo.environment
|
||||
) -> String {
|
||||
guard source == .implicitDefault else {
|
||||
return requestedPath
|
||||
}
|
||||
|
||||
let candidates = dedupe(candidatePaths(requestedPath: requestedPath, environment: environment))
|
||||
|
||||
// Prefer sockets that are currently accepting connections.
|
||||
for path in candidates where canConnect(to: path) {
|
||||
return path
|
||||
}
|
||||
|
||||
// If the listener is still starting, prefer existing socket files.
|
||||
for path in candidates where isSocketFile(path) {
|
||||
return path
|
||||
}
|
||||
|
||||
return requestedPath
|
||||
}
|
||||
|
||||
private static func candidatePaths(requestedPath: String, environment: [String: String]) -> [String] {
|
||||
var candidates: [String] = []
|
||||
|
||||
if let tag = normalized(environment["CMUX_TAG"]) {
|
||||
let slug = sanitizeTagSlug(tag)
|
||||
candidates.append("/tmp/cmux-debug-\(slug).sock")
|
||||
candidates.append("/tmp/cmux-\(slug).sock")
|
||||
}
|
||||
|
||||
candidates.append(requestedPath)
|
||||
candidates.append(fallbackSocketPath)
|
||||
candidates.append(stagingSocketPath)
|
||||
candidates.append(contentsOf: discoverTaggedSockets(limit: 12))
|
||||
if let last = readLastSocketPath() {
|
||||
candidates.append(last)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
private static func readLastSocketPath() -> String? {
|
||||
guard let data = try? String(contentsOfFile: lastSocketPathFile, encoding: .utf8) else {
|
||||
return nil
|
||||
}
|
||||
return normalized(data)
|
||||
}
|
||||
|
||||
private static func discoverTaggedSockets(limit: Int) -> [String] {
|
||||
guard let entries = try? FileManager.default.contentsOfDirectory(atPath: "/tmp") else {
|
||||
return []
|
||||
}
|
||||
|
||||
var discovered: [(path: String, mtime: TimeInterval)] = []
|
||||
discovered.reserveCapacity(min(limit, entries.count))
|
||||
for name in entries where name.hasPrefix("cmux") && name.hasSuffix(".sock") {
|
||||
let path = "/tmp/\(name)"
|
||||
var st = stat()
|
||||
guard lstat(path, &st) == 0 else { continue }
|
||||
guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else { continue }
|
||||
if path == defaultSocketPath || path == fallbackSocketPath || path == stagingSocketPath {
|
||||
continue
|
||||
}
|
||||
let modified = TimeInterval(st.st_mtimespec.tv_sec) + TimeInterval(st.st_mtimespec.tv_nsec) / 1_000_000_000
|
||||
discovered.append((path: path, mtime: modified))
|
||||
}
|
||||
|
||||
discovered.sort { $0.mtime > $1.mtime }
|
||||
return discovered.prefix(limit).map(\.path)
|
||||
}
|
||||
|
||||
private static func isSocketFile(_ path: String) -> Bool {
|
||||
var st = stat()
|
||||
return lstat(path, &st) == 0 && (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK)
|
||||
}
|
||||
|
||||
private static func canConnect(to path: String) -> Bool {
|
||||
guard isSocketFile(path) else { return false }
|
||||
let fd = socket(AF_UNIX, SOCK_STREAM, 0)
|
||||
guard fd >= 0 else { return false }
|
||||
defer { Darwin.close(fd) }
|
||||
|
||||
var addr = sockaddr_un()
|
||||
addr.sun_family = sa_family_t(AF_UNIX)
|
||||
let maxLength = MemoryLayout.size(ofValue: addr.sun_path)
|
||||
path.withCString { ptr in
|
||||
withUnsafeMutablePointer(to: &addr.sun_path) { pathPtr in
|
||||
let buf = UnsafeMutableRawPointer(pathPtr).assumingMemoryBound(to: CChar.self)
|
||||
strncpy(buf, ptr, maxLength - 1)
|
||||
}
|
||||
}
|
||||
|
||||
let result = withUnsafePointer(to: &addr) { ptr in
|
||||
ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in
|
||||
Darwin.connect(fd, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_un>.size))
|
||||
}
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
|
||||
private static func sanitizeTagSlug(_ raw: String) -> String {
|
||||
let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
|
||||
let slug = trimmed
|
||||
.replacingOccurrences(of: "[^a-z0-9]+", with: "-", options: .regularExpression)
|
||||
.replacingOccurrences(of: "-+", with: "-", options: .regularExpression)
|
||||
.trimmingCharacters(in: CharacterSet(charactersIn: "-"))
|
||||
return slug.isEmpty ? "agent" : slug
|
||||
}
|
||||
|
||||
private static func normalized(_ value: String?) -> String? {
|
||||
guard let value else { return nil }
|
||||
let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
return trimmed.isEmpty ? nil : trimmed
|
||||
}
|
||||
|
||||
private static func dedupe(_ paths: [String]) -> [String] {
|
||||
var seen: Set<String> = []
|
||||
var ordered: [String] = []
|
||||
ordered.reserveCapacity(paths.count)
|
||||
for path in paths where !path.isEmpty {
|
||||
if seen.insert(path).inserted {
|
||||
ordered.append(path)
|
||||
}
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
}
|
||||
|
||||
final class SocketClient {
|
||||
private let path: String
|
||||
private var socketFD: Int32 = -1
|
||||
private static let connectRetryWindowSeconds: TimeInterval = 2.0
|
||||
private static let connectRetryIntervalSeconds: TimeInterval = 0.1
|
||||
private static let retriableConnectErrnos: Set<Int32> = [
|
||||
ENOENT,
|
||||
ECONNREFUSED,
|
||||
EAGAIN,
|
||||
EINTR
|
||||
]
|
||||
private static let defaultResponseTimeoutSeconds: TimeInterval = 15.0
|
||||
private static let responseTimeoutSeconds: TimeInterval = {
|
||||
let env = ProcessInfo.processInfo.environment
|
||||
|
|
@ -472,10 +622,23 @@ final class SocketClient {
|
|||
func connect() throws {
|
||||
if socketFD >= 0 { return }
|
||||
|
||||
// Verify socket is owned by the current user to prevent fake-socket attacks
|
||||
let deadline = Date().addingTimeInterval(Self.connectRetryWindowSeconds)
|
||||
var lastError: CLIError?
|
||||
|
||||
while true {
|
||||
// Verify socket is owned by the current user to prevent fake-socket attacks.
|
||||
var st = stat()
|
||||
guard stat(path, &st) == 0 else {
|
||||
throw CLIError(message: "Socket not found at \(path)")
|
||||
let error = CLIError(message: "Socket not found at \(path)")
|
||||
lastError = error
|
||||
if errno == ENOENT, Date() < deadline {
|
||||
Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds)
|
||||
continue
|
||||
}
|
||||
throw error
|
||||
}
|
||||
guard (st.st_mode & mode_t(S_IFMT)) == mode_t(S_IFSOCK) else {
|
||||
throw CLIError(message: "Path exists at \(path) but is not a Unix socket")
|
||||
}
|
||||
guard st.st_uid == getuid() else {
|
||||
throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect")
|
||||
|
|
@ -501,11 +664,24 @@ final class SocketClient {
|
|||
Darwin.connect(socketFD, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_un>.size))
|
||||
}
|
||||
}
|
||||
if result != 0 {
|
||||
if result == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
let connectErrno = errno
|
||||
Darwin.close(socketFD)
|
||||
socketFD = -1
|
||||
throw CLIError(message: "Failed to connect to socket at \(path)")
|
||||
|
||||
let error = CLIError(message: "Failed to connect to socket at \(path)")
|
||||
lastError = error
|
||||
if Self.retriableConnectErrnos.contains(connectErrno), Date() < deadline {
|
||||
Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds)
|
||||
continue
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
throw lastError ?? CLIError(message: "Failed to connect to socket at \(path)")
|
||||
}
|
||||
|
||||
func close() {
|
||||
|
|
@ -614,7 +790,19 @@ struct CMUXCLI {
|
|||
let args: [String]
|
||||
|
||||
func run() throws {
|
||||
var socketPath = ProcessInfo.processInfo.environment["CMUX_SOCKET_PATH"] ?? "/tmp/cmux.sock"
|
||||
let processEnv = ProcessInfo.processInfo.environment
|
||||
let envSocketPath: String? = {
|
||||
guard let raw = processEnv["CMUX_SOCKET_PATH"] else { return nil }
|
||||
let trimmed = raw.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
return trimmed.isEmpty ? nil : trimmed
|
||||
}()
|
||||
var socketPath = envSocketPath ?? CLISocketPathResolver.defaultSocketPath
|
||||
var socketPathSource: CLISocketPathSource
|
||||
if let envSocketPath {
|
||||
socketPathSource = envSocketPath == CLISocketPathResolver.defaultSocketPath ? .implicitDefault : .environment
|
||||
} else {
|
||||
socketPathSource = .implicitDefault
|
||||
}
|
||||
var jsonOutput = false
|
||||
var idFormatArg: String? = nil
|
||||
var windowId: String? = nil
|
||||
|
|
@ -628,6 +816,7 @@ struct CMUXCLI {
|
|||
throw CLIError(message: "--socket requires a path")
|
||||
}
|
||||
socketPath = args[index + 1]
|
||||
socketPathSource = .explicitFlag
|
||||
index += 2
|
||||
continue
|
||||
}
|
||||
|
|
@ -682,7 +871,12 @@ struct CMUXCLI {
|
|||
command: command,
|
||||
commandArgs: commandArgs,
|
||||
socketPath: socketPath,
|
||||
processEnv: ProcessInfo.processInfo.environment
|
||||
processEnv: processEnv
|
||||
)
|
||||
let resolvedSocketPath = CLISocketPathResolver.resolve(
|
||||
requestedPath: socketPath,
|
||||
source: socketPathSource,
|
||||
environment: processEnv
|
||||
)
|
||||
|
||||
if command == "version" {
|
||||
|
|
@ -692,7 +886,7 @@ struct CMUXCLI {
|
|||
|
||||
// If the argument looks like a path (not a known command), open a workspace there.
|
||||
if looksLikePath(command) {
|
||||
try openPath(command, socketPath: socketPath)
|
||||
try openPath(command, socketPath: resolvedSocketPath)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -706,16 +900,28 @@ struct CMUXCLI {
|
|||
return
|
||||
}
|
||||
|
||||
let client = SocketClient(path: socketPath)
|
||||
let client = SocketClient(path: resolvedSocketPath)
|
||||
if resolvedSocketPath != socketPath {
|
||||
cliTelemetry.breadcrumb(
|
||||
"socket.path.autodiscovered",
|
||||
data: [
|
||||
"requested_path": socketPath,
|
||||
"resolved_path": resolvedSocketPath
|
||||
]
|
||||
)
|
||||
}
|
||||
cliTelemetry.breadcrumb(
|
||||
"socket.connect.attempt",
|
||||
data: ["command": command]
|
||||
data: [
|
||||
"command": command,
|
||||
"path": resolvedSocketPath
|
||||
]
|
||||
)
|
||||
do {
|
||||
try client.connect()
|
||||
cliTelemetry.breadcrumb("socket.connect.success")
|
||||
cliTelemetry.breadcrumb("socket.connect.success", data: ["path": resolvedSocketPath])
|
||||
} catch {
|
||||
cliTelemetry.breadcrumb("socket.connect.failure")
|
||||
cliTelemetry.breadcrumb("socket.connect.failure", data: ["path": resolvedSocketPath])
|
||||
cliTelemetry.captureError(stage: "socket_connect", error: error)
|
||||
throw error
|
||||
}
|
||||
|
|
@ -6613,7 +6819,8 @@ struct CMUXCLI {
|
|||
ALL commands (send, list-panels, new-split, notify, etc.).
|
||||
CMUX_TAB_ID Optional alias used by `tab-action`/`rename-tab` as default --tab.
|
||||
CMUX_SURFACE_ID Auto-set in cmux terminals. Used as default --surface.
|
||||
CMUX_SOCKET_PATH Override the default Unix socket path (/tmp/cmux.sock).
|
||||
CMUX_SOCKET_PATH Override the Unix socket path. Without this, the CLI defaults
|
||||
to /tmp/cmux.sock and auto-discovers tagged/debug sockets.
|
||||
CMUX_CLI_SENTRY_DISABLED
|
||||
Set to 1 to disable CLI Sentry socket diagnostics.
|
||||
"""
|
||||
|
|
|
|||
150
tests/test_cli_socket_autodiscovery.py
Executable file
150
tests/test_cli_socket_autodiscovery.py
Executable file
|
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Regression test: CLI should auto-discover tagged debug sockets from CMUX_TAG."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
|
||||
def resolve_cmux_cli() -> str:
|
||||
explicit = os.environ.get("CMUX_CLI_BIN") or os.environ.get("CMUX_CLI")
|
||||
if explicit and os.path.exists(explicit) and os.access(explicit, os.X_OK):
|
||||
return explicit
|
||||
|
||||
candidates: list[str] = []
|
||||
candidates.extend(glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/*/Build/Products/Debug/cmux")))
|
||||
candidates.extend(glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux"))
|
||||
candidates = [p for p in candidates if os.path.exists(p) and os.access(p, os.X_OK)]
|
||||
if candidates:
|
||||
candidates.sort(key=os.path.getmtime, reverse=True)
|
||||
return candidates[0]
|
||||
|
||||
in_path = shutil.which("cmux")
|
||||
if in_path:
|
||||
return in_path
|
||||
|
||||
raise RuntimeError("Unable to find cmux CLI binary. Set CMUX_CLI_BIN.")
|
||||
|
||||
|
||||
class PingServer:
|
||||
def __init__(self, socket_path: str):
|
||||
self.socket_path = socket_path
|
||||
self.ready = threading.Event()
|
||||
self.error: Exception | None = None
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def start(self) -> None:
|
||||
self._thread.start()
|
||||
|
||||
def wait_ready(self, timeout: float) -> bool:
|
||||
return self.ready.wait(timeout)
|
||||
|
||||
def join(self, timeout: float) -> None:
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
def _run(self) -> None:
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
if os.path.exists(self.socket_path):
|
||||
os.remove(self.socket_path)
|
||||
server.bind(self.socket_path)
|
||||
server.listen(1)
|
||||
server.settimeout(6.0)
|
||||
self.ready.set()
|
||||
|
||||
# The CLI may probe candidate sockets with a connect-only check before
|
||||
# issuing the actual command, so handle more than one connection.
|
||||
for _ in range(4):
|
||||
conn, _ = server.accept()
|
||||
with conn:
|
||||
conn.settimeout(2.0)
|
||||
data = b""
|
||||
while b"\n" not in data:
|
||||
chunk = conn.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
|
||||
if b"ping" in data:
|
||||
conn.sendall(b"PONG\n")
|
||||
return
|
||||
raise RuntimeError("Did not receive ping command on test socket")
|
||||
except Exception as exc: # pragma: no cover - explicit surface on failure
|
||||
self.error = exc
|
||||
self.ready.set()
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
try:
|
||||
cli_path = resolve_cmux_cli()
|
||||
except Exception as exc:
|
||||
print(f"FAIL: {exc}")
|
||||
return 1
|
||||
|
||||
tag = f"cli-autodiscover-{os.getpid()}"
|
||||
socket_path = f"/tmp/cmux-debug-{tag}.sock"
|
||||
server = PingServer(socket_path)
|
||||
server.start()
|
||||
|
||||
if not server.wait_ready(2.0):
|
||||
print("FAIL: socket server did not become ready")
|
||||
return 1
|
||||
|
||||
if server.error is not None:
|
||||
print(f"FAIL: socket server failed to start: {server.error}")
|
||||
return 1
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CMUX_SOCKET_PATH"] = "/tmp/cmux.sock"
|
||||
env["CMUX_TAG"] = tag
|
||||
env["CMUX_CLI_SENTRY_DISABLED"] = "1"
|
||||
env["CMUX_CLAUDE_HOOK_SENTRY_DISABLED"] = "1"
|
||||
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
[cli_path, "ping"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
env=env,
|
||||
timeout=8,
|
||||
check=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"FAIL: invoking cmux ping failed: {exc}")
|
||||
return 1
|
||||
finally:
|
||||
server.join(timeout=2.0)
|
||||
try:
|
||||
os.remove(socket_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if server.error is not None:
|
||||
print(f"FAIL: socket server error: {server.error}")
|
||||
return 1
|
||||
|
||||
if proc.returncode != 0:
|
||||
print("FAIL: cmux ping returned non-zero status")
|
||||
print(f"stdout={proc.stdout!r}")
|
||||
print(f"stderr={proc.stderr!r}")
|
||||
return 1
|
||||
|
||||
if proc.stdout.strip() != "PONG":
|
||||
print("FAIL: cmux ping did not use auto-discovered socket")
|
||||
print(f"stdout={proc.stdout!r}")
|
||||
print(f"stderr={proc.stderr!r}")
|
||||
return 1
|
||||
|
||||
print("PASS: cmux ping auto-discovers tagged socket from CMUX_TAG")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue