1021 lines
21 KiB
Go
1021 lines
21 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var version = "dev"
|
|
|
|
type rpcRequest struct {
|
|
ID any `json:"id"`
|
|
Method string `json:"method"`
|
|
Params map[string]any `json:"params"`
|
|
}
|
|
|
|
type rpcError struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
type rpcResponse struct {
|
|
ID any `json:"id,omitempty"`
|
|
OK bool `json:"ok"`
|
|
Result any `json:"result,omitempty"`
|
|
Error *rpcError `json:"error,omitempty"`
|
|
}
|
|
|
|
type rpcServer struct {
|
|
mu sync.Mutex
|
|
nextStreamID uint64
|
|
nextSessionID uint64
|
|
streams map[string]net.Conn
|
|
sessions map[string]*sessionState
|
|
}
|
|
|
|
type sessionAttachment struct {
|
|
Cols int
|
|
Rows int
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
type sessionState struct {
|
|
attachments map[string]sessionAttachment
|
|
effectiveCols int
|
|
effectiveRows int
|
|
lastKnownCols int
|
|
lastKnownRows int
|
|
}
|
|
|
|
const maxRPCFrameBytes = 4 * 1024 * 1024
|
|
|
|
func main() {
|
|
os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr))
|
|
}
|
|
|
|
func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int {
|
|
if len(args) == 0 {
|
|
usage(stderr)
|
|
return 2
|
|
}
|
|
|
|
switch args[0] {
|
|
case "version":
|
|
_, _ = fmt.Fprintln(stdout, version)
|
|
return 0
|
|
case "serve":
|
|
fs := flag.NewFlagSet("serve", flag.ContinueOnError)
|
|
fs.SetOutput(stderr)
|
|
stdio := fs.Bool("stdio", false, "serve over stdin/stdout")
|
|
if err := fs.Parse(args[1:]); err != nil {
|
|
return 2
|
|
}
|
|
if !*stdio {
|
|
_, _ = fmt.Fprintln(stderr, "serve requires --stdio")
|
|
return 2
|
|
}
|
|
if err := runStdioServer(stdin, stdout); err != nil {
|
|
_, _ = fmt.Fprintf(stderr, "serve failed: %v\n", err)
|
|
return 1
|
|
}
|
|
return 0
|
|
default:
|
|
usage(stderr)
|
|
return 2
|
|
}
|
|
}
|
|
|
|
func usage(w io.Writer) {
|
|
_, _ = fmt.Fprintln(w, "Usage:")
|
|
_, _ = fmt.Fprintln(w, " cmuxd-remote version")
|
|
_, _ = fmt.Fprintln(w, " cmuxd-remote serve --stdio")
|
|
}
|
|
|
|
func runStdioServer(stdin io.Reader, stdout io.Writer) error {
|
|
server := &rpcServer{
|
|
nextStreamID: 1,
|
|
nextSessionID: 1,
|
|
streams: map[string]net.Conn{},
|
|
sessions: map[string]*sessionState{},
|
|
}
|
|
defer server.closeAll()
|
|
|
|
reader := bufio.NewReaderSize(stdin, 64*1024)
|
|
writer := bufio.NewWriter(stdout)
|
|
defer writer.Flush()
|
|
|
|
for {
|
|
line, oversized, readErr := readRPCFrame(reader, maxRPCFrameBytes)
|
|
if readErr != nil {
|
|
if errors.Is(readErr, io.EOF) {
|
|
return nil
|
|
}
|
|
return readErr
|
|
}
|
|
if oversized {
|
|
if err := writeResponse(writer, rpcResponse{
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_request",
|
|
Message: "request frame exceeds maximum size",
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
line = bytes.TrimSuffix(line, []byte{'\n'})
|
|
line = bytes.TrimSuffix(line, []byte{'\r'})
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
var req rpcRequest
|
|
if err := json.Unmarshal(line, &req); err != nil {
|
|
if err := writeResponse(writer, rpcResponse{
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_request",
|
|
Message: "invalid JSON request",
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
resp := server.handleRequest(req)
|
|
if err := writeResponse(writer, resp); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func readRPCFrame(reader *bufio.Reader, maxBytes int) ([]byte, bool, error) {
|
|
frame := make([]byte, 0, 1024)
|
|
for {
|
|
chunk, err := reader.ReadSlice('\n')
|
|
if len(chunk) > 0 {
|
|
if len(frame)+len(chunk) > maxBytes {
|
|
if errors.Is(err, bufio.ErrBufferFull) {
|
|
if drainErr := discardUntilNewline(reader); drainErr != nil && !errors.Is(drainErr, io.EOF) {
|
|
return nil, false, drainErr
|
|
}
|
|
}
|
|
return nil, true, nil
|
|
}
|
|
frame = append(frame, chunk...)
|
|
}
|
|
|
|
if err == nil {
|
|
return frame, false, nil
|
|
}
|
|
if errors.Is(err, bufio.ErrBufferFull) {
|
|
continue
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
if len(frame) == 0 {
|
|
return nil, false, io.EOF
|
|
}
|
|
return frame, false, nil
|
|
}
|
|
return nil, false, err
|
|
}
|
|
}
|
|
|
|
func discardUntilNewline(reader *bufio.Reader) error {
|
|
for {
|
|
_, err := reader.ReadSlice('\n')
|
|
if err == nil || errors.Is(err, io.EOF) {
|
|
return err
|
|
}
|
|
if errors.Is(err, bufio.ErrBufferFull) {
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
func writeResponse(w *bufio.Writer, resp rpcResponse) error {
|
|
payload, err := json.Marshal(resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := w.Write(payload); err != nil {
|
|
return err
|
|
}
|
|
if err := w.WriteByte('\n'); err != nil {
|
|
return err
|
|
}
|
|
return w.Flush()
|
|
}
|
|
|
|
func (s *rpcServer) handleRequest(req rpcRequest) rpcResponse {
|
|
if req.Method == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_request",
|
|
Message: "method is required",
|
|
},
|
|
}
|
|
}
|
|
|
|
switch req.Method {
|
|
case "hello":
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"name": "cmuxd-remote",
|
|
"version": version,
|
|
"capabilities": []string{
|
|
"session.basic",
|
|
"session.resize.min",
|
|
"proxy.http_connect",
|
|
"proxy.socks5",
|
|
"proxy.stream",
|
|
},
|
|
},
|
|
}
|
|
case "ping":
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"pong": true,
|
|
},
|
|
}
|
|
case "proxy.open":
|
|
return s.handleProxyOpen(req)
|
|
case "proxy.close":
|
|
return s.handleProxyClose(req)
|
|
case "proxy.write":
|
|
return s.handleProxyWrite(req)
|
|
case "proxy.read":
|
|
return s.handleProxyRead(req)
|
|
case "session.open":
|
|
return s.handleSessionOpen(req)
|
|
case "session.close":
|
|
return s.handleSessionClose(req)
|
|
case "session.attach":
|
|
return s.handleSessionAttach(req)
|
|
case "session.resize":
|
|
return s.handleSessionResize(req)
|
|
case "session.detach":
|
|
return s.handleSessionDetach(req)
|
|
case "session.status":
|
|
return s.handleSessionStatus(req)
|
|
default:
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "method_not_found",
|
|
Message: fmt.Sprintf("unknown method %q", req.Method),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleProxyOpen(req rpcRequest) rpcResponse {
|
|
host, ok := getStringParam(req.Params, "host")
|
|
if !ok || host == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.open requires host",
|
|
},
|
|
}
|
|
}
|
|
port, ok := getIntParam(req.Params, "port")
|
|
if !ok || port <= 0 || port > 65535 {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.open requires port in range 1-65535",
|
|
},
|
|
}
|
|
}
|
|
|
|
timeoutMs := 10000
|
|
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 {
|
|
timeoutMs = parsed
|
|
}
|
|
|
|
conn, err := net.DialTimeout(
|
|
"tcp",
|
|
net.JoinHostPort(host, strconv.Itoa(port)),
|
|
time.Duration(timeoutMs)*time.Millisecond,
|
|
)
|
|
if err != nil {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "open_failed",
|
|
Message: err.Error(),
|
|
},
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
streamID := fmt.Sprintf("s-%d", s.nextStreamID)
|
|
s.nextStreamID++
|
|
s.streams[streamID] = conn
|
|
s.mu.Unlock()
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"stream_id": streamID,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleProxyClose(req rpcRequest) rpcResponse {
|
|
streamID, ok := getStringParam(req.Params, "stream_id")
|
|
if !ok || streamID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.close requires stream_id",
|
|
},
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
conn, exists := s.streams[streamID]
|
|
if exists {
|
|
delete(s.streams, streamID)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "stream not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
_ = conn.Close()
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"closed": true,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleProxyWrite(req rpcRequest) rpcResponse {
|
|
streamID, ok := getStringParam(req.Params, "stream_id")
|
|
if !ok || streamID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.write requires stream_id",
|
|
},
|
|
}
|
|
}
|
|
dataBase64, ok := getStringParam(req.Params, "data_base64")
|
|
if !ok {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.write requires data_base64",
|
|
},
|
|
}
|
|
}
|
|
payload, err := base64.StdEncoding.DecodeString(dataBase64)
|
|
if err != nil {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "data_base64 must be valid base64",
|
|
},
|
|
}
|
|
}
|
|
|
|
conn, found := s.getStream(streamID)
|
|
if !found {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "stream not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
timeoutMs := 8000
|
|
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout {
|
|
timeoutMs = parsed
|
|
}
|
|
if timeoutMs > 0 {
|
|
if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond)); err != nil {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "stream_error",
|
|
Message: err.Error(),
|
|
},
|
|
}
|
|
}
|
|
defer conn.SetWriteDeadline(time.Time{})
|
|
}
|
|
|
|
total := 0
|
|
for total < len(payload) {
|
|
written, writeErr := conn.Write(payload[total:])
|
|
if written == 0 && writeErr == nil {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "stream_error",
|
|
Message: "write made no progress",
|
|
},
|
|
}
|
|
}
|
|
total += written
|
|
if writeErr != nil {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "stream_error",
|
|
Message: writeErr.Error(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"written": total,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleProxyRead(req rpcRequest) rpcResponse {
|
|
streamID, ok := getStringParam(req.Params, "stream_id")
|
|
if !ok || streamID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "proxy.read requires stream_id",
|
|
},
|
|
}
|
|
}
|
|
|
|
maxBytes := 32768
|
|
if parsed, hasMax := getIntParam(req.Params, "max_bytes"); hasMax {
|
|
maxBytes = parsed
|
|
}
|
|
if maxBytes <= 0 || maxBytes > 262144 {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "max_bytes must be in range 1-262144",
|
|
},
|
|
}
|
|
}
|
|
|
|
timeoutMs := 50
|
|
if parsed, hasTimeout := getIntParam(req.Params, "timeout_ms"); hasTimeout && parsed >= 0 {
|
|
timeoutMs = parsed
|
|
}
|
|
|
|
conn, found := s.getStream(streamID)
|
|
if !found {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "stream not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond))
|
|
buffer := make([]byte, maxBytes)
|
|
n, readErr := conn.Read(buffer)
|
|
data := buffer[:max(0, n)]
|
|
|
|
if readErr != nil {
|
|
if netErr, ok := readErr.(net.Error); ok && netErr.Timeout() {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"data_base64": "",
|
|
"eof": false,
|
|
},
|
|
}
|
|
}
|
|
if readErr == io.EOF {
|
|
s.dropStream(streamID)
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"data_base64": base64.StdEncoding.EncodeToString(data),
|
|
"eof": true,
|
|
},
|
|
}
|
|
}
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "stream_error",
|
|
Message: readErr.Error(),
|
|
},
|
|
}
|
|
}
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"data_base64": base64.StdEncoding.EncodeToString(data),
|
|
"eof": false,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionOpen(req rpcRequest) rpcResponse {
|
|
sessionID, _ := getStringParam(req.Params, "session_id")
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if sessionID == "" {
|
|
sessionID = fmt.Sprintf("sess-%d", s.nextSessionID)
|
|
s.nextSessionID++
|
|
}
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
session = &sessionState{
|
|
attachments: map[string]sessionAttachment{},
|
|
}
|
|
s.sessions[sessionID] = session
|
|
}
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: sessionSnapshot(sessionID, session),
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionClose(req rpcRequest) rpcResponse {
|
|
sessionID, ok := getStringParam(req.Params, "session_id")
|
|
if !ok || sessionID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "session.close requires session_id",
|
|
},
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
_, exists := s.sessions[sessionID]
|
|
if exists {
|
|
delete(s.sessions, sessionID)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "session not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: map[string]any{
|
|
"session_id": sessionID,
|
|
"closed": true,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionAttach(req rpcRequest) rpcResponse {
|
|
sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.attach")
|
|
if badResp != nil {
|
|
return *badResp
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "session not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
session.attachments[attachmentID] = sessionAttachment{
|
|
Cols: cols,
|
|
Rows: rows,
|
|
UpdatedAt: time.Now().UTC(),
|
|
}
|
|
recomputeSessionSize(session)
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: sessionSnapshot(sessionID, session),
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionResize(req rpcRequest) rpcResponse {
|
|
sessionID, attachmentID, cols, rows, badResp := parseSessionAttachmentParams(req, "session.resize")
|
|
if badResp != nil {
|
|
return *badResp
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "session not found",
|
|
},
|
|
}
|
|
}
|
|
if _, exists := session.attachments[attachmentID]; !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "attachment not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
session.attachments[attachmentID] = sessionAttachment{
|
|
Cols: cols,
|
|
Rows: rows,
|
|
UpdatedAt: time.Now().UTC(),
|
|
}
|
|
recomputeSessionSize(session)
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: sessionSnapshot(sessionID, session),
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionDetach(req rpcRequest) rpcResponse {
|
|
sessionID, ok := getStringParam(req.Params, "session_id")
|
|
if !ok || sessionID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "session.detach requires session_id",
|
|
},
|
|
}
|
|
}
|
|
attachmentID, ok := getStringParam(req.Params, "attachment_id")
|
|
if !ok || attachmentID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "session.detach requires attachment_id",
|
|
},
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "session not found",
|
|
},
|
|
}
|
|
}
|
|
if _, exists := session.attachments[attachmentID]; !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "attachment not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
delete(session.attachments, attachmentID)
|
|
recomputeSessionSize(session)
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: sessionSnapshot(sessionID, session),
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) handleSessionStatus(req rpcRequest) rpcResponse {
|
|
sessionID, ok := getStringParam(req.Params, "session_id")
|
|
if !ok || sessionID == "" {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: "session.status requires session_id",
|
|
},
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "not_found",
|
|
Message: "session not found",
|
|
},
|
|
}
|
|
}
|
|
|
|
return rpcResponse{
|
|
ID: req.ID,
|
|
OK: true,
|
|
Result: sessionSnapshot(sessionID, session),
|
|
}
|
|
}
|
|
|
|
func parseSessionAttachmentParams(req rpcRequest, method string) (sessionID string, attachmentID string, cols int, rows int, badResp *rpcResponse) {
|
|
sessionID, ok := getStringParam(req.Params, "session_id")
|
|
if !ok || sessionID == "" {
|
|
resp := rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: method + " requires session_id",
|
|
},
|
|
}
|
|
return "", "", 0, 0, &resp
|
|
}
|
|
attachmentID, ok = getStringParam(req.Params, "attachment_id")
|
|
if !ok || attachmentID == "" {
|
|
resp := rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: method + " requires attachment_id",
|
|
},
|
|
}
|
|
return "", "", 0, 0, &resp
|
|
}
|
|
|
|
cols, ok = getIntParam(req.Params, "cols")
|
|
if !ok || cols <= 0 {
|
|
resp := rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: method + " requires cols > 0",
|
|
},
|
|
}
|
|
return "", "", 0, 0, &resp
|
|
}
|
|
rows, ok = getIntParam(req.Params, "rows")
|
|
if !ok || rows <= 0 {
|
|
resp := rpcResponse{
|
|
ID: req.ID,
|
|
OK: false,
|
|
Error: &rpcError{
|
|
Code: "invalid_params",
|
|
Message: method + " requires rows > 0",
|
|
},
|
|
}
|
|
return "", "", 0, 0, &resp
|
|
}
|
|
|
|
return sessionID, attachmentID, cols, rows, nil
|
|
}
|
|
|
|
func recomputeSessionSize(session *sessionState) {
|
|
if len(session.attachments) == 0 {
|
|
session.effectiveCols = session.lastKnownCols
|
|
session.effectiveRows = session.lastKnownRows
|
|
return
|
|
}
|
|
|
|
minCols := 0
|
|
minRows := 0
|
|
for _, attachment := range session.attachments {
|
|
if minCols == 0 || attachment.Cols < minCols {
|
|
minCols = attachment.Cols
|
|
}
|
|
if minRows == 0 || attachment.Rows < minRows {
|
|
minRows = attachment.Rows
|
|
}
|
|
}
|
|
|
|
session.effectiveCols = minCols
|
|
session.effectiveRows = minRows
|
|
session.lastKnownCols = minCols
|
|
session.lastKnownRows = minRows
|
|
}
|
|
|
|
func sessionSnapshot(sessionID string, session *sessionState) map[string]any {
|
|
attachmentIDs := make([]string, 0, len(session.attachments))
|
|
for attachmentID := range session.attachments {
|
|
attachmentIDs = append(attachmentIDs, attachmentID)
|
|
}
|
|
sort.Strings(attachmentIDs)
|
|
|
|
attachments := make([]map[string]any, 0, len(attachmentIDs))
|
|
for _, attachmentID := range attachmentIDs {
|
|
attachment := session.attachments[attachmentID]
|
|
attachments = append(attachments, map[string]any{
|
|
"attachment_id": attachmentID,
|
|
"cols": attachment.Cols,
|
|
"rows": attachment.Rows,
|
|
"updated_at": attachment.UpdatedAt.Format(time.RFC3339Nano),
|
|
})
|
|
}
|
|
|
|
return map[string]any{
|
|
"session_id": sessionID,
|
|
"attachments": attachments,
|
|
"effective_cols": session.effectiveCols,
|
|
"effective_rows": session.effectiveRows,
|
|
"last_known_cols": session.lastKnownCols,
|
|
"last_known_rows": session.lastKnownRows,
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) getStream(streamID string) (net.Conn, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
conn, ok := s.streams[streamID]
|
|
return conn, ok
|
|
}
|
|
|
|
func (s *rpcServer) dropStream(streamID string) {
|
|
s.mu.Lock()
|
|
conn, ok := s.streams[streamID]
|
|
if ok {
|
|
delete(s.streams, streamID)
|
|
}
|
|
s.mu.Unlock()
|
|
if ok {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (s *rpcServer) closeAll() {
|
|
s.mu.Lock()
|
|
streams := make([]net.Conn, 0, len(s.streams))
|
|
for id, conn := range s.streams {
|
|
delete(s.streams, id)
|
|
streams = append(streams, conn)
|
|
}
|
|
for id := range s.sessions {
|
|
delete(s.sessions, id)
|
|
}
|
|
s.mu.Unlock()
|
|
for _, conn := range streams {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func getStringParam(params map[string]any, key string) (string, bool) {
|
|
if params == nil {
|
|
return "", false
|
|
}
|
|
raw, ok := params[key]
|
|
if !ok || raw == nil {
|
|
return "", false
|
|
}
|
|
value, ok := raw.(string)
|
|
return value, ok
|
|
}
|
|
|
|
func getIntParam(params map[string]any, key string) (int, bool) {
|
|
if params == nil {
|
|
return 0, false
|
|
}
|
|
raw, ok := params[key]
|
|
if !ok || raw == nil {
|
|
return 0, false
|
|
}
|
|
switch value := raw.(type) {
|
|
case int:
|
|
return value, true
|
|
case int8:
|
|
return int(value), true
|
|
case int16:
|
|
return int(value), true
|
|
case int32:
|
|
return int(value), true
|
|
case int64:
|
|
return int(value), true
|
|
case uint:
|
|
return int(value), true
|
|
case uint8:
|
|
return int(value), true
|
|
case uint16:
|
|
return int(value), true
|
|
case uint32:
|
|
return int(value), true
|
|
case uint64:
|
|
return int(value), true
|
|
case float64:
|
|
return int(value), true
|
|
case json.Number:
|
|
n, err := value.Int64()
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return int(n), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|