Address PR review feedback for SSH remote workspace flow

This commit is contained in:
Lawrence Chen 2026-02-28 17:36:07 -08:00
parent c179ee74ea
commit 47f4b5e55a
12 changed files with 277 additions and 52 deletions

View file

@ -2,8 +2,10 @@ package main
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
@ -57,6 +59,8 @@ type sessionState struct {
lastKnownRows int
}
const maxRPCFrameBytes = 4 * 1024 * 1024
func main() {
os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr))
}
@ -108,13 +112,32 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error {
}
defer server.closeAll()
scanner := bufio.NewScanner(stdin)
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
reader := bufio.NewReaderSize(stdin, 64*1024)
writer := bufio.NewWriter(stdout)
defer writer.Flush()
for scanner.Scan() {
line := scanner.Bytes()
for {
line, oversized, readErr := readRPCFrame(reader, maxRPCFrameBytes)
if readErr != nil {
if errors.Is(readErr, io.EOF) {
return nil
}
return readErr
}
if oversized {
if err := writeResponse(writer, rpcResponse{
OK: false,
Error: &rpcError{
Code: "invalid_request",
Message: "request frame exceeds maximum size",
},
}); err != nil {
return err
}
continue
}
line = bytes.TrimSuffix(line, []byte{'\n'})
line = bytes.TrimSuffix(line, []byte{'\r'})
if len(line) == 0 {
continue
}
@ -138,11 +161,51 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error {
return err
}
}
}
if err := scanner.Err(); err != nil {
func readRPCFrame(reader *bufio.Reader, maxBytes int) ([]byte, bool, error) {
frame := make([]byte, 0, 1024)
for {
chunk, err := reader.ReadSlice('\n')
if len(chunk) > 0 {
if len(frame)+len(chunk) > maxBytes {
if errors.Is(err, bufio.ErrBufferFull) {
if drainErr := discardUntilNewline(reader); drainErr != nil && !errors.Is(drainErr, io.EOF) {
return nil, false, drainErr
}
}
return nil, true, nil
}
frame = append(frame, chunk...)
}
if err == nil {
return frame, false, nil
}
if errors.Is(err, bufio.ErrBufferFull) {
continue
}
if errors.Is(err, io.EOF) {
if len(frame) == 0 {
return nil, false, io.EOF
}
return frame, false, nil
}
return nil, false, err
}
}
func discardUntilNewline(reader *bufio.Reader) error {
for {
_, err := reader.ReadSlice('\n')
if err == nil || errors.Is(err, io.EOF) {
return err
}
if errors.Is(err, bufio.ErrBufferFull) {
continue
}
return err
}
return nil
}
func writeResponse(w *bufio.Writer, resp rpcResponse) error {
@ -376,9 +439,37 @@ func (s *rpcServer) handleProxyWrite(req rpcRequest) rpcResponse {
}
}
timeoutMs := 8000
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout {
timeoutMs = parsed
}
if timeoutMs > 0 {
if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond)); err != nil {
return rpcResponse{
ID: req.ID,
OK: false,
Error: &rpcError{
Code: "stream_error",
Message: err.Error(),
},
}
}
defer conn.SetWriteDeadline(time.Time{})
}
total := 0
for total < len(payload) {
written, writeErr := conn.Write(payload[total:])
if written == 0 && writeErr == nil {
return rpcResponse{
ID: req.ID,
OK: false,
Error: &rpcError{
Code: "stream_error",
Message: "write made no progress",
},
}
}
total += written
if writeErr != nil {
return rpcResponse{

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net"
"strconv"
"strings"
@ -156,12 +157,11 @@ func TestProxyStreamRoundTrip(t *testing.T) {
}
defer conn.Close()
buffer := make([]byte, 8)
n, readErr := conn.Read(buffer)
if readErr != nil {
buffer := make([]byte, 4)
if _, readErr := io.ReadFull(conn, buffer); readErr != nil {
return
}
if string(buffer[:n]) != "ping" {
if string(buffer) != "ping" {
return
}
_, _ = conn.Write([]byte("pong"))
@ -246,6 +246,41 @@ func TestProxyStreamRoundTrip(t *testing.T) {
}
}
func TestRunStdioOversizedFrameContinuesServing(t *testing.T) {
oversized := `{"id":1,"method":"ping","params":{"blob":"` + strings.Repeat("a", maxRPCFrameBytes) + `"}}`
input := strings.NewReader(oversized + "\n" + `{"id":2,"method":"ping","params":{}}` + "\n")
var out bytes.Buffer
code := run([]string{"serve", "--stdio"}, input, &out, &bytes.Buffer{})
if code != 0 {
t.Fatalf("run serve exit code = %d, want 0", code)
}
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
if len(lines) != 2 {
t.Fatalf("got %d response lines, want 2: %q", len(lines), out.String())
}
var first map[string]any
if err := json.Unmarshal([]byte(lines[0]), &first); err != nil {
t.Fatalf("failed to decode first response: %v", err)
}
if ok, _ := first["ok"].(bool); ok {
t.Fatalf("first response should be oversized-frame error: %v", first)
}
firstError, _ := first["error"].(map[string]any)
if got := firstError["code"]; got != "invalid_request" {
t.Fatalf("oversized frame should return invalid_request; got=%v payload=%v", got, first)
}
var second map[string]any
if err := json.Unmarshal([]byte(lines[1]), &second); err != nil {
t.Fatalf("failed to decode second response: %v", err)
}
if ok, _ := second["ok"].(bool); !ok {
t.Fatalf("second response should still be handled after oversized frame: %v", second)
}
}
func TestProxyOpenInvalidParams(t *testing.T) {
server := &rpcServer{
nextStreamID: 1,