From 96bd2463b8da4fb8b9e75d57ba90edf50bb0aded Mon Sep 17 00:00:00 2001 From: Lawrence Chen Date: Tue, 17 Mar 2026 01:53:23 -0700 Subject: [PATCH] Add regression tests for SSH remote CLI follow-ups --- cmuxTests/GhosttyConfigTests.swift | 34 +++++ daemon/remote/cmd/cmuxd-remote/main_test.go | 156 ++++++++++++++++++++ 2 files changed, 190 insertions(+) diff --git a/cmuxTests/GhosttyConfigTests.swift b/cmuxTests/GhosttyConfigTests.swift index 367d8d73..9cfd242c 100644 --- a/cmuxTests/GhosttyConfigTests.swift +++ b/cmuxTests/GhosttyConfigTests.swift @@ -867,6 +867,40 @@ final class RemoteLoopbackHTTPRequestRewriterTests: XCTestCase { XCTAssertEqual(rewritten, original) } + func testBuffersSplitLoopbackAliasHeadersUntilFullRequestArrives() { + var streamRewriter = RemoteLoopbackHTTPRequestStreamRewriter( + aliasHost: "cmux-loopback.localtest.me" + ) + + let firstChunk = Data( + ( + "GET /demo HTTP/1.1\r\n" + + "Host: cmux-loop" + ).utf8 + ) + let secondChunk = Data( + ( + "back.localtest.me:3000\r\n" + + "Origin: http://cmux-loopback.localtest.me:3000\r\n" + + "Referer: http://cmux-loopback.localtest.me:3000/app\r\n" + + "\r\n" + + "body=1" + ).utf8 + ) + + let firstOutput = streamRewriter.rewriteNextChunk(firstChunk, eof: false) + let secondOutput = streamRewriter.rewriteNextChunk(secondChunk, eof: false) + + XCTAssertTrue(firstOutput.isEmpty) + + let text = String(decoding: secondOutput, as: UTF8.self) + XCTAssertTrue(text.contains("Host: localhost:3000")) + XCTAssertTrue(text.contains("Origin: http://localhost:3000")) + XCTAssertTrue(text.contains("Referer: http://localhost:3000/app")) + XCTAssertTrue(text.hasSuffix("\r\n\r\nbody=1")) + XCTAssertFalse(text.contains("cmux-loopback.localtest.me")) + } + func testRewritesLoopbackResponseHeadersBackToAlias() { let original = Data( ( diff --git a/daemon/remote/cmd/cmuxd-remote/main_test.go b/daemon/remote/cmd/cmuxd-remote/main_test.go index 3216373d..15301033 100644 --- a/daemon/remote/cmd/cmuxd-remote/main_test.go +++ b/daemon/remote/cmd/cmuxd-remote/main_test.go @@ -8,6 +8,9 @@ import ( "io" "math" "net" + "os" + "os/exec" + "path/filepath" "strconv" "strings" "sync" @@ -44,6 +47,35 @@ func (b *notifyingBuffer) String() string { return b.buffer.String() } +type eofWithPayloadConn struct { + payload []byte + readOnce bool +} + +func (c *eofWithPayloadConn) Read(p []byte) (int, error) { + if c.readOnce { + return 0, io.EOF + } + c.readOnce = true + n := copy(p, c.payload) + return n, io.EOF +} + +func (c *eofWithPayloadConn) Write(p []byte) (int, error) { + return len(p), nil +} + +func (c *eofWithPayloadConn) Close() error { return nil } +func (c *eofWithPayloadConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} +func (c *eofWithPayloadConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} +func (c *eofWithPayloadConn) SetDeadline(time.Time) error { return nil } +func (c *eofWithPayloadConn) SetReadDeadline(time.Time) error { return nil } +func (c *eofWithPayloadConn) SetWriteDeadline(time.Time) error { return nil } + func TestRunVersion(t *testing.T) { var out bytes.Buffer code := run([]string{"version"}, strings.NewReader(""), &out, &bytes.Buffer{}) @@ -55,6 +87,46 @@ func TestRunVersion(t *testing.T) { } } +func TestWrapperBinaryDispatchesIntoCLI(t *testing.T) { + if os.Getenv("CMUXD_REMOTE_MAIN_HELPER") == "1" { + separator := 0 + for i, arg := range os.Args { + if arg == "--" { + separator = i + break + } + } + if separator == 0 { + t.Fatal("helper process missing -- separator") + } + os.Args = append([]string{os.Args[0]}, os.Args[separator+1:]...) + main() + return + } + + sockPath := startMockSocket(t, "PONG") + wrapperPath := filepath.Join(t.TempDir(), "cmuxd-remote-current") + if err := os.Symlink(os.Args[0], wrapperPath); err != nil { + t.Fatalf("symlink wrapper path: %v", err) + } + + cmd := exec.Command( + wrapperPath, + "-test.run=TestWrapperBinaryDispatchesIntoCLI", + "--", + "--socket", sockPath, "ping", + ) + cmd.Env = append(os.Environ(), "CMUXD_REMOTE_MAIN_HELPER=1") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("wrapper invocation failed: %v\n%s", err, output) + } + + if got := strings.TrimSpace(string(output)); got != "PONG" { + t.Fatalf("wrapper invocation output = %q, want %q", got, "PONG") + } +} + func TestRunStdioHelloAndPing(t *testing.T) { input := strings.NewReader( `{"id":1,"method":"hello","params":{}}` + "\n" + @@ -307,6 +379,90 @@ func TestProxyStreamRoundTrip(t *testing.T) { } } +func TestProxyStreamEOFPayloadIsNotDuplicatedAcrossDataAndEOFEvents(t *testing.T) { + eventOutput := newNotifyingBuffer() + server := &rpcServer{ + nextStreamID: 1, + nextSessionID: 1, + streams: map[string]*streamState{ + "stream-1": { + conn: &eofWithPayloadConn{payload: []byte("tail")}, + }, + }, + sessions: map[string]*sessionState{}, + frameWriter: &stdioFrameWriter{ + writer: bufio.NewWriter(eventOutput), + }, + } + defer server.closeAll() + + resp := server.handleRequest(rpcRequest{ + ID: 1, + Method: "proxy.stream.subscribe", + Params: map[string]any{"stream_id": "stream-1"}, + }) + if !resp.OK { + t.Fatalf("proxy.stream.subscribe failed: %+v", resp) + } + + deadline := time.Now().Add(2 * time.Second) + for strings.Count(strings.TrimSpace(eventOutput.String()), "\n")+boolToInt(strings.TrimSpace(eventOutput.String()) != "") < 2 { + remaining := time.Until(deadline) + if remaining <= 0 { + t.Fatalf("timed out waiting for proxy stream events: %q", eventOutput.String()) + } + select { + case <-eventOutput.notify: + case <-time.After(remaining): + t.Fatalf("timed out waiting for proxy stream events: %q", eventOutput.String()) + } + } + + lines := strings.Split(strings.TrimSpace(eventOutput.String()), "\n") + if len(lines) != 2 { + t.Fatalf("expected exactly 2 stream events, got %d: %q", len(lines), eventOutput.String()) + } + + var first map[string]any + if err := json.Unmarshal([]byte(lines[0]), &first); err != nil { + t.Fatalf("decode first event: %v", err) + } + var second map[string]any + if err := json.Unmarshal([]byte(lines[1]), &second); err != nil { + t.Fatalf("decode second event: %v", err) + } + + if got := first["event"]; got != "proxy.stream.data" { + t.Fatalf("first event = %v, want proxy.stream.data", got) + } + if got := second["event"]; got != "proxy.stream.eof" { + t.Fatalf("second event = %v, want proxy.stream.eof", got) + } + + firstPayload, err := base64.StdEncoding.DecodeString(first["data_base64"].(string)) + if err != nil { + t.Fatalf("decode first payload: %v", err) + } + secondPayload, err := base64.StdEncoding.DecodeString(second["data_base64"].(string)) + if err != nil { + t.Fatalf("decode second payload: %v", err) + } + + if string(firstPayload) != "tail" { + t.Fatalf("proxy.stream.data payload = %q, want %q", string(firstPayload), "tail") + } + if len(secondPayload) != 0 { + t.Fatalf("proxy.stream.eof payload = %q, want empty payload after data event", string(secondPayload)) + } +} + +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + func TestGetIntParamRejectsFractionalFloat64(t *testing.T) { params := map[string]any{ "port": 80.9,