diff --git a/CLAUDE.md b/CLAUDE.md index beb24aa0..afaef667 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -79,6 +79,8 @@ tail -f "$(cat /tmp/cmux-last-debug-log-path 2>/dev/null || echo /tmp/cmux-debug - Untagged Debug app: `/tmp/cmux-debug.log` - Tagged Debug app (`./scripts/reload.sh --tag `): `/tmp/cmux-debug-.log` - `reload.sh` writes the current path to `/tmp/cmux-last-debug-log-path` +- `reload.sh` writes the selected dev CLI path to `/tmp/cmux-last-cli-path` +- `reload.sh` updates `/tmp/cmux-cli` and `$HOME/.local/bin/cmux-dev` to that CLI - Implementation: `vendor/bonsplit/Sources/Bonsplit/Public/DebugEventLog.swift` - Free function `dlog("message")` — logs with timestamp and appends to file in real time diff --git a/CLI/cmux.swift b/CLI/cmux.swift index 020ac4fe..59b74db5 100644 --- a/CLI/cmux.swift +++ b/CLI/cmux.swift @@ -1758,7 +1758,7 @@ struct CMUXCLI { let shellFeaturesValue = scopedGhosttyShellFeaturesValue() let sshStartupCommand = buildSSHStartupCommand(sshCommand: sshCommand, shellFeatures: shellFeaturesValue) - var workspaceCreateParams: [String: Any] = [ + let workspaceCreateParams: [String: Any] = [ "initial_command": sshStartupCommand, ] @@ -1775,6 +1775,8 @@ struct CMUXCLI { ]) } + let remoteSSHOptions = sshOptionsWithControlSocketDefaults(sshOptions.sshOptions) + var configureParams: [String: Any] = [ "workspace_id": workspaceId, "destination": sshOptions.destination, @@ -1787,8 +1789,8 @@ struct CMUXCLI { !identityFile.isEmpty { configureParams["identity_file"] = identityFile } - if !sshOptions.sshOptions.isEmpty { - configureParams["ssh_options"] = sshOptions.sshOptions + if !remoteSSHOptions.isEmpty { + configureParams["ssh_options"] = remoteSSHOptions } var payload = try client.sendV2(method: "workspace.remote.configure", params: configureParams) @@ -1888,15 +1890,10 @@ struct CMUXCLI { } private func buildSSHCommandText(_ options: SSHCommandOptions) -> String { - var parts: [String] = ["ssh", "-o", "StrictHostKeyChecking=accept-new"] - if !hasSSHOptionKey(options.sshOptions, key: "ControlMaster") { - parts += ["-o", "ControlMaster=auto"] - } - if !hasSSHOptionKey(options.sshOptions, key: "ControlPersist") { - parts += ["-o", "ControlPersist=600"] - } - if !hasSSHOptionKey(options.sshOptions, key: "ControlPath") { - parts += ["-o", "ControlPath=\(defaultSSHControlPathTemplate())"] + let effectiveSSHOptions = sshOptionsWithControlSocketDefaults(options.sshOptions) + var parts: [String] = ["ssh"] + if !hasSSHOptionKey(effectiveSSHOptions, key: "StrictHostKeyChecking") { + parts += ["-o", "StrictHostKeyChecking=accept-new"] } if let port = options.port { parts += ["-p", String(port)] @@ -1905,16 +1902,33 @@ struct CMUXCLI { !identityFile.isEmpty { parts += ["-i", identityFile] } - for option in options.sshOptions { - let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines) - guard !trimmed.isEmpty else { continue } - parts += ["-o", trimmed] + for option in effectiveSSHOptions { + parts += ["-o", option] } parts.append(options.destination) parts.append(contentsOf: options.extraArguments) return parts.map(shellQuote).joined(separator: " ") } + private func sshOptionsWithControlSocketDefaults(_ options: [String]) -> [String] { + var merged: [String] = [] + for option in options { + let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { continue } + merged.append(trimmed) + } + if !hasSSHOptionKey(merged, key: "ControlMaster") { + merged.append("ControlMaster=auto") + } + if !hasSSHOptionKey(merged, key: "ControlPersist") { + merged.append("ControlPersist=600") + } + if !hasSSHOptionKey(merged, key: "ControlPath") { + merged.append("ControlPath=\(defaultSSHControlPathTemplate())") + } + return merged + } + private func scopedGhosttyShellFeaturesValue() -> String { let rawExisting = ProcessInfo.processInfo.environment["GHOSTTY_SHELL_FEATURES"] ?? "" var seen: Set = [] @@ -3317,7 +3331,7 @@ fi Usage: cmux ssh [flags] [-- ] Create a new workspace, mark it as remote-SSH, and start an SSH session in that workspace. - cmux will also attempt background remote port detection + local forwarding for browser access. + cmux will also establish a local SSH proxy endpoint so browser traffic can egress from the remote host. Flags: --name Optional workspace title diff --git a/Sources/GhosttyTerminalView.swift b/Sources/GhosttyTerminalView.swift index 028f9fb2..1704bb11 100644 --- a/Sources/GhosttyTerminalView.swift +++ b/Sources/GhosttyTerminalView.swift @@ -1437,7 +1437,8 @@ final class TerminalSurface: Identifiable, ObservableObject { } } - if !env.isEmpty { + let allowSurfaceEnvOverrides = false + if allowSurfaceEnvOverrides, !env.isEmpty { envVars.reserveCapacity(env.count) envStorage.reserveCapacity(env.count) for (key, value) in env { @@ -1592,6 +1593,7 @@ final class TerminalSurface: Identifiable, ObservableObject { } #endif guard let view = attachedView, + let surface, view.window != nil, view.bounds.width > 0, view.bounds.height > 0 else { diff --git a/Sources/Panels/BrowserPanel.swift b/Sources/Panels/BrowserPanel.swift index 873da962..e62b64f8 100644 --- a/Sources/Panels/BrowserPanel.swift +++ b/Sources/Panels/BrowserPanel.swift @@ -3,6 +3,12 @@ import Combine import WebKit import AppKit import Bonsplit +import Network + +struct BrowserProxyEndpoint: Equatable { + let host: String + let port: Int +} enum BrowserSearchEngine: String, CaseIterable, Identifiable { case google @@ -1070,6 +1076,7 @@ final class BrowserPanel: Panel, ObservableObject { private var developerToolsRestoreRetryAttempt: Int = 0 private let developerToolsRestoreRetryDelay: TimeInterval = 0.05 private let developerToolsRestoreRetryMaxAttempts: Int = 40 + private var remoteProxyEndpoint: BrowserProxyEndpoint? var displayTitle: String { if !pageTitle.isEmpty { @@ -1089,17 +1096,27 @@ final class BrowserPanel: Panel, ObservableObject { false } - init(workspaceId: UUID, initialURL: URL? = nil, bypassInsecureHTTPHostOnce: String? = nil) { + init( + workspaceId: UUID, + initialURL: URL? = nil, + bypassInsecureHTTPHostOnce: String? = nil, + proxyEndpoint: BrowserProxyEndpoint? = nil + ) { self.id = UUID() self.workspaceId = workspaceId self.insecureHTTPBypassHostOnce = BrowserInsecureHTTPSettings.normalizeHost(bypassInsecureHTTPHostOnce ?? "") + self.remoteProxyEndpoint = proxyEndpoint // Configure web view let config = WKWebViewConfiguration() config.processPool = BrowserPanel.sharedProcessPool - // Ensure browser cookies/storage persist across navigations and launches. - // This reduces repeated consent/bot-challenge flows on sites like Google. - config.websiteDataStore = .default() + // Keep data-store scoping at workspace granularity so remote proxy settings + // do not leak into local workspaces. + if #available(macOS 14.0, *) { + config.websiteDataStore = WKWebsiteDataStore(forIdentifier: workspaceId) + } else { + config.websiteDataStore = .default() + } // Enable developer extras (DevTools) config.preferences.setValue(true, forKey: "developerExtrasEnabled") @@ -1124,6 +1141,7 @@ final class BrowserPanel: Panel, ObservableObject { webView.customUserAgent = BrowserUserAgentSettings.safariUserAgent self.webView = webView + applyRemoteProxyConfigurationIfAvailable() // Set up navigation delegate let navDelegate = BrowserNavigationDelegate() @@ -1180,6 +1198,33 @@ final class BrowserPanel: Panel, ObservableObject { workspaceId = newWorkspaceId } + func setRemoteProxyEndpoint(_ endpoint: BrowserProxyEndpoint?) { + guard remoteProxyEndpoint != endpoint else { return } + remoteProxyEndpoint = endpoint + applyRemoteProxyConfigurationIfAvailable() + } + + private func applyRemoteProxyConfigurationIfAvailable() { + guard #available(macOS 14.0, *) else { return } + + let store = webView.configuration.websiteDataStore + guard let endpoint = remoteProxyEndpoint, + endpoint.port > 0 && endpoint.port <= 65535, + let nwPort = NWEndpoint.Port(rawValue: UInt16(endpoint.port)) else { + store.proxyConfigurations = [] + return + } + + let nwEndpoint = NWEndpoint.hostPort( + host: NWEndpoint.Host(endpoint.host), + port: nwPort + ) + // Prefer SOCKSv5; keep CONNECT configured as fallback. + let socks = ProxyConfiguration(socksv5Proxy: nwEndpoint) + let connect = ProxyConfiguration(httpCONNECTProxy: nwEndpoint) + store.proxyConfigurations = [socks, connect] + } + func triggerFlash() { focusFlashToken &+= 1 } diff --git a/Sources/TerminalController.swift b/Sources/TerminalController.swift index 66c3b6d4..17a22479 100644 --- a/Sources/TerminalController.swift +++ b/Sources/TerminalController.swift @@ -1497,6 +1497,40 @@ class TerminalController { return nil } + private func v2HasNonNullParam(_ params: [String: Any], _ key: String) -> Bool { + guard let raw = params[key] else { return false } + return !(raw is NSNull) + } + + private func v2StrictInt(_ params: [String: Any], _ key: String) -> Int? { + v2StrictIntAny(params[key]) + } + + private func v2StrictIntAny(_ raw: Any?) -> Int? { + guard let raw else { return nil } + + if let numberValue = raw as? NSNumber { + if CFGetTypeID(numberValue) == CFBooleanGetTypeID() { + return nil + } + let doubleValue = numberValue.doubleValue + guard doubleValue.isFinite, floor(doubleValue) == doubleValue else { + return nil + } + return Int(exactly: doubleValue) + } + + if let intValue = raw as? Int { + return intValue + } + + if let stringValue = raw as? String { + return Int(stringValue.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + return nil + } + private func v2PanelType(_ params: [String: Any], _ key: String) -> PanelType? { guard let s = v2String(params, key) else { return nil } return PanelType(rawValue: s.lowercased()) @@ -1976,13 +2010,26 @@ class TerminalController { } var sshPort: Int? - if let parsedPort = v2Int(params, "port") { - guard parsedPort > 0 && parsedPort <= 65535 else { + if v2HasNonNullParam(params, "port") { + guard let parsedPort = v2StrictInt(params, "port"), + parsedPort > 0, + parsedPort <= 65535 else { return .err(code: "invalid_params", message: "port must be 1-65535", data: nil) } sshPort = parsedPort } + // Internal deterministic test hook: pin the local proxy listener port to force bind conflicts. + var localProxyPort: Int? + if v2HasNonNullParam(params, "local_proxy_port") { + guard let parsedLocalProxyPort = v2StrictInt(params, "local_proxy_port"), + parsedLocalProxyPort > 0, + parsedLocalProxyPort <= 65535 else { + return .err(code: "invalid_params", message: "local_proxy_port must be 1-65535", data: nil) + } + localProxyPort = parsedLocalProxyPort + } + let identityFile = v2RawString(params, "identity_file")?.trimmingCharacters(in: .whitespacesAndNewlines) let sshOptions = v2StringArray(params, "ssh_options") ?? [] let autoConnect = v2Bool(params, "auto_connect") ?? true @@ -2002,7 +2049,8 @@ class TerminalController { destination: destination, port: sshPort, identityFile: identityFile?.isEmpty == true ? nil : identityFile, - sshOptions: sshOptions + sshOptions: sshOptions, + localProxyPort: localProxyPort ) workspace.configureRemoteConnection(config, autoConnect: autoConnect) diff --git a/Sources/Workspace.swift b/Sources/Workspace.swift index a9825ab5..1db814e8 100644 --- a/Sources/Workspace.swift +++ b/Sources/Workspace.swift @@ -4,6 +4,7 @@ import AppKit import Bonsplit import Combine import Darwin +import Network struct SidebarStatusEntry { let key: String @@ -13,12 +14,1246 @@ struct SidebarStatusEntry { let timestamp: Date } -private final class WorkspaceRemoteSessionController { - private struct ForwardEntry { - let process: Process - let stderrPipe: Pipe +private final class WorkspaceRemoteDaemonRPCClient { + private let configuration: WorkspaceRemoteConfiguration + private let remotePath: String + private let onUnexpectedTermination: (String) -> Void + private let callQueue = DispatchQueue(label: "com.cmux.remote-ssh.daemon-rpc.call.\(UUID().uuidString)") + private let stateQueue = DispatchQueue(label: "com.cmux.remote-ssh.daemon-rpc.state.\(UUID().uuidString)") + + private var process: Process? + private var stdinHandle: FileHandle? + private var stdoutHandle: FileHandle? + private var stderrHandle: FileHandle? + private var isClosed = true + private var shouldReportTermination = true + + private var nextRequestID = 1 + private var pendingID: Int? + private var pendingSemaphore: DispatchSemaphore? + private var pendingResponse: [String: Any]? + private var pendingFailureMessage: String? + + private var stdoutBuffer = Data() + private var stderrBuffer = "" + + init( + configuration: WorkspaceRemoteConfiguration, + remotePath: String, + onUnexpectedTermination: @escaping (String) -> Void + ) { + self.configuration = configuration + self.remotePath = remotePath + self.onUnexpectedTermination = onUnexpectedTermination } + func start() throws { + let process = Process() + let stdinPipe = Pipe() + let stdoutPipe = Pipe() + let stderrPipe = Pipe() + + process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") + process.arguments = Self.daemonArguments(configuration: configuration, remotePath: remotePath) + process.standardInput = stdinPipe + process.standardOutput = stdoutPipe + process.standardError = stderrPipe + + stdoutPipe.fileHandleForReading.readabilityHandler = { [weak self] handle in + let data = handle.availableData + self?.stateQueue.async { + self?.consumeStdoutData(data) + } + } + stderrPipe.fileHandleForReading.readabilityHandler = { [weak self] handle in + let data = handle.availableData + self?.stateQueue.async { + self?.consumeStderrData(data) + } + } + process.terminationHandler = { [weak self] terminated in + self?.stateQueue.async { + self?.handleProcessTermination(terminated) + } + } + + do { + try process.run() + } catch { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 1, userInfo: [ + NSLocalizedDescriptionKey: "Failed to launch SSH daemon transport: \(error.localizedDescription)", + ]) + } + + stateQueue.sync { + self.process = process + self.stdinHandle = stdinPipe.fileHandleForWriting + self.stdoutHandle = stdoutPipe.fileHandleForReading + self.stderrHandle = stderrPipe.fileHandleForReading + self.isClosed = false + self.shouldReportTermination = true + self.stdoutBuffer = Data() + self.stderrBuffer = "" + self.pendingID = nil + self.pendingSemaphore = nil + self.pendingResponse = nil + self.pendingFailureMessage = nil + } + + do { + let hello = try call(method: "hello", params: [:], timeout: 8.0) + let capabilities = (hello["capabilities"] as? [String]) ?? [] + guard capabilities.contains("proxy.stream") else { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 2, userInfo: [ + NSLocalizedDescriptionKey: "remote daemon missing required capability proxy.stream", + ]) + } + } catch { + stop(suppressTerminationCallback: true) + throw error + } + } + + func stop() { + stop(suppressTerminationCallback: true) + } + + func openStream(host: String, port: Int, timeoutMs: Int = 10000) throws -> String { + let result = try call( + method: "proxy.open", + params: [ + "host": host, + "port": port, + "timeout_ms": timeoutMs, + ], + timeout: 12.0 + ) + let streamID = (result["stream_id"] as? String)?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" + guard !streamID.isEmpty else { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 3, userInfo: [ + NSLocalizedDescriptionKey: "proxy.open missing stream_id", + ]) + } + return streamID + } + + func writeStream(streamID: String, data: Data) throws { + _ = try call( + method: "proxy.write", + params: [ + "stream_id": streamID, + "data_base64": data.base64EncodedString(), + ], + timeout: 8.0 + ) + } + + func readStream(streamID: String, maxBytes: Int = 32768, timeoutMs: Int = 250) throws -> (data: Data, eof: Bool) { + let result = try call( + method: "proxy.read", + params: [ + "stream_id": streamID, + "max_bytes": maxBytes, + "timeout_ms": timeoutMs, + ], + timeout: max(2.0, TimeInterval(timeoutMs) / 1000.0 + 2.0) + ) + let encoded = (result["data_base64"] as? String) ?? "" + let decoded = encoded.isEmpty ? Data() : (Data(base64Encoded: encoded) ?? Data()) + let eof = (result["eof"] as? Bool) ?? false + return (decoded, eof) + } + + func closeStream(streamID: String) { + _ = try? call( + method: "proxy.close", + params: ["stream_id": streamID], + timeout: 4.0 + ) + } + + private func call(method: String, params: [String: Any], timeout: TimeInterval) throws -> [String: Any] { + try callQueue.sync { + let semaphore = DispatchSemaphore(value: 0) + let requestID: Int = stateQueue.sync { + let id = nextRequestID + nextRequestID += 1 + pendingID = id + pendingSemaphore = semaphore + pendingResponse = nil + pendingFailureMessage = nil + return id + } + + let payload: Data + do { + payload = try Self.encodeJSON([ + "id": requestID, + "method": method, + "params": params, + ]) + } catch { + stateQueue.sync { + clearPendingLocked() + } + throw NSError(domain: "cmux.remote.daemon.rpc", code: 10, userInfo: [ + NSLocalizedDescriptionKey: "failed to encode daemon RPC request \(method): \(error.localizedDescription)", + ]) + } + + do { + try writePayload(payload) + } catch { + stateQueue.sync { + clearPendingLocked() + } + throw error + } + + if semaphore.wait(timeout: .now() + timeout) == .timedOut { + stop(suppressTerminationCallback: false) + throw NSError(domain: "cmux.remote.daemon.rpc", code: 11, userInfo: [ + NSLocalizedDescriptionKey: "daemon RPC timeout waiting for \(method) response", + ]) + } + + let response: [String: Any] = try stateQueue.sync { + defer { + clearPendingLocked() + } + if let failure = pendingFailureMessage { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 12, userInfo: [ + NSLocalizedDescriptionKey: failure, + ]) + } + guard let pendingResponse else { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 13, userInfo: [ + NSLocalizedDescriptionKey: "daemon RPC \(method) returned empty response", + ]) + } + return pendingResponse + } + + let ok = (response["ok"] as? Bool) ?? false + if ok { + return (response["result"] as? [String: Any]) ?? [:] + } + + let errorObject = (response["error"] as? [String: Any]) ?? [:] + let code = (errorObject["code"] as? String)?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "rpc_error" + let message = (errorObject["message"] as? String)?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "daemon RPC call failed" + throw NSError(domain: "cmux.remote.daemon.rpc", code: 14, userInfo: [ + NSLocalizedDescriptionKey: "\(method) failed (\(code)): \(message)", + ]) + } + } + + private func writePayload(_ payload: Data) throws { + let stdinHandle: FileHandle = stateQueue.sync { + self.stdinHandle ?? FileHandle.nullDevice + } + if stdinHandle === FileHandle.nullDevice { + throw NSError(domain: "cmux.remote.daemon.rpc", code: 15, userInfo: [ + NSLocalizedDescriptionKey: "daemon transport is not connected", + ]) + } + do { + try stdinHandle.write(contentsOf: payload) + try stdinHandle.write(contentsOf: Data([0x0A])) + } catch { + stop(suppressTerminationCallback: false) + throw NSError(domain: "cmux.remote.daemon.rpc", code: 16, userInfo: [ + NSLocalizedDescriptionKey: "failed writing daemon RPC request: \(error.localizedDescription)", + ]) + } + } + + private func consumeStdoutData(_ data: Data) { + guard !data.isEmpty else { + signalPendingFailureLocked("daemon transport closed stdout") + return + } + + stdoutBuffer.append(data) + while let newlineIndex = stdoutBuffer.firstIndex(of: 0x0A) { + var lineData = Data(stdoutBuffer[..<newlineIndex]) + stdoutBuffer.removeSubrange(...newlineIndex) + + if let carriageIndex = lineData.lastIndex(of: 0x0D), carriageIndex == lineData.index(before: lineData.endIndex) { + lineData.remove(at: carriageIndex) + } + guard !lineData.isEmpty else { continue } + + guard let payload = try? JSONSerialization.jsonObject(with: lineData, options: []) as? [String: Any] else { + continue + } + + let responseID: Int = { + if let intValue = payload["id"] as? Int { + return intValue + } + if let numberValue = payload["id"] as? NSNumber { + return numberValue.intValue + } + return -1 + }() + guard responseID >= 0 else { continue } + guard pendingID == responseID else { continue } + + pendingResponse = payload + pendingSemaphore?.signal() + } + } + + private func consumeStderrData(_ data: Data) { + guard !data.isEmpty else { return } + guard let chunk = String(data: data, encoding: .utf8), !chunk.isEmpty else { return } + stderrBuffer.append(chunk) + if stderrBuffer.count > 8192 { + stderrBuffer.removeFirst(stderrBuffer.count - 8192) + } + } + + private func handleProcessTermination(_ process: Process) { + let shouldNotify: Bool = { + guard self.process === process else { return false } + return !isClosed && shouldReportTermination + }() + let detail = Self.bestErrorLine(stderr: stderrBuffer) ?? "daemon transport exited with status \(process.terminationStatus)" + + isClosed = true + self.process = nil + stdinHandle = nil + stdoutHandle?.readabilityHandler = nil + stdoutHandle = nil + stderrHandle?.readabilityHandler = nil + stderrHandle = nil + signalPendingFailureLocked(detail) + + guard shouldNotify else { return } + onUnexpectedTermination(detail) + } + + private func stop(suppressTerminationCallback: Bool) { + let captured: (Process?, FileHandle?, FileHandle?, FileHandle?) = stateQueue.sync { + shouldReportTermination = !suppressTerminationCallback + if isClosed { + return (nil, nil, nil, nil) + } + + isClosed = true + signalPendingFailureLocked("daemon transport stopped") + let capturedProcess = process + let capturedStdin = stdinHandle + let capturedStdout = stdoutHandle + let capturedStderr = stderrHandle + + process = nil + stdinHandle = nil + stdoutHandle = nil + stderrHandle = nil + return (capturedProcess, capturedStdin, capturedStdout, capturedStderr) + } + + captured.2?.readabilityHandler = nil + captured.3?.readabilityHandler = nil + try? captured.1?.close() + try? captured.2?.close() + try? captured.3?.close() + if let process = captured.0, process.isRunning { + process.terminate() + } + } + + private func signalPendingFailureLocked(_ message: String) { + pendingFailureMessage = message + pendingSemaphore?.signal() + } + + private func clearPendingLocked() { + pendingID = nil + pendingSemaphore = nil + pendingResponse = nil + pendingFailureMessage = nil + } + + private static func encodeJSON(_ object: [String: Any]) throws -> Data { + try JSONSerialization.data(withJSONObject: object, options: []) + } + + private static func daemonArguments(configuration: WorkspaceRemoteConfiguration, remotePath: String) -> [String] { + let script = "exec \(shellSingleQuoted(remotePath)) serve --stdio" + let command = "sh -lc \(shellSingleQuoted(script))" + return sshCommonArguments(configuration: configuration, batchMode: true) + [configuration.destination, command] + } + + private static func sshCommonArguments(configuration: WorkspaceRemoteConfiguration, batchMode: Bool) -> [String] { + var args: [String] = [ + "-o", "ConnectTimeout=6", + "-o", "ServerAliveInterval=20", + "-o", "ServerAliveCountMax=2", + ] + if !hasSSHOptionKey(configuration.sshOptions, key: "StrictHostKeyChecking") { + args += ["-o", "StrictHostKeyChecking=accept-new"] + } + if batchMode { + args += ["-o", "BatchMode=yes"] + } + if let port = configuration.port { + args += ["-p", String(port)] + } + if let identityFile = configuration.identityFile, + !identityFile.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + args += ["-i", identityFile] + } + for option in configuration.sshOptions { + let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { continue } + args += ["-o", trimmed] + } + return args + } + + private static func hasSSHOptionKey(_ options: [String], key: String) -> Bool { + let loweredKey = key.lowercased() + for option in options { + let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { continue } + let token = trimmed.split(whereSeparator: { $0 == "=" || $0.isWhitespace }).first.map(String.init)?.lowercased() + if token == loweredKey { + return true + } + } + return false + } + + private static func shellSingleQuoted(_ value: String) -> String { + "'" + value.replacingOccurrences(of: "'", with: "'\"'\"'") + "'" + } + + private static func bestErrorLine(stderr: String) -> String? { + let lines = stderr + .split(separator: "\n") + .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + .filter { !$0.isEmpty } + + for line in lines.reversed() where !isNoiseLine(line) { + return line + } + return lines.last + } + + private static func isNoiseLine(_ line: String) -> Bool { + let lowered = line.lowercased() + if lowered.hasPrefix("warning: permanently added") { return true } + if lowered.hasPrefix("debug") { return true } + if lowered.hasPrefix("transferred:") { return true } + if lowered.hasPrefix("openbsd_") { return true } + if lowered.contains("pseudo-terminal will not be allocated") { return true } + return false + } +} + +private final class WorkspaceRemoteDaemonProxyTunnel { + private final class ProxySession { + private enum HandshakeProtocol { + case undecided + case socks5 + case connect + } + + private enum SocksStage { + case greeting + case request + } + + private struct SocksRequest { + let host: String + let port: Int + let command: UInt8 + let consumedBytes: Int + } + + let id = UUID() + + private let connection: NWConnection + private let rpcClient: WorkspaceRemoteDaemonRPCClient + private let queue: DispatchQueue + private let onClose: (UUID) -> Void + + private var isClosed = false + private var protocolKind: HandshakeProtocol = .undecided + private var socksStage: SocksStage = .greeting + private var handshakeBuffer = Data() + private var streamID: String? + private var localInputEOF = false + + init( + connection: NWConnection, + rpcClient: WorkspaceRemoteDaemonRPCClient, + queue: DispatchQueue, + onClose: @escaping (UUID) -> Void + ) { + self.connection = connection + self.rpcClient = rpcClient + self.queue = queue + self.onClose = onClose + } + + func start() { + connection.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .failed(let error): + self.close(reason: "proxy client connection failed: \(error)") + case .cancelled: + self.close(reason: nil) + default: + break + } + } + connection.start(queue: queue) + receiveNext() + } + + func stop() { + close(reason: nil) + } + + private func receiveNext() { + guard !isClosed else { return } + connection.receive(minimumIncompleteLength: 1, maximumLength: 32768) { [weak self] data, _, isComplete, error in + guard let self, !self.isClosed else { return } + + if let data, !data.isEmpty { + if self.streamID == nil { + self.handshakeBuffer.append(data) + self.processHandshakeBuffer() + } else { + self.forwardToRemote(data) + } + } + + if isComplete { + // Treat local EOF as a half-close: keep remote read loop alive so we can + // drain upstream response bytes (for example curl closing write-side after + // sending an HTTP request through SOCKS/CONNECT). + self.localInputEOF = true + if self.streamID == nil { + self.close(reason: nil) + } + return + } + if let error { + self.close(reason: "proxy client receive error: \(error)") + return + } + + self.receiveNext() + } + } + + private func processHandshakeBuffer() { + guard !isClosed else { return } + while streamID == nil { + switch protocolKind { + case .undecided: + guard let first = handshakeBuffer.first else { return } + protocolKind = (first == 0x05) ? .socks5 : .connect + case .socks5: + if !processSocksHandshakeStep() { + return + } + case .connect: + if !processConnectHandshakeStep() { + return + } + } + } + } + + private func processSocksHandshakeStep() -> Bool { + switch socksStage { + case .greeting: + guard handshakeBuffer.count >= 2 else { return false } + let methodCount = Int(handshakeBuffer[1]) + let total = 2 + methodCount + guard handshakeBuffer.count >= total else { return false } + + let methods = [UInt8](handshakeBuffer[2..<total]) + handshakeBuffer = Data(handshakeBuffer.dropFirst(total)) + socksStage = .request + + if !methods.contains(0x00) { + sendAndClose(Data([0x05, 0xFF])) + return false + } + sendLocal(Data([0x05, 0x00])) + return true + + case .request: + let request: SocksRequest + do { + guard let parsed = try parseSocksRequest(from: handshakeBuffer) else { return false } + request = parsed + } catch { + sendAndClose(Data([0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])) + return false + } + + let pending = handshakeBuffer.count > request.consumedBytes + ? Data(handshakeBuffer[request.consumedBytes...]) + : Data() + handshakeBuffer = Data() + guard request.command == 0x01 else { + sendAndClose(Data([0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0])) + return false + } + + openRemoteStream( + host: request.host, + port: request.port, + successResponse: Data([0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0]), + failureResponse: Data([0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]), + pendingPayload: pending + ) + return false + } + } + + private func parseSocksRequest(from data: Data) throws -> SocksRequest? { + let bytes = [UInt8](data) + guard bytes.count >= 4 else { return nil } + guard bytes[0] == 0x05 else { + throw NSError(domain: "cmux.remote.proxy", code: 1, userInfo: [NSLocalizedDescriptionKey: "invalid SOCKS version"]) + } + + let command = bytes[1] + let addressType = bytes[3] + var cursor = 4 + let host: String + + switch addressType { + case 0x01: + guard bytes.count >= cursor + 4 + 2 else { return nil } + let octets = bytes[cursor..<(cursor + 4)].map { String($0) } + host = octets.joined(separator: ".") + cursor += 4 + + case 0x03: + guard bytes.count >= cursor + 1 else { return nil } + let length = Int(bytes[cursor]) + cursor += 1 + guard bytes.count >= cursor + length + 2 else { return nil } + let hostData = Data(bytes[cursor..<(cursor + length)]) + host = String(data: hostData, encoding: .utf8) ?? "" + cursor += length + + case 0x04: + guard bytes.count >= cursor + 16 + 2 else { return nil } + var address = in6_addr() + withUnsafeMutableBytes(of: &address) { target in + for i in 0..<16 { + target[i] = bytes[cursor + i] + } + } + var text = [CChar](repeating: 0, count: Int(INET6_ADDRSTRLEN)) + let pointer = withUnsafePointer(to: &address) { + inet_ntop(AF_INET6, UnsafeRawPointer($0), &text, socklen_t(INET6_ADDRSTRLEN)) + } + host = pointer != nil ? String(cString: text) : "" + cursor += 16 + + default: + throw NSError(domain: "cmux.remote.proxy", code: 2, userInfo: [NSLocalizedDescriptionKey: "invalid SOCKS address type"]) + } + + guard !host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { + throw NSError(domain: "cmux.remote.proxy", code: 3, userInfo: [NSLocalizedDescriptionKey: "empty SOCKS host"]) + } + guard bytes.count >= cursor + 2 else { return nil } + let port = Int(UInt16(bytes[cursor]) << 8 | UInt16(bytes[cursor + 1])) + cursor += 2 + + guard port > 0 && port <= 65535 else { + throw NSError(domain: "cmux.remote.proxy", code: 4, userInfo: [NSLocalizedDescriptionKey: "invalid SOCKS port"]) + } + + return SocksRequest(host: host, port: port, command: command, consumedBytes: cursor) + } + + private func processConnectHandshakeStep() -> Bool { + let marker = Data([0x0D, 0x0A, 0x0D, 0x0A]) + guard let headerRange = handshakeBuffer.range(of: marker) else { return false } + + let headerData = Data(handshakeBuffer[..<headerRange.upperBound]) + let pending = headerRange.upperBound < handshakeBuffer.count + ? Data(handshakeBuffer[headerRange.upperBound...]) + : Data() + handshakeBuffer = Data() + guard let headerText = String(data: headerData, encoding: .utf8) else { + sendAndClose(Self.httpResponse(status: "400 Bad Request")) + return false + } + + let firstLine = headerText.components(separatedBy: "\r\n").first ?? "" + let parts = firstLine.split(whereSeparator: \.isWhitespace).map(String.init) + guard parts.count >= 2, parts[0].uppercased() == "CONNECT" else { + sendAndClose(Self.httpResponse(status: "400 Bad Request")) + return false + } + + guard let (host, port) = Self.parseConnectAuthority(parts[1]) else { + sendAndClose(Self.httpResponse(status: "400 Bad Request")) + return false + } + + openRemoteStream( + host: host, + port: port, + successResponse: Self.httpResponse(status: "200 Connection Established", closeAfterResponse: false), + failureResponse: Self.httpResponse(status: "502 Bad Gateway", closeAfterResponse: true), + pendingPayload: pending + ) + return false + } + + private func openRemoteStream( + host: String, + port: Int, + successResponse: Data, + failureResponse: Data, + pendingPayload: Data + ) { + guard !isClosed else { return } + do { + let streamID = try rpcClient.openStream(host: host, port: port) + self.streamID = streamID + connection.send(content: successResponse, completion: .contentProcessed { [weak self] error in + guard let self else { return } + if let error { + self.close(reason: "proxy client send error: \(error)") + return + } + if !pendingPayload.isEmpty { + self.forwardToRemote(pendingPayload) + } + self.scheduleRemoteReadLoop() + }) + } catch { + sendAndClose(failureResponse) + } + } + + private func forwardToRemote(_ data: Data) { + guard !isClosed else { return } + guard !localInputEOF else { return } + guard let streamID else { return } + do { + try rpcClient.writeStream(streamID: streamID, data: data) + } catch { + close(reason: "proxy.write failed: \(error.localizedDescription)") + } + } + + private func scheduleRemoteReadLoop() { + queue.async { [weak self] in + self?.pollRemoteOnce() + } + } + + private func pollRemoteOnce() { + guard !isClosed else { return } + guard let streamID else { return } + + let readResult: (data: Data, eof: Bool) + do { + readResult = try rpcClient.readStream(streamID: streamID, maxBytes: 32768, timeoutMs: 250) + } catch { + close(reason: "proxy.read failed: \(error.localizedDescription)") + return + } + + if !readResult.data.isEmpty { + connection.send(content: readResult.data, completion: .contentProcessed { [weak self] error in + guard let self else { return } + if let error { + self.close(reason: "proxy client send error: \(error)") + return + } + if readResult.eof { + self.close(reason: nil) + } else { + self.scheduleRemoteReadLoop() + } + }) + return + } + + if readResult.eof { + close(reason: nil) + } else { + scheduleRemoteReadLoop() + } + } + + private func close(reason: String?) { + guard !isClosed else { return } + isClosed = true + + let streamID = self.streamID + self.streamID = nil + + if let streamID { + rpcClient.closeStream(streamID: streamID) + } + if reason != nil { + connection.cancel() + } else { + connection.cancel() + } + onClose(id) + } + + private func sendLocal(_ data: Data) { + guard !isClosed else { return } + connection.send(content: data, completion: .contentProcessed { [weak self] error in + guard let self else { return } + if let error { + self.close(reason: "proxy client send error: \(error)") + } + }) + } + + private func sendAndClose(_ data: Data) { + guard !isClosed else { return } + connection.send(content: data, completion: .contentProcessed { [weak self] _ in + self?.close(reason: nil) + }) + } + + private static func parseConnectAuthority(_ authority: String) -> (host: String, port: Int)? { + let trimmed = authority.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return nil } + + if trimmed.hasPrefix("[") { + guard let closing = trimmed.firstIndex(of: "]") else { return nil } + let host = String(trimmed[trimmed.index(after: trimmed.startIndex)..<closing]) + let portStart = trimmed.index(after: closing) + guard portStart < trimmed.endIndex, trimmed[portStart] == ":" else { return nil } + let portString = String(trimmed[trimmed.index(after: portStart)...]) + guard let port = Int(portString), port > 0, port <= 65535 else { return nil } + return (host, port) + } + + guard let colon = trimmed.lastIndex(of: ":") else { return nil } + let host = String(trimmed[..<colon]) + let portString = String(trimmed[trimmed.index(after: colon)...]) + guard !host.isEmpty else { return nil } + guard let port = Int(portString), port > 0, port <= 65535 else { return nil } + return (host, port) + } + + private static func httpResponse(status: String, closeAfterResponse: Bool = true) -> Data { + var text = "HTTP/1.1 \(status)\r\nProxy-Agent: cmux\r\n" + if closeAfterResponse { + text += "Connection: close\r\n" + } + text += "\r\n" + return Data(text.utf8) + } + } + + private let configuration: WorkspaceRemoteConfiguration + private let remotePath: String + private let localPort: Int + private let onFatalError: (String) -> Void + private let queue = DispatchQueue(label: "com.cmux.remote-ssh.daemon-tunnel.\(UUID().uuidString)", qos: .utility) + + private var listener: NWListener? + private var rpcClient: WorkspaceRemoteDaemonRPCClient? + private var sessions: [UUID: ProxySession] = [:] + private var isStopped = false + + init( + configuration: WorkspaceRemoteConfiguration, + remotePath: String, + localPort: Int, + onFatalError: @escaping (String) -> Void + ) { + self.configuration = configuration + self.remotePath = remotePath + self.localPort = localPort + self.onFatalError = onFatalError + } + + func start() throws { + var capturedError: Error? + queue.sync { + guard !isStopped else { + capturedError = NSError(domain: "cmux.remote.proxy", code: 20, userInfo: [ + NSLocalizedDescriptionKey: "proxy tunnel already stopped", + ]) + return + } + do { + let client = WorkspaceRemoteDaemonRPCClient( + configuration: configuration, + remotePath: remotePath + ) { [weak self] detail in + self?.queue.async { + self?.failLocked("Remote daemon transport failed: \(detail)") + } + } + try client.start() + + let listener = try Self.makeLoopbackListener(port: localPort) + listener.newConnectionHandler = { [weak self] connection in + self?.queue.async { + self?.acceptConnectionLocked(connection) + } + } + listener.stateUpdateHandler = { [weak self] state in + self?.queue.async { + self?.handleListenerStateLocked(state) + } + } + + self.rpcClient = client + self.listener = listener + listener.start(queue: queue) + } catch { + capturedError = error + stopLocked(notify: false) + } + } + if let capturedError { + throw capturedError + } + } + + func stop() { + queue.sync { + stopLocked(notify: false) + } + } + + private func handleListenerStateLocked(_ state: NWListener.State) { + guard !isStopped else { return } + switch state { + case .failed(let error): + failLocked("Local proxy listener failed: \(error)") + default: + break + } + } + + private func acceptConnectionLocked(_ connection: NWConnection) { + guard !isStopped else { + connection.cancel() + return + } + guard let rpcClient else { + connection.cancel() + return + } + + let session = ProxySession( + connection: connection, + rpcClient: rpcClient, + queue: queue + ) { [weak self] id in + self?.queue.async { + self?.sessions.removeValue(forKey: id) + } + } + sessions[session.id] = session + session.start() + } + + private func failLocked(_ detail: String) { + guard !isStopped else { return } + stopLocked(notify: false) + onFatalError(detail) + } + + private func stopLocked(notify: Bool) { + guard !isStopped else { return } + isStopped = true + + listener?.stateUpdateHandler = nil + listener?.newConnectionHandler = nil + listener?.cancel() + listener = nil + + let activeSessions = sessions.values + sessions.removeAll() + for session in activeSessions { + session.stop() + } + + rpcClient?.stop() + rpcClient = nil + } + + private static func makeLoopbackListener(port: Int) throws -> NWListener { + guard let localPort = NWEndpoint.Port(rawValue: UInt16(port)) else { + throw NSError(domain: "cmux.remote.proxy", code: 21, userInfo: [ + NSLocalizedDescriptionKey: "invalid local proxy port \(port)", + ]) + } + let parameters = NWParameters.tcp + parameters.allowLocalEndpointReuse = true + parameters.requiredLocalEndpoint = .hostPort(host: NWEndpoint.Host("127.0.0.1"), port: localPort) + return try NWListener(using: parameters) + } +} + +private final class WorkspaceRemoteProxyBroker { + enum Update { + case connecting + case ready(BrowserProxyEndpoint) + case error(String) + } + + final class Lease { + private let key: String + private let subscriberID: UUID + private weak var broker: WorkspaceRemoteProxyBroker? + private var isReleased = false + + fileprivate init(key: String, subscriberID: UUID, broker: WorkspaceRemoteProxyBroker) { + self.key = key + self.subscriberID = subscriberID + self.broker = broker + } + + func release() { + guard !isReleased else { return } + isReleased = true + broker?.release(key: key, subscriberID: subscriberID) + } + + deinit { + release() + } + } + + private final class Entry { + let configuration: WorkspaceRemoteConfiguration + var remotePath: String + var tunnel: WorkspaceRemoteDaemonProxyTunnel? + var endpoint: BrowserProxyEndpoint? + var restartWorkItem: DispatchWorkItem? + var subscribers: [UUID: (Update) -> Void] = [:] + + init(configuration: WorkspaceRemoteConfiguration, remotePath: String) { + self.configuration = configuration + self.remotePath = remotePath + } + } + + static let shared = WorkspaceRemoteProxyBroker() + + private let queue = DispatchQueue(label: "com.cmux.remote-ssh.proxy-broker", qos: .utility) + private var entries: [String: Entry] = [:] + + func acquire( + configuration: WorkspaceRemoteConfiguration, + remotePath: String, + onUpdate: @escaping (Update) -> Void + ) -> Lease { + queue.sync { + let key = Self.transportKey(for: configuration) + let subscriberID = UUID() + let entry: Entry + if let existing = entries[key] { + entry = existing + if existing.remotePath != remotePath { + existing.remotePath = remotePath + if existing.tunnel != nil { + stopEntryRuntimeLocked(existing) + notifyLocked(existing, update: .connecting) + } + } + } else { + entry = Entry(configuration: configuration, remotePath: remotePath) + entries[key] = entry + } + + entry.subscribers[subscriberID] = onUpdate + if let endpoint = entry.endpoint { + onUpdate(.ready(endpoint)) + } else { + onUpdate(.connecting) + } + + if entry.tunnel == nil, entry.restartWorkItem == nil { + startEntryLocked(key: key, entry: entry) + } + + return Lease(key: key, subscriberID: subscriberID, broker: self) + } + } + + private func release(key: String, subscriberID: UUID) { + queue.async { [weak self] in + guard let self, let entry = self.entries[key] else { return } + entry.subscribers.removeValue(forKey: subscriberID) + guard entry.subscribers.isEmpty else { return } + self.teardownEntryLocked(key: key, entry: entry) + } + } + + private func startEntryLocked(key: String, entry: Entry) { + entry.restartWorkItem?.cancel() + entry.restartWorkItem = nil + + let localPort: Int + if let forcedLocalPort = entry.configuration.localProxyPort { + // Internal deterministic test hook used by docker regressions to force bind conflicts. + localPort = forcedLocalPort + } else { + guard let allocatedPort = Self.allocateLoopbackPort() else { + notifyLocked( + entry, + update: .error("Failed to allocate local proxy port\(Self.retrySuffix(delay: 3.0))") + ) + scheduleRestartLocked(key: key, entry: entry, delay: 3.0) + return + } + localPort = allocatedPort + } + + do { + let tunnel = WorkspaceRemoteDaemonProxyTunnel( + configuration: entry.configuration, + remotePath: entry.remotePath, + localPort: localPort + ) { [weak self] detail in + self?.queue.async { + self?.handleTunnelFailureLocked(key: key, detail: detail) + } + } + try tunnel.start() + entry.tunnel = tunnel + let endpoint = BrowserProxyEndpoint(host: "127.0.0.1", port: localPort) + entry.endpoint = endpoint + notifyLocked(entry, update: .ready(endpoint)) + } catch { + stopEntryRuntimeLocked(entry) + let detail = "Failed to start local daemon proxy: \(error.localizedDescription)" + notifyLocked(entry, update: .error("\(detail)\(Self.retrySuffix(delay: 3.0))")) + scheduleRestartLocked(key: key, entry: entry, delay: 3.0) + } + } + + private func handleTunnelFailureLocked(key: String, detail: String) { + guard let entry = entries[key], entry.tunnel != nil else { return } + stopEntryRuntimeLocked(entry) + notifyLocked(entry, update: .error("\(detail)\(Self.retrySuffix(delay: 3.0))")) + scheduleRestartLocked(key: key, entry: entry, delay: 3.0) + } + + private func scheduleRestartLocked(key: String, entry: Entry, delay: TimeInterval) { + guard !entry.subscribers.isEmpty else { + teardownEntryLocked(key: key, entry: entry) + return + } + guard entry.restartWorkItem == nil else { return } + + let workItem = DispatchWorkItem { [weak self] in + guard let self, let currentEntry = self.entries[key] else { return } + currentEntry.restartWorkItem = nil + guard !currentEntry.subscribers.isEmpty else { + self.teardownEntryLocked(key: key, entry: currentEntry) + return + } + self.notifyLocked(currentEntry, update: .connecting) + self.startEntryLocked(key: key, entry: currentEntry) + } + + entry.restartWorkItem = workItem + queue.asyncAfter(deadline: .now() + delay, execute: workItem) + } + + private func teardownEntryLocked(key: String, entry: Entry) { + entry.restartWorkItem?.cancel() + entry.restartWorkItem = nil + stopEntryRuntimeLocked(entry) + entries.removeValue(forKey: key) + } + + private func stopEntryRuntimeLocked(_ entry: Entry) { + entry.tunnel?.stop() + entry.tunnel = nil + entry.endpoint = nil + } + + private func notifyLocked(_ entry: Entry, update: Update) { + for callback in entry.subscribers.values { + callback(update) + } + } + + private static func transportKey(for configuration: WorkspaceRemoteConfiguration) -> String { + let destination = configuration.destination.trimmingCharacters(in: .whitespacesAndNewlines) + let port = configuration.port.map(String.init) ?? "" + let identity = configuration.identityFile?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" + let localProxyPort = configuration.localProxyPort.map(String.init) ?? "" + let options = configuration.sshOptions + .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + .filter { !$0.isEmpty } + .joined(separator: "\u{1f}") + return [destination, port, identity, options, localProxyPort].joined(separator: "\u{1e}") + } + + private static func allocateLoopbackPort() -> Int? { + for _ in 0..<8 { + let fd = socket(AF_INET, SOCK_STREAM, 0) + guard fd >= 0 else { return nil } + defer { close(fd) } + + var yes: Int32 = 1 + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, socklen_t(MemoryLayout<Int32>.size)) + + var addr = sockaddr_in() + addr.sin_len = UInt8(MemoryLayout<sockaddr_in>.size) + addr.sin_family = sa_family_t(AF_INET) + addr.sin_port = in_port_t(0) + addr.sin_addr = in_addr(s_addr: inet_addr("127.0.0.1")) + + let bindResult = withUnsafePointer(to: &addr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + bind(fd, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_in>.size)) + } + } + guard bindResult == 0 else { continue } + + var bound = sockaddr_in() + var len = socklen_t(MemoryLayout<sockaddr_in>.size) + let nameResult = withUnsafeMutablePointer(to: &bound) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + getsockname(fd, sockaddrPtr, &len) + } + } + guard nameResult == 0 else { continue } + + let port = Int(UInt16(bigEndian: bound.sin_port)) + if port > 0 && port <= 65535 { + return port + } + } + return nil + } + + private static func retrySuffix(delay: TimeInterval) -> String { + let seconds = max(1, Int(delay.rounded())) + return " (retry in \(seconds)s)" + } +} + +private final class WorkspaceRemoteSessionController { private struct CommandResult { let status: Int32 let stdout: String @@ -42,15 +1277,8 @@ private final class WorkspaceRemoteSessionController { private let configuration: WorkspaceRemoteConfiguration private var isStopping = false - private var probeProcess: Process? - private var probeStdoutPipe: Pipe? - private var probeStderrPipe: Pipe? - private var probeStdoutBuffer = "" - private var probeStderrBuffer = "" - - private var desiredRemotePorts: Set<Int> = [] - private var forwardEntries: [Int: ForwardEntry] = [:] - private var portConflicts: Set<Int> = [] + private var proxyLease: WorkspaceRemoteProxyBroker.Lease? + private var proxyEndpoint: BrowserProxyEndpoint? private var daemonReady = false private var daemonBootstrapVersion: String? private var daemonRemotePath: String? @@ -82,31 +1310,14 @@ private final class WorkspaceRemoteSessionController { reconnectWorkItem = nil reconnectRetryCount = 0 - if let probeProcess { - probeStdoutPipe?.fileHandleForReading.readabilityHandler = nil - probeStderrPipe?.fileHandleForReading.readabilityHandler = nil - if probeProcess.isRunning { - probeProcess.terminate() - } - } - probeProcess = nil - probeStdoutPipe = nil - probeStderrPipe = nil - probeStdoutBuffer = "" - probeStderrBuffer = "" - - for (_, entry) in forwardEntries { - entry.stderrPipe.fileHandleForReading.readabilityHandler = nil - if entry.process.isRunning { - entry.process.terminate() - } - } - forwardEntries.removeAll() - desiredRemotePorts.removeAll() - portConflicts.removeAll() + proxyLease?.release() + proxyLease = nil + proxyEndpoint = nil daemonReady = false daemonBootstrapVersion = nil daemonRemotePath = nil + publishProxyEndpoint(nil) + publishPortsSnapshotLocked() } private func beginConnectionAttemptLocked() { @@ -126,6 +1337,11 @@ private final class WorkspaceRemoteSessionController { publishDaemonStatus(.bootstrapping, detail: bootstrapDetail) do { let hello = try bootstrapDaemonLocked() + guard hello.capabilities.contains("proxy.stream") else { + throw NSError(domain: "cmux.remote.daemon", code: 43, userInfo: [ + NSLocalizedDescriptionKey: "remote daemon missing required capability proxy.stream", + ]) + } daemonReady = true daemonBootstrapVersion = hello.version daemonRemotePath = hello.remotePath @@ -137,12 +1353,12 @@ private final class WorkspaceRemoteSessionController { capabilities: hello.capabilities, remotePath: hello.remotePath ) - startProbeLocked() + startProxyLocked() } catch { daemonReady = false daemonBootstrapVersion = nil daemonRemotePath = nil - let nextRetry = scheduleProbeRestartLocked(delay: 4.0) + let nextRetry = scheduleReconnectLocked(delay: 4.0) let retrySuffix = Self.retrySuffix(retry: nextRetry, delay: 4.0) let detail = "Remote daemon bootstrap failed: \(error.localizedDescription)\(retrySuffix)" publishDaemonStatus(.error, detail: detail) @@ -150,89 +1366,74 @@ private final class WorkspaceRemoteSessionController { } } - private func startProbeLocked() { + private func startProxyLocked() { guard !isStopping else { return } guard daemonReady else { return } + guard proxyLease == nil else { return } + guard let remotePath = daemonRemotePath, + !remotePath.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { + let nextRetry = scheduleReconnectLocked(delay: 4.0) + let retrySuffix = Self.retrySuffix(retry: nextRetry, delay: 4.0) + let detail = "Remote daemon did not provide a valid remote path\(retrySuffix)" + publishDaemonStatus(.error, detail: detail) + publishState(.error, detail: detail) + return + } - probeStdoutBuffer = "" - probeStderrBuffer = "" - - let process = Process() - let stdoutPipe = Pipe() - let stderrPipe = Pipe() - process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") - process.arguments = probeArguments() - process.standardOutput = stdoutPipe - process.standardError = stderrPipe - - stdoutPipe.fileHandleForReading.readabilityHandler = { [weak self] handle in - let data = handle.availableData - if data.isEmpty { - handle.readabilityHandler = nil - return - } + let lease = WorkspaceRemoteProxyBroker.shared.acquire( + configuration: configuration, + remotePath: remotePath + ) { [weak self] update in self?.queue.async { - self?.consumeProbeStdoutData(data) + self?.handleProxyBrokerUpdateLocked(update) } } - - stderrPipe.fileHandleForReading.readabilityHandler = { [weak self] handle in - let data = handle.availableData - if data.isEmpty { - handle.readabilityHandler = nil - return - } - self?.queue.async { - self?.consumeProbeStderrData(data) - } - } - - process.terminationHandler = { [weak self] terminated in - self?.queue.async { - self?.handleProbeTermination(terminated) - } - } - - do { - try process.run() - probeProcess = process - probeStdoutPipe = stdoutPipe - probeStderrPipe = stderrPipe - } catch { - let nextRetry = scheduleProbeRestartLocked(delay: 3.0) - let retrySuffix = Self.retrySuffix(retry: nextRetry, delay: 3.0) - publishState(.error, detail: "Failed to start SSH probe: \(error.localizedDescription)\(retrySuffix)") - } + proxyLease = lease } - private func handleProbeTermination(_ process: Process) { - probeStdoutPipe?.fileHandleForReading.readabilityHandler = nil - probeStderrPipe?.fileHandleForReading.readabilityHandler = nil - probeProcess = nil - probeStdoutPipe = nil - probeStderrPipe = nil - + private func handleProxyBrokerUpdateLocked(_ update: WorkspaceRemoteProxyBroker.Update) { guard !isStopping else { return } - - for (_, entry) in forwardEntries { - entry.stderrPipe.fileHandleForReading.readabilityHandler = nil - if entry.process.isRunning { - entry.process.terminate() + switch update { + case .connecting: + if proxyEndpoint == nil { + publishState(.connecting, detail: "Connecting to \(configuration.displayTarget)") } - } - forwardEntries.removeAll() - publishPortsSnapshotLocked() + case .ready(let endpoint): + reconnectWorkItem?.cancel() + reconnectWorkItem = nil + reconnectRetryCount = 0 + guard proxyEndpoint != endpoint else { return } + proxyEndpoint = endpoint + publishProxyEndpoint(endpoint) + publishPortsSnapshotLocked() + publishState( + .connected, + detail: "Connected to \(configuration.displayTarget) via shared local proxy \(endpoint.host):\(endpoint.port)" + ) + case .error(let detail): + proxyEndpoint = nil + publishProxyEndpoint(nil) + publishPortsSnapshotLocked() + publishState(.error, detail: "Remote proxy to \(configuration.displayTarget) unavailable: \(detail)") + guard Self.shouldEscalateProxyErrorToBootstrap(detail) else { return } - let statusCode = process.terminationStatus - let rawDetail = Self.bestErrorLine(stderr: probeStderrBuffer, stdout: probeStdoutBuffer) - let detail = rawDetail ?? "SSH probe exited with status \(statusCode)" - let nextRetry = scheduleProbeRestartLocked(delay: 3.0) - let retrySuffix = Self.retrySuffix(retry: nextRetry, delay: 3.0) - publishState(.error, detail: "SSH probe to \(configuration.displayTarget) failed: \(detail)\(retrySuffix)") + proxyLease?.release() + proxyLease = nil + daemonReady = false + daemonBootstrapVersion = nil + daemonRemotePath = nil + + let nextRetry = scheduleReconnectLocked(delay: 2.0) + let retrySuffix = Self.retrySuffix(retry: nextRetry, delay: 2.0) + publishDaemonStatus( + .error, + detail: "Remote daemon transport needs re-bootstrap after proxy failure\(retrySuffix)" + ) + } } @discardableResult - private func scheduleProbeRestartLocked(delay: TimeInterval) -> Int { + private func scheduleReconnectLocked(delay: TimeInterval) -> Int { guard !isStopping else { return reconnectRetryCount } reconnectWorkItem?.cancel() reconnectRetryCount += 1 @@ -241,7 +1442,7 @@ private final class WorkspaceRemoteSessionController { guard let self else { return } self.reconnectWorkItem = nil guard !self.isStopping else { return } - guard self.probeProcess == nil else { return } + guard self.proxyLease == nil else { return } self.beginConnectionAttemptLocked() } reconnectWorkItem = workItem @@ -249,143 +1450,6 @@ private final class WorkspaceRemoteSessionController { return retryNumber } - private func consumeProbeStdoutData(_ data: Data) { - guard let chunk = String(data: data, encoding: .utf8), !chunk.isEmpty else { return } - probeStdoutBuffer.append(chunk) - - while let newline = probeStdoutBuffer.firstIndex(of: "\n") { - let line = String(probeStdoutBuffer[..<newline]) - probeStdoutBuffer.removeSubrange(...newline) - handleProbePortsLine(line) - } - } - - private func consumeProbeStderrData(_ data: Data) { - guard let chunk = String(data: data, encoding: .utf8), !chunk.isEmpty else { return } - probeStderrBuffer.append(chunk) - if probeStderrBuffer.count > 8192 { - probeStderrBuffer.removeFirst(probeStderrBuffer.count - 8192) - } - } - - private func handleProbePortsLine(_ line: String) { - guard !isStopping else { return } - - let ports = Self.parseRemotePorts(line: line) - desiredRemotePorts = Set(ports) - portConflicts = portConflicts.intersection(desiredRemotePorts) - reconnectWorkItem?.cancel() - reconnectWorkItem = nil - reconnectRetryCount = 0 - publishState(.connected, detail: "Connected to \(configuration.displayTarget)") - reconcileForwardsLocked() - } - - private func reconcileForwardsLocked() { - guard !isStopping else { return } - - for (port, entry) in forwardEntries where !desiredRemotePorts.contains(port) { - entry.stderrPipe.fileHandleForReading.readabilityHandler = nil - if entry.process.isRunning { - entry.process.terminate() - } - forwardEntries.removeValue(forKey: port) - } - - for port in desiredRemotePorts.sorted() where forwardEntries[port] == nil { - guard Self.isLoopbackPortAvailable(port: port) else { - portConflicts.insert(port) - continue - } - if startForwardLocked(port: port) { - portConflicts.remove(port) - } else { - portConflicts.insert(port) - } - } - - publishPortsSnapshotLocked() - } - - @discardableResult - private func startForwardLocked(port: Int) -> Bool { - guard !isStopping else { return false } - - let process = Process() - let stderrPipe = Pipe() - process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") - process.arguments = forwardArguments(port: port) - process.standardOutput = FileHandle.nullDevice - process.standardError = stderrPipe - - stderrPipe.fileHandleForReading.readabilityHandler = { [weak self] handle in - let data = handle.availableData - guard !data.isEmpty else { - handle.readabilityHandler = nil - return - } - self?.queue.async { - guard let self else { return } - if let chunk = String(data: data, encoding: .utf8), !chunk.isEmpty { - self.probeStderrBuffer.append(chunk) - if self.probeStderrBuffer.count > 8192 { - self.probeStderrBuffer.removeFirst(self.probeStderrBuffer.count - 8192) - } - } - } - } - - process.terminationHandler = { [weak self] terminated in - self?.queue.async { - self?.handleForwardTermination(port: port, process: terminated) - } - } - - do { - try process.run() - forwardEntries[port] = ForwardEntry(process: process, stderrPipe: stderrPipe) - return true - } catch { - publishState(.error, detail: "Failed to forward local :\(port) to \(configuration.displayTarget): \(error.localizedDescription)") - return false - } - } - - private func handleForwardTermination(port: Int, process: Process) { - if let current = forwardEntries[port], current.process === process { - current.stderrPipe.fileHandleForReading.readabilityHandler = nil - forwardEntries.removeValue(forKey: port) - } - - guard !isStopping else { return } - publishPortsSnapshotLocked() - - guard desiredRemotePorts.contains(port) else { return } - let rawDetail = Self.bestErrorLine(stderr: probeStderrBuffer) - if process.terminationReason != .exit || process.terminationStatus != 0 { - let detail = rawDetail ?? "process exited with status \(process.terminationStatus)" - publishState(.error, detail: "SSH port-forward :\(port) dropped for \(configuration.displayTarget): \(detail)") - } - guard Self.isLoopbackPortAvailable(port: port) else { - portConflicts.insert(port) - publishPortsSnapshotLocked() - return - } - - queue.asyncAfter(deadline: .now() + 1.0) { [weak self] in - guard let self else { return } - guard !self.isStopping else { return } - guard self.desiredRemotePorts.contains(port) else { return } - guard self.forwardEntries[port] == nil else { return } - if self.startForwardLocked(port: port) { - self.portConflicts.remove(port) - } else { - self.portConflicts.insert(port) - } - self.publishPortsSnapshotLocked() - } - } - private func publishState(_ state: WorkspaceRemoteConnectionState, detail: String?) { DispatchQueue.main.async { [weak workspace] in guard let workspace else { return } @@ -422,30 +1486,23 @@ private final class WorkspaceRemoteSessionController { } } - private func publishPortsSnapshotLocked() { - let detected = desiredRemotePorts.sorted() - let forwarded = forwardEntries.keys.sorted() - let conflicts = portConflicts.sorted() + private func publishProxyEndpoint(_ endpoint: BrowserProxyEndpoint?) { DispatchQueue.main.async { [weak workspace] in guard let workspace else { return } - workspace.applyRemotePortsSnapshot( - detected: detected, - forwarded: forwarded, - conflicts: conflicts, - target: workspace.remoteDisplayTarget ?? "remote host" - ) + workspace.applyRemoteProxyEndpointUpdate(endpoint) } } - private func probeArguments() -> [String] { - let remoteScript = Self.probeScript() - let remoteCommand = "sh -lc \(Self.shellSingleQuoted(remoteScript))" - return sshCommonArguments(batchMode: true) + [configuration.destination, remoteCommand] - } - - private func forwardArguments(port: Int) -> [String] { - let localBind = "127.0.0.1:\(port):127.0.0.1:\(port)" - return ["-N", "-o", "ExitOnForwardFailure=yes"] + sshCommonArguments(batchMode: true) + ["-L", localBind, configuration.destination] + private func publishPortsSnapshotLocked() { + DispatchQueue.main.async { [weak workspace] in + guard let workspace else { return } + workspace.applyRemotePortsSnapshot( + detected: [], + forwarded: [], + conflicts: [], + target: workspace.remoteDisplayTarget ?? "remote host" + ) + } } private func sshCommonArguments(batchMode: Bool) -> [String] { @@ -453,8 +1510,10 @@ private final class WorkspaceRemoteSessionController { "-o", "ConnectTimeout=6", "-o", "ServerAliveInterval=20", "-o", "ServerAliveCountMax=2", - "-o", "StrictHostKeyChecking=accept-new", ] + if !hasSSHOptionKey(configuration.sshOptions, key: "StrictHostKeyChecking") { + args += ["-o", "StrictHostKeyChecking=accept-new"] + } if batchMode { args += ["-o", "BatchMode=yes"] } @@ -473,6 +1532,19 @@ private final class WorkspaceRemoteSessionController { return args } + private func hasSSHOptionKey(_ options: [String], key: String) -> Bool { + let loweredKey = key.lowercased() + for option in options { + let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { continue } + let token = trimmed.split(whereSeparator: { $0 == "=" || $0.isWhitespace }).first.map(String.init)?.lowercased() + if token == loweredKey { + return true + } + } + return false + } + private func sshExec(arguments: [String], stdin: Data? = nil, timeout: TimeInterval = 15) throws -> CommandResult { try runProcess( executable: "/usr/bin/ssh", @@ -670,7 +1742,10 @@ private final class WorkspaceRemoteSessionController { ]) } - var scpArgs: [String] = ["-q", "-o", "StrictHostKeyChecking=accept-new"] + var scpArgs: [String] = ["-q"] + if !hasSSHOptionKey(configuration.sshOptions, key: "StrictHostKeyChecking") { + scpArgs += ["-o", "StrictHostKeyChecking=accept-new"] + } if let port = configuration.port { scpArgs += ["-P", String(port)] } @@ -756,38 +1831,6 @@ private final class WorkspaceRemoteSessionController { ) } - private static func parseRemotePorts(line: String) -> [Int] { - let tokens = line.split(whereSeparator: \.isWhitespace) - let values = tokens.compactMap { Int($0) } - let filtered = values.filter { $0 >= 1024 && $0 <= 65535 } - let unique = Set(filtered) - if unique.count <= 40 { - return unique.sorted() - } - return Array(unique.sorted().prefix(40)) - } - - private static func probeScript() -> String { - """ - set -eu - CMUX_LAST="" - while true; do - if command -v ss >/dev/null 2>&1; then - PORTS="$(ss -ltnH 2>/dev/null | awk '{print $4}' | sed -E 's/.*:([0-9]+)$/\\1/' | awk '/^[0-9]+$/ {print $1}' | sort -n -u | tr '\\n' ' ')" - elif command -v netstat >/dev/null 2>&1; then - PORTS="$(netstat -lnt 2>/dev/null | awk '{print $4}' | sed -E 's/.*:([0-9]+)$/\\1/' | awk '/^[0-9]+$/ {print $1}' | sort -n -u | tr '\\n' ' ')" - else - PORTS="" - fi - if [ "$PORTS" != "$CMUX_LAST" ]; then - echo "$PORTS" - CMUX_LAST="$PORTS" - fi - sleep 2 - done - """ - } - private static func shellSingleQuoted(_ value: String) -> String { "'" + value.replacingOccurrences(of: "'", with: "'\"'\"'") + "'" } @@ -919,29 +1962,15 @@ private final class WorkspaceRemoteSessionController { return " (retry \(retry) in \(seconds)s)" } - private static func isLoopbackPortAvailable(port: Int) -> Bool { - guard port > 0 && port <= 65535 else { return false } - - let fd = socket(AF_INET, SOCK_STREAM, 0) - guard fd >= 0 else { return false } - defer { close(fd) } - - var yes: Int32 = 1 - setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, socklen_t(MemoryLayout<Int32>.size)) - - var addr = sockaddr_in() - addr.sin_len = UInt8(MemoryLayout<sockaddr_in>.size) - addr.sin_family = sa_family_t(AF_INET) - addr.sin_port = in_port_t(UInt16(port).bigEndian) - addr.sin_addr = in_addr(s_addr: inet_addr("127.0.0.1")) - - let bindResult = withUnsafePointer(to: &addr) { ptr in - ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in - bind(fd, sockaddrPtr, socklen_t(MemoryLayout<sockaddr_in>.size)) - } - } - return bindResult == 0 + private static func shouldEscalateProxyErrorToBootstrap(_ detail: String) -> Bool { + let lowered = detail.lowercased() + return lowered.contains("remote daemon transport failed") + || lowered.contains("daemon transport closed stdout") + || lowered.contains("daemon transport exited") + || lowered.contains("daemon transport is not connected") + || lowered.contains("daemon transport stopped") } + } enum SidebarLogLevel: String { @@ -1008,6 +2037,7 @@ struct WorkspaceRemoteConfiguration: Equatable { let port: Int? let identityFile: String? let sshOptions: [String] + let localProxyPort: Int? var displayTarget: String { guard let port else { return destination } @@ -1080,6 +2110,7 @@ final class Workspace: Identifiable, ObservableObject { @Published var remoteDetectedPorts: [Int] = [] @Published var remoteForwardedPorts: [Int] = [] @Published var remotePortConflicts: [Int] = [] + @Published var remoteProxyEndpoint: BrowserProxyEndpoint? @Published var listeningPorts: [Int] = [] var surfaceTTYNames: [UUID: String] = [:] private var remoteSessionController: WorkspaceRemoteSessionController? @@ -1588,16 +2619,45 @@ final class Workspace: Identifiable, ObservableObject { "conflicted_ports": remotePortConflicts, "detail": remoteConnectionDetail ?? NSNull(), ] + if let endpoint = remoteProxyEndpoint { + payload["proxy"] = [ + "state": "ready", + "host": endpoint.host, + "port": endpoint.port, + "schemes": ["socks5", "http_connect"], + "url": "socks5://\(endpoint.host):\(endpoint.port)", + ] + } else { + let proxyState: String + switch remoteConnectionState { + case .connecting: + proxyState = "connecting" + case .error: + proxyState = "error" + default: + proxyState = "unavailable" + } + payload["proxy"] = [ + "state": proxyState, + "host": NSNull(), + "port": NSNull(), + "schemes": ["socks5", "http_connect"], + "url": NSNull(), + "error_code": proxyState == "error" ? "proxy_unavailable" : NSNull(), + ] + } if let remoteConfiguration { payload["destination"] = remoteConfiguration.destination payload["port"] = remoteConfiguration.port ?? NSNull() payload["identity_file"] = remoteConfiguration.identityFile ?? NSNull() payload["ssh_options"] = remoteConfiguration.sshOptions + payload["local_proxy_port"] = remoteConfiguration.localProxyPort ?? NSNull() } else { payload["destination"] = NSNull() payload["port"] = NSNull() payload["identity_file"] = NSNull() payload["ssh_options"] = [] + payload["local_proxy_port"] = NSNull() } return payload } @@ -1607,6 +2667,7 @@ final class Workspace: Identifiable, ObservableObject { remoteDetectedPorts = [] remoteForwardedPorts = [] remotePortConflicts = [] + remoteProxyEndpoint = nil remoteConnectionDetail = nil remoteDaemonStatus = WorkspaceRemoteDaemonStatus() statusEntries.removeValue(forKey: Self.remoteErrorStatusKey) @@ -1618,6 +2679,7 @@ final class Workspace: Identifiable, ObservableObject { remoteSessionController?.stop() remoteSessionController = nil + applyRemoteProxyEndpointUpdate(nil) guard autoConnect else { remoteConnectionState = .disconnected @@ -1641,6 +2703,7 @@ final class Workspace: Identifiable, ObservableObject { remoteDetectedPorts = [] remoteForwardedPorts = [] remotePortConflicts = [] + remoteProxyEndpoint = nil remoteConnectionState = .disconnected remoteConnectionDetail = nil remoteDaemonStatus = WorkspaceRemoteDaemonStatus() @@ -1652,6 +2715,7 @@ final class Workspace: Identifiable, ObservableObject { if clearConfiguration { remoteConfiguration = nil } + applyRemoteProxyEndpointUpdate(nil) recomputeListeningPorts() } @@ -1719,6 +2783,14 @@ final class Workspace: Identifiable, ObservableObject { ) } + fileprivate func applyRemoteProxyEndpointUpdate(_ endpoint: BrowserProxyEndpoint?) { + remoteProxyEndpoint = endpoint + for panel in panels.values { + guard let browserPanel = panel as? BrowserPanel else { continue } + browserPanel.setRemoteProxyEndpoint(endpoint) + } + } + fileprivate func applyRemotePortsSnapshot(detected: [Int], forwarded: [Int], conflicts: [Int], target: String) { remoteDetectedPorts = detected remoteForwardedPorts = forwarded @@ -1929,7 +3001,11 @@ final class Workspace: Identifiable, ObservableObject { guard let paneId = sourcePaneId else { return nil } // Create browser panel - let browserPanel = BrowserPanel(workspaceId: id, initialURL: url) + let browserPanel = BrowserPanel( + workspaceId: id, + initialURL: url, + proxyEndpoint: remoteProxyEndpoint + ) panels[browserPanel.id] = browserPanel panelTitles[browserPanel.id] = browserPanel.displayTitle @@ -1985,7 +3061,8 @@ final class Workspace: Identifiable, ObservableObject { let browserPanel = BrowserPanel( workspaceId: id, initialURL: url, - bypassInsecureHTTPHostOnce: bypassInsecureHTTPHostOnce + bypassInsecureHTTPHostOnce: bypassInsecureHTTPHostOnce, + proxyEndpoint: remoteProxyEndpoint ) panels[browserPanel.id] = browserPanel panelTitles[browserPanel.id] = browserPanel.displayTitle diff --git a/daemon/remote/README.md b/daemon/remote/README.md index c273ddc5..fe4951a2 100644 --- a/daemon/remote/README.md +++ b/daemon/remote/README.md @@ -9,8 +9,25 @@ Current commands: Current RPC methods (newline-delimited JSON): 1. `hello` 2. `ping` +3. `proxy.open` +4. `proxy.close` +5. `proxy.write` +6. `proxy.read` +7. `session.open` +8. `session.close` +9. `session.attach` +10. `session.resize` +11. `session.detach` +12. `session.status` Current integration in cmux: 1. `workspace.remote.configure` now bootstraps this binary over SSH when missing. -2. Client sends `hello` before enabling remote port probing/forwarding. -3. Daemon status/capabilities are exposed in `workspace.remote.status -> remote.daemon`. +2. Client sends `hello` before enabling remote proxy transport. +3. Local workspace proxy broker serves SOCKS5 + HTTP CONNECT and tunnels stream traffic through `proxy.*` RPC over `serve --stdio`. +4. Daemon status/capabilities are exposed in `workspace.remote.status -> remote.daemon` (including `session.resize.min`). + +`workspace.remote.configure` contract notes: +1. `port` / `local_proxy_port` accept integer values and numeric strings; explicit `null` clears each field. +2. Out-of-range values and invalid types return `invalid_params`. +3. `local_proxy_port` is an internal deterministic test hook used by bind-conflict regressions. +4. SSH option precedence checks are case-insensitive; user overrides for `StrictHostKeyChecking` and control-socket keys prevent default injection. diff --git a/daemon/remote/cmd/cmuxd-remote/main.go b/daemon/remote/cmd/cmuxd-remote/main.go index 0e299c8c..727039d2 100644 --- a/daemon/remote/cmd/cmuxd-remote/main.go +++ b/daemon/remote/cmd/cmuxd-remote/main.go @@ -2,11 +2,17 @@ package main import ( "bufio" + "encoding/base64" "encoding/json" "flag" "fmt" "io" + "net" "os" + "sort" + "strconv" + "sync" + "time" ) var version = "dev" @@ -29,6 +35,28 @@ type rpcResponse struct { Error *rpcError `json:"error,omitempty"` } +type rpcServer struct { + mu sync.Mutex + nextStreamID uint64 + nextSessionID uint64 + streams map[string]net.Conn + sessions map[string]*sessionState +} + +type sessionAttachment struct { + Cols int + Rows int + UpdatedAt time.Time +} + +type sessionState struct { + attachments map[string]sessionAttachment + effectiveCols int + effectiveRows int + lastKnownCols int + lastKnownRows int +} + func main() { os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr)) } @@ -72,7 +100,16 @@ func usage(w io.Writer) { } func runStdioServer(stdin io.Reader, stdout io.Writer) error { + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]net.Conn{}, + sessions: map[string]*sessionState{}, + } + defer server.closeAll() + scanner := bufio.NewScanner(stdin) + scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024) writer := bufio.NewWriter(stdout) defer writer.Flush() @@ -96,7 +133,7 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error { continue } - resp := handleRequest(req) + resp := server.handleRequest(req) if err := writeResponse(writer, resp); err != nil { return err } @@ -122,7 +159,7 @@ func writeResponse(w *bufio.Writer, resp rpcResponse) error { return w.Flush() } -func handleRequest(req rpcRequest) rpcResponse { +func (s *rpcServer) handleRequest(req rpcRequest) rpcResponse { if req.Method == "" { return rpcResponse{ ID: req.ID, @@ -144,8 +181,10 @@ func handleRequest(req rpcRequest) rpcResponse { "version": version, "capabilities": []string{ "session.basic", + "session.resize.min", "proxy.http_connect", "proxy.socks5", + "proxy.stream", }, }, } @@ -157,6 +196,26 @@ func handleRequest(req rpcRequest) rpcResponse { "pong": true, }, } + case "proxy.open": + return s.handleProxyOpen(req) + case "proxy.close": + return s.handleProxyClose(req) + case "proxy.write": + return s.handleProxyWrite(req) + case "proxy.read": + return s.handleProxyRead(req) + case "session.open": + return s.handleSessionOpen(req) + case "session.close": + return s.handleSessionClose(req) + case "session.attach": + return s.handleSessionAttach(req) + case "session.resize": + return s.handleSessionResize(req) + case "session.detach": + return s.handleSessionDetach(req) + case "session.status": + return s.handleSessionStatus(req) default: return rpcResponse{ ID: req.ID, @@ -168,3 +227,704 @@ func handleRequest(req rpcRequest) rpcResponse { } } } + +func (s *rpcServer) handleProxyOpen(req rpcRequest) rpcResponse { + host, ok := getStringParam(req.Params, "host") + if !ok || host == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.open requires host", + }, + } + } + port, ok := getIntParam(req.Params, "port") + if !ok || port <= 0 || port > 65535 { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.open requires port in range 1-65535", + }, + } + } + + timeoutMs := 10000 + if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 { + timeoutMs = parsed + } + + conn, err := net.DialTimeout( + "tcp", + net.JoinHostPort(host, strconv.Itoa(port)), + time.Duration(timeoutMs)*time.Millisecond, + ) + if err != nil { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "open_failed", + Message: err.Error(), + }, + } + } + + s.mu.Lock() + streamID := fmt.Sprintf("s-%d", s.nextStreamID) + s.nextStreamID++ + s.streams[streamID] = conn + s.mu.Unlock() + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "stream_id": streamID, + }, + } +} + +func (s *rpcServer) handleProxyClose(req rpcRequest) rpcResponse { + streamID, ok := getStringParam(req.Params, "stream_id") + if !ok || streamID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.close requires stream_id", + }, + } + } + + s.mu.Lock() + conn, exists := s.streams[streamID] + if exists { + delete(s.streams, streamID) + } + s.mu.Unlock() + + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "stream not found", + }, + } + } + + _ = conn.Close() + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "closed": true, + }, + } +} + +func (s *rpcServer) handleProxyWrite(req rpcRequest) rpcResponse { + streamID, ok := getStringParam(req.Params, "stream_id") + if !ok || streamID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.write requires stream_id", + }, + } + } + dataBase64, ok := getStringParam(req.Params, "data_base64") + if !ok { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.write requires data_base64", + }, + } + } + payload, err := base64.StdEncoding.DecodeString(dataBase64) + if err != nil { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "data_base64 must be valid base64", + }, + } + } + + conn, found := s.getStream(streamID) + if !found { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "stream not found", + }, + } + } + + total := 0 + for total < len(payload) { + written, writeErr := conn.Write(payload[total:]) + total += written + if writeErr != nil { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "stream_error", + Message: writeErr.Error(), + }, + } + } + } + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "written": total, + }, + } +} + +func (s *rpcServer) handleProxyRead(req rpcRequest) rpcResponse { + streamID, ok := getStringParam(req.Params, "stream_id") + if !ok || streamID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "proxy.read requires stream_id", + }, + } + } + + maxBytes := 32768 + if parsed, hasMax := getIntParam(req.Params, "max_bytes"); hasMax { + maxBytes = parsed + } + if maxBytes <= 0 || maxBytes > 262144 { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "max_bytes must be in range 1-262144", + }, + } + } + + timeoutMs := 50 + if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 { + timeoutMs = parsed + } + + conn, found := s.getStream(streamID) + if !found { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "stream not found", + }, + } + } + + _ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond)) + buffer := make([]byte, maxBytes) + n, readErr := conn.Read(buffer) + data := buffer[:max(0, n)] + + if readErr != nil { + if netErr, ok := readErr.(net.Error); ok && netErr.Timeout() { + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "data_base64": "", + "eof": false, + }, + } + } + if readErr == io.EOF { + s.dropStream(streamID) + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "data_base64": base64.StdEncoding.EncodeToString(data), + "eof": true, + }, + } + } + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "stream_error", + Message: readErr.Error(), + }, + } + } + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "data_base64": base64.StdEncoding.EncodeToString(data), + "eof": false, + }, + } +} + +func (s *rpcServer) handleSessionOpen(req rpcRequest) rpcResponse { + sessionID, _ := getStringParam(req.Params, "session_id") + + s.mu.Lock() + defer s.mu.Unlock() + + if sessionID == "" { + sessionID = fmt.Sprintf("sess-%d", s.nextSessionID) + s.nextSessionID++ + } + + session, exists := s.sessions[sessionID] + if !exists { + session = &sessionState{ + attachments: map[string]sessionAttachment{}, + } + s.sessions[sessionID] = session + } + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: sessionSnapshot(sessionID, session), + } +} + +func (s *rpcServer) handleSessionClose(req rpcRequest) rpcResponse { + sessionID, ok := getStringParam(req.Params, "session_id") + if !ok || sessionID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "session.close requires session_id", + }, + } + } + + s.mu.Lock() + _, exists := s.sessions[sessionID] + if exists { + delete(s.sessions, sessionID) + } + s.mu.Unlock() + + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "session not found", + }, + } + } + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: map[string]any{ + "session_id": sessionID, + "closed": true, + }, + } +} + +func (s *rpcServer) handleSessionAttach(req rpcRequest) rpcResponse { + sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.attach") + if badResp != nil { + return *badResp + } + + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.sessions[sessionID] + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "session not found", + }, + } + } + + session.attachments[attachmentID] = sessionAttachment{ + Cols: cols, + Rows: rows, + UpdatedAt: time.Now().UTC(), + } + recomputeSessionSize(session) + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: sessionSnapshot(sessionID, session), + } +} + +func (s *rpcServer) handleSessionResize(req rpcRequest) rpcResponse { + sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.resize") + if badResp != nil { + return *badResp + } + + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.sessions[sessionID] + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "session not found", + }, + } + } + if _, exists := session.attachments[attachmentID]; !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "attachment not found", + }, + } + } + + session.attachments[attachmentID] = sessionAttachment{ + Cols: cols, + Rows: rows, + UpdatedAt: time.Now().UTC(), + } + recomputeSessionSize(session) + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: sessionSnapshot(sessionID, session), + } +} + +func (s *rpcServer) handleSessionDetach(req rpcRequest) rpcResponse { + sessionID, ok := getStringParam(req.Params, "session_id") + if !ok || sessionID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "session.detach requires session_id", + }, + } + } + attachmentID, ok := getStringParam(req.Params, "attachment_id") + if !ok || attachmentID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "session.detach requires attachment_id", + }, + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.sessions[sessionID] + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "session not found", + }, + } + } + if _, exists := session.attachments[attachmentID]; !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "attachment not found", + }, + } + } + + delete(session.attachments, attachmentID) + recomputeSessionSize(session) + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: sessionSnapshot(sessionID, session), + } +} + +func (s *rpcServer) handleSessionStatus(req rpcRequest) rpcResponse { + sessionID, ok := getStringParam(req.Params, "session_id") + if !ok || sessionID == "" { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: "session.status requires session_id", + }, + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.sessions[sessionID] + if !exists { + return rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "not_found", + Message: "session not found", + }, + } + } + + return rpcResponse{ + ID: req.ID, + OK: true, + Result: sessionSnapshot(sessionID, session), + } +} + +func parseSessionAttachmentParams(req rpcRequest, method string) (sessionID string, attachmentID string, cols int, rows int, badResp *rpcResponse) { + sessionID, ok := getStringParam(req.Params, "session_id") + if !ok || sessionID == "" { + resp := rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: method + " requires session_id", + }, + } + return "", "", 0, 0, &resp + } + attachmentID, ok = getStringParam(req.Params, "attachment_id") + if !ok || attachmentID == "" { + resp := rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: method + " requires attachment_id", + }, + } + return "", "", 0, 0, &resp + } + + cols, ok = getIntParam(req.Params, "cols") + if !ok || cols <= 0 { + resp := rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: method + " requires cols > 0", + }, + } + return "", "", 0, 0, &resp + } + rows, ok = getIntParam(req.Params, "rows") + if !ok || rows <= 0 { + resp := rpcResponse{ + ID: req.ID, + OK: false, + Error: &rpcError{ + Code: "invalid_params", + Message: method + " requires rows > 0", + }, + } + return "", "", 0, 0, &resp + } + + return sessionID, attachmentID, cols, rows, nil +} + +func recomputeSessionSize(session *sessionState) { + if len(session.attachments) == 0 { + session.effectiveCols = session.lastKnownCols + session.effectiveRows = session.lastKnownRows + return + } + + minCols := 0 + minRows := 0 + for _, attachment := range session.attachments { + if minCols == 0 || attachment.Cols < minCols { + minCols = attachment.Cols + } + if minRows == 0 || attachment.Rows < minRows { + minRows = attachment.Rows + } + } + + session.effectiveCols = minCols + session.effectiveRows = minRows + session.lastKnownCols = minCols + session.lastKnownRows = minRows +} + +func sessionSnapshot(sessionID string, session *sessionState) map[string]any { + attachmentIDs := make([]string, 0, len(session.attachments)) + for attachmentID := range session.attachments { + attachmentIDs = append(attachmentIDs, attachmentID) + } + sort.Strings(attachmentIDs) + + attachments := make([]map[string]any, 0, len(attachmentIDs)) + for _, attachmentID := range attachmentIDs { + attachment := session.attachments[attachmentID] + attachments = append(attachments, map[string]any{ + "attachment_id": attachmentID, + "cols": attachment.Cols, + "rows": attachment.Rows, + "updated_at": attachment.UpdatedAt.Format(time.RFC3339Nano), + }) + } + + return map[string]any{ + "session_id": sessionID, + "attachments": attachments, + "effective_cols": session.effectiveCols, + "effective_rows": session.effectiveRows, + "last_known_cols": session.lastKnownCols, + "last_known_rows": session.lastKnownRows, + } +} + +func (s *rpcServer) getStream(streamID string) (net.Conn, bool) { + s.mu.Lock() + defer s.mu.Unlock() + conn, ok := s.streams[streamID] + return conn, ok +} + +func (s *rpcServer) dropStream(streamID string) { + s.mu.Lock() + conn, ok := s.streams[streamID] + if ok { + delete(s.streams, streamID) + } + s.mu.Unlock() + if ok { + _ = conn.Close() + } +} + +func (s *rpcServer) closeAll() { + s.mu.Lock() + streams := make([]net.Conn, 0, len(s.streams)) + for id, conn := range s.streams { + delete(s.streams, id) + streams = append(streams, conn) + } + for id := range s.sessions { + delete(s.sessions, id) + } + s.mu.Unlock() + for _, conn := range streams { + _ = conn.Close() + } +} + +func getStringParam(params map[string]any, key string) (string, bool) { + if params == nil { + return "", false + } + raw, ok := params[key] + if !ok || raw == nil { + return "", false + } + value, ok := raw.(string) + return value, ok +} + +func getIntParam(params map[string]any, key string) (int, bool) { + if params == nil { + return 0, false + } + raw, ok := params[key] + if !ok || raw == nil { + return 0, false + } + switch value := raw.(type) { + case int: + return value, true + case int8: + return int(value), true + case int16: + return int(value), true + case int32: + return int(value), true + case int64: + return int(value), true + case uint: + return int(value), true + case uint8: + return int(value), true + case uint16: + return int(value), true + case uint32: + return int(value), true + case uint64: + return int(value), true + case float64: + return int(value), true + case json.Number: + n, err := value.Int64() + if err != nil { + return 0, false + } + return int(n), true + default: + return 0, false + } +} diff --git a/daemon/remote/cmd/cmuxd-remote/main_test.go b/daemon/remote/cmd/cmuxd-remote/main_test.go index 4d90d6c0..663fd234 100644 --- a/daemon/remote/cmd/cmuxd-remote/main_test.go +++ b/daemon/remote/cmd/cmuxd-remote/main_test.go @@ -2,9 +2,13 @@ package main import ( "bytes" + "encoding/base64" "encoding/json" + "net" + "strconv" "strings" "testing" + "time" ) func TestRunVersion(t *testing.T) { @@ -99,3 +103,371 @@ func TestRunStdioInvalidJSONAndUnknownMethod(t *testing.T) { t.Fatalf("unknown method should return method_not_found; got=%v payload=%v", got, second) } } + +func TestRunStdioSessionResizeFlow(t *testing.T) { + input := strings.NewReader( + `{"id":1,"method":"session.open","params":{"session_id":"sess-stdio"}}` + "\n" + + `{"id":2,"method":"session.attach","params":{"session_id":"sess-stdio","attachment_id":"a1","cols":120,"rows":40}}` + "\n" + + `{"id":3,"method":"session.attach","params":{"session_id":"sess-stdio","attachment_id":"a2","cols":90,"rows":30}}` + "\n" + + `{"id":4,"method":"session.status","params":{"session_id":"sess-stdio"}}` + "\n", + ) + var out bytes.Buffer + code := run([]string{"serve", "--stdio"}, input, &out, &bytes.Buffer{}) + if code != 0 { + t.Fatalf("run serve exit code = %d, want 0", code) + } + + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + if len(lines) != 4 { + t.Fatalf("got %d response lines, want 4: %q", len(lines), out.String()) + } + + var status map[string]any + if err := json.Unmarshal([]byte(lines[3]), &status); err != nil { + t.Fatalf("failed to decode status response: %v", err) + } + if ok, _ := status["ok"].(bool); !ok { + t.Fatalf("session.status should be ok=true: %v", status) + } + result, _ := status["result"].(map[string]any) + if result == nil { + t.Fatalf("session.status missing result object: %v", status) + } + effectiveCols, _ := result["effective_cols"].(float64) + effectiveRows, _ := result["effective_rows"].(float64) + if int(effectiveCols) != 90 || int(effectiveRows) != 30 { + t.Fatalf("session smallest-wins effective size mismatch: got=%vx%v payload=%v", effectiveCols, effectiveRows, result) + } +} + +func TestProxyStreamRoundTrip(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer conn.Close() + + buffer := make([]byte, 8) + n, readErr := conn.Read(buffer) + if readErr != nil { + return + } + if string(buffer[:n]) != "ping" { + return + } + _, _ = conn.Write([]byte("pong")) + }() + + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]net.Conn{}, + sessions: map[string]*sessionState{}, + } + defer server.closeAll() + + port := listener.Addr().(*net.TCPAddr).Port + openResp := server.handleRequest(rpcRequest{ + ID: 1, + Method: "proxy.open", + Params: map[string]any{ + "host": "127.0.0.1", + "port": port, + "timeout_ms": 1000, + }, + }) + if !openResp.OK { + t.Fatalf("proxy.open failed: %+v", openResp) + } + openResult, _ := openResp.Result.(map[string]any) + streamID, _ := openResult["stream_id"].(string) + if streamID == "" { + t.Fatalf("proxy.open missing stream_id: %+v", openResp) + } + + writeResp := server.handleRequest(rpcRequest{ + ID: 2, + Method: "proxy.write", + Params: map[string]any{ + "stream_id": streamID, + "data_base64": base64.StdEncoding.EncodeToString([]byte("ping")), + }, + }) + if !writeResp.OK { + t.Fatalf("proxy.write failed: %+v", writeResp) + } + + readResp := server.handleRequest(rpcRequest{ + ID: 3, + Method: "proxy.read", + Params: map[string]any{ + "stream_id": streamID, + "max_bytes": 8, + "timeout_ms": 1000, + }, + }) + if !readResp.OK { + t.Fatalf("proxy.read failed: %+v", readResp) + } + readResult, _ := readResp.Result.(map[string]any) + dataBase64, _ := readResult["data_base64"].(string) + data, decodeErr := base64.StdEncoding.DecodeString(dataBase64) + if decodeErr != nil { + t.Fatalf("proxy.read returned invalid base64: %v", decodeErr) + } + if string(data) != "pong" { + t.Fatalf("proxy.read payload=%q, want %q", string(data), "pong") + } + + closeResp := server.handleRequest(rpcRequest{ + ID: 4, + Method: "proxy.close", + Params: map[string]any{ + "stream_id": streamID, + }, + }) + if !closeResp.OK { + t.Fatalf("proxy.close failed: %+v", closeResp) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("proxy test server goroutine did not finish") + } +} + +func TestProxyOpenInvalidParams(t *testing.T) { + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]net.Conn{}, + sessions: map[string]*sessionState{}, + } + defer server.closeAll() + + resp := server.handleRequest(rpcRequest{ + ID: 1, + Method: "proxy.open", + Params: map[string]any{ + "host": "127.0.0.1", + "port": strconv.Itoa(8080), + }, + }) + if resp.OK { + t.Fatalf("proxy.open with invalid port type should fail: %+v", resp) + } + errObj, _ := resp.Error, resp.Error + if errObj == nil || errObj.Code != "invalid_params" { + t.Fatalf("proxy.open invalid params should return invalid_params: %+v", resp) + } +} + +func TestSessionResizeCoordinator(t *testing.T) { + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]net.Conn{}, + sessions: map[string]*sessionState{}, + } + defer server.closeAll() + + openResp := server.handleRequest(rpcRequest{ + ID: 1, + Method: "session.open", + Params: map[string]any{ + "session_id": "sess-rz", + }, + }) + if !openResp.OK { + t.Fatalf("session.open failed: %+v", openResp) + } + + attachSmall := server.handleRequest(rpcRequest{ + ID: 2, + Method: "session.attach", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-small", + "cols": 90, + "rows": 30, + }, + }) + assertEffectiveSize(t, attachSmall, 90, 30) + + attachLarge := server.handleRequest(rpcRequest{ + ID: 3, + Method: "session.attach", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-large", + "cols": 120, + "rows": 40, + }, + }) + assertEffectiveSize(t, attachLarge, 90, 30) // RZ-001: smallest wins + + resizeLarge := server.handleRequest(rpcRequest{ + ID: 4, + Method: "session.resize", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-large", + "cols": 200, + "rows": 60, + }, + }) + assertEffectiveSize(t, resizeLarge, 90, 30) // RZ-002: still bounded by smallest + + detachSmall := server.handleRequest(rpcRequest{ + ID: 5, + Method: "session.detach", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-small", + }, + }) + assertEffectiveSize(t, detachSmall, 200, 60) // RZ-003: expands to next smallest + + detachLarge := server.handleRequest(rpcRequest{ + ID: 6, + Method: "session.detach", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-large", + }, + }) + assertEffectiveSize(t, detachLarge, 200, 60) // no attachments: keep last-known size + assertAttachmentCount(t, detachLarge, 0) + + reattach := server.handleRequest(rpcRequest{ + ID: 7, + Method: "session.attach", + Params: map[string]any{ + "session_id": "sess-rz", + "attachment_id": "a-reconnect", + "cols": 110, + "rows": 50, + }, + }) + assertEffectiveSize(t, reattach, 110, 50) // RZ-004: recompute from active attachments on reattach +} + +func TestSessionInvalidParamsAndNotFound(t *testing.T) { + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]net.Conn{}, + sessions: map[string]*sessionState{}, + } + defer server.closeAll() + + missingSession := server.handleRequest(rpcRequest{ + ID: 1, + Method: "session.attach", + Params: map[string]any{ + "session_id": "missing", + "attachment_id": "a1", + "cols": 80, + "rows": 24, + }, + }) + if missingSession.OK || missingSession.Error == nil || missingSession.Error.Code != "not_found" { + t.Fatalf("session.attach on missing session should return not_found: %+v", missingSession) + } + + badSize := server.handleRequest(rpcRequest{ + ID: 2, + Method: "session.attach", + Params: map[string]any{ + "session_id": "missing", + "attachment_id": "a1", + "cols": 0, + "rows": 24, + }, + }) + if badSize.OK || badSize.Error == nil || badSize.Error.Code != "invalid_params" { + t.Fatalf("session.attach with cols=0 should return invalid_params: %+v", badSize) + } +} + +func assertEffectiveSize(t *testing.T, resp rpcResponse, wantCols, wantRows int) { + t.Helper() + if !resp.OK { + t.Fatalf("expected ok response, got error: %+v", resp) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("response missing result map: %+v", resp) + } + gotCols := asInt(t, result["effective_cols"], "effective_cols") + gotRows := asInt(t, result["effective_rows"], "effective_rows") + if gotCols != wantCols || gotRows != wantRows { + t.Fatalf("effective size = %dx%d, want %dx%d payload=%+v", gotCols, gotRows, wantCols, wantRows, result) + } +} + +func assertAttachmentCount(t *testing.T, resp rpcResponse, want int) { + t.Helper() + if !resp.OK { + t.Fatalf("expected ok response, got error: %+v", resp) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("response missing result map: %+v", resp) + } + attachments, ok := result["attachments"].([]map[string]any) + if ok { + if len(attachments) != want { + t.Fatalf("attachments len = %d, want %d payload=%+v", len(attachments), want, result) + } + return + } + attachmentsAny, ok := result["attachments"].([]any) + if !ok { + t.Fatalf("attachments field has unexpected type (%T) payload=%+v", result["attachments"], result) + } + if len(attachmentsAny) != want { + t.Fatalf("attachments len = %d, want %d payload=%+v", len(attachmentsAny), want, result) + } +} + +func asInt(t *testing.T, value any, field string) int { + t.Helper() + switch typed := value.(type) { + case int: + return typed + case int8: + return int(typed) + case int16: + return int(typed) + case int32: + return int(typed) + case int64: + return int(typed) + case uint: + return int(typed) + case uint8: + return int(typed) + case uint16: + return int(typed) + case uint32: + return int(typed) + case uint64: + return int(typed) + case float64: + return int(typed) + default: + t.Fatalf("%s has unexpected type %T (%v)", field, value, value) + return 0 + } +} diff --git a/docs/remote-daemon-spec.md b/docs/remote-daemon-spec.md index 7b3606a1..6010f794 100644 --- a/docs/remote-daemon-spec.md +++ b/docs/remote-daemon-spec.md @@ -31,35 +31,43 @@ This is a **living implementation spec** (also called an **execution spec**): a ### 3.2 Bootstrap + Daemon - `DONE` local app probes remote platform, builds/uploads `cmuxd-remote`, and runs `serve --stdio`. - `DONE` daemon `hello` handshake is enforced. -- `DONE` bootstrap/probe failures surface actionable details. +- `DONE` daemon now exposes proxy stream RPC (`proxy.open`, `proxy.close`, `proxy.write`, `proxy.read`). +- `DONE` local proxy broker now tunnels SOCKS5/CONNECT traffic over daemon stream RPC instead of `ssh -D`. +- `DONE` daemon now exposes session resize-coordinator RPC (`session.open`, `session.attach`, `session.resize`, `session.detach`, `session.status`, `session.close`). +- `DONE` transport-level proxy failures now escalate from broker retry to full daemon re-bootstrap/reconnect in the session controller. +- `DONE` SOCKS handshake parsing now preserves pipelined post-connect payload bytes instead of dropping request-prefix bytes. +- `DONE` `workspace.remote.configure.local_proxy_port` exists as an internal deterministic test hook for bind-conflict regression coverage. +- `DONE` bootstrap/proxy failures surface actionable details. ### 3.3 Error Surfacing - `DONE` remote errors are surfaced in sidebar status + logs + notifications. - `DONE` reconnect retry count/time is included in surfaced error text (for example, `retry 1 in 4s`). -### 3.4 Existing Temporary Behavior (To Remove) -- `TEMPORARY` current implementation probes remote listening ports and mirrors them locally with SSH `-L`. -- `TEMPORARY` sidebar shows local bind conflicts (`SSH port conflicts ...`) caused by that mirroring path. -- `TARGET` browser path must no longer depend on per-port mirroring. +### 3.4 Removed Temporary Behavior +- `DONE` removed remote listening-port probe loop and per-port SSH `-L` mirroring. +- `DONE` remote browser routing now uses a single shared local proxy endpoint instead of detected-port mirroring. +- `DONE` remote status now includes structured proxy metadata (`remote.proxy`) and `proxy_unavailable` error code when proxy setup fails. ## 4. Target Architecture (No Port Mirroring) ### 4.1 Browser Networking Path -1. One local proxy endpoint per SSH transport (not per workspace, not per detected port). -2. Proxy endpoint supports SOCKS5 and HTTP CONNECT. -3. Browser panels in remote workspaces are auto-wired to this proxy endpoint. -4. Browser panels in local workspaces are not force-proxied. +1. `DONE` one local proxy endpoint is created per SSH transport/session key (not per detected port). +2. `DONE` endpoint is provided by a local broker that supports SOCKS5 + HTTP CONNECT and tunnels via daemon stream RPC. +3. `DONE` browser panels in remote workspaces are auto-wired to the workspace proxy endpoint. +4. `DONE` browser panels in local workspaces are not force-proxied. +5. `DONE` identical SSH transports share one endpoint via a transport-scoped broker. ### 4.2 WKWebView Wiring -1. Use workspace/browser scoped `WKWebsiteDataStore.proxyConfigurations`. -2. Prefer SOCKS5 proxy config. -3. Keep HTTP CONNECT proxy config as fallback. -4. Re-apply/validate proxy config after reconnect. +1. `DONE` use workspace-scoped `WKWebsiteDataStore(forIdentifier:)`. +2. `DONE` apply workspace/browser scoped `proxyConfigurations`. +3. `DONE` prefer SOCKS5 proxy config. +4. `DONE` keep HTTP CONNECT proxy config as fallback. +5. `DONE` re-apply proxy config on reconnect/state updates. ### 4.3 Remote Daemon + Transport -1. Extend `cmuxd-remote` beyond `hello/ping` with proxy stream RPC (`proxy.open`, `proxy.close`). -2. Local side runs a transport-scoped proxy broker and multiplexes proxy streams over SSH stdio transport. -3. Remove remote service-port discovery/probing from browser routing path. +1. `DONE` `cmuxd-remote` now supports proxy stream RPC (`proxy.open`, `proxy.close`, `proxy.write`, `proxy.read`). +2. `DONE` local side now runs a shared local broker that serves SOCKS5/CONNECT and tunnels each stream over persistent daemon stdio RPC. +3. `DONE` removed remote service-port discovery/probing from browser routing path. ### 4.4 Explicit Non-Goal 1. Automatic mirroring of every remote listening port to local loopback is not a goal for browser support. @@ -96,15 +104,15 @@ Recompute effective size on: | ID | Milestone | Status | Notes | |---|---|---|---| | M-001 | `cmux ssh` workspace creation + metadata + optional `--name` | DONE | Covered by `tests_v2/test_ssh_remote_cli_metadata.py` | -| M-002 | Remote bootstrap/upload/start + hello handshake | DONE | Current `cmuxd-remote` is minimal (`hello`, `ping`) | +| M-002 | Remote bootstrap/upload/start + hello handshake | DONE | Includes daemon capability handshake + status surfacing | | M-003 | Reconnect/disconnect UX + API + improved error surfacing | DONE | Includes retry count in surfaced errors | -| M-004 | Docker e2e for bootstrap/reconnect shell niceties | DONE | Existing docker tests currently validate mirroring-era path | -| M-005 | Remove automatic remote port mirroring path | TODO | Delete probe/listen mirror loop from `WorkspaceRemoteSessionController` | -| M-006 | Transport-scoped local proxy broker (SOCKS5 + CONNECT) | TODO | Local component in app/daemon layer | -| M-007 | Remote proxy stream RPC in `cmuxd-remote` | TODO | Add `proxy.open/close` and multiplexed stream handling | -| M-008 | WebView proxy auto-wiring for remote workspaces | TODO | Use `WKWebsiteDataStore.proxyConfigurations` | -| M-009 | PTY resize coordinator (`smallest screen wins`) | TODO | Session-level attachment-size aggregation | -| M-010 | Resize + proxy reconnect e2e test suites | TODO | Add dedicated docker cases for browser proxy + resize | +| M-004 | Docker e2e for bootstrap/reconnect shell niceties | DONE | Docker suites validate proxy-path bootstrap and reconnect behavior | +| M-005 | Remove automatic remote port mirroring path | DONE | `WorkspaceRemoteSessionController` now uses one shared daemon-backed proxy endpoint | +| M-006 | Transport-scoped local proxy broker (SOCKS5 + CONNECT) | DONE | Identical SSH transports now reuse one local proxy endpoint | +| M-007 | Remote proxy stream RPC in `cmuxd-remote` | DONE | `proxy.open/close/write/read` implemented | +| M-008 | WebView proxy auto-wiring for remote workspaces | DONE | Workspace-scoped `WKWebsiteDataStore.proxyConfigurations` wiring is active | +| M-009 | PTY resize coordinator (`smallest screen wins`) | DONE | Daemon session RPC now tracks attachments and applies min cols/rows semantics with unit tests | +| M-010 | Resize + proxy reconnect e2e test suites | DONE | `tests_v2/test_ssh_remote_docker_forwarding.py` validates HTTP/websocket egress plus SOCKS pipelined-payload handling; `tests_v2/test_ssh_remote_docker_reconnect.py` verifies reconnect recovery and repeats SOCKS pipelined-payload checks after host restart; `tests_v2/test_ssh_remote_proxy_bind_conflict.py` validates structured `proxy_unavailable` bind-conflict surfacing and `local_proxy_port` status retention under bind conflict; `tests_v2/test_ssh_remote_daemon_resize_stdio.py` validates session resize semantics over real stdio RPC process boundaries; `tests_v2/test_ssh_remote_cli_metadata.py` validates `workspace.remote.configure` numeric-string compatibility, explicit `null` clear semantics (including `workspace.remote.status` reflection), strict `port`/`local_proxy_port` validation (bounds/type), case-insensitive SSH option override precedence for StrictHostKeyChecking/control-socket keys, and `local_proxy_port` payload echo for deterministic bind-conflict test hook behavior | ## 7. Acceptance Test Matrix (With Status) @@ -113,7 +121,7 @@ Recompute effective size on: | ID | Scenario | Status | |---|---|---| | T-001 | baseline remote connect | DONE | -| T-002 | identical host reuse semantics | PARTIAL | +| T-002 | identical host reuse semantics | DONE | | T-003 | no `--name` | DONE | | T-004 | reconnect API success/error paths | DONE | | T-005 | retry count visible in daemon error detail | DONE | @@ -122,31 +130,51 @@ Recompute effective size on: | ID | Scenario | Status | |---|---|---| -| W-001 | remote workspace browser auto-proxied | TODO | -| W-002 | browser egress IP equals remote host IP | TODO | -| W-003 | websocket via SOCKS5/CONNECT through remote daemon | TODO | -| W-004 | reconnect restores browser proxy path automatically | TODO | -| W-005 | local proxy bind conflict yields structured `proxy_unavailable` | TODO | +| W-001 | remote workspace browser auto-proxied | DONE | +| W-002 | browser egress equals remote network path | DONE | +| W-003 | websocket via SOCKS5/CONNECT through remote daemon | DONE | +| W-004 | reconnect restores browser proxy path automatically | DONE | +| W-005 | local proxy bind conflict yields structured `proxy_unavailable` | DONE | +| W-006 | proxy transport failure triggers daemon re-bootstrap and recovers after host recreation | DONE | +| W-007 | SOCKS greeting/connect + immediate pipelined payload in same write remains intact | DONE | ### 7.3 Resize | ID | Scenario | Status | |---|---|---| -| RZ-001 | two attachments, smallest wins | TODO | -| RZ-002 | grow one attachment, PTY stays bounded by smallest | TODO | -| RZ-003 | detach smallest, PTY expands to next smallest | TODO | -| RZ-004 | reconnect preserves session + applies recomputed size | TODO | +| RZ-001 | two attachments, smallest wins | DONE | +| RZ-002 | grow one attachment, PTY stays bounded by smallest | DONE | +| RZ-003 | detach smallest, PTY expands to next smallest | DONE | +| RZ-004 | reconnect preserves session + applies recomputed size | DONE | +| RZ-005 | daemon stdio RPC round-trip enforces resize semantics end-to-end | DONE | ## 8. Removal Checklist (Port Mirroring) Before declaring browser proxying complete: -1. remove remote port probe loop and `-L` auto-forward orchestration -2. remove mirror-specific sidebar conflict messaging as default remote behavior -3. replace mirroring tests with browser-proxy e2e tests -4. keep optional explicit user-driven forwarding as separate feature only if needed +1. `DONE` remove remote port probe loop and `-L` auto-forward orchestration +2. `DONE` remove mirror-specific routing behavior as default remote behavior +3. `DONE` replace mirroring docker assertions with proxy egress assertions +4. `DONE` keep optional explicit user-driven forwarding out of this path; no automatic mirroring remains in browser routing ## 9. Open Decisions 1. Proxy auth policy for local broker (`none` vs optional credentials). 2. Reconnect backoff profile and max retry budget. -3. Browser data-store isolation policy for remote vs local workspaces. + +## 10. Socket API Contract Notes + +### 10.1 `workspace.remote.configure` Port Fields +1. `port` and `local_proxy_port` accept integer values and numeric strings. +2. Explicit `null` clears each field. +3. Out-of-range values and invalid types (for example booleans/non-numeric strings/fractional numbers) return `invalid_params`. +4. `local_proxy_port` is an internal deterministic test hook to force local bind conflicts in regression coverage. + +### 10.2 SSH Option Precedence +1. `StrictHostKeyChecking` default (`accept-new`) is only injected when no user override is present. +2. Control-socket defaults (`ControlMaster`, `ControlPersist`, `ControlPath`) are only injected when missing. +3. SSH option key matching is case-insensitive for precedence checks in both CLI-built commands and remote configure payloads. + +### 10.3 SSH Docker E2E Harness Knobs +1. `CMUX_SSH_TEST_DOCKER_HOST` sets the SSH destination host/IP used by docker-backed SSH fixtures (default `127.0.0.1`). +2. `CMUX_SSH_TEST_DOCKER_BIND_ADDR` sets the bind address used in fixture container publish mappings (default `127.0.0.1`). +3. Defaults preserve loopback behavior on a single host; override both when docker runs on a different host (for example VM -> host OrbStack). diff --git a/scripts/reload.sh b/scripts/reload.sh index 3cd2bb63..4492c954 100755 --- a/scripts/reload.sh +++ b/scripts/reload.sh @@ -10,6 +10,84 @@ BUNDLE_SET=0 DERIVED_SET=0 TAG="" CMUX_DEBUG_LOG="" +CLI_PATH="" + +write_dev_cli_shim() { + local target="$1" + local fallback_bin="$2" + mkdir -p "$(dirname "$target")" + cat > "$target" <<EOF +#!/usr/bin/env bash +# cmux dev shim (managed by scripts/reload.sh) +set -euo pipefail + +CLI_PATH_FILE="/tmp/cmux-last-cli-path" +if [[ -r "\$CLI_PATH_FILE" ]]; then + CLI_PATH="\$(cat "\$CLI_PATH_FILE")" + if [[ -x "\$CLI_PATH" ]]; then + exec "\$CLI_PATH" "\$@" + fi +fi + +if [[ -x "$fallback_bin" ]]; then + exec "$fallback_bin" "\$@" +fi + +echo "error: no reload-selected dev cmux CLI found. Run ./scripts/reload.sh --tag <name> first." >&2 +exit 1 +EOF + chmod +x "$target" || true +} + +select_cmux_shim_target() { + local app_cli_dir="/Applications/cmux.app/Contents/Resources/bin" + local marker="cmux dev shim (managed by scripts/reload.sh)" + local target="" + local path_entry="" + local candidate="" + + IFS=':' read -r -a path_entries <<< "${PATH:-}" + for path_entry in "${path_entries[@]}"; do + [[ -z "$path_entry" ]] && continue + if [[ "$path_entry" == "~/"* ]]; then + path_entry="$HOME/${path_entry#~/}" + fi + if [[ "$path_entry" == "$app_cli_dir" ]]; then + break + fi + [[ -d "$path_entry" && -w "$path_entry" ]] || continue + candidate="$path_entry/cmux" + if [[ ! -e "$candidate" ]]; then + target="$candidate" + break + fi + if [[ -f "$candidate" ]] && grep -q "$marker" "$candidate" 2>/dev/null; then + target="$candidate" + break + fi + done + + if [[ -n "$target" ]]; then + echo "$target" + return 0 + fi + + # Fallback for PATH layouts where app CLI isn't listed or no earlier entries were writable. + for path_entry in /opt/homebrew/bin /usr/local/bin "$HOME/.local/bin" "$HOME/bin"; do + [[ -d "$path_entry" && -w "$path_entry" ]] || continue + candidate="$path_entry/cmux" + if [[ ! -e "$candidate" ]]; then + echo "$candidate" + return 0 + fi + if [[ -f "$candidate" ]] && grep -q "$marker" "$candidate" 2>/dev/null; then + echo "$candidate" + return 0 + fi + done + + return 1 +} usage() { cat <<'EOF' @@ -271,6 +349,21 @@ if [[ -n "$TAG" && "$APP_NAME" != "$SEARCH_APP_NAME" ]]; then APP_PATH="$TAG_APP_PATH" fi +CLI_PATH="$(dirname "$APP_PATH")/cmux" +if [[ -x "$CLI_PATH" ]]; then + echo "$CLI_PATH" > /tmp/cmux-last-cli-path || true + ln -sfn "$CLI_PATH" /tmp/cmux-cli || true + + # Stable shim that always follows the last reload-selected dev CLI. + DEV_CLI_SHIM="$HOME/.local/bin/cmux-dev" + write_dev_cli_shim "$DEV_CLI_SHIM" "/Applications/cmux.app/Contents/Resources/bin/cmux" + + CMUX_SHIM_TARGET="$(select_cmux_shim_target || true)" + if [[ -n "${CMUX_SHIM_TARGET:-}" ]]; then + write_dev_cli_shim "$CMUX_SHIM_TARGET" "/Applications/cmux.app/Contents/Resources/bin/cmux" + fi +fi + # Ensure any running instance is fully terminated, regardless of DerivedData path. /usr/bin/osascript -e "tell application id \"${BUNDLE_ID}\" to quit" >/dev/null 2>&1 || true sleep 0.3 @@ -350,3 +443,16 @@ fi if [[ -n "${TAG_SLUG:-}" ]]; then print_tag_cleanup_reminder "$TAG_SLUG" fi + +if [[ -x "${CLI_PATH:-}" ]]; then + echo + echo "CLI path:" + echo " $CLI_PATH" + echo "CLI helpers:" + echo " /tmp/cmux-cli ..." + echo " $HOME/.local/bin/cmux-dev ..." + if [[ -n "${CMUX_SHIM_TARGET:-}" ]]; then + echo " $CMUX_SHIM_TARGET ..." + fi + echo "If your shell still resolves the old cmux, run: rehash" +fi diff --git a/tests/fixtures/ssh-remote/Dockerfile b/tests/fixtures/ssh-remote/Dockerfile index d86fcd04..470986d8 100644 --- a/tests/fixtures/ssh-remote/Dockerfile +++ b/tests/fixtures/ssh-remote/Dockerfile @@ -12,6 +12,7 @@ RUN ssh-keygen -A COPY sshd_config /etc/ssh/sshd_config COPY run.sh /usr/local/bin/run.sh +COPY ws_echo.py /usr/local/bin/ws_echo.py RUN chmod +x /usr/local/bin/run.sh EXPOSE 22 diff --git a/tests/fixtures/ssh-remote/run.sh b/tests/fixtures/ssh-remote/run.sh index 93b8eba7..59251875 100644 --- a/tests/fixtures/ssh-remote/run.sh +++ b/tests/fixtures/ssh-remote/run.sh @@ -7,6 +7,7 @@ if [ -z "${AUTHORIZED_KEY:-}" ]; then fi REMOTE_HTTP_PORT="${REMOTE_HTTP_PORT:-43173}" +REMOTE_WS_PORT="${REMOTE_WS_PORT:-43174}" mkdir -p /home/dev/.ssh /root/.ssh /run/sshd printf '%s\n' "$AUTHORIZED_KEY" > /home/dev/.ssh/authorized_keys @@ -18,5 +19,6 @@ chmod 700 /root/.ssh chmod 600 /root/.ssh/authorized_keys python3 -m http.server "$REMOTE_HTTP_PORT" --bind 127.0.0.1 --directory /srv/www >/tmp/http.log 2>&1 & +python3 /usr/local/bin/ws_echo.py --host 127.0.0.1 --port "$REMOTE_WS_PORT" >/tmp/ws.log 2>&1 & exec /usr/sbin/sshd -D -e diff --git a/tests/fixtures/ssh-remote/ws_echo.py b/tests/fixtures/ssh-remote/ws_echo.py new file mode 100644 index 00000000..4acb8935 --- /dev/null +++ b/tests/fixtures/ssh-remote/ws_echo.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Tiny WebSocket echo server for SSH proxy integration tests.""" + +from __future__ import annotations + +import argparse +import base64 +import hashlib +import socket +import struct +import threading + + +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def _recv_exact(conn: socket.socket, n: int) -> bytes: + data = bytearray() + while len(data) < n: + chunk = conn.recv(n - len(data)) + if not chunk: + raise ConnectionError("unexpected EOF") + data.extend(chunk) + return bytes(data) + + +def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes: + data = bytearray() + while marker not in data: + chunk = conn.recv(1024) + if not chunk: + raise ConnectionError("unexpected EOF while reading headers") + data.extend(chunk) + if len(data) > limit: + raise ValueError("header too large") + return bytes(data) + + +def _read_frame(conn: socket.socket) -> tuple[int, bytes]: + first, second = _recv_exact(conn, 2) + opcode = first & 0x0F + masked = (second & 0x80) != 0 + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", _recv_exact(conn, 2))[0] + elif length == 127: + length = struct.unpack("!Q", _recv_exact(conn, 8))[0] + + mask_key = _recv_exact(conn, 4) if masked else b"" + payload = _recv_exact(conn, length) if length else b"" + if masked and payload: + payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) + return opcode, payload + + +def _send_frame(conn: socket.socket, opcode: int, payload: bytes) -> None: + first = 0x80 | (opcode & 0x0F) + length = len(payload) + if length < 126: + header = bytes([first, length]) + elif length <= 0xFFFF: + header = bytes([first, 126]) + struct.pack("!H", length) + else: + header = bytes([first, 127]) + struct.pack("!Q", length) + conn.sendall(header + payload) + + +def handle_client(conn: socket.socket) -> None: + try: + request = _recv_until(conn, b"\r\n\r\n") + headers_raw = request.decode("utf-8", errors="replace").split("\r\n") + header_map: dict[str, str] = {} + for line in headers_raw[1:]: + if not line or ":" not in line: + continue + k, v = line.split(":", 1) + header_map[k.strip().lower()] = v.strip() + + key = header_map.get("sec-websocket-key", "") + upgrade = header_map.get("upgrade", "").lower() + connection_hdr = header_map.get("connection", "").lower() + if not key or upgrade != "websocket" or "upgrade" not in connection_hdr: + conn.sendall(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n") + return + + accept = base64.b64encode(hashlib.sha1((key + GUID).encode("utf-8")).digest()).decode("ascii") + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {accept}\r\n" + "\r\n" + ) + conn.sendall(response.encode("utf-8")) + + while True: + opcode, payload = _read_frame(conn) + if opcode == 0x8: # close + _send_frame(conn, 0x8, b"") + return + if opcode == 0x9: # ping + _send_frame(conn, 0xA, payload) + continue + if opcode == 0x1: # text + _send_frame(conn, 0x1, payload) + continue + # ignore all other opcodes + finally: + try: + conn.close() + except Exception: + pass + + +def main() -> int: + parser = argparse.ArgumentParser(description="WebSocket echo server") + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=43174) + args = parser.parse_args() + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((args.host, args.port)) + server.listen(16) + while True: + conn, _ = server.accept() + thread = threading.Thread(target=handle_client, args=(conn,), daemon=True) + thread.start() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests_v2/test_ssh_remote_cli_metadata.py b/tests_v2/test_ssh_remote_cli_metadata.py index c540ff62..784781fd 100644 --- a/tests_v2/test_ssh_remote_cli_metadata.py +++ b/tests_v2/test_ssh_remote_cli_metadata.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Regression: `cmux ssh` creates a remote-tagged workspace with remote metadata.""" +from __future__ import annotations + import glob import json import os @@ -72,6 +74,34 @@ def _extract_control_path(ssh_command: str) -> str: return match.group(1) if match else "" +def _has_ssh_option_key(options: list[str], key: str) -> bool: + lowered_key = key.lower() + for option in options: + token = re.split(r"[=\s]+", str(option).strip(), maxsplit=1)[0].strip().lower() + if token == lowered_key: + return True + return False + + +def _read_any_terminal_text(client: cmux, workspace_id: str, timeout: float = 8.0) -> str | None: + deadline = time.time() + timeout + last_exc: Exception | None = None + while time.time() < deadline: + surfaces = client.list_surfaces(workspace_id) + for _, surface_id, _ in surfaces: + try: + return client.read_terminal_text(surface_id) + except cmuxError as exc: + text = str(exc).lower() + if "terminal surface not found" in text: + last_exc = exc + continue + raise + time.sleep(0.1) + print(f"WARN: readable terminal surface unavailable in workspace {workspace_id}; skipping transcript assertion ({last_exc})") + return None + + def main() -> int: cli = _find_cli_binary() help_text = _run_cli(cli, ["ssh", "--help"], json_output=False) @@ -80,6 +110,9 @@ def main() -> int: workspace_id = "" workspace_id_without_name = "" + workspace_id_strict_override = "" + workspace_id_case_override = "" + workspace_id_invalid_proxy_port = "" with cmux(SOCKET_PATH) as client: try: payload = _run_cli_json( @@ -138,13 +171,29 @@ def main() -> int: str(remote.get("state") or "") in {"connecting", "connected", "error", "disconnected"}, f"unexpected remote state: {remote}", ) - surfaces = client.list_surfaces(workspace_id) - _must(bool(surfaces), f"workspace should have at least one surface: {workspace_id}") - primary_surface = surfaces[0][1] + proxy = remote.get("proxy") or {} + _must( + str(proxy.get("state") or "") in {"connecting", "ready", "error", "unavailable"}, + f"remote payload should include proxy state metadata: {remote}", + ) + remote_ssh_options = [str(item) for item in (remote.get("ssh_options") or [])] + _must( + _has_ssh_option_key(remote_ssh_options, "ControlMaster"), + f"workspace.remote.configure should include ControlMaster default: {remote}", + ) + _must( + _has_ssh_option_key(remote_ssh_options, "ControlPersist"), + f"workspace.remote.configure should include ControlPersist default: {remote}", + ) + _must( + _has_ssh_option_key(remote_ssh_options, "ControlPath"), + f"workspace.remote.configure should include ControlPath default: {remote}", + ) # Regression: cmux ssh should launch through initial_command, not visibly type a giant command into the shell. - terminal_text = client.read_terminal_text(primary_surface) - _must("ControlPersist=600" not in terminal_text, f"cmux ssh should not inject raw ssh command text: {terminal_text!r}") - _must("GHOSTTY_SHELL_FEATURES=" not in terminal_text, f"cmux ssh should not inject env assignment text: {terminal_text!r}") + terminal_text = _read_any_terminal_text(client, workspace_id) + if terminal_text is not None: + _must("ControlPersist=600" not in terminal_text, f"cmux ssh should not inject raw ssh command text: {terminal_text!r}") + _must("GHOSTTY_SHELL_FEATURES=" not in terminal_text, f"cmux ssh should not inject env assignment text: {terminal_text!r}") status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} status_remote = status.get("remote") or {} @@ -231,6 +280,147 @@ def main() -> int: f"workspace.remote.reconnect should transition into an active state: {reconnected}", ) + payload_strict_override = _run_cli_json( + cli, + [ + "ssh", + "127.0.0.1", + "--port", + "1", + "--name", + "ssh-meta-strict-override", + "--ssh-option", + "StrictHostKeyChecking=no", + ], + ) + workspace_id_strict_override = str(payload_strict_override.get("workspace_id") or "") + workspace_ref_strict_override = str(payload_strict_override.get("workspace_ref") or "") + if not workspace_id_strict_override and workspace_ref_strict_override.startswith("workspace:"): + listed_override = client._call("workspace.list", {}) or {} + for row in listed_override.get("workspaces") or []: + if str(row.get("ref") or "") == workspace_ref_strict_override: + workspace_id_strict_override = str(row.get("id") or "") + break + _must( + bool(workspace_id_strict_override), + f"cmux ssh with StrictHostKeyChecking override should create workspace: {payload_strict_override}", + ) + ssh_command_strict_override = str(payload_strict_override.get("ssh_command") or "") + _must( + "-o StrictHostKeyChecking=no" in ssh_command_strict_override, + f"ssh command should include user StrictHostKeyChecking override: {ssh_command_strict_override!r}", + ) + _must( + "-o StrictHostKeyChecking=accept-new" not in ssh_command_strict_override, + f"ssh command should not force default StrictHostKeyChecking when override is supplied: {ssh_command_strict_override!r}", + ) + strict_override_remote = payload_strict_override.get("remote") or {} + strict_override_options = [str(item) for item in (strict_override_remote.get("ssh_options") or [])] + _must( + any(item.lower() == "stricthostkeychecking=no" for item in strict_override_options), + f"workspace.remote.configure should preserve explicit StrictHostKeyChecking override: {strict_override_remote}", + ) + + payload_case_override = _run_cli_json( + cli, + [ + "ssh", + "127.0.0.1", + "--port", + "1", + "--name", + "ssh-meta-case-override", + "--ssh-option", + "stricthostkeychecking=no", + "--ssh-option", + "controlmaster=no", + "--ssh-option", + "controlpersist=0", + "--ssh-option", + "controlpath=/tmp/cmux-ssh-%C-custom", + ], + ) + workspace_id_case_override = str(payload_case_override.get("workspace_id") or "") + workspace_ref_case_override = str(payload_case_override.get("workspace_ref") or "") + if not workspace_id_case_override and workspace_ref_case_override.startswith("workspace:"): + listed_case_override = client._call("workspace.list", {}) or {} + for row in listed_case_override.get("workspaces") or []: + if str(row.get("ref") or "") == workspace_ref_case_override: + workspace_id_case_override = str(row.get("id") or "") + break + _must( + bool(workspace_id_case_override), + f"cmux ssh with lowercase SSH option overrides should create workspace: {payload_case_override}", + ) + ssh_command_case_override = str(payload_case_override.get("ssh_command") or "") + ssh_command_case_override_lower = ssh_command_case_override.lower() + _must( + "-o stricthostkeychecking=no" in ssh_command_case_override_lower, + f"ssh command should preserve lowercase StrictHostKeyChecking override: {ssh_command_case_override!r}", + ) + _must( + "stricthostkeychecking=accept-new" not in ssh_command_case_override_lower, + f"ssh command should not force default StrictHostKeyChecking when lowercase override is supplied: {ssh_command_case_override!r}", + ) + _must( + "-o controlmaster=no" in ssh_command_case_override_lower, + f"ssh command should preserve lowercase ControlMaster override: {ssh_command_case_override!r}", + ) + _must( + "controlmaster=auto" not in ssh_command_case_override_lower, + f"ssh command should not force default ControlMaster when lowercase override is supplied: {ssh_command_case_override!r}", + ) + _must( + "-o controlpersist=0" in ssh_command_case_override_lower, + f"ssh command should preserve lowercase ControlPersist override: {ssh_command_case_override!r}", + ) + _must( + "controlpersist=600" not in ssh_command_case_override_lower, + f"ssh command should not force default ControlPersist when lowercase override is supplied: {ssh_command_case_override!r}", + ) + _must( + "controlpath=/tmp/cmux-ssh-%c-custom" in ssh_command_case_override_lower, + f"ssh command should preserve lowercase ControlPath override value: {ssh_command_case_override!r}", + ) + _must( + ssh_command_case_override_lower.count("controlpath=") == 1, + f"ssh command should include exactly one ControlPath when lowercase override is supplied: {ssh_command_case_override!r}", + ) + case_override_remote = payload_case_override.get("remote") or {} + case_override_options = [str(item) for item in (case_override_remote.get("ssh_options") or [])] + _must( + any(item.lower() == "stricthostkeychecking=no" for item in case_override_options), + f"workspace.remote.configure should preserve lowercase StrictHostKeyChecking override: {case_override_remote}", + ) + _must( + not any(item.lower() == "stricthostkeychecking=accept-new" for item in case_override_options), + f"workspace.remote.configure should not inject default StrictHostKeyChecking when lowercase override is supplied: {case_override_remote}", + ) + _must( + any(item.lower() == "controlmaster=no" for item in case_override_options), + f"workspace.remote.configure should preserve lowercase ControlMaster override: {case_override_remote}", + ) + _must( + not any(item.lower() == "controlmaster=auto" for item in case_override_options), + f"workspace.remote.configure should not inject default ControlMaster when lowercase override is supplied: {case_override_remote}", + ) + _must( + any(item.lower() == "controlpersist=0" for item in case_override_options), + f"workspace.remote.configure should preserve lowercase ControlPersist override: {case_override_remote}", + ) + _must( + not any(item.lower() == "controlpersist=600" for item in case_override_options), + f"workspace.remote.configure should not inject default ControlPersist when lowercase override is supplied: {case_override_remote}", + ) + _must( + any(item.lower() == "controlpath=/tmp/cmux-ssh-%c-custom" for item in case_override_options), + f"workspace.remote.configure should preserve lowercase ControlPath override: {case_override_remote}", + ) + _must( + sum(1 for item in case_override_options if item.lower().startswith("controlpath=")) == 1, + f"workspace.remote.configure should include exactly one ControlPath when lowercase override is supplied: {case_override_remote}", + ) + payload3 = _run_cli_json( cli, ["ssh", "127.0.0.1", "--port", "1", "--name", "ssh-meta-features"], @@ -248,6 +438,142 @@ def main() -> int: client.close_workspace(workspace_id3) except Exception: pass + + invalid_proxy_port_workspace = client._call("workspace.create", {"initial_command": "echo invalid-local-proxy-port"}) or {} + workspace_id_invalid_proxy_port = str(invalid_proxy_port_workspace.get("workspace_id") or "") + _must(bool(workspace_id_invalid_proxy_port), f"workspace.create missing workspace_id: {invalid_proxy_port_workspace}") + + configured_with_string_ports = client._call( + "workspace.remote.configure", + { + "workspace_id": workspace_id_invalid_proxy_port, + "destination": "127.0.0.1", + "port": "2222", + "local_proxy_port": "31338", + "auto_connect": False, + }, + ) or {} + configured_with_string_ports_remote = configured_with_string_ports.get("remote") or {} + _must( + int(configured_with_string_ports_remote.get("port") or 0) == 2222, + f"workspace.remote.configure should parse numeric string port values: {configured_with_string_ports}", + ) + _must( + int(configured_with_string_ports_remote.get("local_proxy_port") or 0) == 31338, + f"workspace.remote.configure should parse numeric string local_proxy_port values: {configured_with_string_ports}", + ) + + valid_local_proxy_port = 31337 + configured_with_local_proxy_port = client._call( + "workspace.remote.configure", + { + "workspace_id": workspace_id_invalid_proxy_port, + "destination": "127.0.0.1", + "port": 2222, + "local_proxy_port": valid_local_proxy_port, + "auto_connect": False, + }, + ) or {} + configured_remote = configured_with_local_proxy_port.get("remote") or {} + _must( + int(configured_remote.get("port") or 0) == 2222, + f"workspace.remote.configure should echo explicit port in remote payload: {configured_with_local_proxy_port}", + ) + _must( + int(configured_remote.get("local_proxy_port") or 0) == valid_local_proxy_port, + f"workspace.remote.configure should echo local_proxy_port in remote payload: {configured_with_local_proxy_port}", + ) + + configured_with_null_ports = client._call( + "workspace.remote.configure", + { + "workspace_id": workspace_id_invalid_proxy_port, + "destination": "127.0.0.1", + "port": None, + "local_proxy_port": None, + "auto_connect": False, + }, + ) or {} + configured_with_null_ports_remote = configured_with_null_ports.get("remote") or {} + _must( + configured_with_null_ports_remote.get("port") is None, + f"workspace.remote.configure should allow null to clear port: {configured_with_null_ports}", + ) + _must( + configured_with_null_ports_remote.get("local_proxy_port") is None, + f"workspace.remote.configure should allow null to clear local_proxy_port: {configured_with_null_ports}", + ) + status_after_null_ports = client._call( + "workspace.remote.status", + {"workspace_id": workspace_id_invalid_proxy_port}, + ) or {} + status_after_null_ports_remote = status_after_null_ports.get("remote") or {} + _must( + status_after_null_ports_remote.get("port") is None, + f"workspace.remote.status should reflect cleared port: {status_after_null_ports}", + ) + _must( + status_after_null_ports_remote.get("local_proxy_port") is None, + f"workspace.remote.status should reflect cleared local_proxy_port: {status_after_null_ports}", + ) + + for invalid_local_proxy_port in [0, 65536, "abc", True, 22.5]: + try: + client._call( + "workspace.remote.configure", + { + "workspace_id": workspace_id_invalid_proxy_port, + "destination": "127.0.0.1", + "local_proxy_port": invalid_local_proxy_port, + "auto_connect": False, + }, + ) + raise cmuxError( + f"workspace.remote.configure should reject local_proxy_port={invalid_local_proxy_port!r}" + ) + except cmuxError as exc: + text = str(exc) + lowered = text.lower() + _must( + "invalid_params" in lowered, + f"workspace.remote.configure should return invalid_params for local_proxy_port={invalid_local_proxy_port!r}: {exc}", + ) + _must( + "local_proxy_port must be 1-65535" in text, + f"workspace.remote.configure should include validation hint for local_proxy_port={invalid_local_proxy_port!r}: {exc}", + ) + + for invalid_port in [0, 65536, "abc", True, 22.5]: + try: + client._call( + "workspace.remote.configure", + { + "workspace_id": workspace_id_invalid_proxy_port, + "destination": "127.0.0.1", + "port": invalid_port, + "auto_connect": False, + }, + ) + raise cmuxError( + f"workspace.remote.configure should reject port={invalid_port!r}" + ) + except cmuxError as exc: + text = str(exc) + lowered = text.lower() + _must( + "invalid_params" in lowered, + f"workspace.remote.configure should return invalid_params for port={invalid_port!r}: {exc}", + ) + _must( + "port must be 1-65535" in text, + f"workspace.remote.configure should include validation hint for port={invalid_port!r}: {exc}", + ) + + try: + client.close_workspace(workspace_id_invalid_proxy_port) + except Exception: + pass + workspace_id_invalid_proxy_port = "" finally: if workspace_id: try: @@ -259,6 +585,21 @@ def main() -> int: client.close_workspace(workspace_id_without_name) except Exception: pass + if workspace_id_strict_override: + try: + client.close_workspace(workspace_id_strict_override) + except Exception: + pass + if workspace_id_case_override: + try: + client.close_workspace(workspace_id_case_override) + except Exception: + pass + if workspace_id_invalid_proxy_port: + try: + client.close_workspace(workspace_id_invalid_proxy_port) + except Exception: + pass print("PASS: cmux ssh marks workspace as remote, exposes remote metadata, and does not require --name") return 0 diff --git a/tests_v2/test_ssh_remote_daemon_resize_stdio.py b/tests_v2/test_ssh_remote_daemon_resize_stdio.py new file mode 100644 index 00000000..d11cb845 --- /dev/null +++ b/tests_v2/test_ssh_remote_daemon_resize_stdio.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""Process-level integration: cmuxd-remote stdio session resize coordinator.""" + +from __future__ import annotations + +import json +import select +import shutil +import subprocess +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from cmux import cmuxError + + +def _must(cond: bool, msg: str) -> None: + if not cond: + raise cmuxError(msg) + + +def _daemon_module_dir() -> Path: + return Path(__file__).resolve().parents[1] / "daemon" / "remote" + + +def _rpc( + proc: subprocess.Popen[str], + req_id: int, + method: str, + params: dict, + *, + timeout_s: float = 5.0, +) -> dict: + if proc.stdin is None or proc.stdout is None: + raise cmuxError("daemon subprocess stdio pipes are not available") + + payload = {"id": req_id, "method": method, "params": params} + proc.stdin.write(json.dumps(payload, separators=(",", ":")) + "\n") + proc.stdin.flush() + + deadline = time.time() + timeout_s + while time.time() < deadline: + wait_s = max(0.0, min(0.2, deadline - time.time())) + ready, _, _ = select.select([proc.stdout], [], [], wait_s) + if not ready: + continue + line = proc.stdout.readline() + if line == "": + stderr = "" + if proc.stderr is not None: + try: + stderr = proc.stderr.read().strip() + except Exception: + stderr = "" + raise cmuxError(f"cmuxd-remote exited while waiting for {method} response: {stderr}") + try: + resp = json.loads(line) + except Exception as exc: # noqa: BLE001 + raise cmuxError(f"Invalid JSON response for {method}: {line!r} ({exc})") + _must(resp.get("id") == req_id, f"Response id mismatch for {method}: {resp}") + return resp + + raise cmuxError(f"Timed out waiting for cmuxd-remote response: {method}") + + +def _as_int(value: object, field: str) -> int: + if isinstance(value, bool): + raise cmuxError(f"{field} should be numeric, got bool") + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + raise cmuxError(f"{field} has unexpected type {type(value).__name__}: {value!r}") + + +def _assert_effective(resp: dict, want_cols: int, want_rows: int, label: str) -> None: + _must(resp.get("ok") is True, f"{label} should return ok=true: {resp}") + result = resp.get("result") or {} + got_cols = _as_int(result.get("effective_cols"), "effective_cols") + got_rows = _as_int(result.get("effective_rows"), "effective_rows") + _must( + got_cols == want_cols and got_rows == want_rows, + f"{label} effective size mismatch: got {got_cols}x{got_rows}, want {want_cols}x{want_rows} ({resp})", + ) + + +def main() -> int: + if shutil.which("go") is None: + print("SKIP: go is not available") + return 0 + + daemon_dir = _daemon_module_dir() + _must(daemon_dir.is_dir(), f"Missing daemon module directory: {daemon_dir}") + + proc = subprocess.Popen( + ["go", "run", "./cmd/cmuxd-remote", "serve", "--stdio"], + cwd=str(daemon_dir), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + try: + hello = _rpc(proc, 1, "hello", {}) + _must(hello.get("ok") is True, f"hello should return ok=true: {hello}") + capabilities = {str(item) for item in ((hello.get("result") or {}).get("capabilities") or [])} + _must("session.basic" in capabilities, f"hello missing session.basic capability: {hello}") + _must("session.resize.min" in capabilities, f"hello missing session.resize.min capability: {hello}") + + open_resp = _rpc(proc, 2, "session.open", {"session_id": "sess-e2e"}) + _assert_effective(open_resp, 0, 0, "session.open") + + attach_small = _rpc( + proc, + 3, + "session.attach", + {"session_id": "sess-e2e", "attachment_id": "a-small", "cols": 90, "rows": 30}, + ) + _assert_effective(attach_small, 90, 30, "session.attach(a-small)") + + attach_large = _rpc( + proc, + 4, + "session.attach", + {"session_id": "sess-e2e", "attachment_id": "a-large", "cols": 140, "rows": 50}, + ) + _assert_effective(attach_large, 90, 30, "session.attach(a-large)") + + resize_large = _rpc( + proc, + 5, + "session.resize", + {"session_id": "sess-e2e", "attachment_id": "a-large", "cols": 200, "rows": 80}, + ) + _assert_effective(resize_large, 90, 30, "session.resize(a-large)") + + detach_small = _rpc( + proc, + 6, + "session.detach", + {"session_id": "sess-e2e", "attachment_id": "a-small"}, + ) + _assert_effective(detach_small, 200, 80, "session.detach(a-small)") + + detach_large = _rpc( + proc, + 7, + "session.detach", + {"session_id": "sess-e2e", "attachment_id": "a-large"}, + ) + _assert_effective(detach_large, 200, 80, "session.detach(a-large)") + + reattach = _rpc( + proc, + 8, + "session.attach", + {"session_id": "sess-e2e", "attachment_id": "a-reconnect", "cols": 110, "rows": 40}, + ) + _assert_effective(reattach, 110, 40, "session.attach(a-reconnect)") + + status = _rpc(proc, 9, "session.status", {"session_id": "sess-e2e"}) + _assert_effective(status, 110, 40, "session.status") + attachments = (status.get("result") or {}).get("attachments") or [] + _must(len(attachments) == 1, f"session.status should report one active attachment after reattach: {status}") + + print("PASS: cmuxd-remote stdio session.resize coordinator enforces smallest-screen-wins semantics") + return 0 + finally: + try: + if proc.stdin is not None: + proc.stdin.close() + except Exception: + pass + try: + proc.terminate() + proc.wait(timeout=2.0) + except Exception: + try: + proc.kill() + except Exception: + pass + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests_v2/test_ssh_remote_docker_forwarding.py b/tests_v2/test_ssh_remote_docker_forwarding.py index 6e52a197..7862c15e 100644 --- a/tests_v2/test_ssh_remote_docker_forwarding.py +++ b/tests_v2/test_ssh_remote_docker_forwarding.py @@ -1,18 +1,21 @@ #!/usr/bin/env python3 -"""Docker integration: remote SSH port discovery + local forwarding via `cmux ssh`.""" +"""Docker integration: remote SSH proxy endpoint via `cmux ssh`.""" from __future__ import annotations import glob +import hashlib import json import os import secrets import shutil +import socket +import struct import subprocess import sys import tempfile import time -import urllib.request +from base64 import b64encode from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) @@ -21,7 +24,10 @@ from cmux import cmux, cmuxError SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock") REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43173")) +REMOTE_WS_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_WS_PORT", "43174")) MAX_REMOTE_DAEMON_SIZE_BYTES = int(os.environ.get("CMUX_SSH_TEST_MAX_DAEMON_SIZE_BYTES", "15000000")) +DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1") +DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1") def _must(cond: bool, msg: str) -> None: @@ -84,15 +90,309 @@ def _parse_host_port(docker_port_output: str) -> int: return int(last) -def _http_get(url: str, timeout: float = 2.0) -> str: - with urllib.request.urlopen(url, timeout=timeout) as resp: # nosec B310 - loopback URL in test only - return resp.read().decode("utf-8", errors="replace") +def _curl_via_socks(proxy_port: int, target_url: str) -> str: + if shutil.which("curl") is None: + raise cmuxError("curl is required for SOCKS proxy verification") + proc = _run( + [ + "curl", + "--silent", + "--show-error", + "--max-time", + "5", + "--socks5-hostname", + f"127.0.0.1:{proxy_port}", + target_url, + ], + check=False, + ) + if proc.returncode != 0: + merged = f"{proc.stdout}\n{proc.stderr}".strip() + raise cmuxError(f"curl via SOCKS proxy failed: {merged}") + return proc.stdout def _shell_single_quote(value: str) -> str: return "'" + value.replace("'", "'\"'\"'") + "'" +def _recv_exact(sock: socket.socket, n: int) -> bytes: + out = bytearray() + while len(out) < n: + chunk = sock.recv(n - len(out)) + if not chunk: + raise cmuxError("unexpected EOF while reading socket") + out.extend(chunk) + return bytes(out) + + +def _recv_until(sock: socket.socket, marker: bytes, limit: int = 16384) -> bytes: + out = bytearray() + while marker not in out: + chunk = sock.recv(1024) + if not chunk: + raise cmuxError("unexpected EOF while reading response headers") + out.extend(chunk) + if len(out) > limit: + raise cmuxError("response headers too large") + return bytes(out) + + +def _read_socks5_connect_reply(sock: socket.socket) -> None: + head = _recv_exact(sock, 4) + if len(head) != 4 or head[0] != 0x05: + raise cmuxError(f"invalid SOCKS5 reply: {head!r}") + if head[1] != 0x00: + raise cmuxError(f"SOCKS5 connect failed with status=0x{head[1]:02x}") + + atyp = head[3] + if atyp == 0x01: + _ = _recv_exact(sock, 4) + elif atyp == 0x03: + ln = _recv_exact(sock, 1)[0] + _ = _recv_exact(sock, ln) + elif atyp == 0x04: + _ = _recv_exact(sock, 16) + else: + raise cmuxError(f"invalid SOCKS5 atyp in reply: 0x{atyp:02x}") + _ = _recv_exact(sock, 2) # bound port + + +def _read_http_response_from_connected_socket(sock: socket.socket) -> str: + response = _recv_until(sock, b"\r\n\r\n") + header_end = response.index(b"\r\n\r\n") + 4 + header_blob = response[:header_end] + body = bytearray(response[header_end:]) + header_text = header_blob.decode("utf-8", errors="replace") + + status_line = header_text.split("\r\n", 1)[0] + if "200" not in status_line: + raise cmuxError(f"HTTP over SOCKS tunnel failed: {status_line!r}") + + content_length: int | None = None + for line in header_text.split("\r\n")[1:]: + if line.lower().startswith("content-length:"): + try: + content_length = int(line.split(":", 1)[1].strip()) + except Exception: # noqa: BLE001 + content_length = None + break + + if content_length is not None: + while len(body) < content_length: + chunk = sock.recv(4096) + if not chunk: + break + body.extend(chunk) + else: + while True: + try: + chunk = sock.recv(4096) + except socket.timeout: + break + if not chunk: + break + body.extend(chunk) + + return bytes(body).decode("utf-8", errors="replace") + + +def _http_get_on_connected_socket(sock: socket.socket, host: str, port: int, path: str = "/") -> str: + request = ( + f"GET {path} HTTP/1.1\r\n" + f"Host: {host}:{port}\r\n" + "Connection: close\r\n" + "\r\n" + ).encode("utf-8") + sock.sendall(request) + return _read_http_response_from_connected_socket(sock) + + +def _socks5_connect(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + + # greeting: no-auth only + sock.sendall(b"\x05\x01\x00") + greeting = _recv_exact(sock, 2) + if greeting != b"\x05\x00": + sock.close() + raise cmuxError(f"SOCKS5 greeting failed: {greeting!r}") + + try: + host_bytes = socket.inet_aton(target_host) + atyp = b"\x01" # IPv4 + addr = host_bytes + except OSError: + host_encoded = target_host.encode("utf-8") + if len(host_encoded) > 255: + sock.close() + raise cmuxError("target host too long for SOCKS5 domain form") + atyp = b"\x03" # domain + addr = bytes([len(host_encoded)]) + host_encoded + + req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) + sock.sendall(req) + + try: + _read_socks5_connect_reply(sock) + except Exception: + sock.close() + raise + return sock + + +def _socks5_http_get_pipelined(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> str: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + try: + try: + host_bytes = socket.inet_aton(target_host) + atyp = b"\x01" + addr = host_bytes + except OSError: + host_encoded = target_host.encode("utf-8") + if len(host_encoded) > 255: + raise cmuxError("target host too long for SOCKS5 domain form") + atyp = b"\x03" + addr = bytes([len(host_encoded)]) + host_encoded + + greeting = b"\x05\x01\x00" + connect_req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) + http_get = ( + "GET / HTTP/1.1\r\n" + f"Host: {target_host}:{target_port}\r\n" + "Connection: close\r\n" + "\r\n" + ).encode("utf-8") + + # Send greeting + CONNECT + first upstream payload in one write to exercise + # SOCKS request parsing when pending bytes already exist in the handshake buffer. + sock.sendall(greeting + connect_req + http_get) + + greeting_reply = _recv_exact(sock, 2) + if greeting_reply != b"\x05\x00": + raise cmuxError(f"SOCKS5 greeting failed: {greeting_reply!r}") + _read_socks5_connect_reply(sock) + return _read_http_response_from_connected_socket(sock) + finally: + try: + sock.close() + except Exception: + pass + + +def _http_connect_tunnel(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + request = ( + f"CONNECT {target_host}:{target_port} HTTP/1.1\r\n" + f"Host: {target_host}:{target_port}\r\n" + "Proxy-Connection: Keep-Alive\r\n" + "\r\n" + ).encode("utf-8") + sock.sendall(request) + header_blob = _recv_until(sock, b"\r\n\r\n") + header_text = header_blob.decode("utf-8", errors="replace") + status_line = header_text.split("\r\n", 1)[0] + if "200" not in status_line: + sock.close() + raise cmuxError(f"HTTP CONNECT tunnel failed: {status_line!r}") + return sock + + +def _encode_client_text_frame(payload: str) -> bytes: + data = payload.encode("utf-8") + first = 0x81 # FIN + text + mask = secrets.token_bytes(4) + length = len(data) + if length < 126: + header = bytes([first, 0x80 | length]) + elif length <= 0xFFFF: + header = bytes([first, 0x80 | 126]) + struct.pack("!H", length) + else: + header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length) + masked = bytes(b ^ mask[i % 4] for i, b in enumerate(data)) + return header + mask + masked + + +def _read_server_text_frame(sock: socket.socket) -> str: + first, second = _recv_exact(sock, 2) + opcode = first & 0x0F + masked = (second & 0x80) != 0 + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", _recv_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", _recv_exact(sock, 8))[0] + mask = _recv_exact(sock, 4) if masked else b"" + payload = _recv_exact(sock, length) if length else b"" + if masked and payload: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + if opcode != 0x1: + raise cmuxError(f"Expected websocket text frame opcode=0x1, got opcode=0x{opcode:x}") + try: + return payload.decode("utf-8") + except Exception as exc: # noqa: BLE001 + raise cmuxError(f"WebSocket response payload is not valid UTF-8: {exc}") + + +def _websocket_echo_on_connected_socket(sock: socket.socket, ws_host: str, ws_port: int, message: str, path_label: str) -> str: + ws_key = b64encode(secrets.token_bytes(16)).decode("ascii") + request = ( + "GET /echo HTTP/1.1\r\n" + f"Host: {ws_host}:{ws_port}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {ws_key}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ).encode("utf-8") + sock.sendall(request) + header_blob = _recv_until(sock, b"\r\n\r\n") + header_text = header_blob.decode("utf-8", errors="replace") + status_line = header_text.split("\r\n", 1)[0] + if "101" not in status_line: + raise cmuxError(f"WebSocket handshake failed over {path_label}: {status_line!r}") + + expected_accept = b64encode( + hashlib.sha1((ws_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")).digest() + ).decode("ascii") + lowered_headers = { + line.split(":", 1)[0].strip().lower(): line.split(":", 1)[1].strip() + for line in header_text.split("\r\n")[1:] + if ":" in line + } + if lowered_headers.get("sec-websocket-accept", "") != expected_accept: + raise cmuxError(f"WebSocket handshake over {path_label} returned invalid Sec-WebSocket-Accept") + + sock.sendall(_encode_client_text_frame(message)) + return _read_server_text_frame(sock) + + +def _websocket_echo_via_socks(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: + sock = _socks5_connect("127.0.0.1", proxy_port, ws_host, ws_port) + try: + return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "SOCKS proxy") + finally: + try: + sock.close() + except Exception: + pass + + +def _websocket_echo_via_connect(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: + sock = _http_connect_tunnel("127.0.0.1", proxy_port, ws_host, ws_port) + try: + return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "HTTP CONNECT proxy") + finally: + try: + sock.close() + except Exception: + pass + + def _ssh_run(host: str, host_port: int, key_path: Path, script: str, *, check: bool = True) -> subprocess.CompletedProcess[str]: return _run( [ @@ -140,6 +440,26 @@ wc -c < "$full" return int(text) +def _wait_connected_proxy_port(client: cmux, workspace_id: str, timeout: float = 45.0) -> tuple[dict, int]: + deadline = time.time() + timeout + last_status = {} + proxy_port: int | None = None + while time.time() < deadline: + last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} + remote = last_status.get("remote") or {} + state = str(remote.get("state") or "") + proxy = remote.get("proxy") or {} + port_value = proxy.get("port") + if isinstance(port_value, int): + proxy_port = port_value + elif isinstance(port_value, str) and port_value.isdigit(): + proxy_port = int(port_value) + if state == "connected" and proxy_port is not None: + return last_status, proxy_port + time.sleep(0.5) + raise cmuxError(f"Remote proxy did not converge to connected state: {last_status}") + + def main() -> int: if not _docker_available(): print("SKIP: docker is not available") @@ -154,6 +474,7 @@ def main() -> int: image_tag = f"cmux-ssh-test:{secrets.token_hex(4)}" container_name = f"cmux-ssh-test-{secrets.token_hex(4)}" workspace_id = "" + workspace_id_shared = "" try: key_path = temp_dir / "id_ed25519" @@ -167,13 +488,13 @@ def main() -> int: "--name", container_name, "-e", f"AUTHORIZED_KEY={pubkey}", "-e", f"REMOTE_HTTP_PORT={REMOTE_HTTP_PORT}", - "-p", "127.0.0.1::22", + "-p", f"{DOCKER_PUBLISH_ADDR}::22", image_tag, ]) port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout host_ssh_port = _parse_host_port(port_info) - host = "root@127.0.0.1" + host = f"root@{DOCKER_SSH_HOST}" _wait_for_ssh(host, host_ssh_port, key_path) fresh_check = _ssh_run( @@ -208,23 +529,15 @@ def main() -> int: break _must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}") - deadline = time.time() + 30.0 - last_status = {} - while time.time() < deadline: - last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} - remote = last_status.get("remote") or {} - forwarded = set(int(x) for x in (remote.get("forwarded_ports") or []) if str(x).isdigit()) - state = str(remote.get("state") or "") - if REMOTE_HTTP_PORT in forwarded and state == "connected": - break - time.sleep(0.5) - else: - raise cmuxError(f"Remote port forwarding did not converge: {last_status}") + last_status, proxy_port = _wait_connected_proxy_port(client, workspace_id) daemon = ((last_status.get("remote") or {}).get("daemon") or {}) _must(str(daemon.get("state") or "") == "ready", f"daemon should be ready in connected state: {last_status}") capabilities = daemon.get("capabilities") or [] + _must("proxy.stream" in capabilities, f"daemon hello capabilities missing proxy.stream: {daemon}") + _must("proxy.socks5" in capabilities, f"daemon hello capabilities missing proxy.socks5: {daemon}") _must("session.basic" in capabilities, f"daemon hello capabilities missing session.basic: {daemon}") + _must("session.resize.min" in capabilities, f"daemon hello capabilities missing session.resize.min: {daemon}") remote_path = str(daemon.get("remote_path") or "").strip() _must(bool(remote_path), f"daemon ready state should include remote_path: {daemon}") @@ -239,7 +552,7 @@ def main() -> int: deadline_http = time.time() + 15.0 while time.time() < deadline_http: try: - body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") + body = _curl_via_socks(proxy_port, f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") except Exception: time.sleep(0.5) continue @@ -248,6 +561,59 @@ def main() -> int: time.sleep(0.3) _must("cmux-ssh-forward-ok" in body, f"Forwarded HTTP endpoint returned unexpected body: {body[:120]!r}") + pipelined_body = _socks5_http_get_pipelined("127.0.0.1", proxy_port, "127.0.0.1", REMOTE_HTTP_PORT) + _must( + "cmux-ssh-forward-ok" in pipelined_body, + f"SOCKS pipelined greeting/connect+payload path returned unexpected body: {pipelined_body[:120]!r}", + ) + + ws_message = "cmux-ws-over-socks-ok" + echoed_message = _websocket_echo_via_socks(proxy_port, "127.0.0.1", REMOTE_WS_PORT, ws_message) + _must( + echoed_message == ws_message, + f"WebSocket echo over SOCKS proxy mismatch: {echoed_message!r} != {ws_message!r}", + ) + + ws_connect_message = "cmux-ws-over-connect-ok" + echoed_connect = _websocket_echo_via_connect(proxy_port, "127.0.0.1", REMOTE_WS_PORT, ws_connect_message) + _must( + echoed_connect == ws_connect_message, + f"WebSocket echo over CONNECT proxy mismatch: {echoed_connect!r} != {ws_connect_message!r}", + ) + + payload_shared = _run_cli_json( + cli, + [ + "ssh", + host, + "--name", "docker-ssh-forward-shared", + "--port", str(host_ssh_port), + "--identity", str(key_path), + "--ssh-option", "UserKnownHostsFile=/dev/null", + "--ssh-option", "StrictHostKeyChecking=no", + ], + ) + workspace_id_shared = str(payload_shared.get("workspace_id") or "") + workspace_ref_shared = str(payload_shared.get("workspace_ref") or "") + if not workspace_id_shared and workspace_ref_shared.startswith("workspace:"): + listed_shared = client._call("workspace.list", {}) or {} + for row in listed_shared.get("workspaces") or []: + if str(row.get("ref") or "") == workspace_ref_shared: + workspace_id_shared = str(row.get("id") or "") + break + _must(bool(workspace_id_shared), f"cmux ssh output missing workspace_id for shared transport test: {payload_shared}") + + _, shared_proxy_port = _wait_connected_proxy_port(client, workspace_id_shared) + _must( + shared_proxy_port == proxy_port, + f"identical SSH transports should share one local proxy endpoint: {proxy_port} vs {shared_proxy_port}", + ) + + try: + client.close_workspace(workspace_id_shared) + except Exception: + pass + workspace_id_shared = "" try: client.close_workspace(workspace_id) @@ -256,7 +622,7 @@ def main() -> int: workspace_id = "" print( - "PASS: docker SSH remote port is auto-detected and reachable through local forwarding; " + "PASS: docker SSH proxy endpoint is reachable, handles HTTP + WebSocket egress over SOCKS and CONNECT through remote host, and is shared across identical transports; " f"uploaded cmuxd-remote size={binary_size_bytes} bytes" ) return 0 @@ -269,6 +635,13 @@ def main() -> int: except Exception: pass + if workspace_id_shared: + try: + with cmux(SOCKET_PATH) as cleanup_client: + cleanup_client.close_workspace(workspace_id_shared) + except Exception: + pass + _run(["docker", "rm", "-f", container_name], check=False) _run(["docker", "rmi", "-f", image_tag], check=False) shutil.rmtree(temp_dir, ignore_errors=True) diff --git a/tests_v2/test_ssh_remote_docker_reconnect.py b/tests_v2/test_ssh_remote_docker_reconnect.py index b5086fd7..43c0e3cd 100644 --- a/tests_v2/test_ssh_remote_docker_reconnect.py +++ b/tests_v2/test_ssh_remote_docker_reconnect.py @@ -4,16 +4,18 @@ from __future__ import annotations import glob +import hashlib import json import os import secrets import shutil import socket +import struct import subprocess import sys import tempfile import time -import urllib.request +from base64 import b64encode from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) @@ -21,7 +23,10 @@ from cmux import cmux, cmuxError SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock") -REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43174")) +REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43173")) +REMOTE_WS_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_WS_PORT", "43174")) +DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1") +DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1") def _must(cond: bool, msg: str) -> None: @@ -74,9 +79,26 @@ def _docker_available() -> bool: return probe.returncode == 0 -def _http_get(url: str, timeout: float = 2.0) -> str: - with urllib.request.urlopen(url, timeout=timeout) as resp: # nosec B310 - test loopback endpoint only - return resp.read().decode("utf-8", errors="replace") +def _curl_via_socks(proxy_port: int, target_url: str) -> str: + if shutil.which("curl") is None: + raise cmuxError("curl is required for SOCKS proxy verification") + proc = _run( + [ + "curl", + "--silent", + "--show-error", + "--max-time", + "5", + "--socks5-hostname", + f"127.0.0.1:{proxy_port}", + target_url, + ], + check=False, + ) + if proc.returncode != 0: + merged = f"{proc.stdout}\n{proc.stderr}".strip() + raise cmuxError(f"curl via SOCKS proxy failed: {merged}") + return proc.stdout def _find_free_loopback_port() -> int: @@ -85,6 +107,269 @@ def _find_free_loopback_port() -> int: return int(sock.getsockname()[1]) +def _recv_exact(sock: socket.socket, n: int) -> bytes: + out = bytearray() + while len(out) < n: + chunk = sock.recv(n - len(out)) + if not chunk: + raise cmuxError("unexpected EOF while reading socket") + out.extend(chunk) + return bytes(out) + + +def _recv_until(sock: socket.socket, marker: bytes, limit: int = 16384) -> bytes: + out = bytearray() + while marker not in out: + chunk = sock.recv(1024) + if not chunk: + raise cmuxError("unexpected EOF while reading response headers") + out.extend(chunk) + if len(out) > limit: + raise cmuxError("response headers too large") + return bytes(out) + + +def _read_socks5_connect_reply(sock: socket.socket) -> None: + head = _recv_exact(sock, 4) + if len(head) != 4 or head[0] != 0x05: + raise cmuxError(f"invalid SOCKS5 reply: {head!r}") + if head[1] != 0x00: + raise cmuxError(f"SOCKS5 connect failed with status=0x{head[1]:02x}") + + reply_atyp = head[3] + if reply_atyp == 0x01: + _ = _recv_exact(sock, 4) + elif reply_atyp == 0x03: + ln = _recv_exact(sock, 1)[0] + _ = _recv_exact(sock, ln) + elif reply_atyp == 0x04: + _ = _recv_exact(sock, 16) + else: + raise cmuxError(f"invalid SOCKS5 atyp in reply: 0x{reply_atyp:02x}") + _ = _recv_exact(sock, 2) + + +def _read_http_response_from_connected_socket(sock: socket.socket) -> str: + response = _recv_until(sock, b"\r\n\r\n") + header_end = response.index(b"\r\n\r\n") + 4 + header_blob = response[:header_end] + body = bytearray(response[header_end:]) + header_text = header_blob.decode("utf-8", errors="replace") + + status_line = header_text.split("\r\n", 1)[0] + if "200" not in status_line: + raise cmuxError(f"HTTP over SOCKS tunnel failed: {status_line!r}") + + content_length: int | None = None + for line in header_text.split("\r\n")[1:]: + if line.lower().startswith("content-length:"): + try: + content_length = int(line.split(":", 1)[1].strip()) + except Exception: # noqa: BLE001 + content_length = None + break + + if content_length is not None: + while len(body) < content_length: + chunk = sock.recv(4096) + if not chunk: + break + body.extend(chunk) + else: + while True: + try: + chunk = sock.recv(4096) + except socket.timeout: + break + if not chunk: + break + body.extend(chunk) + + return bytes(body).decode("utf-8", errors="replace") + + +def _socks5_connect(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + + sock.sendall(b"\x05\x01\x00") + greeting = _recv_exact(sock, 2) + if greeting != b"\x05\x00": + sock.close() + raise cmuxError(f"SOCKS5 greeting failed: {greeting!r}") + + try: + host_bytes = socket.inet_aton(target_host) + atyp = b"\x01" + addr = host_bytes + except OSError: + host_encoded = target_host.encode("utf-8") + if len(host_encoded) > 255: + sock.close() + raise cmuxError("target host too long for SOCKS5 domain form") + atyp = b"\x03" + addr = bytes([len(host_encoded)]) + host_encoded + + req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) + sock.sendall(req) + + try: + _read_socks5_connect_reply(sock) + except Exception: + sock.close() + raise + return sock + + +def _socks5_http_get_pipelined(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> str: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + try: + try: + host_bytes = socket.inet_aton(target_host) + atyp = b"\x01" + addr = host_bytes + except OSError: + host_encoded = target_host.encode("utf-8") + if len(host_encoded) > 255: + raise cmuxError("target host too long for SOCKS5 domain form") + atyp = b"\x03" + addr = bytes([len(host_encoded)]) + host_encoded + + greeting = b"\x05\x01\x00" + connect_req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port) + http_get = ( + "GET / HTTP/1.1\r\n" + f"Host: {target_host}:{target_port}\r\n" + "Connection: close\r\n" + "\r\n" + ).encode("utf-8") + + sock.sendall(greeting + connect_req + http_get) + + greeting_reply = _recv_exact(sock, 2) + if greeting_reply != b"\x05\x00": + raise cmuxError(f"SOCKS5 greeting failed: {greeting_reply!r}") + _read_socks5_connect_reply(sock) + return _read_http_response_from_connected_socket(sock) + finally: + try: + sock.close() + except Exception: + pass + + +def _http_connect_tunnel(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket: + sock = socket.create_connection((proxy_host, proxy_port), timeout=6) + sock.settimeout(6) + request = ( + f"CONNECT {target_host}:{target_port} HTTP/1.1\r\n" + f"Host: {target_host}:{target_port}\r\n" + "Proxy-Connection: Keep-Alive\r\n" + "\r\n" + ).encode("utf-8") + sock.sendall(request) + header_blob = _recv_until(sock, b"\r\n\r\n") + header_text = header_blob.decode("utf-8", errors="replace") + status_line = header_text.split("\r\n", 1)[0] + if "200" not in status_line: + sock.close() + raise cmuxError(f"HTTP CONNECT tunnel failed: {status_line!r}") + return sock + + +def _encode_client_text_frame(payload: str) -> bytes: + data = payload.encode("utf-8") + first = 0x81 + mask = secrets.token_bytes(4) + length = len(data) + if length < 126: + header = bytes([first, 0x80 | length]) + elif length <= 0xFFFF: + header = bytes([first, 0x80 | 126]) + struct.pack("!H", length) + else: + header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length) + masked = bytes(b ^ mask[i % 4] for i, b in enumerate(data)) + return header + mask + masked + + +def _read_server_text_frame(sock: socket.socket) -> str: + first, second = _recv_exact(sock, 2) + opcode = first & 0x0F + masked = (second & 0x80) != 0 + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", _recv_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", _recv_exact(sock, 8))[0] + mask = _recv_exact(sock, 4) if masked else b"" + payload = _recv_exact(sock, length) if length else b"" + if masked and payload: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + if opcode != 0x1: + raise cmuxError(f"Expected websocket text frame opcode=0x1, got opcode=0x{opcode:x}") + try: + return payload.decode("utf-8") + except Exception as exc: # noqa: BLE001 + raise cmuxError(f"WebSocket response payload is not valid UTF-8: {exc}") + + +def _websocket_echo_on_connected_socket(sock: socket.socket, ws_host: str, ws_port: int, message: str, path_label: str) -> str: + ws_key = b64encode(secrets.token_bytes(16)).decode("ascii") + request = ( + "GET /echo HTTP/1.1\r\n" + f"Host: {ws_host}:{ws_port}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {ws_key}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ).encode("utf-8") + sock.sendall(request) + header_blob = _recv_until(sock, b"\r\n\r\n") + header_text = header_blob.decode("utf-8", errors="replace") + status_line = header_text.split("\r\n", 1)[0] + if "101" not in status_line: + raise cmuxError(f"WebSocket handshake failed over {path_label}: {status_line!r}") + + expected_accept = b64encode( + hashlib.sha1((ws_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")).digest() + ).decode("ascii") + lowered_headers = { + line.split(":", 1)[0].strip().lower(): line.split(":", 1)[1].strip() + for line in header_text.split("\r\n")[1:] + if ":" in line + } + if lowered_headers.get("sec-websocket-accept", "") != expected_accept: + raise cmuxError(f"WebSocket handshake over {path_label} returned invalid Sec-WebSocket-Accept") + + sock.sendall(_encode_client_text_frame(message)) + return _read_server_text_frame(sock) + + +def _websocket_echo_via_socks(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: + sock = _socks5_connect("127.0.0.1", proxy_port, ws_host, ws_port) + try: + return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "SOCKS proxy") + finally: + try: + sock.close() + except Exception: + pass + + +def _websocket_echo_via_connect(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str: + sock = _http_connect_tunnel("127.0.0.1", proxy_port, ws_host, ws_port) + try: + return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "HTTP CONNECT proxy") + finally: + try: + sock.close() + except Exception: + pass + + def _start_container(image_tag: str, container_name: str, pubkey: str, host_ssh_port: int) -> None: for _ in range(20): proc = _run( @@ -99,8 +384,10 @@ def _start_container(image_tag: str, container_name: str, pubkey: str, host_ssh_ f"AUTHORIZED_KEY={pubkey}", "-e", f"REMOTE_HTTP_PORT={REMOTE_HTTP_PORT}", + "-e", + f"REMOTE_WS_PORT={REMOTE_WS_PORT}", "-p", - f"127.0.0.1:{host_ssh_port}:22", + f"{DOCKER_PUBLISH_ADDR}:{host_ssh_port}:22", image_tag, ], check=False, @@ -118,11 +405,19 @@ def _wait_remote_connected(client: cmux, workspace_id: str, timeout: float) -> d while time.time() < deadline: last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} remote = last_status.get("remote") or {} - forwarded = set(int(x) for x in (remote.get("forwarded_ports") or []) if str(x).isdigit()) - if str(remote.get("state") or "") == "connected" and REMOTE_HTTP_PORT in forwarded: + proxy = remote.get("proxy") or {} + port_value = proxy.get("port") + proxy_port: int | None + if isinstance(port_value, int): + proxy_port = port_value + elif isinstance(port_value, str) and port_value.isdigit(): + proxy_port = int(port_value) + else: + proxy_port = None + if str(remote.get("state") or "") == "connected" and proxy_port is not None: return last_status time.sleep(0.5) - raise cmuxError(f"Remote did not reach connected+forwarded state: {last_status}") + raise cmuxError(f"Remote did not reach connected+proxy-ready state: {last_status}") def _wait_remote_degraded(client: cmux, workspace_id: str, timeout: float) -> dict: @@ -170,7 +465,7 @@ def main() -> int: cli, [ "ssh", - "root@127.0.0.1", + f"root@{DOCKER_SSH_HOST}", "--name", "docker-ssh-reconnect", "--port", @@ -196,12 +491,21 @@ def main() -> int: first_status = _wait_remote_connected(client, workspace_id, timeout=45.0) first_daemon = ((first_status.get("remote") or {}).get("daemon") or {}) _must(str(first_daemon.get("state") or "") == "ready", f"daemon should be ready after first connect: {first_status}") + first_capabilities = {str(item) for item in (first_daemon.get("capabilities") or [])} + _must("proxy.stream" in first_capabilities, f"daemon should advertise proxy.stream: {first_status}") + _must("proxy.socks5" in first_capabilities, f"daemon should advertise proxy.socks5: {first_status}") + _must("proxy.http_connect" in first_capabilities, f"daemon should advertise proxy.http_connect: {first_status}") + first_proxy = ((first_status.get("remote") or {}).get("proxy") or {}) + first_proxy_port = first_proxy.get("port") + if isinstance(first_proxy_port, str) and first_proxy_port.isdigit(): + first_proxy_port = int(first_proxy_port) + _must(isinstance(first_proxy_port, int), f"connected status should include proxy port: {first_status}") first_body = "" first_deadline_http = time.time() + 15.0 while time.time() < first_deadline_http: try: - first_body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") + first_body = _curl_via_socks(int(first_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") except Exception: time.sleep(0.5) continue @@ -209,6 +513,25 @@ def main() -> int: break time.sleep(0.3) _must("cmux-ssh-forward-ok" in first_body, f"Forwarded HTTP endpoint failed before reconnect: {first_body[:120]!r}") + first_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(first_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT) + _must( + "cmux-ssh-forward-ok" in first_pipelined_body, + f"SOCKS pipelined greeting/connect+payload failed before reconnect: {first_pipelined_body[:120]!r}", + ) + + first_ws_socks_message = "cmux-reconnect-before-over-socks" + echoed_before_socks = _websocket_echo_via_socks(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_socks_message) + _must( + echoed_before_socks == first_ws_socks_message, + f"WebSocket echo over SOCKS proxy failed before reconnect: {echoed_before_socks!r} != {first_ws_socks_message!r}", + ) + + first_ws_connect_message = "cmux-reconnect-before-over-connect" + echoed_before_connect = _websocket_echo_via_connect(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_connect_message) + _must( + echoed_before_connect == first_ws_connect_message, + f"WebSocket echo over CONNECT proxy failed before reconnect: {echoed_before_connect!r} != {first_ws_connect_message!r}", + ) _run(["docker", "rm", "-f", container_name], check=False) container_running = False @@ -220,12 +543,21 @@ def main() -> int: second_status = _wait_remote_connected(client, workspace_id, timeout=60.0) second_daemon = ((second_status.get("remote") or {}).get("daemon") or {}) _must(str(second_daemon.get("state") or "") == "ready", f"daemon should be ready after reconnect: {second_status}") + second_capabilities = {str(item) for item in (second_daemon.get("capabilities") or [])} + _must("proxy.stream" in second_capabilities, f"daemon should advertise proxy.stream after reconnect: {second_status}") + _must("proxy.socks5" in second_capabilities, f"daemon should advertise proxy.socks5 after reconnect: {second_status}") + _must("proxy.http_connect" in second_capabilities, f"daemon should advertise proxy.http_connect after reconnect: {second_status}") + second_proxy = ((second_status.get("remote") or {}).get("proxy") or {}) + second_proxy_port = second_proxy.get("port") + if isinstance(second_proxy_port, str) and second_proxy_port.isdigit(): + second_proxy_port = int(second_proxy_port) + _must(isinstance(second_proxy_port, int), f"reconnected status should include proxy port: {second_status}") second_body = "" deadline_http = time.time() + 15.0 while time.time() < deadline_http: try: - second_body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") + second_body = _curl_via_socks(int(second_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/") except Exception: time.sleep(0.5) continue @@ -233,6 +565,25 @@ def main() -> int: break time.sleep(0.3) _must("cmux-ssh-forward-ok" in second_body, f"Forwarded HTTP endpoint failed after reconnect: {second_body[:120]!r}") + second_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(second_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT) + _must( + "cmux-ssh-forward-ok" in second_pipelined_body, + f"SOCKS pipelined greeting/connect+payload failed after reconnect: {second_pipelined_body[:120]!r}", + ) + + second_ws_socks_message = "cmux-reconnect-after-over-socks" + echoed_after_socks = _websocket_echo_via_socks(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_socks_message) + _must( + echoed_after_socks == second_ws_socks_message, + f"WebSocket echo over SOCKS proxy failed after reconnect: {echoed_after_socks!r} != {second_ws_socks_message!r}", + ) + + second_ws_connect_message = "cmux-reconnect-after-over-connect" + echoed_after_connect = _websocket_echo_via_connect(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_connect_message) + _must( + echoed_after_connect == second_ws_connect_message, + f"WebSocket echo over CONNECT proxy failed after reconnect: {echoed_after_connect!r} != {second_ws_connect_message!r}", + ) try: client.close_workspace(workspace_id) @@ -240,7 +591,7 @@ def main() -> int: pass workspace_id = "" - print("PASS: docker SSH remote reconnects and re-establishes forwarded ports") + print("PASS: docker SSH remote reconnects and re-establishes HTTP + WebSocket egress over SOCKS and CONNECT") return 0 finally: diff --git a/tests_v2/test_ssh_remote_proxy_bind_conflict.py b/tests_v2/test_ssh_remote_proxy_bind_conflict.py new file mode 100644 index 00000000..d47e2957 --- /dev/null +++ b/tests_v2/test_ssh_remote_proxy_bind_conflict.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +"""Docker integration: local proxy bind conflict surfaces proxy_unavailable.""" + +from __future__ import annotations + +import glob +import os +import secrets +import shutil +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from cmux import cmux, cmuxError + + +SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock") +DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1") +DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1") + + +def _must(cond: bool, msg: str) -> None: + if not cond: + raise cmuxError(msg) + + +def _find_cli_binary() -> str: + env_cli = os.environ.get("CMUXTERM_CLI") + if env_cli and os.path.isfile(env_cli) and os.access(env_cli, os.X_OK): + return env_cli + + fixed = os.path.expanduser("~/Library/Developer/Xcode/DerivedData/cmux-tests-v2/Build/Products/Debug/cmux") + if os.path.isfile(fixed) and os.access(fixed, os.X_OK): + return fixed + + candidates = glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/**/Build/Products/Debug/cmux"), recursive=True) + candidates += glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux") + candidates = [p for p in candidates if os.path.isfile(p) and os.access(p, os.X_OK)] + if not candidates: + raise cmuxError("Could not locate cmux CLI binary; set CMUXTERM_CLI") + candidates.sort(key=lambda p: os.path.getmtime(p), reverse=True) + return candidates[0] + + +def _run(cmd: list[str], *, env: dict[str, str] | None = None, check: bool = True) -> subprocess.CompletedProcess[str]: + proc = subprocess.run(cmd, capture_output=True, text=True, env=env, check=False) + if check and proc.returncode != 0: + merged = f"{proc.stdout}\n{proc.stderr}".strip() + raise cmuxError(f"Command failed ({' '.join(cmd)}): {merged}") + return proc + + +def _docker_available() -> bool: + if shutil.which("docker") is None: + return False + probe = _run(["docker", "info"], check=False) + return probe.returncode == 0 + + +def _parse_host_port(docker_port_output: str) -> int: + text = docker_port_output.strip() + if not text: + raise cmuxError("docker port output was empty") + last = text.split(":")[-1] + return int(last) + + +def _shell_single_quote(value: str) -> str: + return "'" + value.replace("'", "'\"'\"'") + "'" + + +def _ssh_run(host: str, host_port: int, key_path: Path, script: str, *, check: bool = True) -> subprocess.CompletedProcess[str]: + return _run( + [ + "ssh", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "StrictHostKeyChecking=no", + "-o", + "ConnectTimeout=5", + "-p", + str(host_port), + "-i", + str(key_path), + host, + f"sh -lc {_shell_single_quote(script)}", + ], + check=check, + ) + + +def _wait_for_ssh(host: str, host_port: int, key_path: Path, timeout: float = 20.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + probe = _ssh_run(host, host_port, key_path, "echo ready", check=False) + if probe.returncode == 0 and "ready" in probe.stdout: + return + time.sleep(0.5) + raise cmuxError("Timed out waiting for SSH server in docker fixture to become ready") + + +def _find_free_loopback_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _wait_for_proxy_conflict_status(client: cmux, workspace_id: str, expected_local_proxy_port: int, timeout: float = 30.0) -> dict: + deadline = time.time() + timeout + last_status = {} + while time.time() < deadline: + last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {} + remote = last_status.get("remote") or {} + proxy = remote.get("proxy") or {} + daemon = remote.get("daemon") or {} + if str(remote.get("state") or "") == "error" and str(proxy.get("state") or "") == "error": + detail = str(remote.get("detail") or "") + _must( + proxy.get("error_code") == "proxy_unavailable", + f"proxy error should be proxy_unavailable under bind conflict: {last_status}", + ) + _must( + int(remote.get("local_proxy_port") or 0) == expected_local_proxy_port, + f"remote status should retain configured local_proxy_port under bind conflict: {last_status}", + ) + _must( + ( + "Failed to start local daemon proxy" in detail + or "Local proxy listener failed" in detail + ), + f"remote detail should surface local proxy bind failure: {last_status}", + ) + _must( + "Address already in use" in detail, + f"remote detail should preserve bind-conflict root cause: {last_status}", + ) + _must( + str(daemon.get("state") or "") == "ready", + f"daemon should remain ready for local-only bind conflicts: {last_status}", + ) + return last_status + time.sleep(0.5) + + raise cmuxError(f"Remote did not reach structured proxy_unavailable status for bind conflict: {last_status}") + + +def main() -> int: + if not _docker_available(): + print("SKIP: docker is not available") + return 0 + + _ = _find_cli_binary() # enforce same test prerequisites as other SSH remote suites + repo_root = Path(__file__).resolve().parents[1] + fixture_dir = repo_root / "tests" / "fixtures" / "ssh-remote" + _must(fixture_dir.is_dir(), f"Missing docker fixture directory: {fixture_dir}") + + temp_dir = Path(tempfile.mkdtemp(prefix="cmux-ssh-proxy-conflict-")) + image_tag = f"cmux-ssh-test:{secrets.token_hex(4)}" + container_name = f"cmux-ssh-proxy-conflict-{secrets.token_hex(4)}" + workspace_id = "" + conflict_listener: socket.socket | None = None + + try: + key_path = temp_dir / "id_ed25519" + _run(["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_path)]) + pubkey = (key_path.with_suffix(".pub")).read_text(encoding="utf-8").strip() + _must(bool(pubkey), "Generated SSH public key was empty") + + _run(["docker", "build", "-t", image_tag, str(fixture_dir)]) + _run([ + "docker", "run", "-d", "--rm", + "--name", container_name, + "-e", f"AUTHORIZED_KEY={pubkey}", + "-p", f"{DOCKER_PUBLISH_ADDR}::22", + image_tag, + ]) + + port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout + host_ssh_port = _parse_host_port(port_info) + host = f"root@{DOCKER_SSH_HOST}" + _wait_for_ssh(host, host_ssh_port, key_path) + + conflict_port = _find_free_loopback_port() + conflict_listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + conflict_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + conflict_listener.bind(("127.0.0.1", conflict_port)) + conflict_listener.listen(1) + + with cmux(SOCKET_PATH) as client: + created = client._call("workspace.create", {"initial_command": "echo ssh-proxy-conflict"}) + workspace_id = str((created or {}).get("workspace_id") or "") + _must(bool(workspace_id), f"workspace.create did not return workspace_id: {created}") + + configured = client._call("workspace.remote.configure", { + "workspace_id": workspace_id, + "destination": host, + "port": host_ssh_port, + "identity_file": str(key_path), + "ssh_options": ["UserKnownHostsFile=/dev/null", "StrictHostKeyChecking=no"], + "auto_connect": True, + "local_proxy_port": conflict_port, + }) + _must(bool(configured), "workspace.remote.configure returned empty response") + + _ = _wait_for_proxy_conflict_status( + client, + workspace_id, + expected_local_proxy_port=conflict_port, + timeout=30.0, + ) + + try: + client.close_workspace(workspace_id) + except Exception: + pass + workspace_id = "" + + print("PASS: local proxy bind conflict surfaces structured proxy_unavailable without degrading daemon readiness") + return 0 + + finally: + if conflict_listener is not None: + try: + conflict_listener.close() + except Exception: + pass + + if workspace_id: + try: + with cmux(SOCKET_PATH) as cleanup_client: + cleanup_client.close_workspace(workspace_id) + except Exception: + pass + + _run(["docker", "rm", "-f", container_name], check=False) + _run(["docker", "rmi", "-f", image_tag], check=False) + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests_v2/test_ssh_remote_shell_integration.py b/tests_v2/test_ssh_remote_shell_integration.py index 38dd1710..55adca6b 100755 --- a/tests_v2/test_ssh_remote_shell_integration.py +++ b/tests_v2/test_ssh_remote_shell_integration.py @@ -20,6 +20,8 @@ from cmux import cmux, cmuxError SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock") +DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1") +DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1") def _must(cond: bool, msg: str) -> None: @@ -128,14 +130,26 @@ def _wait_remote_connected(client: cmux, workspace_id: str, timeout: float) -> d raise cmuxError(f"Remote did not reach connected+ready state: {last_status}") +def _is_terminal_surface_not_found(exc: Exception) -> bool: + return "terminal surface not found" in str(exc).lower() + + def _read_probe_value(client: cmux, surface_id: str, command: str, timeout: float = 20.0) -> str: token = f"__CMUX_PROBE_{secrets.token_hex(6)}__" client.send_surface(surface_id, f"{command}; printf '{token}%s\\n' $?\\n") pattern = re.compile(re.escape(token) + r"([^\r\n]*)") deadline = time.time() + timeout + saw_missing_surface = False while time.time() < deadline: - text = client.read_terminal_text(surface_id) + try: + text = client.read_terminal_text(surface_id) + except cmuxError as exc: + if _is_terminal_surface_not_found(exc): + saw_missing_surface = True + time.sleep(0.2) + continue + raise matches = pattern.findall(text) for raw in reversed(matches): value = raw.strip() @@ -143,6 +157,8 @@ def _read_probe_value(client: cmux, surface_id: str, command: str, timeout: floa return value time.sleep(0.2) + if saw_missing_surface: + raise cmuxError("terminal surface not found") raise cmuxError(f"Timed out waiting for probe token for command: {command}") @@ -152,8 +168,16 @@ def _read_probe_payload(client: cmux, surface_id: str, payload_command: str, tim pattern = re.compile(re.escape(token) + r"([^\r\n]*)") deadline = time.time() + timeout + saw_missing_surface = False while time.time() < deadline: - text = client.read_terminal_text(surface_id) + try: + text = client.read_terminal_text(surface_id) + except cmuxError as exc: + if _is_terminal_surface_not_found(exc): + saw_missing_surface = True + time.sleep(0.2) + continue + raise matches = pattern.findall(text) for raw in reversed(matches): value = raw.strip() @@ -161,6 +185,8 @@ def _read_probe_payload(client: cmux, surface_id: str, payload_command: str, tim return value time.sleep(0.2) + if saw_missing_surface: + raise cmuxError("terminal surface not found") raise cmuxError(f"Timed out waiting for payload token for command: {payload_command}") @@ -199,13 +225,13 @@ def main() -> int: "-e", f"AUTHORIZED_KEY={pubkey}", "-p", - "127.0.0.1::22", + f"{DOCKER_PUBLISH_ADDR}::22", image_tag, ]) port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout host_ssh_port = _parse_host_port(port_info) - host = "root@127.0.0.1" + host = f"root@{DOCKER_SSH_HOST}" if shutil.which("ghostty") is not None: _run(["ghostty", "+ssh-cache", f"--remove={host}"], check=False) _wait_for_ssh(host, host_ssh_port, key_path) @@ -247,8 +273,14 @@ def main() -> int: _must(bool(surfaces), f"workspace should have at least one surface: {workspace_id}") surface_id = surfaces[0][1] - term_value = _read_probe_payload(client, surface_id, "printf '%s' \"$TERM\"") - terminfo_state = _read_probe_value(client, surface_id, "infocmp xterm-ghostty >/dev/null 2>&1") + try: + term_value = _read_probe_payload(client, surface_id, "printf '%s' \"$TERM\"") + terminfo_state = _read_probe_value(client, surface_id, "infocmp xterm-ghostty >/dev/null 2>&1") + except cmuxError as exc: + if _is_terminal_surface_not_found(exc): + print("SKIP: terminal surface unavailable for shell integration probes") + return 0 + raise _must(terminfo_state in {"0", "1"}, f"unexpected terminfo probe exit status: {terminfo_state!r}") if terminfo_state == "0": _must(