WIP: advance ssh remote workspace proxying
This commit is contained in:
parent
19707299f9
commit
6800cd44bb
20 changed files with 4572 additions and 435 deletions
|
|
@ -79,6 +79,8 @@ tail -f "$(cat /tmp/cmux-last-debug-log-path 2>/dev/null || echo /tmp/cmux-debug
|
|||
- Untagged Debug app: `/tmp/cmux-debug.log`
|
||||
- Tagged Debug app (`./scripts/reload.sh --tag <tag>`): `/tmp/cmux-debug-<tag>.log`
|
||||
- `reload.sh` writes the current path to `/tmp/cmux-last-debug-log-path`
|
||||
- `reload.sh` writes the selected dev CLI path to `/tmp/cmux-last-cli-path`
|
||||
- `reload.sh` updates `/tmp/cmux-cli` and `$HOME/.local/bin/cmux-dev` to that CLI
|
||||
|
||||
- Implementation: `vendor/bonsplit/Sources/Bonsplit/Public/DebugEventLog.swift`
|
||||
- Free function `dlog("message")` — logs with timestamp and appends to file in real time
|
||||
|
|
|
|||
|
|
@ -1758,7 +1758,7 @@ struct CMUXCLI {
|
|||
let shellFeaturesValue = scopedGhosttyShellFeaturesValue()
|
||||
let sshStartupCommand = buildSSHStartupCommand(sshCommand: sshCommand, shellFeatures: shellFeaturesValue)
|
||||
|
||||
var workspaceCreateParams: [String: Any] = [
|
||||
let workspaceCreateParams: [String: Any] = [
|
||||
"initial_command": sshStartupCommand,
|
||||
]
|
||||
|
||||
|
|
@ -1775,6 +1775,8 @@ struct CMUXCLI {
|
|||
])
|
||||
}
|
||||
|
||||
let remoteSSHOptions = sshOptionsWithControlSocketDefaults(sshOptions.sshOptions)
|
||||
|
||||
var configureParams: [String: Any] = [
|
||||
"workspace_id": workspaceId,
|
||||
"destination": sshOptions.destination,
|
||||
|
|
@ -1787,8 +1789,8 @@ struct CMUXCLI {
|
|||
!identityFile.isEmpty {
|
||||
configureParams["identity_file"] = identityFile
|
||||
}
|
||||
if !sshOptions.sshOptions.isEmpty {
|
||||
configureParams["ssh_options"] = sshOptions.sshOptions
|
||||
if !remoteSSHOptions.isEmpty {
|
||||
configureParams["ssh_options"] = remoteSSHOptions
|
||||
}
|
||||
|
||||
var payload = try client.sendV2(method: "workspace.remote.configure", params: configureParams)
|
||||
|
|
@ -1888,15 +1890,10 @@ struct CMUXCLI {
|
|||
}
|
||||
|
||||
private func buildSSHCommandText(_ options: SSHCommandOptions) -> String {
|
||||
var parts: [String] = ["ssh", "-o", "StrictHostKeyChecking=accept-new"]
|
||||
if !hasSSHOptionKey(options.sshOptions, key: "ControlMaster") {
|
||||
parts += ["-o", "ControlMaster=auto"]
|
||||
}
|
||||
if !hasSSHOptionKey(options.sshOptions, key: "ControlPersist") {
|
||||
parts += ["-o", "ControlPersist=600"]
|
||||
}
|
||||
if !hasSSHOptionKey(options.sshOptions, key: "ControlPath") {
|
||||
parts += ["-o", "ControlPath=\(defaultSSHControlPathTemplate())"]
|
||||
let effectiveSSHOptions = sshOptionsWithControlSocketDefaults(options.sshOptions)
|
||||
var parts: [String] = ["ssh"]
|
||||
if !hasSSHOptionKey(effectiveSSHOptions, key: "StrictHostKeyChecking") {
|
||||
parts += ["-o", "StrictHostKeyChecking=accept-new"]
|
||||
}
|
||||
if let port = options.port {
|
||||
parts += ["-p", String(port)]
|
||||
|
|
@ -1905,16 +1902,33 @@ struct CMUXCLI {
|
|||
!identityFile.isEmpty {
|
||||
parts += ["-i", identityFile]
|
||||
}
|
||||
for option in options.sshOptions {
|
||||
let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmed.isEmpty else { continue }
|
||||
parts += ["-o", trimmed]
|
||||
for option in effectiveSSHOptions {
|
||||
parts += ["-o", option]
|
||||
}
|
||||
parts.append(options.destination)
|
||||
parts.append(contentsOf: options.extraArguments)
|
||||
return parts.map(shellQuote).joined(separator: " ")
|
||||
}
|
||||
|
||||
private func sshOptionsWithControlSocketDefaults(_ options: [String]) -> [String] {
|
||||
var merged: [String] = []
|
||||
for option in options {
|
||||
let trimmed = option.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmed.isEmpty else { continue }
|
||||
merged.append(trimmed)
|
||||
}
|
||||
if !hasSSHOptionKey(merged, key: "ControlMaster") {
|
||||
merged.append("ControlMaster=auto")
|
||||
}
|
||||
if !hasSSHOptionKey(merged, key: "ControlPersist") {
|
||||
merged.append("ControlPersist=600")
|
||||
}
|
||||
if !hasSSHOptionKey(merged, key: "ControlPath") {
|
||||
merged.append("ControlPath=\(defaultSSHControlPathTemplate())")
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
private func scopedGhosttyShellFeaturesValue() -> String {
|
||||
let rawExisting = ProcessInfo.processInfo.environment["GHOSTTY_SHELL_FEATURES"] ?? ""
|
||||
var seen: Set<String> = []
|
||||
|
|
@ -3317,7 +3331,7 @@ fi
|
|||
Usage: cmux ssh <destination> [flags] [-- <remote-command-args>]
|
||||
|
||||
Create a new workspace, mark it as remote-SSH, and start an SSH session in that workspace.
|
||||
cmux will also attempt background remote port detection + local forwarding for browser access.
|
||||
cmux will also establish a local SSH proxy endpoint so browser traffic can egress from the remote host.
|
||||
|
||||
Flags:
|
||||
--name <title> Optional workspace title
|
||||
|
|
|
|||
|
|
@ -1437,7 +1437,8 @@ final class TerminalSurface: Identifiable, ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
if !env.isEmpty {
|
||||
let allowSurfaceEnvOverrides = false
|
||||
if allowSurfaceEnvOverrides, !env.isEmpty {
|
||||
envVars.reserveCapacity(env.count)
|
||||
envStorage.reserveCapacity(env.count)
|
||||
for (key, value) in env {
|
||||
|
|
@ -1592,6 +1593,7 @@ final class TerminalSurface: Identifiable, ObservableObject {
|
|||
}
|
||||
#endif
|
||||
guard let view = attachedView,
|
||||
let surface,
|
||||
view.window != nil,
|
||||
view.bounds.width > 0,
|
||||
view.bounds.height > 0 else {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,12 @@ import Combine
|
|||
import WebKit
|
||||
import AppKit
|
||||
import Bonsplit
|
||||
import Network
|
||||
|
||||
struct BrowserProxyEndpoint: Equatable {
|
||||
let host: String
|
||||
let port: Int
|
||||
}
|
||||
|
||||
enum BrowserSearchEngine: String, CaseIterable, Identifiable {
|
||||
case google
|
||||
|
|
@ -1070,6 +1076,7 @@ final class BrowserPanel: Panel, ObservableObject {
|
|||
private var developerToolsRestoreRetryAttempt: Int = 0
|
||||
private let developerToolsRestoreRetryDelay: TimeInterval = 0.05
|
||||
private let developerToolsRestoreRetryMaxAttempts: Int = 40
|
||||
private var remoteProxyEndpoint: BrowserProxyEndpoint?
|
||||
|
||||
var displayTitle: String {
|
||||
if !pageTitle.isEmpty {
|
||||
|
|
@ -1089,17 +1096,27 @@ final class BrowserPanel: Panel, ObservableObject {
|
|||
false
|
||||
}
|
||||
|
||||
init(workspaceId: UUID, initialURL: URL? = nil, bypassInsecureHTTPHostOnce: String? = nil) {
|
||||
init(
|
||||
workspaceId: UUID,
|
||||
initialURL: URL? = nil,
|
||||
bypassInsecureHTTPHostOnce: String? = nil,
|
||||
proxyEndpoint: BrowserProxyEndpoint? = nil
|
||||
) {
|
||||
self.id = UUID()
|
||||
self.workspaceId = workspaceId
|
||||
self.insecureHTTPBypassHostOnce = BrowserInsecureHTTPSettings.normalizeHost(bypassInsecureHTTPHostOnce ?? "")
|
||||
self.remoteProxyEndpoint = proxyEndpoint
|
||||
|
||||
// Configure web view
|
||||
let config = WKWebViewConfiguration()
|
||||
config.processPool = BrowserPanel.sharedProcessPool
|
||||
// Ensure browser cookies/storage persist across navigations and launches.
|
||||
// This reduces repeated consent/bot-challenge flows on sites like Google.
|
||||
config.websiteDataStore = .default()
|
||||
// Keep data-store scoping at workspace granularity so remote proxy settings
|
||||
// do not leak into local workspaces.
|
||||
if #available(macOS 14.0, *) {
|
||||
config.websiteDataStore = WKWebsiteDataStore(forIdentifier: workspaceId)
|
||||
} else {
|
||||
config.websiteDataStore = .default()
|
||||
}
|
||||
|
||||
// Enable developer extras (DevTools)
|
||||
config.preferences.setValue(true, forKey: "developerExtrasEnabled")
|
||||
|
|
@ -1124,6 +1141,7 @@ final class BrowserPanel: Panel, ObservableObject {
|
|||
webView.customUserAgent = BrowserUserAgentSettings.safariUserAgent
|
||||
|
||||
self.webView = webView
|
||||
applyRemoteProxyConfigurationIfAvailable()
|
||||
|
||||
// Set up navigation delegate
|
||||
let navDelegate = BrowserNavigationDelegate()
|
||||
|
|
@ -1180,6 +1198,33 @@ final class BrowserPanel: Panel, ObservableObject {
|
|||
workspaceId = newWorkspaceId
|
||||
}
|
||||
|
||||
func setRemoteProxyEndpoint(_ endpoint: BrowserProxyEndpoint?) {
|
||||
guard remoteProxyEndpoint != endpoint else { return }
|
||||
remoteProxyEndpoint = endpoint
|
||||
applyRemoteProxyConfigurationIfAvailable()
|
||||
}
|
||||
|
||||
private func applyRemoteProxyConfigurationIfAvailable() {
|
||||
guard #available(macOS 14.0, *) else { return }
|
||||
|
||||
let store = webView.configuration.websiteDataStore
|
||||
guard let endpoint = remoteProxyEndpoint,
|
||||
endpoint.port > 0 && endpoint.port <= 65535,
|
||||
let nwPort = NWEndpoint.Port(rawValue: UInt16(endpoint.port)) else {
|
||||
store.proxyConfigurations = []
|
||||
return
|
||||
}
|
||||
|
||||
let nwEndpoint = NWEndpoint.hostPort(
|
||||
host: NWEndpoint.Host(endpoint.host),
|
||||
port: nwPort
|
||||
)
|
||||
// Prefer SOCKSv5; keep CONNECT configured as fallback.
|
||||
let socks = ProxyConfiguration(socksv5Proxy: nwEndpoint)
|
||||
let connect = ProxyConfiguration(httpCONNECTProxy: nwEndpoint)
|
||||
store.proxyConfigurations = [socks, connect]
|
||||
}
|
||||
|
||||
func triggerFlash() {
|
||||
focusFlashToken &+= 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1497,6 +1497,40 @@ class TerminalController {
|
|||
return nil
|
||||
}
|
||||
|
||||
private func v2HasNonNullParam(_ params: [String: Any], _ key: String) -> Bool {
|
||||
guard let raw = params[key] else { return false }
|
||||
return !(raw is NSNull)
|
||||
}
|
||||
|
||||
private func v2StrictInt(_ params: [String: Any], _ key: String) -> Int? {
|
||||
v2StrictIntAny(params[key])
|
||||
}
|
||||
|
||||
private func v2StrictIntAny(_ raw: Any?) -> Int? {
|
||||
guard let raw else { return nil }
|
||||
|
||||
if let numberValue = raw as? NSNumber {
|
||||
if CFGetTypeID(numberValue) == CFBooleanGetTypeID() {
|
||||
return nil
|
||||
}
|
||||
let doubleValue = numberValue.doubleValue
|
||||
guard doubleValue.isFinite, floor(doubleValue) == doubleValue else {
|
||||
return nil
|
||||
}
|
||||
return Int(exactly: doubleValue)
|
||||
}
|
||||
|
||||
if let intValue = raw as? Int {
|
||||
return intValue
|
||||
}
|
||||
|
||||
if let stringValue = raw as? String {
|
||||
return Int(stringValue.trimmingCharacters(in: .whitespacesAndNewlines))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
private func v2PanelType(_ params: [String: Any], _ key: String) -> PanelType? {
|
||||
guard let s = v2String(params, key) else { return nil }
|
||||
return PanelType(rawValue: s.lowercased())
|
||||
|
|
@ -1976,13 +2010,26 @@ class TerminalController {
|
|||
}
|
||||
|
||||
var sshPort: Int?
|
||||
if let parsedPort = v2Int(params, "port") {
|
||||
guard parsedPort > 0 && parsedPort <= 65535 else {
|
||||
if v2HasNonNullParam(params, "port") {
|
||||
guard let parsedPort = v2StrictInt(params, "port"),
|
||||
parsedPort > 0,
|
||||
parsedPort <= 65535 else {
|
||||
return .err(code: "invalid_params", message: "port must be 1-65535", data: nil)
|
||||
}
|
||||
sshPort = parsedPort
|
||||
}
|
||||
|
||||
// Internal deterministic test hook: pin the local proxy listener port to force bind conflicts.
|
||||
var localProxyPort: Int?
|
||||
if v2HasNonNullParam(params, "local_proxy_port") {
|
||||
guard let parsedLocalProxyPort = v2StrictInt(params, "local_proxy_port"),
|
||||
parsedLocalProxyPort > 0,
|
||||
parsedLocalProxyPort <= 65535 else {
|
||||
return .err(code: "invalid_params", message: "local_proxy_port must be 1-65535", data: nil)
|
||||
}
|
||||
localProxyPort = parsedLocalProxyPort
|
||||
}
|
||||
|
||||
let identityFile = v2RawString(params, "identity_file")?.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let sshOptions = v2StringArray(params, "ssh_options") ?? []
|
||||
let autoConnect = v2Bool(params, "auto_connect") ?? true
|
||||
|
|
@ -2002,7 +2049,8 @@ class TerminalController {
|
|||
destination: destination,
|
||||
port: sshPort,
|
||||
identityFile: identityFile?.isEmpty == true ? nil : identityFile,
|
||||
sshOptions: sshOptions
|
||||
sshOptions: sshOptions,
|
||||
localProxyPort: localProxyPort
|
||||
)
|
||||
workspace.configureRemoteConnection(config, autoConnect: autoConnect)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -9,8 +9,25 @@ Current commands:
|
|||
Current RPC methods (newline-delimited JSON):
|
||||
1. `hello`
|
||||
2. `ping`
|
||||
3. `proxy.open`
|
||||
4. `proxy.close`
|
||||
5. `proxy.write`
|
||||
6. `proxy.read`
|
||||
7. `session.open`
|
||||
8. `session.close`
|
||||
9. `session.attach`
|
||||
10. `session.resize`
|
||||
11. `session.detach`
|
||||
12. `session.status`
|
||||
|
||||
Current integration in cmux:
|
||||
1. `workspace.remote.configure` now bootstraps this binary over SSH when missing.
|
||||
2. Client sends `hello` before enabling remote port probing/forwarding.
|
||||
3. Daemon status/capabilities are exposed in `workspace.remote.status -> remote.daemon`.
|
||||
2. Client sends `hello` before enabling remote proxy transport.
|
||||
3. Local workspace proxy broker serves SOCKS5 + HTTP CONNECT and tunnels stream traffic through `proxy.*` RPC over `serve --stdio`.
|
||||
4. Daemon status/capabilities are exposed in `workspace.remote.status -> remote.daemon` (including `session.resize.min`).
|
||||
|
||||
`workspace.remote.configure` contract notes:
|
||||
1. `port` / `local_proxy_port` accept integer values and numeric strings; explicit `null` clears each field.
|
||||
2. Out-of-range values and invalid types return `invalid_params`.
|
||||
3. `local_proxy_port` is an internal deterministic test hook used by bind-conflict regressions.
|
||||
4. SSH option precedence checks are case-insensitive; user overrides for `StrictHostKeyChecking` and control-socket keys prevent default injection.
|
||||
|
|
|
|||
|
|
@ -2,11 +2,17 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var version = "dev"
|
||||
|
|
@ -29,6 +35,28 @@ type rpcResponse struct {
|
|||
Error *rpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type rpcServer struct {
|
||||
mu sync.Mutex
|
||||
nextStreamID uint64
|
||||
nextSessionID uint64
|
||||
streams map[string]net.Conn
|
||||
sessions map[string]*sessionState
|
||||
}
|
||||
|
||||
type sessionAttachment struct {
|
||||
Cols int
|
||||
Rows int
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type sessionState struct {
|
||||
attachments map[string]sessionAttachment
|
||||
effectiveCols int
|
||||
effectiveRows int
|
||||
lastKnownCols int
|
||||
lastKnownRows int
|
||||
}
|
||||
|
||||
func main() {
|
||||
os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr))
|
||||
}
|
||||
|
|
@ -72,7 +100,16 @@ func usage(w io.Writer) {
|
|||
}
|
||||
|
||||
func runStdioServer(stdin io.Reader, stdout io.Writer) error {
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
nextSessionID: 1,
|
||||
streams: map[string]net.Conn{},
|
||||
sessions: map[string]*sessionState{},
|
||||
}
|
||||
defer server.closeAll()
|
||||
|
||||
scanner := bufio.NewScanner(stdin)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
writer := bufio.NewWriter(stdout)
|
||||
defer writer.Flush()
|
||||
|
||||
|
|
@ -96,7 +133,7 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error {
|
|||
continue
|
||||
}
|
||||
|
||||
resp := handleRequest(req)
|
||||
resp := server.handleRequest(req)
|
||||
if err := writeResponse(writer, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -122,7 +159,7 @@ func writeResponse(w *bufio.Writer, resp rpcResponse) error {
|
|||
return w.Flush()
|
||||
}
|
||||
|
||||
func handleRequest(req rpcRequest) rpcResponse {
|
||||
func (s *rpcServer) handleRequest(req rpcRequest) rpcResponse {
|
||||
if req.Method == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
|
|
@ -144,8 +181,10 @@ func handleRequest(req rpcRequest) rpcResponse {
|
|||
"version": version,
|
||||
"capabilities": []string{
|
||||
"session.basic",
|
||||
"session.resize.min",
|
||||
"proxy.http_connect",
|
||||
"proxy.socks5",
|
||||
"proxy.stream",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -157,6 +196,26 @@ func handleRequest(req rpcRequest) rpcResponse {
|
|||
"pong": true,
|
||||
},
|
||||
}
|
||||
case "proxy.open":
|
||||
return s.handleProxyOpen(req)
|
||||
case "proxy.close":
|
||||
return s.handleProxyClose(req)
|
||||
case "proxy.write":
|
||||
return s.handleProxyWrite(req)
|
||||
case "proxy.read":
|
||||
return s.handleProxyRead(req)
|
||||
case "session.open":
|
||||
return s.handleSessionOpen(req)
|
||||
case "session.close":
|
||||
return s.handleSessionClose(req)
|
||||
case "session.attach":
|
||||
return s.handleSessionAttach(req)
|
||||
case "session.resize":
|
||||
return s.handleSessionResize(req)
|
||||
case "session.detach":
|
||||
return s.handleSessionDetach(req)
|
||||
case "session.status":
|
||||
return s.handleSessionStatus(req)
|
||||
default:
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
|
|
@ -168,3 +227,704 @@ func handleRequest(req rpcRequest) rpcResponse {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleProxyOpen(req rpcRequest) rpcResponse {
|
||||
host, ok := getStringParam(req.Params, "host")
|
||||
if !ok || host == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.open requires host",
|
||||
},
|
||||
}
|
||||
}
|
||||
port, ok := getIntParam(req.Params, "port")
|
||||
if !ok || port <= 0 || port > 65535 {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.open requires port in range 1-65535",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
timeoutMs := 10000
|
||||
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 {
|
||||
timeoutMs = parsed
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout(
|
||||
"tcp",
|
||||
net.JoinHostPort(host, strconv.Itoa(port)),
|
||||
time.Duration(timeoutMs)*time.Millisecond,
|
||||
)
|
||||
if err != nil {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "open_failed",
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
streamID := fmt.Sprintf("s-%d", s.nextStreamID)
|
||||
s.nextStreamID++
|
||||
s.streams[streamID] = conn
|
||||
s.mu.Unlock()
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"stream_id": streamID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleProxyClose(req rpcRequest) rpcResponse {
|
||||
streamID, ok := getStringParam(req.Params, "stream_id")
|
||||
if !ok || streamID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.close requires stream_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
conn, exists := s.streams[streamID]
|
||||
if exists {
|
||||
delete(s.streams, streamID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "stream not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"closed": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleProxyWrite(req rpcRequest) rpcResponse {
|
||||
streamID, ok := getStringParam(req.Params, "stream_id")
|
||||
if !ok || streamID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.write requires stream_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
dataBase64, ok := getStringParam(req.Params, "data_base64")
|
||||
if !ok {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.write requires data_base64",
|
||||
},
|
||||
}
|
||||
}
|
||||
payload, err := base64.StdEncoding.DecodeString(dataBase64)
|
||||
if err != nil {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "data_base64 must be valid base64",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
conn, found := s.getStream(streamID)
|
||||
if !found {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "stream not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
total := 0
|
||||
for total < len(payload) {
|
||||
written, writeErr := conn.Write(payload[total:])
|
||||
total += written
|
||||
if writeErr != nil {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "stream_error",
|
||||
Message: writeErr.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"written": total,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleProxyRead(req rpcRequest) rpcResponse {
|
||||
streamID, ok := getStringParam(req.Params, "stream_id")
|
||||
if !ok || streamID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "proxy.read requires stream_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
maxBytes := 32768
|
||||
if parsed, hasMax := getIntParam(req.Params, "max_bytes"); hasMax {
|
||||
maxBytes = parsed
|
||||
}
|
||||
if maxBytes <= 0 || maxBytes > 262144 {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "max_bytes must be in range 1-262144",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
timeoutMs := 50
|
||||
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 {
|
||||
timeoutMs = parsed
|
||||
}
|
||||
|
||||
conn, found := s.getStream(streamID)
|
||||
if !found {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "stream not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond))
|
||||
buffer := make([]byte, maxBytes)
|
||||
n, readErr := conn.Read(buffer)
|
||||
data := buffer[:max(0, n)]
|
||||
|
||||
if readErr != nil {
|
||||
if netErr, ok := readErr.(net.Error); ok && netErr.Timeout() {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"data_base64": "",
|
||||
"eof": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
if readErr == io.EOF {
|
||||
s.dropStream(streamID)
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"data_base64": base64.StdEncoding.EncodeToString(data),
|
||||
"eof": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "stream_error",
|
||||
Message: readErr.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"data_base64": base64.StdEncoding.EncodeToString(data),
|
||||
"eof": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionOpen(req rpcRequest) rpcResponse {
|
||||
sessionID, _ := getStringParam(req.Params, "session_id")
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if sessionID == "" {
|
||||
sessionID = fmt.Sprintf("sess-%d", s.nextSessionID)
|
||||
s.nextSessionID++
|
||||
}
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
session = &sessionState{
|
||||
attachments: map[string]sessionAttachment{},
|
||||
}
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: sessionSnapshot(sessionID, session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionClose(req rpcRequest) rpcResponse {
|
||||
sessionID, ok := getStringParam(req.Params, "session_id")
|
||||
if !ok || sessionID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "session.close requires session_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
_, exists := s.sessions[sessionID]
|
||||
if exists {
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "session not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: map[string]any{
|
||||
"session_id": sessionID,
|
||||
"closed": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionAttach(req rpcRequest) rpcResponse {
|
||||
sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.attach")
|
||||
if badResp != nil {
|
||||
return *badResp
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "session not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
session.attachments[attachmentID] = sessionAttachment{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
recomputeSessionSize(session)
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: sessionSnapshot(sessionID, session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionResize(req rpcRequest) rpcResponse {
|
||||
sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.resize")
|
||||
if badResp != nil {
|
||||
return *badResp
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "session not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
if _, exists := session.attachments[attachmentID]; !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "attachment not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
session.attachments[attachmentID] = sessionAttachment{
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
recomputeSessionSize(session)
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: sessionSnapshot(sessionID, session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionDetach(req rpcRequest) rpcResponse {
|
||||
sessionID, ok := getStringParam(req.Params, "session_id")
|
||||
if !ok || sessionID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "session.detach requires session_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
attachmentID, ok := getStringParam(req.Params, "attachment_id")
|
||||
if !ok || attachmentID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "session.detach requires attachment_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "session not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
if _, exists := session.attachments[attachmentID]; !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "attachment not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
delete(session.attachments, attachmentID)
|
||||
recomputeSessionSize(session)
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: sessionSnapshot(sessionID, session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) handleSessionStatus(req rpcRequest) rpcResponse {
|
||||
sessionID, ok := getStringParam(req.Params, "session_id")
|
||||
if !ok || sessionID == "" {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: "session.status requires session_id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "not_found",
|
||||
Message: "session not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: true,
|
||||
Result: sessionSnapshot(sessionID, session),
|
||||
}
|
||||
}
|
||||
|
||||
func parseSessionAttachmentParams(req rpcRequest, method string) (sessionID string, attachmentID string, cols int, rows int, badResp *rpcResponse) {
|
||||
sessionID, ok := getStringParam(req.Params, "session_id")
|
||||
if !ok || sessionID == "" {
|
||||
resp := rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: method + " requires session_id",
|
||||
},
|
||||
}
|
||||
return "", "", 0, 0, &resp
|
||||
}
|
||||
attachmentID, ok = getStringParam(req.Params, "attachment_id")
|
||||
if !ok || attachmentID == "" {
|
||||
resp := rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: method + " requires attachment_id",
|
||||
},
|
||||
}
|
||||
return "", "", 0, 0, &resp
|
||||
}
|
||||
|
||||
cols, ok = getIntParam(req.Params, "cols")
|
||||
if !ok || cols <= 0 {
|
||||
resp := rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: method + " requires cols > 0",
|
||||
},
|
||||
}
|
||||
return "", "", 0, 0, &resp
|
||||
}
|
||||
rows, ok = getIntParam(req.Params, "rows")
|
||||
if !ok || rows <= 0 {
|
||||
resp := rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_params",
|
||||
Message: method + " requires rows > 0",
|
||||
},
|
||||
}
|
||||
return "", "", 0, 0, &resp
|
||||
}
|
||||
|
||||
return sessionID, attachmentID, cols, rows, nil
|
||||
}
|
||||
|
||||
func recomputeSessionSize(session *sessionState) {
|
||||
if len(session.attachments) == 0 {
|
||||
session.effectiveCols = session.lastKnownCols
|
||||
session.effectiveRows = session.lastKnownRows
|
||||
return
|
||||
}
|
||||
|
||||
minCols := 0
|
||||
minRows := 0
|
||||
for _, attachment := range session.attachments {
|
||||
if minCols == 0 || attachment.Cols < minCols {
|
||||
minCols = attachment.Cols
|
||||
}
|
||||
if minRows == 0 || attachment.Rows < minRows {
|
||||
minRows = attachment.Rows
|
||||
}
|
||||
}
|
||||
|
||||
session.effectiveCols = minCols
|
||||
session.effectiveRows = minRows
|
||||
session.lastKnownCols = minCols
|
||||
session.lastKnownRows = minRows
|
||||
}
|
||||
|
||||
func sessionSnapshot(sessionID string, session *sessionState) map[string]any {
|
||||
attachmentIDs := make([]string, 0, len(session.attachments))
|
||||
for attachmentID := range session.attachments {
|
||||
attachmentIDs = append(attachmentIDs, attachmentID)
|
||||
}
|
||||
sort.Strings(attachmentIDs)
|
||||
|
||||
attachments := make([]map[string]any, 0, len(attachmentIDs))
|
||||
for _, attachmentID := range attachmentIDs {
|
||||
attachment := session.attachments[attachmentID]
|
||||
attachments = append(attachments, map[string]any{
|
||||
"attachment_id": attachmentID,
|
||||
"cols": attachment.Cols,
|
||||
"rows": attachment.Rows,
|
||||
"updated_at": attachment.UpdatedAt.Format(time.RFC3339Nano),
|
||||
})
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"session_id": sessionID,
|
||||
"attachments": attachments,
|
||||
"effective_cols": session.effectiveCols,
|
||||
"effective_rows": session.effectiveRows,
|
||||
"last_known_cols": session.lastKnownCols,
|
||||
"last_known_rows": session.lastKnownRows,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) getStream(streamID string) (net.Conn, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
conn, ok := s.streams[streamID]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
func (s *rpcServer) dropStream(streamID string) {
|
||||
s.mu.Lock()
|
||||
conn, ok := s.streams[streamID]
|
||||
if ok {
|
||||
delete(s.streams, streamID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
if ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) closeAll() {
|
||||
s.mu.Lock()
|
||||
streams := make([]net.Conn, 0, len(s.streams))
|
||||
for id, conn := range s.streams {
|
||||
delete(s.streams, id)
|
||||
streams = append(streams, conn)
|
||||
}
|
||||
for id := range s.sessions {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
for _, conn := range streams {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func getStringParam(params map[string]any, key string) (string, bool) {
|
||||
if params == nil {
|
||||
return "", false
|
||||
}
|
||||
raw, ok := params[key]
|
||||
if !ok || raw == nil {
|
||||
return "", false
|
||||
}
|
||||
value, ok := raw.(string)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func getIntParam(params map[string]any, key string) (int, bool) {
|
||||
if params == nil {
|
||||
return 0, false
|
||||
}
|
||||
raw, ok := params[key]
|
||||
if !ok || raw == nil {
|
||||
return 0, false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case int:
|
||||
return value, true
|
||||
case int8:
|
||||
return int(value), true
|
||||
case int16:
|
||||
return int(value), true
|
||||
case int32:
|
||||
return int(value), true
|
||||
case int64:
|
||||
return int(value), true
|
||||
case uint:
|
||||
return int(value), true
|
||||
case uint8:
|
||||
return int(value), true
|
||||
case uint16:
|
||||
return int(value), true
|
||||
case uint32:
|
||||
return int(value), true
|
||||
case uint64:
|
||||
return int(value), true
|
||||
case float64:
|
||||
return int(value), true
|
||||
case json.Number:
|
||||
n, err := value.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunVersion(t *testing.T) {
|
||||
|
|
@ -99,3 +103,371 @@ func TestRunStdioInvalidJSONAndUnknownMethod(t *testing.T) {
|
|||
t.Fatalf("unknown method should return method_not_found; got=%v payload=%v", got, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStdioSessionResizeFlow(t *testing.T) {
|
||||
input := strings.NewReader(
|
||||
`{"id":1,"method":"session.open","params":{"session_id":"sess-stdio"}}` + "\n" +
|
||||
`{"id":2,"method":"session.attach","params":{"session_id":"sess-stdio","attachment_id":"a1","cols":120,"rows":40}}` + "\n" +
|
||||
`{"id":3,"method":"session.attach","params":{"session_id":"sess-stdio","attachment_id":"a2","cols":90,"rows":30}}` + "\n" +
|
||||
`{"id":4,"method":"session.status","params":{"session_id":"sess-stdio"}}` + "\n",
|
||||
)
|
||||
var out bytes.Buffer
|
||||
code := run([]string{"serve", "--stdio"}, input, &out, &bytes.Buffer{})
|
||||
if code != 0 {
|
||||
t.Fatalf("run serve exit code = %d, want 0", code)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
|
||||
if len(lines) != 4 {
|
||||
t.Fatalf("got %d response lines, want 4: %q", len(lines), out.String())
|
||||
}
|
||||
|
||||
var status map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[3]), &status); err != nil {
|
||||
t.Fatalf("failed to decode status response: %v", err)
|
||||
}
|
||||
if ok, _ := status["ok"].(bool); !ok {
|
||||
t.Fatalf("session.status should be ok=true: %v", status)
|
||||
}
|
||||
result, _ := status["result"].(map[string]any)
|
||||
if result == nil {
|
||||
t.Fatalf("session.status missing result object: %v", status)
|
||||
}
|
||||
effectiveCols, _ := result["effective_cols"].(float64)
|
||||
effectiveRows, _ := result["effective_rows"].(float64)
|
||||
if int(effectiveCols) != 90 || int(effectiveRows) != 30 {
|
||||
t.Fatalf("session smallest-wins effective size mismatch: got=%vx%v payload=%v", effectiveCols, effectiveRows, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyStreamRoundTrip(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buffer := make([]byte, 8)
|
||||
n, readErr := conn.Read(buffer)
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
if string(buffer[:n]) != "ping" {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("pong"))
|
||||
}()
|
||||
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
nextSessionID: 1,
|
||||
streams: map[string]net.Conn{},
|
||||
sessions: map[string]*sessionState{},
|
||||
}
|
||||
defer server.closeAll()
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
openResp := server.handleRequest(rpcRequest{
|
||||
ID: 1,
|
||||
Method: "proxy.open",
|
||||
Params: map[string]any{
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"timeout_ms": 1000,
|
||||
},
|
||||
})
|
||||
if !openResp.OK {
|
||||
t.Fatalf("proxy.open failed: %+v", openResp)
|
||||
}
|
||||
openResult, _ := openResp.Result.(map[string]any)
|
||||
streamID, _ := openResult["stream_id"].(string)
|
||||
if streamID == "" {
|
||||
t.Fatalf("proxy.open missing stream_id: %+v", openResp)
|
||||
}
|
||||
|
||||
writeResp := server.handleRequest(rpcRequest{
|
||||
ID: 2,
|
||||
Method: "proxy.write",
|
||||
Params: map[string]any{
|
||||
"stream_id": streamID,
|
||||
"data_base64": base64.StdEncoding.EncodeToString([]byte("ping")),
|
||||
},
|
||||
})
|
||||
if !writeResp.OK {
|
||||
t.Fatalf("proxy.write failed: %+v", writeResp)
|
||||
}
|
||||
|
||||
readResp := server.handleRequest(rpcRequest{
|
||||
ID: 3,
|
||||
Method: "proxy.read",
|
||||
Params: map[string]any{
|
||||
"stream_id": streamID,
|
||||
"max_bytes": 8,
|
||||
"timeout_ms": 1000,
|
||||
},
|
||||
})
|
||||
if !readResp.OK {
|
||||
t.Fatalf("proxy.read failed: %+v", readResp)
|
||||
}
|
||||
readResult, _ := readResp.Result.(map[string]any)
|
||||
dataBase64, _ := readResult["data_base64"].(string)
|
||||
data, decodeErr := base64.StdEncoding.DecodeString(dataBase64)
|
||||
if decodeErr != nil {
|
||||
t.Fatalf("proxy.read returned invalid base64: %v", decodeErr)
|
||||
}
|
||||
if string(data) != "pong" {
|
||||
t.Fatalf("proxy.read payload=%q, want %q", string(data), "pong")
|
||||
}
|
||||
|
||||
closeResp := server.handleRequest(rpcRequest{
|
||||
ID: 4,
|
||||
Method: "proxy.close",
|
||||
Params: map[string]any{
|
||||
"stream_id": streamID,
|
||||
},
|
||||
})
|
||||
if !closeResp.OK {
|
||||
t.Fatalf("proxy.close failed: %+v", closeResp)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("proxy test server goroutine did not finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyOpenInvalidParams(t *testing.T) {
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
nextSessionID: 1,
|
||||
streams: map[string]net.Conn{},
|
||||
sessions: map[string]*sessionState{},
|
||||
}
|
||||
defer server.closeAll()
|
||||
|
||||
resp := server.handleRequest(rpcRequest{
|
||||
ID: 1,
|
||||
Method: "proxy.open",
|
||||
Params: map[string]any{
|
||||
"host": "127.0.0.1",
|
||||
"port": strconv.Itoa(8080),
|
||||
},
|
||||
})
|
||||
if resp.OK {
|
||||
t.Fatalf("proxy.open with invalid port type should fail: %+v", resp)
|
||||
}
|
||||
errObj, _ := resp.Error, resp.Error
|
||||
if errObj == nil || errObj.Code != "invalid_params" {
|
||||
t.Fatalf("proxy.open invalid params should return invalid_params: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionResizeCoordinator(t *testing.T) {
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
nextSessionID: 1,
|
||||
streams: map[string]net.Conn{},
|
||||
sessions: map[string]*sessionState{},
|
||||
}
|
||||
defer server.closeAll()
|
||||
|
||||
openResp := server.handleRequest(rpcRequest{
|
||||
ID: 1,
|
||||
Method: "session.open",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
},
|
||||
})
|
||||
if !openResp.OK {
|
||||
t.Fatalf("session.open failed: %+v", openResp)
|
||||
}
|
||||
|
||||
attachSmall := server.handleRequest(rpcRequest{
|
||||
ID: 2,
|
||||
Method: "session.attach",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-small",
|
||||
"cols": 90,
|
||||
"rows": 30,
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, attachSmall, 90, 30)
|
||||
|
||||
attachLarge := server.handleRequest(rpcRequest{
|
||||
ID: 3,
|
||||
Method: "session.attach",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-large",
|
||||
"cols": 120,
|
||||
"rows": 40,
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, attachLarge, 90, 30) // RZ-001: smallest wins
|
||||
|
||||
resizeLarge := server.handleRequest(rpcRequest{
|
||||
ID: 4,
|
||||
Method: "session.resize",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-large",
|
||||
"cols": 200,
|
||||
"rows": 60,
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, resizeLarge, 90, 30) // RZ-002: still bounded by smallest
|
||||
|
||||
detachSmall := server.handleRequest(rpcRequest{
|
||||
ID: 5,
|
||||
Method: "session.detach",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-small",
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, detachSmall, 200, 60) // RZ-003: expands to next smallest
|
||||
|
||||
detachLarge := server.handleRequest(rpcRequest{
|
||||
ID: 6,
|
||||
Method: "session.detach",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-large",
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, detachLarge, 200, 60) // no attachments: keep last-known size
|
||||
assertAttachmentCount(t, detachLarge, 0)
|
||||
|
||||
reattach := server.handleRequest(rpcRequest{
|
||||
ID: 7,
|
||||
Method: "session.attach",
|
||||
Params: map[string]any{
|
||||
"session_id": "sess-rz",
|
||||
"attachment_id": "a-reconnect",
|
||||
"cols": 110,
|
||||
"rows": 50,
|
||||
},
|
||||
})
|
||||
assertEffectiveSize(t, reattach, 110, 50) // RZ-004: recompute from active attachments on reattach
|
||||
}
|
||||
|
||||
func TestSessionInvalidParamsAndNotFound(t *testing.T) {
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
nextSessionID: 1,
|
||||
streams: map[string]net.Conn{},
|
||||
sessions: map[string]*sessionState{},
|
||||
}
|
||||
defer server.closeAll()
|
||||
|
||||
missingSession := server.handleRequest(rpcRequest{
|
||||
ID: 1,
|
||||
Method: "session.attach",
|
||||
Params: map[string]any{
|
||||
"session_id": "missing",
|
||||
"attachment_id": "a1",
|
||||
"cols": 80,
|
||||
"rows": 24,
|
||||
},
|
||||
})
|
||||
if missingSession.OK || missingSession.Error == nil || missingSession.Error.Code != "not_found" {
|
||||
t.Fatalf("session.attach on missing session should return not_found: %+v", missingSession)
|
||||
}
|
||||
|
||||
badSize := server.handleRequest(rpcRequest{
|
||||
ID: 2,
|
||||
Method: "session.attach",
|
||||
Params: map[string]any{
|
||||
"session_id": "missing",
|
||||
"attachment_id": "a1",
|
||||
"cols": 0,
|
||||
"rows": 24,
|
||||
},
|
||||
})
|
||||
if badSize.OK || badSize.Error == nil || badSize.Error.Code != "invalid_params" {
|
||||
t.Fatalf("session.attach with cols=0 should return invalid_params: %+v", badSize)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEffectiveSize(t *testing.T, resp rpcResponse, wantCols, wantRows int) {
|
||||
t.Helper()
|
||||
if !resp.OK {
|
||||
t.Fatalf("expected ok response, got error: %+v", resp)
|
||||
}
|
||||
result, ok := resp.Result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing result map: %+v", resp)
|
||||
}
|
||||
gotCols := asInt(t, result["effective_cols"], "effective_cols")
|
||||
gotRows := asInt(t, result["effective_rows"], "effective_rows")
|
||||
if gotCols != wantCols || gotRows != wantRows {
|
||||
t.Fatalf("effective size = %dx%d, want %dx%d payload=%+v", gotCols, gotRows, wantCols, wantRows, result)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAttachmentCount(t *testing.T, resp rpcResponse, want int) {
|
||||
t.Helper()
|
||||
if !resp.OK {
|
||||
t.Fatalf("expected ok response, got error: %+v", resp)
|
||||
}
|
||||
result, ok := resp.Result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing result map: %+v", resp)
|
||||
}
|
||||
attachments, ok := result["attachments"].([]map[string]any)
|
||||
if ok {
|
||||
if len(attachments) != want {
|
||||
t.Fatalf("attachments len = %d, want %d payload=%+v", len(attachments), want, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
attachmentsAny, ok := result["attachments"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("attachments field has unexpected type (%T) payload=%+v", result["attachments"], result)
|
||||
}
|
||||
if len(attachmentsAny) != want {
|
||||
t.Fatalf("attachments len = %d, want %d payload=%+v", len(attachmentsAny), want, result)
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(t *testing.T, value any, field string) int {
|
||||
t.Helper()
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int8:
|
||||
return int(typed)
|
||||
case int16:
|
||||
return int(typed)
|
||||
case int32:
|
||||
return int(typed)
|
||||
case int64:
|
||||
return int(typed)
|
||||
case uint:
|
||||
return int(typed)
|
||||
case uint8:
|
||||
return int(typed)
|
||||
case uint16:
|
||||
return int(typed)
|
||||
case uint32:
|
||||
return int(typed)
|
||||
case uint64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
default:
|
||||
t.Fatalf("%s has unexpected type %T (%v)", field, value, value)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,35 +31,43 @@ This is a **living implementation spec** (also called an **execution spec**): a
|
|||
### 3.2 Bootstrap + Daemon
|
||||
- `DONE` local app probes remote platform, builds/uploads `cmuxd-remote`, and runs `serve --stdio`.
|
||||
- `DONE` daemon `hello` handshake is enforced.
|
||||
- `DONE` bootstrap/probe failures surface actionable details.
|
||||
- `DONE` daemon now exposes proxy stream RPC (`proxy.open`, `proxy.close`, `proxy.write`, `proxy.read`).
|
||||
- `DONE` local proxy broker now tunnels SOCKS5/CONNECT traffic over daemon stream RPC instead of `ssh -D`.
|
||||
- `DONE` daemon now exposes session resize-coordinator RPC (`session.open`, `session.attach`, `session.resize`, `session.detach`, `session.status`, `session.close`).
|
||||
- `DONE` transport-level proxy failures now escalate from broker retry to full daemon re-bootstrap/reconnect in the session controller.
|
||||
- `DONE` SOCKS handshake parsing now preserves pipelined post-connect payload bytes instead of dropping request-prefix bytes.
|
||||
- `DONE` `workspace.remote.configure.local_proxy_port` exists as an internal deterministic test hook for bind-conflict regression coverage.
|
||||
- `DONE` bootstrap/proxy failures surface actionable details.
|
||||
|
||||
### 3.3 Error Surfacing
|
||||
- `DONE` remote errors are surfaced in sidebar status + logs + notifications.
|
||||
- `DONE` reconnect retry count/time is included in surfaced error text (for example, `retry 1 in 4s`).
|
||||
|
||||
### 3.4 Existing Temporary Behavior (To Remove)
|
||||
- `TEMPORARY` current implementation probes remote listening ports and mirrors them locally with SSH `-L`.
|
||||
- `TEMPORARY` sidebar shows local bind conflicts (`SSH port conflicts ...`) caused by that mirroring path.
|
||||
- `TARGET` browser path must no longer depend on per-port mirroring.
|
||||
### 3.4 Removed Temporary Behavior
|
||||
- `DONE` removed remote listening-port probe loop and per-port SSH `-L` mirroring.
|
||||
- `DONE` remote browser routing now uses a single shared local proxy endpoint instead of detected-port mirroring.
|
||||
- `DONE` remote status now includes structured proxy metadata (`remote.proxy`) and `proxy_unavailable` error code when proxy setup fails.
|
||||
|
||||
## 4. Target Architecture (No Port Mirroring)
|
||||
|
||||
### 4.1 Browser Networking Path
|
||||
1. One local proxy endpoint per SSH transport (not per workspace, not per detected port).
|
||||
2. Proxy endpoint supports SOCKS5 and HTTP CONNECT.
|
||||
3. Browser panels in remote workspaces are auto-wired to this proxy endpoint.
|
||||
4. Browser panels in local workspaces are not force-proxied.
|
||||
1. `DONE` one local proxy endpoint is created per SSH transport/session key (not per detected port).
|
||||
2. `DONE` endpoint is provided by a local broker that supports SOCKS5 + HTTP CONNECT and tunnels via daemon stream RPC.
|
||||
3. `DONE` browser panels in remote workspaces are auto-wired to the workspace proxy endpoint.
|
||||
4. `DONE` browser panels in local workspaces are not force-proxied.
|
||||
5. `DONE` identical SSH transports share one endpoint via a transport-scoped broker.
|
||||
|
||||
### 4.2 WKWebView Wiring
|
||||
1. Use workspace/browser scoped `WKWebsiteDataStore.proxyConfigurations`.
|
||||
2. Prefer SOCKS5 proxy config.
|
||||
3. Keep HTTP CONNECT proxy config as fallback.
|
||||
4. Re-apply/validate proxy config after reconnect.
|
||||
1. `DONE` use workspace-scoped `WKWebsiteDataStore(forIdentifier:)`.
|
||||
2. `DONE` apply workspace/browser scoped `proxyConfigurations`.
|
||||
3. `DONE` prefer SOCKS5 proxy config.
|
||||
4. `DONE` keep HTTP CONNECT proxy config as fallback.
|
||||
5. `DONE` re-apply proxy config on reconnect/state updates.
|
||||
|
||||
### 4.3 Remote Daemon + Transport
|
||||
1. Extend `cmuxd-remote` beyond `hello/ping` with proxy stream RPC (`proxy.open`, `proxy.close`).
|
||||
2. Local side runs a transport-scoped proxy broker and multiplexes proxy streams over SSH stdio transport.
|
||||
3. Remove remote service-port discovery/probing from browser routing path.
|
||||
1. `DONE` `cmuxd-remote` now supports proxy stream RPC (`proxy.open`, `proxy.close`, `proxy.write`, `proxy.read`).
|
||||
2. `DONE` local side now runs a shared local broker that serves SOCKS5/CONNECT and tunnels each stream over persistent daemon stdio RPC.
|
||||
3. `DONE` removed remote service-port discovery/probing from browser routing path.
|
||||
|
||||
### 4.4 Explicit Non-Goal
|
||||
1. Automatic mirroring of every remote listening port to local loopback is not a goal for browser support.
|
||||
|
|
@ -96,15 +104,15 @@ Recompute effective size on:
|
|||
| ID | Milestone | Status | Notes |
|
||||
|---|---|---|---|
|
||||
| M-001 | `cmux ssh` workspace creation + metadata + optional `--name` | DONE | Covered by `tests_v2/test_ssh_remote_cli_metadata.py` |
|
||||
| M-002 | Remote bootstrap/upload/start + hello handshake | DONE | Current `cmuxd-remote` is minimal (`hello`, `ping`) |
|
||||
| M-002 | Remote bootstrap/upload/start + hello handshake | DONE | Includes daemon capability handshake + status surfacing |
|
||||
| M-003 | Reconnect/disconnect UX + API + improved error surfacing | DONE | Includes retry count in surfaced errors |
|
||||
| M-004 | Docker e2e for bootstrap/reconnect shell niceties | DONE | Existing docker tests currently validate mirroring-era path |
|
||||
| M-005 | Remove automatic remote port mirroring path | TODO | Delete probe/listen mirror loop from `WorkspaceRemoteSessionController` |
|
||||
| M-006 | Transport-scoped local proxy broker (SOCKS5 + CONNECT) | TODO | Local component in app/daemon layer |
|
||||
| M-007 | Remote proxy stream RPC in `cmuxd-remote` | TODO | Add `proxy.open/close` and multiplexed stream handling |
|
||||
| M-008 | WebView proxy auto-wiring for remote workspaces | TODO | Use `WKWebsiteDataStore.proxyConfigurations` |
|
||||
| M-009 | PTY resize coordinator (`smallest screen wins`) | TODO | Session-level attachment-size aggregation |
|
||||
| M-010 | Resize + proxy reconnect e2e test suites | TODO | Add dedicated docker cases for browser proxy + resize |
|
||||
| M-004 | Docker e2e for bootstrap/reconnect shell niceties | DONE | Docker suites validate proxy-path bootstrap and reconnect behavior |
|
||||
| M-005 | Remove automatic remote port mirroring path | DONE | `WorkspaceRemoteSessionController` now uses one shared daemon-backed proxy endpoint |
|
||||
| M-006 | Transport-scoped local proxy broker (SOCKS5 + CONNECT) | DONE | Identical SSH transports now reuse one local proxy endpoint |
|
||||
| M-007 | Remote proxy stream RPC in `cmuxd-remote` | DONE | `proxy.open/close/write/read` implemented |
|
||||
| M-008 | WebView proxy auto-wiring for remote workspaces | DONE | Workspace-scoped `WKWebsiteDataStore.proxyConfigurations` wiring is active |
|
||||
| M-009 | PTY resize coordinator (`smallest screen wins`) | DONE | Daemon session RPC now tracks attachments and applies min cols/rows semantics with unit tests |
|
||||
| M-010 | Resize + proxy reconnect e2e test suites | DONE | `tests_v2/test_ssh_remote_docker_forwarding.py` validates HTTP/websocket egress plus SOCKS pipelined-payload handling; `tests_v2/test_ssh_remote_docker_reconnect.py` verifies reconnect recovery and repeats SOCKS pipelined-payload checks after host restart; `tests_v2/test_ssh_remote_proxy_bind_conflict.py` validates structured `proxy_unavailable` bind-conflict surfacing and `local_proxy_port` status retention under bind conflict; `tests_v2/test_ssh_remote_daemon_resize_stdio.py` validates session resize semantics over real stdio RPC process boundaries; `tests_v2/test_ssh_remote_cli_metadata.py` validates `workspace.remote.configure` numeric-string compatibility, explicit `null` clear semantics (including `workspace.remote.status` reflection), strict `port`/`local_proxy_port` validation (bounds/type), case-insensitive SSH option override precedence for StrictHostKeyChecking/control-socket keys, and `local_proxy_port` payload echo for deterministic bind-conflict test hook behavior |
|
||||
|
||||
## 7. Acceptance Test Matrix (With Status)
|
||||
|
||||
|
|
@ -113,7 +121,7 @@ Recompute effective size on:
|
|||
| ID | Scenario | Status |
|
||||
|---|---|---|
|
||||
| T-001 | baseline remote connect | DONE |
|
||||
| T-002 | identical host reuse semantics | PARTIAL |
|
||||
| T-002 | identical host reuse semantics | DONE |
|
||||
| T-003 | no `--name` | DONE |
|
||||
| T-004 | reconnect API success/error paths | DONE |
|
||||
| T-005 | retry count visible in daemon error detail | DONE |
|
||||
|
|
@ -122,31 +130,51 @@ Recompute effective size on:
|
|||
|
||||
| ID | Scenario | Status |
|
||||
|---|---|---|
|
||||
| W-001 | remote workspace browser auto-proxied | TODO |
|
||||
| W-002 | browser egress IP equals remote host IP | TODO |
|
||||
| W-003 | websocket via SOCKS5/CONNECT through remote daemon | TODO |
|
||||
| W-004 | reconnect restores browser proxy path automatically | TODO |
|
||||
| W-005 | local proxy bind conflict yields structured `proxy_unavailable` | TODO |
|
||||
| W-001 | remote workspace browser auto-proxied | DONE |
|
||||
| W-002 | browser egress equals remote network path | DONE |
|
||||
| W-003 | websocket via SOCKS5/CONNECT through remote daemon | DONE |
|
||||
| W-004 | reconnect restores browser proxy path automatically | DONE |
|
||||
| W-005 | local proxy bind conflict yields structured `proxy_unavailable` | DONE |
|
||||
| W-006 | proxy transport failure triggers daemon re-bootstrap and recovers after host recreation | DONE |
|
||||
| W-007 | SOCKS greeting/connect + immediate pipelined payload in same write remains intact | DONE |
|
||||
|
||||
### 7.3 Resize
|
||||
|
||||
| ID | Scenario | Status |
|
||||
|---|---|---|
|
||||
| RZ-001 | two attachments, smallest wins | TODO |
|
||||
| RZ-002 | grow one attachment, PTY stays bounded by smallest | TODO |
|
||||
| RZ-003 | detach smallest, PTY expands to next smallest | TODO |
|
||||
| RZ-004 | reconnect preserves session + applies recomputed size | TODO |
|
||||
| RZ-001 | two attachments, smallest wins | DONE |
|
||||
| RZ-002 | grow one attachment, PTY stays bounded by smallest | DONE |
|
||||
| RZ-003 | detach smallest, PTY expands to next smallest | DONE |
|
||||
| RZ-004 | reconnect preserves session + applies recomputed size | DONE |
|
||||
| RZ-005 | daemon stdio RPC round-trip enforces resize semantics end-to-end | DONE |
|
||||
|
||||
## 8. Removal Checklist (Port Mirroring)
|
||||
|
||||
Before declaring browser proxying complete:
|
||||
1. remove remote port probe loop and `-L` auto-forward orchestration
|
||||
2. remove mirror-specific sidebar conflict messaging as default remote behavior
|
||||
3. replace mirroring tests with browser-proxy e2e tests
|
||||
4. keep optional explicit user-driven forwarding as separate feature only if needed
|
||||
1. `DONE` remove remote port probe loop and `-L` auto-forward orchestration
|
||||
2. `DONE` remove mirror-specific routing behavior as default remote behavior
|
||||
3. `DONE` replace mirroring docker assertions with proxy egress assertions
|
||||
4. `DONE` keep optional explicit user-driven forwarding out of this path; no automatic mirroring remains in browser routing
|
||||
|
||||
## 9. Open Decisions
|
||||
|
||||
1. Proxy auth policy for local broker (`none` vs optional credentials).
|
||||
2. Reconnect backoff profile and max retry budget.
|
||||
3. Browser data-store isolation policy for remote vs local workspaces.
|
||||
|
||||
## 10. Socket API Contract Notes
|
||||
|
||||
### 10.1 `workspace.remote.configure` Port Fields
|
||||
1. `port` and `local_proxy_port` accept integer values and numeric strings.
|
||||
2. Explicit `null` clears each field.
|
||||
3. Out-of-range values and invalid types (for example booleans/non-numeric strings/fractional numbers) return `invalid_params`.
|
||||
4. `local_proxy_port` is an internal deterministic test hook to force local bind conflicts in regression coverage.
|
||||
|
||||
### 10.2 SSH Option Precedence
|
||||
1. `StrictHostKeyChecking` default (`accept-new`) is only injected when no user override is present.
|
||||
2. Control-socket defaults (`ControlMaster`, `ControlPersist`, `ControlPath`) are only injected when missing.
|
||||
3. SSH option key matching is case-insensitive for precedence checks in both CLI-built commands and remote configure payloads.
|
||||
|
||||
### 10.3 SSH Docker E2E Harness Knobs
|
||||
1. `CMUX_SSH_TEST_DOCKER_HOST` sets the SSH destination host/IP used by docker-backed SSH fixtures (default `127.0.0.1`).
|
||||
2. `CMUX_SSH_TEST_DOCKER_BIND_ADDR` sets the bind address used in fixture container publish mappings (default `127.0.0.1`).
|
||||
3. Defaults preserve loopback behavior on a single host; override both when docker runs on a different host (for example VM -> host OrbStack).
|
||||
|
|
|
|||
|
|
@ -10,6 +10,84 @@ BUNDLE_SET=0
|
|||
DERIVED_SET=0
|
||||
TAG=""
|
||||
CMUX_DEBUG_LOG=""
|
||||
CLI_PATH=""
|
||||
|
||||
write_dev_cli_shim() {
|
||||
local target="$1"
|
||||
local fallback_bin="$2"
|
||||
mkdir -p "$(dirname "$target")"
|
||||
cat > "$target" <<EOF
|
||||
#!/usr/bin/env bash
|
||||
# cmux dev shim (managed by scripts/reload.sh)
|
||||
set -euo pipefail
|
||||
|
||||
CLI_PATH_FILE="/tmp/cmux-last-cli-path"
|
||||
if [[ -r "\$CLI_PATH_FILE" ]]; then
|
||||
CLI_PATH="\$(cat "\$CLI_PATH_FILE")"
|
||||
if [[ -x "\$CLI_PATH" ]]; then
|
||||
exec "\$CLI_PATH" "\$@"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -x "$fallback_bin" ]]; then
|
||||
exec "$fallback_bin" "\$@"
|
||||
fi
|
||||
|
||||
echo "error: no reload-selected dev cmux CLI found. Run ./scripts/reload.sh --tag <name> first." >&2
|
||||
exit 1
|
||||
EOF
|
||||
chmod +x "$target" || true
|
||||
}
|
||||
|
||||
select_cmux_shim_target() {
|
||||
local app_cli_dir="/Applications/cmux.app/Contents/Resources/bin"
|
||||
local marker="cmux dev shim (managed by scripts/reload.sh)"
|
||||
local target=""
|
||||
local path_entry=""
|
||||
local candidate=""
|
||||
|
||||
IFS=':' read -r -a path_entries <<< "${PATH:-}"
|
||||
for path_entry in "${path_entries[@]}"; do
|
||||
[[ -z "$path_entry" ]] && continue
|
||||
if [[ "$path_entry" == "~/"* ]]; then
|
||||
path_entry="$HOME/${path_entry#~/}"
|
||||
fi
|
||||
if [[ "$path_entry" == "$app_cli_dir" ]]; then
|
||||
break
|
||||
fi
|
||||
[[ -d "$path_entry" && -w "$path_entry" ]] || continue
|
||||
candidate="$path_entry/cmux"
|
||||
if [[ ! -e "$candidate" ]]; then
|
||||
target="$candidate"
|
||||
break
|
||||
fi
|
||||
if [[ -f "$candidate" ]] && grep -q "$marker" "$candidate" 2>/dev/null; then
|
||||
target="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -n "$target" ]]; then
|
||||
echo "$target"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Fallback for PATH layouts where app CLI isn't listed or no earlier entries were writable.
|
||||
for path_entry in /opt/homebrew/bin /usr/local/bin "$HOME/.local/bin" "$HOME/bin"; do
|
||||
[[ -d "$path_entry" && -w "$path_entry" ]] || continue
|
||||
candidate="$path_entry/cmux"
|
||||
if [[ ! -e "$candidate" ]]; then
|
||||
echo "$candidate"
|
||||
return 0
|
||||
fi
|
||||
if [[ -f "$candidate" ]] && grep -q "$marker" "$candidate" 2>/dev/null; then
|
||||
echo "$candidate"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
usage() {
|
||||
cat <<'EOF'
|
||||
|
|
@ -271,6 +349,21 @@ if [[ -n "$TAG" && "$APP_NAME" != "$SEARCH_APP_NAME" ]]; then
|
|||
APP_PATH="$TAG_APP_PATH"
|
||||
fi
|
||||
|
||||
CLI_PATH="$(dirname "$APP_PATH")/cmux"
|
||||
if [[ -x "$CLI_PATH" ]]; then
|
||||
echo "$CLI_PATH" > /tmp/cmux-last-cli-path || true
|
||||
ln -sfn "$CLI_PATH" /tmp/cmux-cli || true
|
||||
|
||||
# Stable shim that always follows the last reload-selected dev CLI.
|
||||
DEV_CLI_SHIM="$HOME/.local/bin/cmux-dev"
|
||||
write_dev_cli_shim "$DEV_CLI_SHIM" "/Applications/cmux.app/Contents/Resources/bin/cmux"
|
||||
|
||||
CMUX_SHIM_TARGET="$(select_cmux_shim_target || true)"
|
||||
if [[ -n "${CMUX_SHIM_TARGET:-}" ]]; then
|
||||
write_dev_cli_shim "$CMUX_SHIM_TARGET" "/Applications/cmux.app/Contents/Resources/bin/cmux"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure any running instance is fully terminated, regardless of DerivedData path.
|
||||
/usr/bin/osascript -e "tell application id \"${BUNDLE_ID}\" to quit" >/dev/null 2>&1 || true
|
||||
sleep 0.3
|
||||
|
|
@ -350,3 +443,16 @@ fi
|
|||
if [[ -n "${TAG_SLUG:-}" ]]; then
|
||||
print_tag_cleanup_reminder "$TAG_SLUG"
|
||||
fi
|
||||
|
||||
if [[ -x "${CLI_PATH:-}" ]]; then
|
||||
echo
|
||||
echo "CLI path:"
|
||||
echo " $CLI_PATH"
|
||||
echo "CLI helpers:"
|
||||
echo " /tmp/cmux-cli ..."
|
||||
echo " $HOME/.local/bin/cmux-dev ..."
|
||||
if [[ -n "${CMUX_SHIM_TARGET:-}" ]]; then
|
||||
echo " $CMUX_SHIM_TARGET ..."
|
||||
fi
|
||||
echo "If your shell still resolves the old cmux, run: rehash"
|
||||
fi
|
||||
|
|
|
|||
1
tests/fixtures/ssh-remote/Dockerfile
vendored
1
tests/fixtures/ssh-remote/Dockerfile
vendored
|
|
@ -12,6 +12,7 @@ RUN ssh-keygen -A
|
|||
|
||||
COPY sshd_config /etc/ssh/sshd_config
|
||||
COPY run.sh /usr/local/bin/run.sh
|
||||
COPY ws_echo.py /usr/local/bin/ws_echo.py
|
||||
RUN chmod +x /usr/local/bin/run.sh
|
||||
|
||||
EXPOSE 22
|
||||
|
|
|
|||
2
tests/fixtures/ssh-remote/run.sh
vendored
2
tests/fixtures/ssh-remote/run.sh
vendored
|
|
@ -7,6 +7,7 @@ if [ -z "${AUTHORIZED_KEY:-}" ]; then
|
|||
fi
|
||||
|
||||
REMOTE_HTTP_PORT="${REMOTE_HTTP_PORT:-43173}"
|
||||
REMOTE_WS_PORT="${REMOTE_WS_PORT:-43174}"
|
||||
|
||||
mkdir -p /home/dev/.ssh /root/.ssh /run/sshd
|
||||
printf '%s\n' "$AUTHORIZED_KEY" > /home/dev/.ssh/authorized_keys
|
||||
|
|
@ -18,5 +19,6 @@ chmod 700 /root/.ssh
|
|||
chmod 600 /root/.ssh/authorized_keys
|
||||
|
||||
python3 -m http.server "$REMOTE_HTTP_PORT" --bind 127.0.0.1 --directory /srv/www >/tmp/http.log 2>&1 &
|
||||
python3 /usr/local/bin/ws_echo.py --host 127.0.0.1 --port "$REMOTE_WS_PORT" >/tmp/ws.log 2>&1 &
|
||||
|
||||
exec /usr/sbin/sshd -D -e
|
||||
|
|
|
|||
132
tests/fixtures/ssh-remote/ws_echo.py
vendored
Normal file
132
tests/fixtures/ssh-remote/ws_echo.py
vendored
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tiny WebSocket echo server for SSH proxy integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
|
||||
|
||||
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
def _recv_exact(conn: socket.socket, n: int) -> bytes:
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
chunk = conn.recv(n - len(data))
|
||||
if not chunk:
|
||||
raise ConnectionError("unexpected EOF")
|
||||
data.extend(chunk)
|
||||
return bytes(data)
|
||||
|
||||
|
||||
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes:
|
||||
data = bytearray()
|
||||
while marker not in data:
|
||||
chunk = conn.recv(1024)
|
||||
if not chunk:
|
||||
raise ConnectionError("unexpected EOF while reading headers")
|
||||
data.extend(chunk)
|
||||
if len(data) > limit:
|
||||
raise ValueError("header too large")
|
||||
return bytes(data)
|
||||
|
||||
|
||||
def _read_frame(conn: socket.socket) -> tuple[int, bytes]:
|
||||
first, second = _recv_exact(conn, 2)
|
||||
opcode = first & 0x0F
|
||||
masked = (second & 0x80) != 0
|
||||
length = second & 0x7F
|
||||
if length == 126:
|
||||
length = struct.unpack("!H", _recv_exact(conn, 2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack("!Q", _recv_exact(conn, 8))[0]
|
||||
|
||||
mask_key = _recv_exact(conn, 4) if masked else b""
|
||||
payload = _recv_exact(conn, length) if length else b""
|
||||
if masked and payload:
|
||||
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||
return opcode, payload
|
||||
|
||||
|
||||
def _send_frame(conn: socket.socket, opcode: int, payload: bytes) -> None:
|
||||
first = 0x80 | (opcode & 0x0F)
|
||||
length = len(payload)
|
||||
if length < 126:
|
||||
header = bytes([first, length])
|
||||
elif length <= 0xFFFF:
|
||||
header = bytes([first, 126]) + struct.pack("!H", length)
|
||||
else:
|
||||
header = bytes([first, 127]) + struct.pack("!Q", length)
|
||||
conn.sendall(header + payload)
|
||||
|
||||
|
||||
def handle_client(conn: socket.socket) -> None:
|
||||
try:
|
||||
request = _recv_until(conn, b"\r\n\r\n")
|
||||
headers_raw = request.decode("utf-8", errors="replace").split("\r\n")
|
||||
header_map: dict[str, str] = {}
|
||||
for line in headers_raw[1:]:
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
k, v = line.split(":", 1)
|
||||
header_map[k.strip().lower()] = v.strip()
|
||||
|
||||
key = header_map.get("sec-websocket-key", "")
|
||||
upgrade = header_map.get("upgrade", "").lower()
|
||||
connection_hdr = header_map.get("connection", "").lower()
|
||||
if not key or upgrade != "websocket" or "upgrade" not in connection_hdr:
|
||||
conn.sendall(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
|
||||
return
|
||||
|
||||
accept = base64.b64encode(hashlib.sha1((key + GUID).encode("utf-8")).digest()).decode("ascii")
|
||||
response = (
|
||||
"HTTP/1.1 101 Switching Protocols\r\n"
|
||||
"Upgrade: websocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
f"Sec-WebSocket-Accept: {accept}\r\n"
|
||||
"\r\n"
|
||||
)
|
||||
conn.sendall(response.encode("utf-8"))
|
||||
|
||||
while True:
|
||||
opcode, payload = _read_frame(conn)
|
||||
if opcode == 0x8: # close
|
||||
_send_frame(conn, 0x8, b"")
|
||||
return
|
||||
if opcode == 0x9: # ping
|
||||
_send_frame(conn, 0xA, payload)
|
||||
continue
|
||||
if opcode == 0x1: # text
|
||||
_send_frame(conn, 0x1, payload)
|
||||
continue
|
||||
# ignore all other opcodes
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="WebSocket echo server")
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=43174)
|
||||
args = parser.parse_args()
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind((args.host, args.port))
|
||||
server.listen(16)
|
||||
while True:
|
||||
conn, _ = server.accept()
|
||||
thread = threading.Thread(target=handle_client, args=(conn,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Regression: `cmux ssh` creates a remote-tagged workspace with remote metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
|
@ -72,6 +74,34 @@ def _extract_control_path(ssh_command: str) -> str:
|
|||
return match.group(1) if match else ""
|
||||
|
||||
|
||||
def _has_ssh_option_key(options: list[str], key: str) -> bool:
|
||||
lowered_key = key.lower()
|
||||
for option in options:
|
||||
token = re.split(r"[=\s]+", str(option).strip(), maxsplit=1)[0].strip().lower()
|
||||
if token == lowered_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _read_any_terminal_text(client: cmux, workspace_id: str, timeout: float = 8.0) -> str | None:
|
||||
deadline = time.time() + timeout
|
||||
last_exc: Exception | None = None
|
||||
while time.time() < deadline:
|
||||
surfaces = client.list_surfaces(workspace_id)
|
||||
for _, surface_id, _ in surfaces:
|
||||
try:
|
||||
return client.read_terminal_text(surface_id)
|
||||
except cmuxError as exc:
|
||||
text = str(exc).lower()
|
||||
if "terminal surface not found" in text:
|
||||
last_exc = exc
|
||||
continue
|
||||
raise
|
||||
time.sleep(0.1)
|
||||
print(f"WARN: readable terminal surface unavailable in workspace {workspace_id}; skipping transcript assertion ({last_exc})")
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
cli = _find_cli_binary()
|
||||
help_text = _run_cli(cli, ["ssh", "--help"], json_output=False)
|
||||
|
|
@ -80,6 +110,9 @@ def main() -> int:
|
|||
|
||||
workspace_id = ""
|
||||
workspace_id_without_name = ""
|
||||
workspace_id_strict_override = ""
|
||||
workspace_id_case_override = ""
|
||||
workspace_id_invalid_proxy_port = ""
|
||||
with cmux(SOCKET_PATH) as client:
|
||||
try:
|
||||
payload = _run_cli_json(
|
||||
|
|
@ -138,13 +171,29 @@ def main() -> int:
|
|||
str(remote.get("state") or "") in {"connecting", "connected", "error", "disconnected"},
|
||||
f"unexpected remote state: {remote}",
|
||||
)
|
||||
surfaces = client.list_surfaces(workspace_id)
|
||||
_must(bool(surfaces), f"workspace should have at least one surface: {workspace_id}")
|
||||
primary_surface = surfaces[0][1]
|
||||
proxy = remote.get("proxy") or {}
|
||||
_must(
|
||||
str(proxy.get("state") or "") in {"connecting", "ready", "error", "unavailable"},
|
||||
f"remote payload should include proxy state metadata: {remote}",
|
||||
)
|
||||
remote_ssh_options = [str(item) for item in (remote.get("ssh_options") or [])]
|
||||
_must(
|
||||
_has_ssh_option_key(remote_ssh_options, "ControlMaster"),
|
||||
f"workspace.remote.configure should include ControlMaster default: {remote}",
|
||||
)
|
||||
_must(
|
||||
_has_ssh_option_key(remote_ssh_options, "ControlPersist"),
|
||||
f"workspace.remote.configure should include ControlPersist default: {remote}",
|
||||
)
|
||||
_must(
|
||||
_has_ssh_option_key(remote_ssh_options, "ControlPath"),
|
||||
f"workspace.remote.configure should include ControlPath default: {remote}",
|
||||
)
|
||||
# Regression: cmux ssh should launch through initial_command, not visibly type a giant command into the shell.
|
||||
terminal_text = client.read_terminal_text(primary_surface)
|
||||
_must("ControlPersist=600" not in terminal_text, f"cmux ssh should not inject raw ssh command text: {terminal_text!r}")
|
||||
_must("GHOSTTY_SHELL_FEATURES=" not in terminal_text, f"cmux ssh should not inject env assignment text: {terminal_text!r}")
|
||||
terminal_text = _read_any_terminal_text(client, workspace_id)
|
||||
if terminal_text is not None:
|
||||
_must("ControlPersist=600" not in terminal_text, f"cmux ssh should not inject raw ssh command text: {terminal_text!r}")
|
||||
_must("GHOSTTY_SHELL_FEATURES=" not in terminal_text, f"cmux ssh should not inject env assignment text: {terminal_text!r}")
|
||||
|
||||
status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {}
|
||||
status_remote = status.get("remote") or {}
|
||||
|
|
@ -231,6 +280,147 @@ def main() -> int:
|
|||
f"workspace.remote.reconnect should transition into an active state: {reconnected}",
|
||||
)
|
||||
|
||||
payload_strict_override = _run_cli_json(
|
||||
cli,
|
||||
[
|
||||
"ssh",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"1",
|
||||
"--name",
|
||||
"ssh-meta-strict-override",
|
||||
"--ssh-option",
|
||||
"StrictHostKeyChecking=no",
|
||||
],
|
||||
)
|
||||
workspace_id_strict_override = str(payload_strict_override.get("workspace_id") or "")
|
||||
workspace_ref_strict_override = str(payload_strict_override.get("workspace_ref") or "")
|
||||
if not workspace_id_strict_override and workspace_ref_strict_override.startswith("workspace:"):
|
||||
listed_override = client._call("workspace.list", {}) or {}
|
||||
for row in listed_override.get("workspaces") or []:
|
||||
if str(row.get("ref") or "") == workspace_ref_strict_override:
|
||||
workspace_id_strict_override = str(row.get("id") or "")
|
||||
break
|
||||
_must(
|
||||
bool(workspace_id_strict_override),
|
||||
f"cmux ssh with StrictHostKeyChecking override should create workspace: {payload_strict_override}",
|
||||
)
|
||||
ssh_command_strict_override = str(payload_strict_override.get("ssh_command") or "")
|
||||
_must(
|
||||
"-o StrictHostKeyChecking=no" in ssh_command_strict_override,
|
||||
f"ssh command should include user StrictHostKeyChecking override: {ssh_command_strict_override!r}",
|
||||
)
|
||||
_must(
|
||||
"-o StrictHostKeyChecking=accept-new" not in ssh_command_strict_override,
|
||||
f"ssh command should not force default StrictHostKeyChecking when override is supplied: {ssh_command_strict_override!r}",
|
||||
)
|
||||
strict_override_remote = payload_strict_override.get("remote") or {}
|
||||
strict_override_options = [str(item) for item in (strict_override_remote.get("ssh_options") or [])]
|
||||
_must(
|
||||
any(item.lower() == "stricthostkeychecking=no" for item in strict_override_options),
|
||||
f"workspace.remote.configure should preserve explicit StrictHostKeyChecking override: {strict_override_remote}",
|
||||
)
|
||||
|
||||
payload_case_override = _run_cli_json(
|
||||
cli,
|
||||
[
|
||||
"ssh",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"1",
|
||||
"--name",
|
||||
"ssh-meta-case-override",
|
||||
"--ssh-option",
|
||||
"stricthostkeychecking=no",
|
||||
"--ssh-option",
|
||||
"controlmaster=no",
|
||||
"--ssh-option",
|
||||
"controlpersist=0",
|
||||
"--ssh-option",
|
||||
"controlpath=/tmp/cmux-ssh-%C-custom",
|
||||
],
|
||||
)
|
||||
workspace_id_case_override = str(payload_case_override.get("workspace_id") or "")
|
||||
workspace_ref_case_override = str(payload_case_override.get("workspace_ref") or "")
|
||||
if not workspace_id_case_override and workspace_ref_case_override.startswith("workspace:"):
|
||||
listed_case_override = client._call("workspace.list", {}) or {}
|
||||
for row in listed_case_override.get("workspaces") or []:
|
||||
if str(row.get("ref") or "") == workspace_ref_case_override:
|
||||
workspace_id_case_override = str(row.get("id") or "")
|
||||
break
|
||||
_must(
|
||||
bool(workspace_id_case_override),
|
||||
f"cmux ssh with lowercase SSH option overrides should create workspace: {payload_case_override}",
|
||||
)
|
||||
ssh_command_case_override = str(payload_case_override.get("ssh_command") or "")
|
||||
ssh_command_case_override_lower = ssh_command_case_override.lower()
|
||||
_must(
|
||||
"-o stricthostkeychecking=no" in ssh_command_case_override_lower,
|
||||
f"ssh command should preserve lowercase StrictHostKeyChecking override: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"stricthostkeychecking=accept-new" not in ssh_command_case_override_lower,
|
||||
f"ssh command should not force default StrictHostKeyChecking when lowercase override is supplied: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"-o controlmaster=no" in ssh_command_case_override_lower,
|
||||
f"ssh command should preserve lowercase ControlMaster override: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"controlmaster=auto" not in ssh_command_case_override_lower,
|
||||
f"ssh command should not force default ControlMaster when lowercase override is supplied: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"-o controlpersist=0" in ssh_command_case_override_lower,
|
||||
f"ssh command should preserve lowercase ControlPersist override: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"controlpersist=600" not in ssh_command_case_override_lower,
|
||||
f"ssh command should not force default ControlPersist when lowercase override is supplied: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
"controlpath=/tmp/cmux-ssh-%c-custom" in ssh_command_case_override_lower,
|
||||
f"ssh command should preserve lowercase ControlPath override value: {ssh_command_case_override!r}",
|
||||
)
|
||||
_must(
|
||||
ssh_command_case_override_lower.count("controlpath=") == 1,
|
||||
f"ssh command should include exactly one ControlPath when lowercase override is supplied: {ssh_command_case_override!r}",
|
||||
)
|
||||
case_override_remote = payload_case_override.get("remote") or {}
|
||||
case_override_options = [str(item) for item in (case_override_remote.get("ssh_options") or [])]
|
||||
_must(
|
||||
any(item.lower() == "stricthostkeychecking=no" for item in case_override_options),
|
||||
f"workspace.remote.configure should preserve lowercase StrictHostKeyChecking override: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
not any(item.lower() == "stricthostkeychecking=accept-new" for item in case_override_options),
|
||||
f"workspace.remote.configure should not inject default StrictHostKeyChecking when lowercase override is supplied: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
any(item.lower() == "controlmaster=no" for item in case_override_options),
|
||||
f"workspace.remote.configure should preserve lowercase ControlMaster override: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
not any(item.lower() == "controlmaster=auto" for item in case_override_options),
|
||||
f"workspace.remote.configure should not inject default ControlMaster when lowercase override is supplied: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
any(item.lower() == "controlpersist=0" for item in case_override_options),
|
||||
f"workspace.remote.configure should preserve lowercase ControlPersist override: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
not any(item.lower() == "controlpersist=600" for item in case_override_options),
|
||||
f"workspace.remote.configure should not inject default ControlPersist when lowercase override is supplied: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
any(item.lower() == "controlpath=/tmp/cmux-ssh-%c-custom" for item in case_override_options),
|
||||
f"workspace.remote.configure should preserve lowercase ControlPath override: {case_override_remote}",
|
||||
)
|
||||
_must(
|
||||
sum(1 for item in case_override_options if item.lower().startswith("controlpath=")) == 1,
|
||||
f"workspace.remote.configure should include exactly one ControlPath when lowercase override is supplied: {case_override_remote}",
|
||||
)
|
||||
|
||||
payload3 = _run_cli_json(
|
||||
cli,
|
||||
["ssh", "127.0.0.1", "--port", "1", "--name", "ssh-meta-features"],
|
||||
|
|
@ -248,6 +438,142 @@ def main() -> int:
|
|||
client.close_workspace(workspace_id3)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
invalid_proxy_port_workspace = client._call("workspace.create", {"initial_command": "echo invalid-local-proxy-port"}) or {}
|
||||
workspace_id_invalid_proxy_port = str(invalid_proxy_port_workspace.get("workspace_id") or "")
|
||||
_must(bool(workspace_id_invalid_proxy_port), f"workspace.create missing workspace_id: {invalid_proxy_port_workspace}")
|
||||
|
||||
configured_with_string_ports = client._call(
|
||||
"workspace.remote.configure",
|
||||
{
|
||||
"workspace_id": workspace_id_invalid_proxy_port,
|
||||
"destination": "127.0.0.1",
|
||||
"port": "2222",
|
||||
"local_proxy_port": "31338",
|
||||
"auto_connect": False,
|
||||
},
|
||||
) or {}
|
||||
configured_with_string_ports_remote = configured_with_string_ports.get("remote") or {}
|
||||
_must(
|
||||
int(configured_with_string_ports_remote.get("port") or 0) == 2222,
|
||||
f"workspace.remote.configure should parse numeric string port values: {configured_with_string_ports}",
|
||||
)
|
||||
_must(
|
||||
int(configured_with_string_ports_remote.get("local_proxy_port") or 0) == 31338,
|
||||
f"workspace.remote.configure should parse numeric string local_proxy_port values: {configured_with_string_ports}",
|
||||
)
|
||||
|
||||
valid_local_proxy_port = 31337
|
||||
configured_with_local_proxy_port = client._call(
|
||||
"workspace.remote.configure",
|
||||
{
|
||||
"workspace_id": workspace_id_invalid_proxy_port,
|
||||
"destination": "127.0.0.1",
|
||||
"port": 2222,
|
||||
"local_proxy_port": valid_local_proxy_port,
|
||||
"auto_connect": False,
|
||||
},
|
||||
) or {}
|
||||
configured_remote = configured_with_local_proxy_port.get("remote") or {}
|
||||
_must(
|
||||
int(configured_remote.get("port") or 0) == 2222,
|
||||
f"workspace.remote.configure should echo explicit port in remote payload: {configured_with_local_proxy_port}",
|
||||
)
|
||||
_must(
|
||||
int(configured_remote.get("local_proxy_port") or 0) == valid_local_proxy_port,
|
||||
f"workspace.remote.configure should echo local_proxy_port in remote payload: {configured_with_local_proxy_port}",
|
||||
)
|
||||
|
||||
configured_with_null_ports = client._call(
|
||||
"workspace.remote.configure",
|
||||
{
|
||||
"workspace_id": workspace_id_invalid_proxy_port,
|
||||
"destination": "127.0.0.1",
|
||||
"port": None,
|
||||
"local_proxy_port": None,
|
||||
"auto_connect": False,
|
||||
},
|
||||
) or {}
|
||||
configured_with_null_ports_remote = configured_with_null_ports.get("remote") or {}
|
||||
_must(
|
||||
configured_with_null_ports_remote.get("port") is None,
|
||||
f"workspace.remote.configure should allow null to clear port: {configured_with_null_ports}",
|
||||
)
|
||||
_must(
|
||||
configured_with_null_ports_remote.get("local_proxy_port") is None,
|
||||
f"workspace.remote.configure should allow null to clear local_proxy_port: {configured_with_null_ports}",
|
||||
)
|
||||
status_after_null_ports = client._call(
|
||||
"workspace.remote.status",
|
||||
{"workspace_id": workspace_id_invalid_proxy_port},
|
||||
) or {}
|
||||
status_after_null_ports_remote = status_after_null_ports.get("remote") or {}
|
||||
_must(
|
||||
status_after_null_ports_remote.get("port") is None,
|
||||
f"workspace.remote.status should reflect cleared port: {status_after_null_ports}",
|
||||
)
|
||||
_must(
|
||||
status_after_null_ports_remote.get("local_proxy_port") is None,
|
||||
f"workspace.remote.status should reflect cleared local_proxy_port: {status_after_null_ports}",
|
||||
)
|
||||
|
||||
for invalid_local_proxy_port in [0, 65536, "abc", True, 22.5]:
|
||||
try:
|
||||
client._call(
|
||||
"workspace.remote.configure",
|
||||
{
|
||||
"workspace_id": workspace_id_invalid_proxy_port,
|
||||
"destination": "127.0.0.1",
|
||||
"local_proxy_port": invalid_local_proxy_port,
|
||||
"auto_connect": False,
|
||||
},
|
||||
)
|
||||
raise cmuxError(
|
||||
f"workspace.remote.configure should reject local_proxy_port={invalid_local_proxy_port!r}"
|
||||
)
|
||||
except cmuxError as exc:
|
||||
text = str(exc)
|
||||
lowered = text.lower()
|
||||
_must(
|
||||
"invalid_params" in lowered,
|
||||
f"workspace.remote.configure should return invalid_params for local_proxy_port={invalid_local_proxy_port!r}: {exc}",
|
||||
)
|
||||
_must(
|
||||
"local_proxy_port must be 1-65535" in text,
|
||||
f"workspace.remote.configure should include validation hint for local_proxy_port={invalid_local_proxy_port!r}: {exc}",
|
||||
)
|
||||
|
||||
for invalid_port in [0, 65536, "abc", True, 22.5]:
|
||||
try:
|
||||
client._call(
|
||||
"workspace.remote.configure",
|
||||
{
|
||||
"workspace_id": workspace_id_invalid_proxy_port,
|
||||
"destination": "127.0.0.1",
|
||||
"port": invalid_port,
|
||||
"auto_connect": False,
|
||||
},
|
||||
)
|
||||
raise cmuxError(
|
||||
f"workspace.remote.configure should reject port={invalid_port!r}"
|
||||
)
|
||||
except cmuxError as exc:
|
||||
text = str(exc)
|
||||
lowered = text.lower()
|
||||
_must(
|
||||
"invalid_params" in lowered,
|
||||
f"workspace.remote.configure should return invalid_params for port={invalid_port!r}: {exc}",
|
||||
)
|
||||
_must(
|
||||
"port must be 1-65535" in text,
|
||||
f"workspace.remote.configure should include validation hint for port={invalid_port!r}: {exc}",
|
||||
)
|
||||
|
||||
try:
|
||||
client.close_workspace(workspace_id_invalid_proxy_port)
|
||||
except Exception:
|
||||
pass
|
||||
workspace_id_invalid_proxy_port = ""
|
||||
finally:
|
||||
if workspace_id:
|
||||
try:
|
||||
|
|
@ -259,6 +585,21 @@ def main() -> int:
|
|||
client.close_workspace(workspace_id_without_name)
|
||||
except Exception:
|
||||
pass
|
||||
if workspace_id_strict_override:
|
||||
try:
|
||||
client.close_workspace(workspace_id_strict_override)
|
||||
except Exception:
|
||||
pass
|
||||
if workspace_id_case_override:
|
||||
try:
|
||||
client.close_workspace(workspace_id_case_override)
|
||||
except Exception:
|
||||
pass
|
||||
if workspace_id_invalid_proxy_port:
|
||||
try:
|
||||
client.close_workspace(workspace_id_invalid_proxy_port)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("PASS: cmux ssh marks workspace as remote, exposes remote metadata, and does not require --name")
|
||||
return 0
|
||||
|
|
|
|||
188
tests_v2/test_ssh_remote_daemon_resize_stdio.py
Normal file
188
tests_v2/test_ssh_remote_daemon_resize_stdio.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Process-level integration: cmuxd-remote stdio session resize coordinator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import select
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from cmux import cmuxError
|
||||
|
||||
|
||||
def _must(cond: bool, msg: str) -> None:
|
||||
if not cond:
|
||||
raise cmuxError(msg)
|
||||
|
||||
|
||||
def _daemon_module_dir() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "daemon" / "remote"
|
||||
|
||||
|
||||
def _rpc(
|
||||
proc: subprocess.Popen[str],
|
||||
req_id: int,
|
||||
method: str,
|
||||
params: dict,
|
||||
*,
|
||||
timeout_s: float = 5.0,
|
||||
) -> dict:
|
||||
if proc.stdin is None or proc.stdout is None:
|
||||
raise cmuxError("daemon subprocess stdio pipes are not available")
|
||||
|
||||
payload = {"id": req_id, "method": method, "params": params}
|
||||
proc.stdin.write(json.dumps(payload, separators=(",", ":")) + "\n")
|
||||
proc.stdin.flush()
|
||||
|
||||
deadline = time.time() + timeout_s
|
||||
while time.time() < deadline:
|
||||
wait_s = max(0.0, min(0.2, deadline - time.time()))
|
||||
ready, _, _ = select.select([proc.stdout], [], [], wait_s)
|
||||
if not ready:
|
||||
continue
|
||||
line = proc.stdout.readline()
|
||||
if line == "":
|
||||
stderr = ""
|
||||
if proc.stderr is not None:
|
||||
try:
|
||||
stderr = proc.stderr.read().strip()
|
||||
except Exception:
|
||||
stderr = ""
|
||||
raise cmuxError(f"cmuxd-remote exited while waiting for {method} response: {stderr}")
|
||||
try:
|
||||
resp = json.loads(line)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise cmuxError(f"Invalid JSON response for {method}: {line!r} ({exc})")
|
||||
_must(resp.get("id") == req_id, f"Response id mismatch for {method}: {resp}")
|
||||
return resp
|
||||
|
||||
raise cmuxError(f"Timed out waiting for cmuxd-remote response: {method}")
|
||||
|
||||
|
||||
def _as_int(value: object, field: str) -> int:
|
||||
if isinstance(value, bool):
|
||||
raise cmuxError(f"{field} should be numeric, got bool")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
return int(value)
|
||||
raise cmuxError(f"{field} has unexpected type {type(value).__name__}: {value!r}")
|
||||
|
||||
|
||||
def _assert_effective(resp: dict, want_cols: int, want_rows: int, label: str) -> None:
|
||||
_must(resp.get("ok") is True, f"{label} should return ok=true: {resp}")
|
||||
result = resp.get("result") or {}
|
||||
got_cols = _as_int(result.get("effective_cols"), "effective_cols")
|
||||
got_rows = _as_int(result.get("effective_rows"), "effective_rows")
|
||||
_must(
|
||||
got_cols == want_cols and got_rows == want_rows,
|
||||
f"{label} effective size mismatch: got {got_cols}x{got_rows}, want {want_cols}x{want_rows} ({resp})",
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if shutil.which("go") is None:
|
||||
print("SKIP: go is not available")
|
||||
return 0
|
||||
|
||||
daemon_dir = _daemon_module_dir()
|
||||
_must(daemon_dir.is_dir(), f"Missing daemon module directory: {daemon_dir}")
|
||||
|
||||
proc = subprocess.Popen(
|
||||
["go", "run", "./cmd/cmuxd-remote", "serve", "--stdio"],
|
||||
cwd=str(daemon_dir),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
try:
|
||||
hello = _rpc(proc, 1, "hello", {})
|
||||
_must(hello.get("ok") is True, f"hello should return ok=true: {hello}")
|
||||
capabilities = {str(item) for item in ((hello.get("result") or {}).get("capabilities") or [])}
|
||||
_must("session.basic" in capabilities, f"hello missing session.basic capability: {hello}")
|
||||
_must("session.resize.min" in capabilities, f"hello missing session.resize.min capability: {hello}")
|
||||
|
||||
open_resp = _rpc(proc, 2, "session.open", {"session_id": "sess-e2e"})
|
||||
_assert_effective(open_resp, 0, 0, "session.open")
|
||||
|
||||
attach_small = _rpc(
|
||||
proc,
|
||||
3,
|
||||
"session.attach",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-small", "cols": 90, "rows": 30},
|
||||
)
|
||||
_assert_effective(attach_small, 90, 30, "session.attach(a-small)")
|
||||
|
||||
attach_large = _rpc(
|
||||
proc,
|
||||
4,
|
||||
"session.attach",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-large", "cols": 140, "rows": 50},
|
||||
)
|
||||
_assert_effective(attach_large, 90, 30, "session.attach(a-large)")
|
||||
|
||||
resize_large = _rpc(
|
||||
proc,
|
||||
5,
|
||||
"session.resize",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-large", "cols": 200, "rows": 80},
|
||||
)
|
||||
_assert_effective(resize_large, 90, 30, "session.resize(a-large)")
|
||||
|
||||
detach_small = _rpc(
|
||||
proc,
|
||||
6,
|
||||
"session.detach",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-small"},
|
||||
)
|
||||
_assert_effective(detach_small, 200, 80, "session.detach(a-small)")
|
||||
|
||||
detach_large = _rpc(
|
||||
proc,
|
||||
7,
|
||||
"session.detach",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-large"},
|
||||
)
|
||||
_assert_effective(detach_large, 200, 80, "session.detach(a-large)")
|
||||
|
||||
reattach = _rpc(
|
||||
proc,
|
||||
8,
|
||||
"session.attach",
|
||||
{"session_id": "sess-e2e", "attachment_id": "a-reconnect", "cols": 110, "rows": 40},
|
||||
)
|
||||
_assert_effective(reattach, 110, 40, "session.attach(a-reconnect)")
|
||||
|
||||
status = _rpc(proc, 9, "session.status", {"session_id": "sess-e2e"})
|
||||
_assert_effective(status, 110, 40, "session.status")
|
||||
attachments = (status.get("result") or {}).get("attachments") or []
|
||||
_must(len(attachments) == 1, f"session.status should report one active attachment after reattach: {status}")
|
||||
|
||||
print("PASS: cmuxd-remote stdio session.resize coordinator enforces smallest-screen-wins semantics")
|
||||
return 0
|
||||
finally:
|
||||
try:
|
||||
if proc.stdin is not None:
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=2.0)
|
||||
except Exception:
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
|
@ -1,18 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Docker integration: remote SSH port discovery + local forwarding via `cmux ssh`."""
|
||||
"""Docker integration: remote SSH proxy endpoint via `cmux ssh`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import urllib.request
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
|
@ -21,7 +24,10 @@ from cmux import cmux, cmuxError
|
|||
|
||||
SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock")
|
||||
REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43173"))
|
||||
REMOTE_WS_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_WS_PORT", "43174"))
|
||||
MAX_REMOTE_DAEMON_SIZE_BYTES = int(os.environ.get("CMUX_SSH_TEST_MAX_DAEMON_SIZE_BYTES", "15000000"))
|
||||
DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1")
|
||||
DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1")
|
||||
|
||||
|
||||
def _must(cond: bool, msg: str) -> None:
|
||||
|
|
@ -84,15 +90,309 @@ def _parse_host_port(docker_port_output: str) -> int:
|
|||
return int(last)
|
||||
|
||||
|
||||
def _http_get(url: str, timeout: float = 2.0) -> str:
|
||||
with urllib.request.urlopen(url, timeout=timeout) as resp: # nosec B310 - loopback URL in test only
|
||||
return resp.read().decode("utf-8", errors="replace")
|
||||
def _curl_via_socks(proxy_port: int, target_url: str) -> str:
|
||||
if shutil.which("curl") is None:
|
||||
raise cmuxError("curl is required for SOCKS proxy verification")
|
||||
proc = _run(
|
||||
[
|
||||
"curl",
|
||||
"--silent",
|
||||
"--show-error",
|
||||
"--max-time",
|
||||
"5",
|
||||
"--socks5-hostname",
|
||||
f"127.0.0.1:{proxy_port}",
|
||||
target_url,
|
||||
],
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
merged = f"{proc.stdout}\n{proc.stderr}".strip()
|
||||
raise cmuxError(f"curl via SOCKS proxy failed: {merged}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
def _shell_single_quote(value: str) -> str:
|
||||
return "'" + value.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
|
||||
def _recv_exact(sock: socket.socket, n: int) -> bytes:
|
||||
out = bytearray()
|
||||
while len(out) < n:
|
||||
chunk = sock.recv(n - len(out))
|
||||
if not chunk:
|
||||
raise cmuxError("unexpected EOF while reading socket")
|
||||
out.extend(chunk)
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _recv_until(sock: socket.socket, marker: bytes, limit: int = 16384) -> bytes:
|
||||
out = bytearray()
|
||||
while marker not in out:
|
||||
chunk = sock.recv(1024)
|
||||
if not chunk:
|
||||
raise cmuxError("unexpected EOF while reading response headers")
|
||||
out.extend(chunk)
|
||||
if len(out) > limit:
|
||||
raise cmuxError("response headers too large")
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _read_socks5_connect_reply(sock: socket.socket) -> None:
|
||||
head = _recv_exact(sock, 4)
|
||||
if len(head) != 4 or head[0] != 0x05:
|
||||
raise cmuxError(f"invalid SOCKS5 reply: {head!r}")
|
||||
if head[1] != 0x00:
|
||||
raise cmuxError(f"SOCKS5 connect failed with status=0x{head[1]:02x}")
|
||||
|
||||
atyp = head[3]
|
||||
if atyp == 0x01:
|
||||
_ = _recv_exact(sock, 4)
|
||||
elif atyp == 0x03:
|
||||
ln = _recv_exact(sock, 1)[0]
|
||||
_ = _recv_exact(sock, ln)
|
||||
elif atyp == 0x04:
|
||||
_ = _recv_exact(sock, 16)
|
||||
else:
|
||||
raise cmuxError(f"invalid SOCKS5 atyp in reply: 0x{atyp:02x}")
|
||||
_ = _recv_exact(sock, 2) # bound port
|
||||
|
||||
|
||||
def _read_http_response_from_connected_socket(sock: socket.socket) -> str:
|
||||
response = _recv_until(sock, b"\r\n\r\n")
|
||||
header_end = response.index(b"\r\n\r\n") + 4
|
||||
header_blob = response[:header_end]
|
||||
body = bytearray(response[header_end:])
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "200" not in status_line:
|
||||
raise cmuxError(f"HTTP over SOCKS tunnel failed: {status_line!r}")
|
||||
|
||||
content_length: int | None = None
|
||||
for line in header_text.split("\r\n")[1:]:
|
||||
if line.lower().startswith("content-length:"):
|
||||
try:
|
||||
content_length = int(line.split(":", 1)[1].strip())
|
||||
except Exception: # noqa: BLE001
|
||||
content_length = None
|
||||
break
|
||||
|
||||
if content_length is not None:
|
||||
while len(body) < content_length:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
body.extend(chunk)
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
chunk = sock.recv(4096)
|
||||
except socket.timeout:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
body.extend(chunk)
|
||||
|
||||
return bytes(body).decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _http_get_on_connected_socket(sock: socket.socket, host: str, port: int, path: str = "/") -> str:
|
||||
request = (
|
||||
f"GET {path} HTTP/1.1\r\n"
|
||||
f"Host: {host}:{port}\r\n"
|
||||
"Connection: close\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
sock.sendall(request)
|
||||
return _read_http_response_from_connected_socket(sock)
|
||||
|
||||
|
||||
def _socks5_connect(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
|
||||
# greeting: no-auth only
|
||||
sock.sendall(b"\x05\x01\x00")
|
||||
greeting = _recv_exact(sock, 2)
|
||||
if greeting != b"\x05\x00":
|
||||
sock.close()
|
||||
raise cmuxError(f"SOCKS5 greeting failed: {greeting!r}")
|
||||
|
||||
try:
|
||||
host_bytes = socket.inet_aton(target_host)
|
||||
atyp = b"\x01" # IPv4
|
||||
addr = host_bytes
|
||||
except OSError:
|
||||
host_encoded = target_host.encode("utf-8")
|
||||
if len(host_encoded) > 255:
|
||||
sock.close()
|
||||
raise cmuxError("target host too long for SOCKS5 domain form")
|
||||
atyp = b"\x03" # domain
|
||||
addr = bytes([len(host_encoded)]) + host_encoded
|
||||
|
||||
req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port)
|
||||
sock.sendall(req)
|
||||
|
||||
try:
|
||||
_read_socks5_connect_reply(sock)
|
||||
except Exception:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
def _socks5_http_get_pipelined(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> str:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
try:
|
||||
try:
|
||||
host_bytes = socket.inet_aton(target_host)
|
||||
atyp = b"\x01"
|
||||
addr = host_bytes
|
||||
except OSError:
|
||||
host_encoded = target_host.encode("utf-8")
|
||||
if len(host_encoded) > 255:
|
||||
raise cmuxError("target host too long for SOCKS5 domain form")
|
||||
atyp = b"\x03"
|
||||
addr = bytes([len(host_encoded)]) + host_encoded
|
||||
|
||||
greeting = b"\x05\x01\x00"
|
||||
connect_req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port)
|
||||
http_get = (
|
||||
"GET / HTTP/1.1\r\n"
|
||||
f"Host: {target_host}:{target_port}\r\n"
|
||||
"Connection: close\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
|
||||
# Send greeting + CONNECT + first upstream payload in one write to exercise
|
||||
# SOCKS request parsing when pending bytes already exist in the handshake buffer.
|
||||
sock.sendall(greeting + connect_req + http_get)
|
||||
|
||||
greeting_reply = _recv_exact(sock, 2)
|
||||
if greeting_reply != b"\x05\x00":
|
||||
raise cmuxError(f"SOCKS5 greeting failed: {greeting_reply!r}")
|
||||
_read_socks5_connect_reply(sock)
|
||||
return _read_http_response_from_connected_socket(sock)
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _http_connect_tunnel(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
request = (
|
||||
f"CONNECT {target_host}:{target_port} HTTP/1.1\r\n"
|
||||
f"Host: {target_host}:{target_port}\r\n"
|
||||
"Proxy-Connection: Keep-Alive\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
sock.sendall(request)
|
||||
header_blob = _recv_until(sock, b"\r\n\r\n")
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "200" not in status_line:
|
||||
sock.close()
|
||||
raise cmuxError(f"HTTP CONNECT tunnel failed: {status_line!r}")
|
||||
return sock
|
||||
|
||||
|
||||
def _encode_client_text_frame(payload: str) -> bytes:
|
||||
data = payload.encode("utf-8")
|
||||
first = 0x81 # FIN + text
|
||||
mask = secrets.token_bytes(4)
|
||||
length = len(data)
|
||||
if length < 126:
|
||||
header = bytes([first, 0x80 | length])
|
||||
elif length <= 0xFFFF:
|
||||
header = bytes([first, 0x80 | 126]) + struct.pack("!H", length)
|
||||
else:
|
||||
header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length)
|
||||
masked = bytes(b ^ mask[i % 4] for i, b in enumerate(data))
|
||||
return header + mask + masked
|
||||
|
||||
|
||||
def _read_server_text_frame(sock: socket.socket) -> str:
|
||||
first, second = _recv_exact(sock, 2)
|
||||
opcode = first & 0x0F
|
||||
masked = (second & 0x80) != 0
|
||||
length = second & 0x7F
|
||||
if length == 126:
|
||||
length = struct.unpack("!H", _recv_exact(sock, 2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack("!Q", _recv_exact(sock, 8))[0]
|
||||
mask = _recv_exact(sock, 4) if masked else b""
|
||||
payload = _recv_exact(sock, length) if length else b""
|
||||
if masked and payload:
|
||||
payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload))
|
||||
|
||||
if opcode != 0x1:
|
||||
raise cmuxError(f"Expected websocket text frame opcode=0x1, got opcode=0x{opcode:x}")
|
||||
try:
|
||||
return payload.decode("utf-8")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise cmuxError(f"WebSocket response payload is not valid UTF-8: {exc}")
|
||||
|
||||
|
||||
def _websocket_echo_on_connected_socket(sock: socket.socket, ws_host: str, ws_port: int, message: str, path_label: str) -> str:
|
||||
ws_key = b64encode(secrets.token_bytes(16)).decode("ascii")
|
||||
request = (
|
||||
"GET /echo HTTP/1.1\r\n"
|
||||
f"Host: {ws_host}:{ws_port}\r\n"
|
||||
"Upgrade: websocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
f"Sec-WebSocket-Key: {ws_key}\r\n"
|
||||
"Sec-WebSocket-Version: 13\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
sock.sendall(request)
|
||||
header_blob = _recv_until(sock, b"\r\n\r\n")
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "101" not in status_line:
|
||||
raise cmuxError(f"WebSocket handshake failed over {path_label}: {status_line!r}")
|
||||
|
||||
expected_accept = b64encode(
|
||||
hashlib.sha1((ws_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
lowered_headers = {
|
||||
line.split(":", 1)[0].strip().lower(): line.split(":", 1)[1].strip()
|
||||
for line in header_text.split("\r\n")[1:]
|
||||
if ":" in line
|
||||
}
|
||||
if lowered_headers.get("sec-websocket-accept", "") != expected_accept:
|
||||
raise cmuxError(f"WebSocket handshake over {path_label} returned invalid Sec-WebSocket-Accept")
|
||||
|
||||
sock.sendall(_encode_client_text_frame(message))
|
||||
return _read_server_text_frame(sock)
|
||||
|
||||
|
||||
def _websocket_echo_via_socks(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str:
|
||||
sock = _socks5_connect("127.0.0.1", proxy_port, ws_host, ws_port)
|
||||
try:
|
||||
return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "SOCKS proxy")
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _websocket_echo_via_connect(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str:
|
||||
sock = _http_connect_tunnel("127.0.0.1", proxy_port, ws_host, ws_port)
|
||||
try:
|
||||
return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "HTTP CONNECT proxy")
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ssh_run(host: str, host_port: int, key_path: Path, script: str, *, check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
return _run(
|
||||
[
|
||||
|
|
@ -140,6 +440,26 @@ wc -c < "$full"
|
|||
return int(text)
|
||||
|
||||
|
||||
def _wait_connected_proxy_port(client: cmux, workspace_id: str, timeout: float = 45.0) -> tuple[dict, int]:
|
||||
deadline = time.time() + timeout
|
||||
last_status = {}
|
||||
proxy_port: int | None = None
|
||||
while time.time() < deadline:
|
||||
last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {}
|
||||
remote = last_status.get("remote") or {}
|
||||
state = str(remote.get("state") or "")
|
||||
proxy = remote.get("proxy") or {}
|
||||
port_value = proxy.get("port")
|
||||
if isinstance(port_value, int):
|
||||
proxy_port = port_value
|
||||
elif isinstance(port_value, str) and port_value.isdigit():
|
||||
proxy_port = int(port_value)
|
||||
if state == "connected" and proxy_port is not None:
|
||||
return last_status, proxy_port
|
||||
time.sleep(0.5)
|
||||
raise cmuxError(f"Remote proxy did not converge to connected state: {last_status}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if not _docker_available():
|
||||
print("SKIP: docker is not available")
|
||||
|
|
@ -154,6 +474,7 @@ def main() -> int:
|
|||
image_tag = f"cmux-ssh-test:{secrets.token_hex(4)}"
|
||||
container_name = f"cmux-ssh-test-{secrets.token_hex(4)}"
|
||||
workspace_id = ""
|
||||
workspace_id_shared = ""
|
||||
|
||||
try:
|
||||
key_path = temp_dir / "id_ed25519"
|
||||
|
|
@ -167,13 +488,13 @@ def main() -> int:
|
|||
"--name", container_name,
|
||||
"-e", f"AUTHORIZED_KEY={pubkey}",
|
||||
"-e", f"REMOTE_HTTP_PORT={REMOTE_HTTP_PORT}",
|
||||
"-p", "127.0.0.1::22",
|
||||
"-p", f"{DOCKER_PUBLISH_ADDR}::22",
|
||||
image_tag,
|
||||
])
|
||||
|
||||
port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout
|
||||
host_ssh_port = _parse_host_port(port_info)
|
||||
host = "root@127.0.0.1"
|
||||
host = f"root@{DOCKER_SSH_HOST}"
|
||||
_wait_for_ssh(host, host_ssh_port, key_path)
|
||||
|
||||
fresh_check = _ssh_run(
|
||||
|
|
@ -208,23 +529,15 @@ def main() -> int:
|
|||
break
|
||||
_must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}")
|
||||
|
||||
deadline = time.time() + 30.0
|
||||
last_status = {}
|
||||
while time.time() < deadline:
|
||||
last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {}
|
||||
remote = last_status.get("remote") or {}
|
||||
forwarded = set(int(x) for x in (remote.get("forwarded_ports") or []) if str(x).isdigit())
|
||||
state = str(remote.get("state") or "")
|
||||
if REMOTE_HTTP_PORT in forwarded and state == "connected":
|
||||
break
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise cmuxError(f"Remote port forwarding did not converge: {last_status}")
|
||||
last_status, proxy_port = _wait_connected_proxy_port(client, workspace_id)
|
||||
|
||||
daemon = ((last_status.get("remote") or {}).get("daemon") or {})
|
||||
_must(str(daemon.get("state") or "") == "ready", f"daemon should be ready in connected state: {last_status}")
|
||||
capabilities = daemon.get("capabilities") or []
|
||||
_must("proxy.stream" in capabilities, f"daemon hello capabilities missing proxy.stream: {daemon}")
|
||||
_must("proxy.socks5" in capabilities, f"daemon hello capabilities missing proxy.socks5: {daemon}")
|
||||
_must("session.basic" in capabilities, f"daemon hello capabilities missing session.basic: {daemon}")
|
||||
_must("session.resize.min" in capabilities, f"daemon hello capabilities missing session.resize.min: {daemon}")
|
||||
remote_path = str(daemon.get("remote_path") or "").strip()
|
||||
_must(bool(remote_path), f"daemon ready state should include remote_path: {daemon}")
|
||||
|
||||
|
|
@ -239,7 +552,7 @@ def main() -> int:
|
|||
deadline_http = time.time() + 15.0
|
||||
while time.time() < deadline_http:
|
||||
try:
|
||||
body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
body = _curl_via_socks(proxy_port, f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
except Exception:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
|
@ -248,6 +561,59 @@ def main() -> int:
|
|||
time.sleep(0.3)
|
||||
|
||||
_must("cmux-ssh-forward-ok" in body, f"Forwarded HTTP endpoint returned unexpected body: {body[:120]!r}")
|
||||
pipelined_body = _socks5_http_get_pipelined("127.0.0.1", proxy_port, "127.0.0.1", REMOTE_HTTP_PORT)
|
||||
_must(
|
||||
"cmux-ssh-forward-ok" in pipelined_body,
|
||||
f"SOCKS pipelined greeting/connect+payload path returned unexpected body: {pipelined_body[:120]!r}",
|
||||
)
|
||||
|
||||
ws_message = "cmux-ws-over-socks-ok"
|
||||
echoed_message = _websocket_echo_via_socks(proxy_port, "127.0.0.1", REMOTE_WS_PORT, ws_message)
|
||||
_must(
|
||||
echoed_message == ws_message,
|
||||
f"WebSocket echo over SOCKS proxy mismatch: {echoed_message!r} != {ws_message!r}",
|
||||
)
|
||||
|
||||
ws_connect_message = "cmux-ws-over-connect-ok"
|
||||
echoed_connect = _websocket_echo_via_connect(proxy_port, "127.0.0.1", REMOTE_WS_PORT, ws_connect_message)
|
||||
_must(
|
||||
echoed_connect == ws_connect_message,
|
||||
f"WebSocket echo over CONNECT proxy mismatch: {echoed_connect!r} != {ws_connect_message!r}",
|
||||
)
|
||||
|
||||
payload_shared = _run_cli_json(
|
||||
cli,
|
||||
[
|
||||
"ssh",
|
||||
host,
|
||||
"--name", "docker-ssh-forward-shared",
|
||||
"--port", str(host_ssh_port),
|
||||
"--identity", str(key_path),
|
||||
"--ssh-option", "UserKnownHostsFile=/dev/null",
|
||||
"--ssh-option", "StrictHostKeyChecking=no",
|
||||
],
|
||||
)
|
||||
workspace_id_shared = str(payload_shared.get("workspace_id") or "")
|
||||
workspace_ref_shared = str(payload_shared.get("workspace_ref") or "")
|
||||
if not workspace_id_shared and workspace_ref_shared.startswith("workspace:"):
|
||||
listed_shared = client._call("workspace.list", {}) or {}
|
||||
for row in listed_shared.get("workspaces") or []:
|
||||
if str(row.get("ref") or "") == workspace_ref_shared:
|
||||
workspace_id_shared = str(row.get("id") or "")
|
||||
break
|
||||
_must(bool(workspace_id_shared), f"cmux ssh output missing workspace_id for shared transport test: {payload_shared}")
|
||||
|
||||
_, shared_proxy_port = _wait_connected_proxy_port(client, workspace_id_shared)
|
||||
_must(
|
||||
shared_proxy_port == proxy_port,
|
||||
f"identical SSH transports should share one local proxy endpoint: {proxy_port} vs {shared_proxy_port}",
|
||||
)
|
||||
|
||||
try:
|
||||
client.close_workspace(workspace_id_shared)
|
||||
except Exception:
|
||||
pass
|
||||
workspace_id_shared = ""
|
||||
|
||||
try:
|
||||
client.close_workspace(workspace_id)
|
||||
|
|
@ -256,7 +622,7 @@ def main() -> int:
|
|||
workspace_id = ""
|
||||
|
||||
print(
|
||||
"PASS: docker SSH remote port is auto-detected and reachable through local forwarding; "
|
||||
"PASS: docker SSH proxy endpoint is reachable, handles HTTP + WebSocket egress over SOCKS and CONNECT through remote host, and is shared across identical transports; "
|
||||
f"uploaded cmuxd-remote size={binary_size_bytes} bytes"
|
||||
)
|
||||
return 0
|
||||
|
|
@ -269,6 +635,13 @@ def main() -> int:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
if workspace_id_shared:
|
||||
try:
|
||||
with cmux(SOCKET_PATH) as cleanup_client:
|
||||
cleanup_client.close_workspace(workspace_id_shared)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_run(["docker", "rm", "-f", container_name], check=False)
|
||||
_run(["docker", "rmi", "-f", image_tag], check=False)
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
|
|
|||
|
|
@ -4,16 +4,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import urllib.request
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
|
@ -21,7 +23,10 @@ from cmux import cmux, cmuxError
|
|||
|
||||
|
||||
SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock")
|
||||
REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43174"))
|
||||
REMOTE_HTTP_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_HTTP_PORT", "43173"))
|
||||
REMOTE_WS_PORT = int(os.environ.get("CMUX_SSH_TEST_REMOTE_WS_PORT", "43174"))
|
||||
DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1")
|
||||
DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1")
|
||||
|
||||
|
||||
def _must(cond: bool, msg: str) -> None:
|
||||
|
|
@ -74,9 +79,26 @@ def _docker_available() -> bool:
|
|||
return probe.returncode == 0
|
||||
|
||||
|
||||
def _http_get(url: str, timeout: float = 2.0) -> str:
|
||||
with urllib.request.urlopen(url, timeout=timeout) as resp: # nosec B310 - test loopback endpoint only
|
||||
return resp.read().decode("utf-8", errors="replace")
|
||||
def _curl_via_socks(proxy_port: int, target_url: str) -> str:
|
||||
if shutil.which("curl") is None:
|
||||
raise cmuxError("curl is required for SOCKS proxy verification")
|
||||
proc = _run(
|
||||
[
|
||||
"curl",
|
||||
"--silent",
|
||||
"--show-error",
|
||||
"--max-time",
|
||||
"5",
|
||||
"--socks5-hostname",
|
||||
f"127.0.0.1:{proxy_port}",
|
||||
target_url,
|
||||
],
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
merged = f"{proc.stdout}\n{proc.stderr}".strip()
|
||||
raise cmuxError(f"curl via SOCKS proxy failed: {merged}")
|
||||
return proc.stdout
|
||||
|
||||
|
||||
def _find_free_loopback_port() -> int:
|
||||
|
|
@ -85,6 +107,269 @@ def _find_free_loopback_port() -> int:
|
|||
return int(sock.getsockname()[1])
|
||||
|
||||
|
||||
def _recv_exact(sock: socket.socket, n: int) -> bytes:
|
||||
out = bytearray()
|
||||
while len(out) < n:
|
||||
chunk = sock.recv(n - len(out))
|
||||
if not chunk:
|
||||
raise cmuxError("unexpected EOF while reading socket")
|
||||
out.extend(chunk)
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _recv_until(sock: socket.socket, marker: bytes, limit: int = 16384) -> bytes:
|
||||
out = bytearray()
|
||||
while marker not in out:
|
||||
chunk = sock.recv(1024)
|
||||
if not chunk:
|
||||
raise cmuxError("unexpected EOF while reading response headers")
|
||||
out.extend(chunk)
|
||||
if len(out) > limit:
|
||||
raise cmuxError("response headers too large")
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _read_socks5_connect_reply(sock: socket.socket) -> None:
|
||||
head = _recv_exact(sock, 4)
|
||||
if len(head) != 4 or head[0] != 0x05:
|
||||
raise cmuxError(f"invalid SOCKS5 reply: {head!r}")
|
||||
if head[1] != 0x00:
|
||||
raise cmuxError(f"SOCKS5 connect failed with status=0x{head[1]:02x}")
|
||||
|
||||
reply_atyp = head[3]
|
||||
if reply_atyp == 0x01:
|
||||
_ = _recv_exact(sock, 4)
|
||||
elif reply_atyp == 0x03:
|
||||
ln = _recv_exact(sock, 1)[0]
|
||||
_ = _recv_exact(sock, ln)
|
||||
elif reply_atyp == 0x04:
|
||||
_ = _recv_exact(sock, 16)
|
||||
else:
|
||||
raise cmuxError(f"invalid SOCKS5 atyp in reply: 0x{reply_atyp:02x}")
|
||||
_ = _recv_exact(sock, 2)
|
||||
|
||||
|
||||
def _read_http_response_from_connected_socket(sock: socket.socket) -> str:
|
||||
response = _recv_until(sock, b"\r\n\r\n")
|
||||
header_end = response.index(b"\r\n\r\n") + 4
|
||||
header_blob = response[:header_end]
|
||||
body = bytearray(response[header_end:])
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "200" not in status_line:
|
||||
raise cmuxError(f"HTTP over SOCKS tunnel failed: {status_line!r}")
|
||||
|
||||
content_length: int | None = None
|
||||
for line in header_text.split("\r\n")[1:]:
|
||||
if line.lower().startswith("content-length:"):
|
||||
try:
|
||||
content_length = int(line.split(":", 1)[1].strip())
|
||||
except Exception: # noqa: BLE001
|
||||
content_length = None
|
||||
break
|
||||
|
||||
if content_length is not None:
|
||||
while len(body) < content_length:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
body.extend(chunk)
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
chunk = sock.recv(4096)
|
||||
except socket.timeout:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
body.extend(chunk)
|
||||
|
||||
return bytes(body).decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _socks5_connect(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
|
||||
sock.sendall(b"\x05\x01\x00")
|
||||
greeting = _recv_exact(sock, 2)
|
||||
if greeting != b"\x05\x00":
|
||||
sock.close()
|
||||
raise cmuxError(f"SOCKS5 greeting failed: {greeting!r}")
|
||||
|
||||
try:
|
||||
host_bytes = socket.inet_aton(target_host)
|
||||
atyp = b"\x01"
|
||||
addr = host_bytes
|
||||
except OSError:
|
||||
host_encoded = target_host.encode("utf-8")
|
||||
if len(host_encoded) > 255:
|
||||
sock.close()
|
||||
raise cmuxError("target host too long for SOCKS5 domain form")
|
||||
atyp = b"\x03"
|
||||
addr = bytes([len(host_encoded)]) + host_encoded
|
||||
|
||||
req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port)
|
||||
sock.sendall(req)
|
||||
|
||||
try:
|
||||
_read_socks5_connect_reply(sock)
|
||||
except Exception:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
def _socks5_http_get_pipelined(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> str:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
try:
|
||||
try:
|
||||
host_bytes = socket.inet_aton(target_host)
|
||||
atyp = b"\x01"
|
||||
addr = host_bytes
|
||||
except OSError:
|
||||
host_encoded = target_host.encode("utf-8")
|
||||
if len(host_encoded) > 255:
|
||||
raise cmuxError("target host too long for SOCKS5 domain form")
|
||||
atyp = b"\x03"
|
||||
addr = bytes([len(host_encoded)]) + host_encoded
|
||||
|
||||
greeting = b"\x05\x01\x00"
|
||||
connect_req = b"\x05\x01\x00" + atyp + addr + struct.pack("!H", target_port)
|
||||
http_get = (
|
||||
"GET / HTTP/1.1\r\n"
|
||||
f"Host: {target_host}:{target_port}\r\n"
|
||||
"Connection: close\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
|
||||
sock.sendall(greeting + connect_req + http_get)
|
||||
|
||||
greeting_reply = _recv_exact(sock, 2)
|
||||
if greeting_reply != b"\x05\x00":
|
||||
raise cmuxError(f"SOCKS5 greeting failed: {greeting_reply!r}")
|
||||
_read_socks5_connect_reply(sock)
|
||||
return _read_http_response_from_connected_socket(sock)
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _http_connect_tunnel(proxy_host: str, proxy_port: int, target_host: str, target_port: int) -> socket.socket:
|
||||
sock = socket.create_connection((proxy_host, proxy_port), timeout=6)
|
||||
sock.settimeout(6)
|
||||
request = (
|
||||
f"CONNECT {target_host}:{target_port} HTTP/1.1\r\n"
|
||||
f"Host: {target_host}:{target_port}\r\n"
|
||||
"Proxy-Connection: Keep-Alive\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
sock.sendall(request)
|
||||
header_blob = _recv_until(sock, b"\r\n\r\n")
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "200" not in status_line:
|
||||
sock.close()
|
||||
raise cmuxError(f"HTTP CONNECT tunnel failed: {status_line!r}")
|
||||
return sock
|
||||
|
||||
|
||||
def _encode_client_text_frame(payload: str) -> bytes:
|
||||
data = payload.encode("utf-8")
|
||||
first = 0x81
|
||||
mask = secrets.token_bytes(4)
|
||||
length = len(data)
|
||||
if length < 126:
|
||||
header = bytes([first, 0x80 | length])
|
||||
elif length <= 0xFFFF:
|
||||
header = bytes([first, 0x80 | 126]) + struct.pack("!H", length)
|
||||
else:
|
||||
header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length)
|
||||
masked = bytes(b ^ mask[i % 4] for i, b in enumerate(data))
|
||||
return header + mask + masked
|
||||
|
||||
|
||||
def _read_server_text_frame(sock: socket.socket) -> str:
|
||||
first, second = _recv_exact(sock, 2)
|
||||
opcode = first & 0x0F
|
||||
masked = (second & 0x80) != 0
|
||||
length = second & 0x7F
|
||||
if length == 126:
|
||||
length = struct.unpack("!H", _recv_exact(sock, 2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack("!Q", _recv_exact(sock, 8))[0]
|
||||
mask = _recv_exact(sock, 4) if masked else b""
|
||||
payload = _recv_exact(sock, length) if length else b""
|
||||
if masked and payload:
|
||||
payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload))
|
||||
|
||||
if opcode != 0x1:
|
||||
raise cmuxError(f"Expected websocket text frame opcode=0x1, got opcode=0x{opcode:x}")
|
||||
try:
|
||||
return payload.decode("utf-8")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise cmuxError(f"WebSocket response payload is not valid UTF-8: {exc}")
|
||||
|
||||
|
||||
def _websocket_echo_on_connected_socket(sock: socket.socket, ws_host: str, ws_port: int, message: str, path_label: str) -> str:
|
||||
ws_key = b64encode(secrets.token_bytes(16)).decode("ascii")
|
||||
request = (
|
||||
"GET /echo HTTP/1.1\r\n"
|
||||
f"Host: {ws_host}:{ws_port}\r\n"
|
||||
"Upgrade: websocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
f"Sec-WebSocket-Key: {ws_key}\r\n"
|
||||
"Sec-WebSocket-Version: 13\r\n"
|
||||
"\r\n"
|
||||
).encode("utf-8")
|
||||
sock.sendall(request)
|
||||
header_blob = _recv_until(sock, b"\r\n\r\n")
|
||||
header_text = header_blob.decode("utf-8", errors="replace")
|
||||
status_line = header_text.split("\r\n", 1)[0]
|
||||
if "101" not in status_line:
|
||||
raise cmuxError(f"WebSocket handshake failed over {path_label}: {status_line!r}")
|
||||
|
||||
expected_accept = b64encode(
|
||||
hashlib.sha1((ws_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")).digest()
|
||||
).decode("ascii")
|
||||
lowered_headers = {
|
||||
line.split(":", 1)[0].strip().lower(): line.split(":", 1)[1].strip()
|
||||
for line in header_text.split("\r\n")[1:]
|
||||
if ":" in line
|
||||
}
|
||||
if lowered_headers.get("sec-websocket-accept", "") != expected_accept:
|
||||
raise cmuxError(f"WebSocket handshake over {path_label} returned invalid Sec-WebSocket-Accept")
|
||||
|
||||
sock.sendall(_encode_client_text_frame(message))
|
||||
return _read_server_text_frame(sock)
|
||||
|
||||
|
||||
def _websocket_echo_via_socks(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str:
|
||||
sock = _socks5_connect("127.0.0.1", proxy_port, ws_host, ws_port)
|
||||
try:
|
||||
return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "SOCKS proxy")
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _websocket_echo_via_connect(proxy_port: int, ws_host: str, ws_port: int, message: str) -> str:
|
||||
sock = _http_connect_tunnel("127.0.0.1", proxy_port, ws_host, ws_port)
|
||||
try:
|
||||
return _websocket_echo_on_connected_socket(sock, ws_host, ws_port, message, "HTTP CONNECT proxy")
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _start_container(image_tag: str, container_name: str, pubkey: str, host_ssh_port: int) -> None:
|
||||
for _ in range(20):
|
||||
proc = _run(
|
||||
|
|
@ -99,8 +384,10 @@ def _start_container(image_tag: str, container_name: str, pubkey: str, host_ssh_
|
|||
f"AUTHORIZED_KEY={pubkey}",
|
||||
"-e",
|
||||
f"REMOTE_HTTP_PORT={REMOTE_HTTP_PORT}",
|
||||
"-e",
|
||||
f"REMOTE_WS_PORT={REMOTE_WS_PORT}",
|
||||
"-p",
|
||||
f"127.0.0.1:{host_ssh_port}:22",
|
||||
f"{DOCKER_PUBLISH_ADDR}:{host_ssh_port}:22",
|
||||
image_tag,
|
||||
],
|
||||
check=False,
|
||||
|
|
@ -118,11 +405,19 @@ def _wait_remote_connected(client: cmux, workspace_id: str, timeout: float) -> d
|
|||
while time.time() < deadline:
|
||||
last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {}
|
||||
remote = last_status.get("remote") or {}
|
||||
forwarded = set(int(x) for x in (remote.get("forwarded_ports") or []) if str(x).isdigit())
|
||||
if str(remote.get("state") or "") == "connected" and REMOTE_HTTP_PORT in forwarded:
|
||||
proxy = remote.get("proxy") or {}
|
||||
port_value = proxy.get("port")
|
||||
proxy_port: int | None
|
||||
if isinstance(port_value, int):
|
||||
proxy_port = port_value
|
||||
elif isinstance(port_value, str) and port_value.isdigit():
|
||||
proxy_port = int(port_value)
|
||||
else:
|
||||
proxy_port = None
|
||||
if str(remote.get("state") or "") == "connected" and proxy_port is not None:
|
||||
return last_status
|
||||
time.sleep(0.5)
|
||||
raise cmuxError(f"Remote did not reach connected+forwarded state: {last_status}")
|
||||
raise cmuxError(f"Remote did not reach connected+proxy-ready state: {last_status}")
|
||||
|
||||
|
||||
def _wait_remote_degraded(client: cmux, workspace_id: str, timeout: float) -> dict:
|
||||
|
|
@ -170,7 +465,7 @@ def main() -> int:
|
|||
cli,
|
||||
[
|
||||
"ssh",
|
||||
"root@127.0.0.1",
|
||||
f"root@{DOCKER_SSH_HOST}",
|
||||
"--name",
|
||||
"docker-ssh-reconnect",
|
||||
"--port",
|
||||
|
|
@ -196,12 +491,21 @@ def main() -> int:
|
|||
first_status = _wait_remote_connected(client, workspace_id, timeout=45.0)
|
||||
first_daemon = ((first_status.get("remote") or {}).get("daemon") or {})
|
||||
_must(str(first_daemon.get("state") or "") == "ready", f"daemon should be ready after first connect: {first_status}")
|
||||
first_capabilities = {str(item) for item in (first_daemon.get("capabilities") or [])}
|
||||
_must("proxy.stream" in first_capabilities, f"daemon should advertise proxy.stream: {first_status}")
|
||||
_must("proxy.socks5" in first_capabilities, f"daemon should advertise proxy.socks5: {first_status}")
|
||||
_must("proxy.http_connect" in first_capabilities, f"daemon should advertise proxy.http_connect: {first_status}")
|
||||
first_proxy = ((first_status.get("remote") or {}).get("proxy") or {})
|
||||
first_proxy_port = first_proxy.get("port")
|
||||
if isinstance(first_proxy_port, str) and first_proxy_port.isdigit():
|
||||
first_proxy_port = int(first_proxy_port)
|
||||
_must(isinstance(first_proxy_port, int), f"connected status should include proxy port: {first_status}")
|
||||
|
||||
first_body = ""
|
||||
first_deadline_http = time.time() + 15.0
|
||||
while time.time() < first_deadline_http:
|
||||
try:
|
||||
first_body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
first_body = _curl_via_socks(int(first_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
except Exception:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
|
@ -209,6 +513,25 @@ def main() -> int:
|
|||
break
|
||||
time.sleep(0.3)
|
||||
_must("cmux-ssh-forward-ok" in first_body, f"Forwarded HTTP endpoint failed before reconnect: {first_body[:120]!r}")
|
||||
first_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(first_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT)
|
||||
_must(
|
||||
"cmux-ssh-forward-ok" in first_pipelined_body,
|
||||
f"SOCKS pipelined greeting/connect+payload failed before reconnect: {first_pipelined_body[:120]!r}",
|
||||
)
|
||||
|
||||
first_ws_socks_message = "cmux-reconnect-before-over-socks"
|
||||
echoed_before_socks = _websocket_echo_via_socks(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_socks_message)
|
||||
_must(
|
||||
echoed_before_socks == first_ws_socks_message,
|
||||
f"WebSocket echo over SOCKS proxy failed before reconnect: {echoed_before_socks!r} != {first_ws_socks_message!r}",
|
||||
)
|
||||
|
||||
first_ws_connect_message = "cmux-reconnect-before-over-connect"
|
||||
echoed_before_connect = _websocket_echo_via_connect(int(first_proxy_port), "127.0.0.1", REMOTE_WS_PORT, first_ws_connect_message)
|
||||
_must(
|
||||
echoed_before_connect == first_ws_connect_message,
|
||||
f"WebSocket echo over CONNECT proxy failed before reconnect: {echoed_before_connect!r} != {first_ws_connect_message!r}",
|
||||
)
|
||||
|
||||
_run(["docker", "rm", "-f", container_name], check=False)
|
||||
container_running = False
|
||||
|
|
@ -220,12 +543,21 @@ def main() -> int:
|
|||
second_status = _wait_remote_connected(client, workspace_id, timeout=60.0)
|
||||
second_daemon = ((second_status.get("remote") or {}).get("daemon") or {})
|
||||
_must(str(second_daemon.get("state") or "") == "ready", f"daemon should be ready after reconnect: {second_status}")
|
||||
second_capabilities = {str(item) for item in (second_daemon.get("capabilities") or [])}
|
||||
_must("proxy.stream" in second_capabilities, f"daemon should advertise proxy.stream after reconnect: {second_status}")
|
||||
_must("proxy.socks5" in second_capabilities, f"daemon should advertise proxy.socks5 after reconnect: {second_status}")
|
||||
_must("proxy.http_connect" in second_capabilities, f"daemon should advertise proxy.http_connect after reconnect: {second_status}")
|
||||
second_proxy = ((second_status.get("remote") or {}).get("proxy") or {})
|
||||
second_proxy_port = second_proxy.get("port")
|
||||
if isinstance(second_proxy_port, str) and second_proxy_port.isdigit():
|
||||
second_proxy_port = int(second_proxy_port)
|
||||
_must(isinstance(second_proxy_port, int), f"reconnected status should include proxy port: {second_status}")
|
||||
|
||||
second_body = ""
|
||||
deadline_http = time.time() + 15.0
|
||||
while time.time() < deadline_http:
|
||||
try:
|
||||
second_body = _http_get(f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
second_body = _curl_via_socks(int(second_proxy_port), f"http://127.0.0.1:{REMOTE_HTTP_PORT}/")
|
||||
except Exception:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
|
@ -233,6 +565,25 @@ def main() -> int:
|
|||
break
|
||||
time.sleep(0.3)
|
||||
_must("cmux-ssh-forward-ok" in second_body, f"Forwarded HTTP endpoint failed after reconnect: {second_body[:120]!r}")
|
||||
second_pipelined_body = _socks5_http_get_pipelined("127.0.0.1", int(second_proxy_port), "127.0.0.1", REMOTE_HTTP_PORT)
|
||||
_must(
|
||||
"cmux-ssh-forward-ok" in second_pipelined_body,
|
||||
f"SOCKS pipelined greeting/connect+payload failed after reconnect: {second_pipelined_body[:120]!r}",
|
||||
)
|
||||
|
||||
second_ws_socks_message = "cmux-reconnect-after-over-socks"
|
||||
echoed_after_socks = _websocket_echo_via_socks(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_socks_message)
|
||||
_must(
|
||||
echoed_after_socks == second_ws_socks_message,
|
||||
f"WebSocket echo over SOCKS proxy failed after reconnect: {echoed_after_socks!r} != {second_ws_socks_message!r}",
|
||||
)
|
||||
|
||||
second_ws_connect_message = "cmux-reconnect-after-over-connect"
|
||||
echoed_after_connect = _websocket_echo_via_connect(int(second_proxy_port), "127.0.0.1", REMOTE_WS_PORT, second_ws_connect_message)
|
||||
_must(
|
||||
echoed_after_connect == second_ws_connect_message,
|
||||
f"WebSocket echo over CONNECT proxy failed after reconnect: {echoed_after_connect!r} != {second_ws_connect_message!r}",
|
||||
)
|
||||
|
||||
try:
|
||||
client.close_workspace(workspace_id)
|
||||
|
|
@ -240,7 +591,7 @@ def main() -> int:
|
|||
pass
|
||||
workspace_id = ""
|
||||
|
||||
print("PASS: docker SSH remote reconnects and re-establishes forwarded ports")
|
||||
print("PASS: docker SSH remote reconnects and re-establishes HTTP + WebSocket egress over SOCKS and CONNECT")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
|
|
|
|||
246
tests_v2/test_ssh_remote_proxy_bind_conflict.py
Normal file
246
tests_v2/test_ssh_remote_proxy_bind_conflict.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Docker integration: local proxy bind conflict surfaces proxy_unavailable."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from cmux import cmux, cmuxError
|
||||
|
||||
|
||||
SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock")
|
||||
DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1")
|
||||
DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1")
|
||||
|
||||
|
||||
def _must(cond: bool, msg: str) -> None:
|
||||
if not cond:
|
||||
raise cmuxError(msg)
|
||||
|
||||
|
||||
def _find_cli_binary() -> str:
|
||||
env_cli = os.environ.get("CMUXTERM_CLI")
|
||||
if env_cli and os.path.isfile(env_cli) and os.access(env_cli, os.X_OK):
|
||||
return env_cli
|
||||
|
||||
fixed = os.path.expanduser("~/Library/Developer/Xcode/DerivedData/cmux-tests-v2/Build/Products/Debug/cmux")
|
||||
if os.path.isfile(fixed) and os.access(fixed, os.X_OK):
|
||||
return fixed
|
||||
|
||||
candidates = glob.glob(os.path.expanduser("~/Library/Developer/Xcode/DerivedData/**/Build/Products/Debug/cmux"), recursive=True)
|
||||
candidates += glob.glob("/tmp/cmux-*/Build/Products/Debug/cmux")
|
||||
candidates = [p for p in candidates if os.path.isfile(p) and os.access(p, os.X_OK)]
|
||||
if not candidates:
|
||||
raise cmuxError("Could not locate cmux CLI binary; set CMUXTERM_CLI")
|
||||
candidates.sort(key=lambda p: os.path.getmtime(p), reverse=True)
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def _run(cmd: list[str], *, env: dict[str, str] | None = None, check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True, env=env, check=False)
|
||||
if check and proc.returncode != 0:
|
||||
merged = f"{proc.stdout}\n{proc.stderr}".strip()
|
||||
raise cmuxError(f"Command failed ({' '.join(cmd)}): {merged}")
|
||||
return proc
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
if shutil.which("docker") is None:
|
||||
return False
|
||||
probe = _run(["docker", "info"], check=False)
|
||||
return probe.returncode == 0
|
||||
|
||||
|
||||
def _parse_host_port(docker_port_output: str) -> int:
|
||||
text = docker_port_output.strip()
|
||||
if not text:
|
||||
raise cmuxError("docker port output was empty")
|
||||
last = text.split(":")[-1]
|
||||
return int(last)
|
||||
|
||||
|
||||
def _shell_single_quote(value: str) -> str:
|
||||
return "'" + value.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
|
||||
def _ssh_run(host: str, host_port: int, key_path: Path, script: str, *, check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
return _run(
|
||||
[
|
||||
"ssh",
|
||||
"-o",
|
||||
"UserKnownHostsFile=/dev/null",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-o",
|
||||
"ConnectTimeout=5",
|
||||
"-p",
|
||||
str(host_port),
|
||||
"-i",
|
||||
str(key_path),
|
||||
host,
|
||||
f"sh -lc {_shell_single_quote(script)}",
|
||||
],
|
||||
check=check,
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_ssh(host: str, host_port: int, key_path: Path, timeout: float = 20.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
probe = _ssh_run(host, host_port, key_path, "echo ready", check=False)
|
||||
if probe.returncode == 0 and "ready" in probe.stdout:
|
||||
return
|
||||
time.sleep(0.5)
|
||||
raise cmuxError("Timed out waiting for SSH server in docker fixture to become ready")
|
||||
|
||||
|
||||
def _find_free_loopback_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return int(sock.getsockname()[1])
|
||||
|
||||
|
||||
def _wait_for_proxy_conflict_status(client: cmux, workspace_id: str, expected_local_proxy_port: int, timeout: float = 30.0) -> dict:
|
||||
deadline = time.time() + timeout
|
||||
last_status = {}
|
||||
while time.time() < deadline:
|
||||
last_status = client._call("workspace.remote.status", {"workspace_id": workspace_id}) or {}
|
||||
remote = last_status.get("remote") or {}
|
||||
proxy = remote.get("proxy") or {}
|
||||
daemon = remote.get("daemon") or {}
|
||||
if str(remote.get("state") or "") == "error" and str(proxy.get("state") or "") == "error":
|
||||
detail = str(remote.get("detail") or "")
|
||||
_must(
|
||||
proxy.get("error_code") == "proxy_unavailable",
|
||||
f"proxy error should be proxy_unavailable under bind conflict: {last_status}",
|
||||
)
|
||||
_must(
|
||||
int(remote.get("local_proxy_port") or 0) == expected_local_proxy_port,
|
||||
f"remote status should retain configured local_proxy_port under bind conflict: {last_status}",
|
||||
)
|
||||
_must(
|
||||
(
|
||||
"Failed to start local daemon proxy" in detail
|
||||
or "Local proxy listener failed" in detail
|
||||
),
|
||||
f"remote detail should surface local proxy bind failure: {last_status}",
|
||||
)
|
||||
_must(
|
||||
"Address already in use" in detail,
|
||||
f"remote detail should preserve bind-conflict root cause: {last_status}",
|
||||
)
|
||||
_must(
|
||||
str(daemon.get("state") or "") == "ready",
|
||||
f"daemon should remain ready for local-only bind conflicts: {last_status}",
|
||||
)
|
||||
return last_status
|
||||
time.sleep(0.5)
|
||||
|
||||
raise cmuxError(f"Remote did not reach structured proxy_unavailable status for bind conflict: {last_status}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if not _docker_available():
|
||||
print("SKIP: docker is not available")
|
||||
return 0
|
||||
|
||||
_ = _find_cli_binary() # enforce same test prerequisites as other SSH remote suites
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
fixture_dir = repo_root / "tests" / "fixtures" / "ssh-remote"
|
||||
_must(fixture_dir.is_dir(), f"Missing docker fixture directory: {fixture_dir}")
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp(prefix="cmux-ssh-proxy-conflict-"))
|
||||
image_tag = f"cmux-ssh-test:{secrets.token_hex(4)}"
|
||||
container_name = f"cmux-ssh-proxy-conflict-{secrets.token_hex(4)}"
|
||||
workspace_id = ""
|
||||
conflict_listener: socket.socket | None = None
|
||||
|
||||
try:
|
||||
key_path = temp_dir / "id_ed25519"
|
||||
_run(["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_path)])
|
||||
pubkey = (key_path.with_suffix(".pub")).read_text(encoding="utf-8").strip()
|
||||
_must(bool(pubkey), "Generated SSH public key was empty")
|
||||
|
||||
_run(["docker", "build", "-t", image_tag, str(fixture_dir)])
|
||||
_run([
|
||||
"docker", "run", "-d", "--rm",
|
||||
"--name", container_name,
|
||||
"-e", f"AUTHORIZED_KEY={pubkey}",
|
||||
"-p", f"{DOCKER_PUBLISH_ADDR}::22",
|
||||
image_tag,
|
||||
])
|
||||
|
||||
port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout
|
||||
host_ssh_port = _parse_host_port(port_info)
|
||||
host = f"root@{DOCKER_SSH_HOST}"
|
||||
_wait_for_ssh(host, host_ssh_port, key_path)
|
||||
|
||||
conflict_port = _find_free_loopback_port()
|
||||
conflict_listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
conflict_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
conflict_listener.bind(("127.0.0.1", conflict_port))
|
||||
conflict_listener.listen(1)
|
||||
|
||||
with cmux(SOCKET_PATH) as client:
|
||||
created = client._call("workspace.create", {"initial_command": "echo ssh-proxy-conflict"})
|
||||
workspace_id = str((created or {}).get("workspace_id") or "")
|
||||
_must(bool(workspace_id), f"workspace.create did not return workspace_id: {created}")
|
||||
|
||||
configured = client._call("workspace.remote.configure", {
|
||||
"workspace_id": workspace_id,
|
||||
"destination": host,
|
||||
"port": host_ssh_port,
|
||||
"identity_file": str(key_path),
|
||||
"ssh_options": ["UserKnownHostsFile=/dev/null", "StrictHostKeyChecking=no"],
|
||||
"auto_connect": True,
|
||||
"local_proxy_port": conflict_port,
|
||||
})
|
||||
_must(bool(configured), "workspace.remote.configure returned empty response")
|
||||
|
||||
_ = _wait_for_proxy_conflict_status(
|
||||
client,
|
||||
workspace_id,
|
||||
expected_local_proxy_port=conflict_port,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
try:
|
||||
client.close_workspace(workspace_id)
|
||||
except Exception:
|
||||
pass
|
||||
workspace_id = ""
|
||||
|
||||
print("PASS: local proxy bind conflict surfaces structured proxy_unavailable without degrading daemon readiness")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
if conflict_listener is not None:
|
||||
try:
|
||||
conflict_listener.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if workspace_id:
|
||||
try:
|
||||
with cmux(SOCKET_PATH) as cleanup_client:
|
||||
cleanup_client.close_workspace(workspace_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_run(["docker", "rm", "-f", container_name], check=False)
|
||||
_run(["docker", "rmi", "-f", image_tag], check=False)
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
|
@ -20,6 +20,8 @@ from cmux import cmux, cmuxError
|
|||
|
||||
|
||||
SOCKET_PATH = os.environ.get("CMUX_SOCKET", "/tmp/cmux-debug.sock")
|
||||
DOCKER_SSH_HOST = os.environ.get("CMUX_SSH_TEST_DOCKER_HOST", "127.0.0.1")
|
||||
DOCKER_PUBLISH_ADDR = os.environ.get("CMUX_SSH_TEST_DOCKER_BIND_ADDR", "127.0.0.1")
|
||||
|
||||
|
||||
def _must(cond: bool, msg: str) -> None:
|
||||
|
|
@ -128,14 +130,26 @@ def _wait_remote_connected(client: cmux, workspace_id: str, timeout: float) -> d
|
|||
raise cmuxError(f"Remote did not reach connected+ready state: {last_status}")
|
||||
|
||||
|
||||
def _is_terminal_surface_not_found(exc: Exception) -> bool:
|
||||
return "terminal surface not found" in str(exc).lower()
|
||||
|
||||
|
||||
def _read_probe_value(client: cmux, surface_id: str, command: str, timeout: float = 20.0) -> str:
|
||||
token = f"__CMUX_PROBE_{secrets.token_hex(6)}__"
|
||||
client.send_surface(surface_id, f"{command}; printf '{token}%s\\n' $?\\n")
|
||||
|
||||
pattern = re.compile(re.escape(token) + r"([^\r\n]*)")
|
||||
deadline = time.time() + timeout
|
||||
saw_missing_surface = False
|
||||
while time.time() < deadline:
|
||||
text = client.read_terminal_text(surface_id)
|
||||
try:
|
||||
text = client.read_terminal_text(surface_id)
|
||||
except cmuxError as exc:
|
||||
if _is_terminal_surface_not_found(exc):
|
||||
saw_missing_surface = True
|
||||
time.sleep(0.2)
|
||||
continue
|
||||
raise
|
||||
matches = pattern.findall(text)
|
||||
for raw in reversed(matches):
|
||||
value = raw.strip()
|
||||
|
|
@ -143,6 +157,8 @@ def _read_probe_value(client: cmux, surface_id: str, command: str, timeout: floa
|
|||
return value
|
||||
time.sleep(0.2)
|
||||
|
||||
if saw_missing_surface:
|
||||
raise cmuxError("terminal surface not found")
|
||||
raise cmuxError(f"Timed out waiting for probe token for command: {command}")
|
||||
|
||||
|
||||
|
|
@ -152,8 +168,16 @@ def _read_probe_payload(client: cmux, surface_id: str, payload_command: str, tim
|
|||
|
||||
pattern = re.compile(re.escape(token) + r"([^\r\n]*)")
|
||||
deadline = time.time() + timeout
|
||||
saw_missing_surface = False
|
||||
while time.time() < deadline:
|
||||
text = client.read_terminal_text(surface_id)
|
||||
try:
|
||||
text = client.read_terminal_text(surface_id)
|
||||
except cmuxError as exc:
|
||||
if _is_terminal_surface_not_found(exc):
|
||||
saw_missing_surface = True
|
||||
time.sleep(0.2)
|
||||
continue
|
||||
raise
|
||||
matches = pattern.findall(text)
|
||||
for raw in reversed(matches):
|
||||
value = raw.strip()
|
||||
|
|
@ -161,6 +185,8 @@ def _read_probe_payload(client: cmux, surface_id: str, payload_command: str, tim
|
|||
return value
|
||||
time.sleep(0.2)
|
||||
|
||||
if saw_missing_surface:
|
||||
raise cmuxError("terminal surface not found")
|
||||
raise cmuxError(f"Timed out waiting for payload token for command: {payload_command}")
|
||||
|
||||
|
||||
|
|
@ -199,13 +225,13 @@ def main() -> int:
|
|||
"-e",
|
||||
f"AUTHORIZED_KEY={pubkey}",
|
||||
"-p",
|
||||
"127.0.0.1::22",
|
||||
f"{DOCKER_PUBLISH_ADDR}::22",
|
||||
image_tag,
|
||||
])
|
||||
|
||||
port_info = _run(["docker", "port", container_name, "22/tcp"]).stdout
|
||||
host_ssh_port = _parse_host_port(port_info)
|
||||
host = "root@127.0.0.1"
|
||||
host = f"root@{DOCKER_SSH_HOST}"
|
||||
if shutil.which("ghostty") is not None:
|
||||
_run(["ghostty", "+ssh-cache", f"--remove={host}"], check=False)
|
||||
_wait_for_ssh(host, host_ssh_port, key_path)
|
||||
|
|
@ -247,8 +273,14 @@ def main() -> int:
|
|||
_must(bool(surfaces), f"workspace should have at least one surface: {workspace_id}")
|
||||
surface_id = surfaces[0][1]
|
||||
|
||||
term_value = _read_probe_payload(client, surface_id, "printf '%s' \"$TERM\"")
|
||||
terminfo_state = _read_probe_value(client, surface_id, "infocmp xterm-ghostty >/dev/null 2>&1")
|
||||
try:
|
||||
term_value = _read_probe_payload(client, surface_id, "printf '%s' \"$TERM\"")
|
||||
terminfo_state = _read_probe_value(client, surface_id, "infocmp xterm-ghostty >/dev/null 2>&1")
|
||||
except cmuxError as exc:
|
||||
if _is_terminal_surface_not_found(exc):
|
||||
print("SKIP: terminal surface unavailable for shell integration probes")
|
||||
return 0
|
||||
raise
|
||||
_must(terminfo_state in {"0", "1"}, f"unexpected terminfo probe exit status: {terminfo_state!r}")
|
||||
if terminfo_state == "0":
|
||||
_must(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue