Address PR review feedback for SSH remote workspace flow
This commit is contained in:
parent
c179ee74ea
commit
47f4b5e55a
12 changed files with 277 additions and 52 deletions
|
|
@ -2,8 +2,10 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -57,6 +59,8 @@ type sessionState struct {
|
|||
lastKnownRows int
|
||||
}
|
||||
|
||||
const maxRPCFrameBytes = 4 * 1024 * 1024
|
||||
|
||||
func main() {
|
||||
os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr))
|
||||
}
|
||||
|
|
@ -108,13 +112,32 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error {
|
|||
}
|
||||
defer server.closeAll()
|
||||
|
||||
scanner := bufio.NewScanner(stdin)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
reader := bufio.NewReaderSize(stdin, 64*1024)
|
||||
writer := bufio.NewWriter(stdout)
|
||||
defer writer.Flush()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
for {
|
||||
line, oversized, readErr := readRPCFrame(reader, maxRPCFrameBytes)
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return readErr
|
||||
}
|
||||
if oversized {
|
||||
if err := writeResponse(writer, rpcResponse{
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "invalid_request",
|
||||
Message: "request frame exceeds maximum size",
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
line = bytes.TrimSuffix(line, []byte{'\n'})
|
||||
line = bytes.TrimSuffix(line, []byte{'\r'})
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
|
@ -138,11 +161,51 @@ func runStdioServer(stdin io.Reader, stdout io.Writer) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
func readRPCFrame(reader *bufio.Reader, maxBytes int) ([]byte, bool, error) {
|
||||
frame := make([]byte, 0, 1024)
|
||||
for {
|
||||
chunk, err := reader.ReadSlice('\n')
|
||||
if len(chunk) > 0 {
|
||||
if len(frame)+len(chunk) > maxBytes {
|
||||
if errors.Is(err, bufio.ErrBufferFull) {
|
||||
if drainErr := discardUntilNewline(reader); drainErr != nil && !errors.Is(drainErr, io.EOF) {
|
||||
return nil, false, drainErr
|
||||
}
|
||||
}
|
||||
return nil, true, nil
|
||||
}
|
||||
frame = append(frame, chunk...)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return frame, false, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrBufferFull) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
if len(frame) == 0 {
|
||||
return nil, false, io.EOF
|
||||
}
|
||||
return frame, false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
func discardUntilNewline(reader *bufio.Reader) error {
|
||||
for {
|
||||
_, err := reader.ReadSlice('\n')
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, bufio.ErrBufferFull) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeResponse(w *bufio.Writer, resp rpcResponse) error {
|
||||
|
|
@ -376,9 +439,37 @@ func (s *rpcServer) handleProxyWrite(req rpcRequest) rpcResponse {
|
|||
}
|
||||
}
|
||||
|
||||
timeoutMs := 8000
|
||||
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout {
|
||||
timeoutMs = parsed
|
||||
}
|
||||
if timeoutMs > 0 {
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond)); err != nil {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "stream_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
total := 0
|
||||
for total < len(payload) {
|
||||
written, writeErr := conn.Write(payload[total:])
|
||||
if written == 0 && writeErr == nil {
|
||||
return rpcResponse{
|
||||
ID: req.ID,
|
||||
OK: false,
|
||||
Error: &rpcError{
|
||||
Code: "stream_error",
|
||||
Message: "write made no progress",
|
||||
},
|
||||
}
|
||||
}
|
||||
total += written
|
||||
if writeErr != nil {
|
||||
return rpcResponse{
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -156,12 +157,11 @@ func TestProxyStreamRoundTrip(t *testing.T) {
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
buffer := make([]byte, 8)
|
||||
n, readErr := conn.Read(buffer)
|
||||
if readErr != nil {
|
||||
buffer := make([]byte, 4)
|
||||
if _, readErr := io.ReadFull(conn, buffer); readErr != nil {
|
||||
return
|
||||
}
|
||||
if string(buffer[:n]) != "ping" {
|
||||
if string(buffer) != "ping" {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("pong"))
|
||||
|
|
@ -246,6 +246,41 @@ func TestProxyStreamRoundTrip(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRunStdioOversizedFrameContinuesServing(t *testing.T) {
|
||||
oversized := `{"id":1,"method":"ping","params":{"blob":"` + strings.Repeat("a", maxRPCFrameBytes) + `"}}`
|
||||
input := strings.NewReader(oversized + "\n" + `{"id":2,"method":"ping","params":{}}` + "\n")
|
||||
var out bytes.Buffer
|
||||
code := run([]string{"serve", "--stdio"}, input, &out, &bytes.Buffer{})
|
||||
if code != 0 {
|
||||
t.Fatalf("run serve exit code = %d, want 0", code)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("got %d response lines, want 2: %q", len(lines), out.String())
|
||||
}
|
||||
|
||||
var first map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[0]), &first); err != nil {
|
||||
t.Fatalf("failed to decode first response: %v", err)
|
||||
}
|
||||
if ok, _ := first["ok"].(bool); ok {
|
||||
t.Fatalf("first response should be oversized-frame error: %v", first)
|
||||
}
|
||||
firstError, _ := first["error"].(map[string]any)
|
||||
if got := firstError["code"]; got != "invalid_request" {
|
||||
t.Fatalf("oversized frame should return invalid_request; got=%v payload=%v", got, first)
|
||||
}
|
||||
|
||||
var second map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[1]), &second); err != nil {
|
||||
t.Fatalf("failed to decode second response: %v", err)
|
||||
}
|
||||
if ok, _ := second["ok"].(bool); !ok {
|
||||
t.Fatalf("second response should still be handled after oversized frame: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyOpenInvalidParams(t *testing.T) {
|
||||
server := &rpcServer{
|
||||
nextStreamID: 1,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue