diff --git a/CLI/cmux.swift b/CLI/cmux.swift index 8d218262..4475f341 100644 --- a/CLI/cmux.swift +++ b/CLI/cmux.swift @@ -672,7 +672,9 @@ final class SocketClient { Darwin.close(socketFD) socketFD = -1 - let error = CLIError(message: "Failed to connect to socket at \(path)") + let error = CLIError( + message: "Failed to connect to socket at \(path) (\(String(cString: strerror(connectErrno))), errno \(connectErrno))" + ) lastError = error if Self.retriableConnectErrnos.contains(connectErrno), Date() < deadline { Thread.sleep(forTimeInterval: Self.connectRetryIntervalSeconds) diff --git a/Resources/bin/claude b/Resources/bin/claude index 2205fe3c..02939248 100755 --- a/Resources/bin/claude +++ b/Resources/bin/claude @@ -18,8 +18,36 @@ find_real_claude() { return 1 } -# Pass through if not in a cmux terminal or hooks are disabled. -if [[ -z "$CMUX_SURFACE_ID" || "$CMUX_CLAUDE_HOOKS_DISABLED" == "1" ]]; then +# Return 0 only when CMUX_SOCKET_PATH points to a live cmux socket. +cmux_socket_available() { + local socket="${CMUX_SOCKET_PATH:-}" + [[ -n "$socket" && -S "$socket" ]] || return 1 + + local self_dir cmux_bin + self_dir="$(cd "$(dirname "$0")" && pwd)" + cmux_bin="$self_dir/cmux" + [[ -x "$cmux_bin" ]] || cmux_bin="$(command -v cmux || true)" + [[ -n "$cmux_bin" ]] || return 1 + + # Keep stale/hung socket checks bounded so claude startup does not block + # behind the CLI default timeout (15s). + CMUXTERM_CLI_RESPONSE_TIMEOUT_SEC=0.75 \ + "$cmux_bin" --socket "$socket" ping >/dev/null 2>&1 +} + +# Pass through if not in a cmux terminal, hooks are disabled, or the cmux +# socket is unavailable (stale env / app not running). +IN_CMUX=0 +if [[ -n "$CMUX_SURFACE_ID" ]]; then + IN_CMUX=1 +fi + +if [[ "$IN_CMUX" == "0" || "$CMUX_CLAUDE_HOOKS_DISABLED" == "1" ]] || ! cmux_socket_available; then + # In cmux-launched shells, preserve old behavior and always clear nested + # Claude session markers, even when we must pass through due to stale socket. + if [[ "$IN_CMUX" == "1" ]]; then + unset CLAUDECODE + fi REAL_CLAUDE="$(find_real_claude)" || { echo "Error: claude not found in PATH" >&2; exit 127; } exec "$REAL_CLAUDE" "$@" fi @@ -50,7 +78,7 @@ done # Build hooks settings JSON. # Claude Code merges --settings additively with the user's own settings.json. -HOOKS_JSON='{"hooks":{"SessionStart":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook session-start","timeout":10}]}],"Stop":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook stop","timeout":10}]}],"Notification":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook notification","timeout":10}]}],"UserPromptSubmit":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook prompt-submit","timeout":10}]}]}}' +HOOKS_JSON='{"hooks":{"SessionStart":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook session-start","timeout":10}]}],"Stop":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook stop","timeout":10}]}],"Notification":[{"matcher":"","hooks":[{"type":"command","command":"cmux claude-hook notification","timeout":10}]}]}}' if [[ "$SKIP_SESSION_ID" == true ]]; then exec "$REAL_CLAUDE" --settings "$HOOKS_JSON" "$@" diff --git a/Sources/TerminalController.swift b/Sources/TerminalController.swift index 43995ccb..4f199842 100644 --- a/Sources/TerminalController.swift +++ b/Sources/TerminalController.swift @@ -34,6 +34,11 @@ class TerminalController { private nonisolated(unsafe) var serverSocket: Int32 = -1 private nonisolated(unsafe) var isRunning = false private nonisolated(unsafe) var acceptLoopAlive = false + private nonisolated(unsafe) var activeAcceptLoopGeneration: UInt64 = 0 + private nonisolated(unsafe) var nextAcceptLoopGeneration: UInt64 = 0 + private nonisolated(unsafe) var pendingAcceptLoopRearmGeneration: UInt64? + private nonisolated(unsafe) var listenerStartInProgress = false + private nonisolated let listenerStateLock = NSLock() private var clientHandlers: [Int32: Thread] = [:] private var tabManager: TabManager? private var accessMode: SocketControlMode = .cmuxOnly @@ -41,6 +46,20 @@ class TerminalController { private nonisolated(unsafe) static var socketCommandPolicyDepth: Int = 0 private nonisolated(unsafe) static var socketCommandFocusAllowanceStack: [Bool] = [] private nonisolated static let socketCommandPolicyLock = NSLock() + private nonisolated static let socketListenBacklog: Int32 = 128 + private nonisolated static let acceptFailureBaseBackoffMs = 10 + private nonisolated static let acceptFailureMaxBackoffMs = 5_000 + private nonisolated static let acceptFailureMinimumRearmDelayMs = 100 + private nonisolated static let acceptFailureRearmThreshold = 50 + + private struct ListenerStateSnapshot { + let socketPath: String + let serverSocket: Int32 + let isRunning: Bool + let acceptLoopAlive: Bool + let activeGeneration: UInt64 + let pendingRearmGeneration: UInt64? + } private static let focusIntentV1Commands: Set = [ "focus_window", @@ -127,6 +146,31 @@ class TerminalController { private init() {} + private nonisolated func withListenerState(_ body: () -> T) -> T { + listenerStateLock.lock() + defer { listenerStateLock.unlock() } + return body() + } + + private nonisolated func listenerStateSnapshot() -> ListenerStateSnapshot { + withListenerState { + ListenerStateSnapshot( + socketPath: socketPath, + serverSocket: serverSocket, + isRunning: isRunning, + acceptLoopAlive: acceptLoopAlive, + activeGeneration: activeAcceptLoopGeneration, + pendingRearmGeneration: pendingAcceptLoopRearmGeneration + ) + } + } + + private nonisolated func shouldContinueAcceptLoop(generation: UInt64) -> Bool { + withListenerState { + isRunning && generation == activeAcceptLoopGeneration + } + } + nonisolated static func shouldSuppressSocketCommandActivation() -> Bool { socketCommandPolicyLock.lock() defer { socketCommandPolicyLock.unlock() } @@ -369,12 +413,14 @@ class TerminalController { errnoCode: Int32? = nil, extra: [String: Any] = [:] ) -> [String: Any] { + let snapshot = listenerStateSnapshot() var data: [String: Any] = [ "stage": stage, - "path": socketPath, - "isRunning": isRunning ? 1 : 0, - "acceptLoopAlive": acceptLoopAlive ? 1 : 0, - "serverSocket": Int(serverSocket) + "path": snapshot.socketPath, + "isRunning": snapshot.isRunning ? 1 : 0, + "acceptLoopAlive": snapshot.acceptLoopAlive ? 1 : 0, + "serverSocket": Int(snapshot.serverSocket), + "activeGeneration": snapshot.activeGeneration ] if let errnoCode { data["errno"] = Int(errnoCode) @@ -397,27 +443,108 @@ class TerminalController { sentryCaptureError(message, category: "socket", data: data, contextKey: "socket_listener") } + nonisolated static func acceptErrorClassification(errnoCode: Int32) -> String { + switch errnoCode { + case EINTR, ECONNABORTED, EAGAIN, EWOULDBLOCK: + return "immediate_retry" + case EMFILE, ENFILE, ENOBUFS, ENOMEM: + return "resource_pressure" + case EBADF, EINVAL, ENOTSOCK: + return "fatal" + default: + return "retry_with_backoff" + } + } + + nonisolated static func shouldRearmListenerForAcceptError(errnoCode: Int32) -> Bool { + acceptErrorClassification(errnoCode: errnoCode) == "fatal" + } + + nonisolated static func shouldRetryAcceptImmediately(errnoCode: Int32) -> Bool { + acceptErrorClassification(errnoCode: errnoCode) == "immediate_retry" + } + + nonisolated static func shouldRearmForConsecutiveAcceptFailures(consecutiveFailures: Int) -> Bool { + consecutiveFailures >= acceptFailureRearmThreshold + } + + nonisolated static func acceptFailureBackoffMilliseconds(consecutiveFailures: Int) -> Int { + guard consecutiveFailures > 0 else { return 0 } + var delay = acceptFailureBaseBackoffMs + var remaining = consecutiveFailures - 1 + while remaining > 0 { + if delay >= acceptFailureMaxBackoffMs { + return acceptFailureMaxBackoffMs + } + delay = min(delay * 2, acceptFailureMaxBackoffMs) + remaining -= 1 + } + return delay + } + + nonisolated static func acceptFailureRearmDelayMilliseconds(consecutiveFailures: Int) -> Int { + max( + acceptFailureBackoffMilliseconds(consecutiveFailures: consecutiveFailures), + acceptFailureMinimumRearmDelayMs + ) + } + + nonisolated static func shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: Int) -> Bool { + guard consecutiveFailures > 0 else { return false } + if consecutiveFailures <= 3 { + return true + } + return (consecutiveFailures & (consecutiveFailures - 1)) == 0 + } + + nonisolated static func shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: Bool, + isRunning: Bool, + activeGeneration: UInt64, + listenerStartInProgress: Bool + ) -> Bool { + guard pathMatches else { return false } + guard !listenerStartInProgress else { return false } + return !isRunning && activeGeneration == 0 + } + func start(tabManager: TabManager, socketPath: String, accessMode: SocketControlMode) { self.tabManager = tabManager self.accessMode = accessMode - if isRunning { - if self.socketPath == socketPath && acceptLoopAlive { - self.accessMode = accessMode - applySocketPermissions() - return - } + let existing = withListenerState { + (isRunning: isRunning, socketPath: self.socketPath, acceptLoopAlive: acceptLoopAlive) + } + + if existing.isRunning && existing.socketPath == socketPath && existing.acceptLoopAlive { + self.accessMode = accessMode + applySocketPermissions() + return + } + + if existing.isRunning { stop() } - self.socketPath = socketPath + withListenerState { + self.socketPath = socketPath + listenerStartInProgress = true + } + var listenerActivated = false + defer { + if !listenerActivated { + withListenerState { + listenerStartInProgress = false + } + } + } // Remove existing socket file unlink(socketPath) // Create socket - serverSocket = socket(AF_UNIX, SOCK_STREAM, 0) - guard serverSocket >= 0 else { + let newServerSocket = socket(AF_UNIX, SOCK_STREAM, 0) + guard newServerSocket >= 0 else { let errnoCode = errno print("TerminalController: Failed to create socket") reportSocketListenerFailure( @@ -440,14 +567,14 @@ class TerminalController { let bindResult = withUnsafePointer(to: &addr) { ptr in ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in - bind(serverSocket, sockaddrPtr, socklen_t(MemoryLayout.size)) + bind(newServerSocket, sockaddrPtr, socklen_t(MemoryLayout.size)) } } guard bindResult >= 0 else { let errnoCode = errno print("TerminalController: Failed to bind socket") - close(serverSocket) + close(newServerSocket) reportSocketListenerFailure( message: "socket.listener.start.failed", stage: "bind", @@ -459,10 +586,10 @@ class TerminalController { applySocketPermissions() // Listen - guard listen(serverSocket, 5) >= 0 else { + guard listen(newServerSocket, Self.socketListenBacklog) >= 0 else { let errnoCode = errno print("TerminalController: Failed to listen on socket") - close(serverSocket) + close(newServerSocket) reportSocketListenerFailure( message: "socket.listener.start.failed", stage: "listen", @@ -471,14 +598,27 @@ class TerminalController { return } - isRunning = true + let generation = withListenerState { + isRunning = true + pendingAcceptLoopRearmGeneration = nil + nextAcceptLoopGeneration &+= 1 + let generation = nextAcceptLoopGeneration + activeAcceptLoopGeneration = generation + serverSocket = newServerSocket + listenerStartInProgress = false + return generation + } + listenerActivated = true + let listenerSocket = newServerSocket print("TerminalController: Listening on \(socketPath)") sentryBreadcrumb( "socket.listener.listening", category: "socket", data: [ "path": socketPath, - "mode": accessMode.rawValue + "mode": accessMode.rawValue, + "generation": generation, + "backlog": Self.socketListenBacklog ] ) @@ -503,40 +643,65 @@ class TerminalController { // Accept connections in background thread Thread.detachNewThread { [weak self] in - self?.acceptLoop() + self?.acceptLoop(listenerSocket: listenerSocket, generation: generation) } } nonisolated func socketListenerHealth(expectedSocketPath: String) -> SocketListenerHealth { - let running = isRunning - let loopAlive = acceptLoopAlive - let pathMatches = socketPath == expectedSocketPath + let snapshot = listenerStateSnapshot() + let pathMatches = snapshot.socketPath == expectedSocketPath var st = stat() let exists = lstat(expectedSocketPath, &st) == 0 && (st.st_mode & S_IFMT) == S_IFSOCK return SocketListenerHealth( - isRunning: running, - acceptLoopAlive: loopAlive, + isRunning: snapshot.isRunning, + acceptLoopAlive: snapshot.acceptLoopAlive, socketPathMatches: pathMatches, socketPathExists: exists ) } nonisolated func stop() { - isRunning = false - if serverSocket >= 0 { - close(serverSocket) + let (socketToClose, socketPathToUnlink) = withListenerState { + isRunning = false + acceptLoopAlive = false + pendingAcceptLoopRearmGeneration = nil + listenerStartInProgress = false + nextAcceptLoopGeneration &+= 1 + activeAcceptLoopGeneration = 0 + let socketToClose = serverSocket serverSocket = -1 + return (socketToClose, socketPath) + } + if socketToClose >= 0 { + close(socketToClose) + } + unlink(socketPathToUnlink) + } + + private nonisolated func unlinkSocketPathIfListenerStillInactive(_ path: String) { + let shouldUnlink = withListenerState { + Self.shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: socketPath == path, + isRunning: isRunning, + activeGeneration: activeAcceptLoopGeneration, + listenerStartInProgress: listenerStartInProgress + ) + } + if shouldUnlink { + unlink(path) } - unlink(socketPath) } private func applySocketPermissions() { let permissions = mode_t(accessMode.socketFilePermissions) - if chmod(socketPath, permissions) != 0 { + let currentSocketPath = withListenerState { socketPath } + if chmod(currentSocketPath, permissions) != 0 { let errnoCode = errno - print("TerminalController: Failed to set socket permissions to \(String(permissions, radix: 8)) for \(socketPath)") + print( + "TerminalController: Failed to set socket permissions to \(String(permissions, radix: 8)) for \(currentSocketPath)" + ) sentryBreadcrumb( "socket.listener.permissions.failed", category: "socket", @@ -640,27 +805,77 @@ class TerminalController { return nil } - private nonisolated func acceptLoop() { - acceptLoopAlive = true + private nonisolated func acceptLoop(listenerSocket: Int32, generation: UInt64) { + let armedAcceptLoop = withListenerState { + guard generation == activeAcceptLoopGeneration else { return false } + acceptLoopAlive = true + return true + } + guard armedAcceptLoop else { + return + } + sentryBreadcrumb( "socket.listener.accept_loop.started", category: "socket", - data: socketListenerEventData(stage: "accept_loop_start") + data: socketListenerEventData( + stage: "accept_loop_start", + extra: [ + "generation": generation, + "listenerSocket": Int(listenerSocket) + ] + ) ) + var exitReason = "stopped" var lastAcceptErrno: Int32? + var lastAcceptErrnoClass = "none" + var rearmRequested = false + defer { - if isRunning && exitReason == "stopped" { - exitReason = "unexpected_loop_exit" + let cleanup = withListenerState { + guard generation == activeAcceptLoopGeneration else { + return (shouldCaptureExit: false, socketToClose: Int32(-1), pathToUnlink: nil as String?) + } + + if isRunning && exitReason == "stopped" { + exitReason = "unexpected_loop_exit" + } + let shouldCaptureExit = exitReason != "stopped" + + acceptLoopAlive = false + isRunning = false + activeAcceptLoopGeneration = 0 + + var socketToClose: Int32 = -1 + var pathToUnlink: String? + if serverSocket == listenerSocket { + socketToClose = serverSocket + serverSocket = -1 + if shouldCaptureExit { + pathToUnlink = socketPath + } + } + return (shouldCaptureExit, socketToClose, pathToUnlink) } - let shouldCaptureExit = exitReason != "stopped" - acceptLoopAlive = false - isRunning = false - if shouldCaptureExit { + + if cleanup.socketToClose >= 0 { + close(cleanup.socketToClose) + } + if let pathToUnlink = cleanup.pathToUnlink { + unlinkSocketPathIfListenerStillInactive(pathToUnlink) + } + + if cleanup.shouldCaptureExit { let data = socketListenerEventData( stage: "accept_loop_exit", errnoCode: lastAcceptErrno, - extra: ["reason": exitReason] + extra: [ + "reason": exitReason, + "generation": generation, + "errnoClass": lastAcceptErrnoClass, + "rearmRequested": rearmRequested ? 1 : 0 + ] ) sentryBreadcrumb("socket.listener.accept_loop.exited", category: "socket", data: data) sentryCaptureError( @@ -673,39 +888,81 @@ class TerminalController { } var consecutiveFailures = 0 - while isRunning { + + while shouldContinueAcceptLoop(generation: generation) { var clientAddr = sockaddr_un() var clientAddrLen = socklen_t(MemoryLayout.size) let clientSocket = withUnsafeMutablePointer(to: &clientAddr) { ptr in ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in - accept(serverSocket, sockaddrPtr, &clientAddrLen) + accept(listenerSocket, sockaddrPtr, &clientAddrLen) } } guard clientSocket >= 0 else { - if isRunning { - let errnoCode = errno - lastAcceptErrno = errnoCode - consecutiveFailures += 1 - print("TerminalController: Accept failed (\(consecutiveFailures) consecutive)") - if consecutiveFailures == 1 || consecutiveFailures % 10 == 0 { - sentryBreadcrumb( - "socket.listener.accept.failed", - category: "socket", - data: socketListenerEventData( - stage: "accept", - errnoCode: errnoCode, - extra: ["consecutiveFailures": consecutiveFailures] - ) + if !shouldContinueAcceptLoop(generation: generation) { + exitReason = "stopped" + break + } + + let errnoCode = errno + lastAcceptErrno = errnoCode + let errnoClass = Self.acceptErrorClassification(errnoCode: errnoCode) + lastAcceptErrnoClass = errnoClass + + if Self.shouldRetryAcceptImmediately(errnoCode: errnoCode) { + continue + } + + consecutiveFailures += 1 + let backoffMs = Self.acceptFailureBackoffMilliseconds( + consecutiveFailures: consecutiveFailures + ) + let rearmDelayMs = Self.acceptFailureRearmDelayMilliseconds( + consecutiveFailures: consecutiveFailures + ) + + if Self.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: consecutiveFailures) { + sentryBreadcrumb( + "socket.listener.accept.failed", + category: "socket", + data: socketListenerEventData( + stage: "accept", + errnoCode: errnoCode, + extra: [ + "consecutiveFailures": consecutiveFailures, + "generation": generation, + "errnoClass": errnoClass, + "backoffMs": backoffMs + ] ) + ) + } + + let shouldRearmForFatalErrno = Self.shouldRearmListenerForAcceptError(errnoCode: errnoCode) + let shouldRearmForPersistentFailures = Self.shouldRearmForConsecutiveAcceptFailures( + consecutiveFailures: consecutiveFailures + ) + + if shouldRearmForFatalErrno || shouldRearmForPersistentFailures { + exitReason = shouldRearmForFatalErrno + ? "fatal_accept_error" + : "persistent_accept_failures" + rearmRequested = true + withListenerState { + pendingAcceptLoopRearmGeneration = generation } - if consecutiveFailures >= 50 { - print("TerminalController: Too many consecutive accept failures, exiting accept loop") - exitReason = "too_many_accept_failures" - break - } - usleep(10_000) // 10ms backoff + scheduleListenerRearm( + generation: generation, + errnoCode: errnoCode, + consecutiveFailures: consecutiveFailures, + delayMs: rearmDelayMs + ) + break + } + + if backoffMs > 0 { + usleep(useconds_t(backoffMs * 1_000)) } continue } @@ -724,6 +981,43 @@ class TerminalController { } } + private nonisolated func scheduleListenerRearm( + generation: UInt64, + errnoCode: Int32, + consecutiveFailures: Int, + delayMs: Int + ) { + let deadline = DispatchTime.now() + .milliseconds(delayMs) + DispatchQueue.main.asyncAfter(deadline: deadline) { [weak self] in + guard let self else { return } + guard let tabManager = self.tabManager else { return } + guard let restartPath = self.withListenerState({ () -> String? in + guard self.pendingAcceptLoopRearmGeneration == generation else { return nil } + self.pendingAcceptLoopRearmGeneration = nil + return self.socketPath + }) else { return } + + let restartMode = self.accessMode + + sentryBreadcrumb( + "socket.listener.rearm.requested", + category: "socket", + data: self.socketListenerEventData( + stage: "accept_rearm", + errnoCode: errnoCode, + extra: [ + "generation": generation, + "consecutiveFailures": consecutiveFailures, + "rearmDelayMs": delayMs + ] + ) + ) + + self.stop() + self.start(tabManager: tabManager, socketPath: restartPath, accessMode: restartMode) + } + } + private func handleClient(_ socket: Int32, peerPid: pid_t? = nil) { defer { close(socket) } @@ -760,7 +1054,7 @@ class TerminalController { var pending = "" var authenticated = false - while isRunning { + while withListenerState({ isRunning }) { let bytesRead = read(socket, &buffer, buffer.count - 1) guard bytesRead > 0 else { break } diff --git a/cmuxTests/SessionPersistenceTests.swift b/cmuxTests/SessionPersistenceTests.swift index ad9f5b3c..88d8f11c 100644 --- a/cmuxTests/SessionPersistenceTests.swift +++ b/cmuxTests/SessionPersistenceTests.swift @@ -733,3 +733,141 @@ final class SessionPersistenceTests: XCTestCase { ) } } + +final class SocketListenerAcceptPolicyTests: XCTestCase { + func testAcceptErrorClassificationBucketsExpectedErrnos() { + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: EINTR), + "immediate_retry" + ) + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: ECONNABORTED), + "immediate_retry" + ) + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: EMFILE), + "resource_pressure" + ) + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: ENOMEM), + "resource_pressure" + ) + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: EBADF), + "fatal" + ) + XCTAssertEqual( + TerminalController.acceptErrorClassification(errnoCode: EINVAL), + "fatal" + ) + } + + func testAcceptErrorPolicySignalsRearmOnlyForFatalErrors() { + XCTAssertTrue(TerminalController.shouldRearmListenerForAcceptError(errnoCode: EBADF)) + XCTAssertTrue(TerminalController.shouldRearmListenerForAcceptError(errnoCode: ENOTSOCK)) + XCTAssertFalse(TerminalController.shouldRearmListenerForAcceptError(errnoCode: EMFILE)) + XCTAssertFalse(TerminalController.shouldRearmListenerForAcceptError(errnoCode: EINTR)) + } + + func testAcceptErrorPolicyRearmsAfterPersistentFailures() { + XCTAssertFalse(TerminalController.shouldRearmForConsecutiveAcceptFailures(consecutiveFailures: 0)) + XCTAssertFalse(TerminalController.shouldRearmForConsecutiveAcceptFailures(consecutiveFailures: 49)) + XCTAssertTrue(TerminalController.shouldRearmForConsecutiveAcceptFailures(consecutiveFailures: 50)) + XCTAssertTrue(TerminalController.shouldRearmForConsecutiveAcceptFailures(consecutiveFailures: 120)) + } + + func testAcceptFailureBackoffIsExponentialAndCapped() { + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 0), + 0 + ) + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 1), + 10 + ) + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 2), + 20 + ) + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 6), + 320 + ) + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 12), + 5_000 + ) + XCTAssertEqual( + TerminalController.acceptFailureBackoffMilliseconds(consecutiveFailures: 50), + 5_000 + ) + } + + func testAcceptFailureRearmDelayAppliesMinimumThrottle() { + XCTAssertEqual( + TerminalController.acceptFailureRearmDelayMilliseconds(consecutiveFailures: 0), + 100 + ) + XCTAssertEqual( + TerminalController.acceptFailureRearmDelayMilliseconds(consecutiveFailures: 1), + 100 + ) + XCTAssertEqual( + TerminalController.acceptFailureRearmDelayMilliseconds(consecutiveFailures: 2), + 100 + ) + XCTAssertEqual( + TerminalController.acceptFailureRearmDelayMilliseconds(consecutiveFailures: 6), + 320 + ) + XCTAssertEqual( + TerminalController.acceptFailureRearmDelayMilliseconds(consecutiveFailures: 12), + 5_000 + ) + } + + func testAcceptFailureBreadcrumbSamplingPrefersEarlyAndPowerOfTwoMilestones() { + XCTAssertTrue(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 1)) + XCTAssertTrue(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 2)) + XCTAssertTrue(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 3)) + XCTAssertFalse(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 5)) + XCTAssertTrue(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 8)) + XCTAssertFalse(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 9)) + XCTAssertTrue(TerminalController.shouldEmitAcceptFailureBreadcrumb(consecutiveFailures: 16)) + } + + func testAcceptLoopCleanupUnlinkPolicySkipsDuringListenerStartup() { + XCTAssertFalse( + TerminalController.shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: true, + isRunning: false, + activeGeneration: 0, + listenerStartInProgress: true + ) + ) + XCTAssertFalse( + TerminalController.shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: false, + isRunning: false, + activeGeneration: 0, + listenerStartInProgress: false + ) + ) + XCTAssertFalse( + TerminalController.shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: true, + isRunning: true, + activeGeneration: 7, + listenerStartInProgress: false + ) + ) + XCTAssertTrue( + TerminalController.shouldUnlinkSocketPathAfterAcceptLoopCleanup( + pathMatches: true, + isRunning: false, + activeGeneration: 0, + listenerStartInProgress: false + ) + ) + } +} diff --git a/tests/test_claude_wrapper_hooks.py b/tests/test_claude_wrapper_hooks.py new file mode 100644 index 00000000..7763bd76 --- /dev/null +++ b/tests/test_claude_wrapper_hooks.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Regression tests for Resources/bin/claude wrapper hook injection. +""" + +from __future__ import annotations + +import json +import os +import shutil +import socket +import subprocess +import tempfile +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SOURCE_WRAPPER = ROOT / "Resources" / "bin" / "claude" + + +def make_executable(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + path.chmod(0o755) + + +def read_lines(path: Path) -> list[str]: + if not path.exists(): + return [] + return [line.rstrip("\n") for line in path.read_text(encoding="utf-8").splitlines()] + + +def parse_settings_arg(argv: list[str]) -> dict: + if "--settings" not in argv: + return {} + index = argv.index("--settings") + if index + 1 >= len(argv): + return {} + return json.loads(argv[index + 1]) + + +def run_wrapper(*, socket_state: str, argv: list[str]) -> tuple[int, list[str], list[str], str, str]: + with tempfile.TemporaryDirectory(prefix="cmux-claude-wrapper-test-") as td: + tmp = Path(td) + wrapper_dir = tmp / "wrapper-bin" + real_dir = tmp / "real-bin" + wrapper_dir.mkdir(parents=True, exist_ok=True) + real_dir.mkdir(parents=True, exist_ok=True) + + wrapper = wrapper_dir / "claude" + shutil.copy2(SOURCE_WRAPPER, wrapper) + wrapper.chmod(0o755) + + real_args_log = tmp / "real-args.log" + real_claudecode_log = tmp / "real-claudecode.log" + cmux_log = tmp / "cmux.log" + socket_path = str(tmp / "cmux.sock") + + make_executable( + real_dir / "claude", + """#!/usr/bin/env bash +set -euo pipefail +: > "$FAKE_REAL_ARGS_LOG" +printf '%s\\n' "${CLAUDECODE-__UNSET__}" > "$FAKE_REAL_CLAUDECODE_LOG" +for arg in "$@"; do + printf '%s\\n' "$arg" >> "$FAKE_REAL_ARGS_LOG" +done +""", + ) + + make_executable( + wrapper_dir / "cmux", + """#!/usr/bin/env bash +set -euo pipefail +printf '%s timeout=%s\\n' "$*" "${CMUXTERM_CLI_RESPONSE_TIMEOUT_SEC-__UNSET__}" >> "$FAKE_CMUX_LOG" +if [[ "${1:-}" == "--socket" ]]; then + shift 2 +fi +if [[ "${1:-}" == "ping" ]]; then + if [[ "${FAKE_CMUX_PING_OK:-0}" == "1" ]]; then + exit 0 + fi + exit 1 +fi +exit 0 +""", + ) + + test_socket: socket.socket | None = None + if socket_state in {"live", "stale"}: + test_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + test_socket.bind(socket_path) + + env = os.environ.copy() + env["PATH"] = f"{wrapper_dir}:{real_dir}:/usr/bin:/bin" + env["CMUX_SURFACE_ID"] = "surface:test" + env["CMUX_SOCKET_PATH"] = socket_path + env["FAKE_REAL_ARGS_LOG"] = str(real_args_log) + env["FAKE_REAL_CLAUDECODE_LOG"] = str(real_claudecode_log) + env["FAKE_CMUX_LOG"] = str(cmux_log) + env["FAKE_CMUX_PING_OK"] = "1" if socket_state == "live" else "0" + env["CLAUDECODE"] = "nested-session-sentinel" + + try: + proc = subprocess.run( + ["claude", *argv], + cwd=tmp, + env=env, + capture_output=True, + text=True, + check=False, + ) + finally: + if test_socket is not None: + test_socket.close() + + claudecode_lines = read_lines(real_claudecode_log) + claudecode_value = claudecode_lines[0] if claudecode_lines else "" + return proc.returncode, read_lines(real_args_log), read_lines(cmux_log), proc.stderr.strip(), claudecode_value + + +def expect(condition: bool, message: str, failures: list[str]) -> None: + if not condition: + failures.append(message) + + +def test_live_socket_injects_supported_hooks(failures: list[str]) -> None: + code, real_argv, cmux_log, stderr, claudecode = run_wrapper(socket_state="live", argv=["hello"]) + expect(code == 0, f"live socket: wrapper exited {code}: {stderr}", failures) + expect("--settings" in real_argv, f"live socket: missing --settings in args: {real_argv}", failures) + expect("--session-id" in real_argv, f"live socket: missing --session-id in args: {real_argv}", failures) + expect(real_argv[-1] == "hello", f"live socket: expected original arg to pass through, got {real_argv}", failures) + expect(any(" ping" in line for line in cmux_log), f"live socket: expected cmux ping, got {cmux_log}", failures) + expect( + any("timeout=0.75" in line for line in cmux_log), + f"live socket: expected bounded ping timeout, got {cmux_log}", + failures, + ) + expect(claudecode == "__UNSET__", f"live socket: expected CLAUDECODE unset, got {claudecode!r}", failures) + + settings = parse_settings_arg(real_argv) + hooks = settings.get("hooks", {}) + expect(set(hooks.keys()) == {"SessionStart", "Stop", "Notification"}, f"unexpected hook keys: {hooks.keys()}", failures) + serialized = json.dumps(settings, sort_keys=True) + expect("UserPromptSubmit" not in serialized, "UserPromptSubmit hook should not be injected", failures) + expect("prompt-submit" not in serialized, "prompt-submit subcommand should not be injected", failures) + + +def test_missing_socket_skips_hook_injection(failures: list[str]) -> None: + code, real_argv, cmux_log, stderr, claudecode = run_wrapper(socket_state="missing", argv=["hello"]) + expect(code == 0, f"missing socket: wrapper exited {code}: {stderr}", failures) + expect(real_argv == ["hello"], f"missing socket: expected passthrough args, got {real_argv}", failures) + expect(cmux_log == [], f"missing socket: expected no cmux calls, got {cmux_log}", failures) + expect(claudecode == "__UNSET__", f"missing socket: expected CLAUDECODE unset, got {claudecode!r}", failures) + + +def test_stale_socket_skips_hook_injection(failures: list[str]) -> None: + code, real_argv, cmux_log, stderr, claudecode = run_wrapper(socket_state="stale", argv=["hello"]) + expect(code == 0, f"stale socket: wrapper exited {code}: {stderr}", failures) + expect(real_argv == ["hello"], f"stale socket: expected passthrough args, got {real_argv}", failures) + expect(any(" ping" in line for line in cmux_log), f"stale socket: expected cmux ping probe, got {cmux_log}", failures) + expect( + any("timeout=0.75" in line for line in cmux_log), + f"stale socket: expected bounded ping timeout, got {cmux_log}", + failures, + ) + expect(claudecode == "__UNSET__", f"stale socket: expected CLAUDECODE unset, got {claudecode!r}", failures) + + +def main() -> int: + failures: list[str] = [] + test_live_socket_injects_supported_hooks(failures) + test_missing_socket_skips_hook_injection(failures) + test_stale_socket_skips_hook_injection(failures) + + if failures: + print("FAIL: claude wrapper regression checks failed") + for failure in failures: + print(f"- {failure}") + return 1 + + print("PASS: claude wrapper hooks handle missing/stale sockets and inject only supported hooks") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())