Add regression tests for SSH remote CLI follow-ups

This commit is contained in:
Lawrence Chen 2026-03-17 01:53:23 -07:00
parent 1b4fd602d8
commit 96bd2463b8
No known key found for this signature in database
2 changed files with 190 additions and 0 deletions

View file

@ -867,6 +867,40 @@ final class RemoteLoopbackHTTPRequestRewriterTests: XCTestCase {
XCTAssertEqual(rewritten, original)
}
func testBuffersSplitLoopbackAliasHeadersUntilFullRequestArrives() {
var streamRewriter = RemoteLoopbackHTTPRequestStreamRewriter(
aliasHost: "cmux-loopback.localtest.me"
)
let firstChunk = Data(
(
"GET /demo HTTP/1.1\r\n" +
"Host: cmux-loop"
).utf8
)
let secondChunk = Data(
(
"back.localtest.me:3000\r\n" +
"Origin: http://cmux-loopback.localtest.me:3000\r\n" +
"Referer: http://cmux-loopback.localtest.me:3000/app\r\n" +
"\r\n" +
"body=1"
).utf8
)
let firstOutput = streamRewriter.rewriteNextChunk(firstChunk, eof: false)
let secondOutput = streamRewriter.rewriteNextChunk(secondChunk, eof: false)
XCTAssertTrue(firstOutput.isEmpty)
let text = String(decoding: secondOutput, as: UTF8.self)
XCTAssertTrue(text.contains("Host: localhost:3000"))
XCTAssertTrue(text.contains("Origin: http://localhost:3000"))
XCTAssertTrue(text.contains("Referer: http://localhost:3000/app"))
XCTAssertTrue(text.hasSuffix("\r\n\r\nbody=1"))
XCTAssertFalse(text.contains("cmux-loopback.localtest.me"))
}
func testRewritesLoopbackResponseHeadersBackToAlias() {
let original = Data(
(

View file

@ -8,6 +8,9 @@ import (
"io"
"math"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
@ -44,6 +47,35 @@ func (b *notifyingBuffer) String() string {
return b.buffer.String()
}
type eofWithPayloadConn struct {
payload []byte
readOnce bool
}
func (c *eofWithPayloadConn) Read(p []byte) (int, error) {
if c.readOnce {
return 0, io.EOF
}
c.readOnce = true
n := copy(p, c.payload)
return n, io.EOF
}
func (c *eofWithPayloadConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *eofWithPayloadConn) Close() error { return nil }
func (c *eofWithPayloadConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
func (c *eofWithPayloadConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
func (c *eofWithPayloadConn) SetDeadline(time.Time) error { return nil }
func (c *eofWithPayloadConn) SetReadDeadline(time.Time) error { return nil }
func (c *eofWithPayloadConn) SetWriteDeadline(time.Time) error { return nil }
func TestRunVersion(t *testing.T) {
var out bytes.Buffer
code := run([]string{"version"}, strings.NewReader(""), &out, &bytes.Buffer{})
@ -55,6 +87,46 @@ func TestRunVersion(t *testing.T) {
}
}
func TestWrapperBinaryDispatchesIntoCLI(t *testing.T) {
if os.Getenv("CMUXD_REMOTE_MAIN_HELPER") == "1" {
separator := 0
for i, arg := range os.Args {
if arg == "--" {
separator = i
break
}
}
if separator == 0 {
t.Fatal("helper process missing -- separator")
}
os.Args = append([]string{os.Args[0]}, os.Args[separator+1:]...)
main()
return
}
sockPath := startMockSocket(t, "PONG")
wrapperPath := filepath.Join(t.TempDir(), "cmuxd-remote-current")
if err := os.Symlink(os.Args[0], wrapperPath); err != nil {
t.Fatalf("symlink wrapper path: %v", err)
}
cmd := exec.Command(
wrapperPath,
"-test.run=TestWrapperBinaryDispatchesIntoCLI",
"--",
"--socket", sockPath, "ping",
)
cmd.Env = append(os.Environ(), "CMUXD_REMOTE_MAIN_HELPER=1")
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("wrapper invocation failed: %v\n%s", err, output)
}
if got := strings.TrimSpace(string(output)); got != "PONG" {
t.Fatalf("wrapper invocation output = %q, want %q", got, "PONG")
}
}
func TestRunStdioHelloAndPing(t *testing.T) {
input := strings.NewReader(
`{"id":1,"method":"hello","params":{}}` + "\n" +
@ -307,6 +379,90 @@ func TestProxyStreamRoundTrip(t *testing.T) {
}
}
func TestProxyStreamEOFPayloadIsNotDuplicatedAcrossDataAndEOFEvents(t *testing.T) {
eventOutput := newNotifyingBuffer()
server := &rpcServer{
nextStreamID: 1,
nextSessionID: 1,
streams: map[string]*streamState{
"stream-1": {
conn: &eofWithPayloadConn{payload: []byte("tail")},
},
},
sessions: map[string]*sessionState{},
frameWriter: &stdioFrameWriter{
writer: bufio.NewWriter(eventOutput),
},
}
defer server.closeAll()
resp := server.handleRequest(rpcRequest{
ID: 1,
Method: "proxy.stream.subscribe",
Params: map[string]any{"stream_id": "stream-1"},
})
if !resp.OK {
t.Fatalf("proxy.stream.subscribe failed: %+v", resp)
}
deadline := time.Now().Add(2 * time.Second)
for strings.Count(strings.TrimSpace(eventOutput.String()), "\n")+boolToInt(strings.TrimSpace(eventOutput.String()) != "") < 2 {
remaining := time.Until(deadline)
if remaining <= 0 {
t.Fatalf("timed out waiting for proxy stream events: %q", eventOutput.String())
}
select {
case <-eventOutput.notify:
case <-time.After(remaining):
t.Fatalf("timed out waiting for proxy stream events: %q", eventOutput.String())
}
}
lines := strings.Split(strings.TrimSpace(eventOutput.String()), "\n")
if len(lines) != 2 {
t.Fatalf("expected exactly 2 stream events, got %d: %q", len(lines), eventOutput.String())
}
var first map[string]any
if err := json.Unmarshal([]byte(lines[0]), &first); err != nil {
t.Fatalf("decode first event: %v", err)
}
var second map[string]any
if err := json.Unmarshal([]byte(lines[1]), &second); err != nil {
t.Fatalf("decode second event: %v", err)
}
if got := first["event"]; got != "proxy.stream.data" {
t.Fatalf("first event = %v, want proxy.stream.data", got)
}
if got := second["event"]; got != "proxy.stream.eof" {
t.Fatalf("second event = %v, want proxy.stream.eof", got)
}
firstPayload, err := base64.StdEncoding.DecodeString(first["data_base64"].(string))
if err != nil {
t.Fatalf("decode first payload: %v", err)
}
secondPayload, err := base64.StdEncoding.DecodeString(second["data_base64"].(string))
if err != nil {
t.Fatalf("decode second payload: %v", err)
}
if string(firstPayload) != "tail" {
t.Fatalf("proxy.stream.data payload = %q, want %q", string(firstPayload), "tail")
}
if len(secondPayload) != 0 {
t.Fatalf("proxy.stream.eof payload = %q, want empty payload after data event", string(secondPayload))
}
}
func boolToInt(value bool) int {
if value {
return 1
}
return 0
}
func TestGetIntParamRejectsFractionalFloat64(t *testing.T) {
params := map[string]any{
"port": 80.9,