Secure remote daemon distribution and relay auth

This commit is contained in:
Lawrence Chen 2026-03-12 05:04:44 -07:00
parent 76cfe01fa2
commit 8a9e28e129
12 changed files with 1419 additions and 117 deletions

View file

@ -2,7 +2,9 @@ package main
import (
"bufio"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
@ -15,6 +17,11 @@ import (
"time"
)
type relayAuthState struct {
RelayID string `json:"relay_id"`
RelayToken string `json:"relay_token"`
}
// protocolVersion indicates whether a command uses the v1 text or v2 JSON-RPC protocol.
type protocolVersion int
@ -376,13 +383,58 @@ func readSocketAddrFile() string {
return strings.TrimSpace(string(data))
}
func readRelayAuthFile(socketPath string) *relayAuthState {
if strings.Contains(socketPath, ":") && !strings.HasPrefix(socketPath, "/") {
_, port, err := net.SplitHostPort(socketPath)
if err != nil || port == "" {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return nil
}
data, err := os.ReadFile(filepath.Join(home, ".cmux", "relay", port+".auth"))
if err != nil {
return nil
}
var state relayAuthState
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
if state.RelayID == "" || state.RelayToken == "" {
return nil
}
return &state
}
return nil
}
func currentRelayAuth(socketPath string) *relayAuthState {
relayID := strings.TrimSpace(os.Getenv("CMUX_RELAY_ID"))
relayToken := strings.TrimSpace(os.Getenv("CMUX_RELAY_TOKEN"))
if relayID != "" && relayToken != "" {
return &relayAuthState{RelayID: relayID, RelayToken: relayToken}
}
return readRelayAuthFile(socketPath)
}
// dialSocket connects to the cmux socket. If addr contains a colon and doesn't
// start with '/', it's treated as a TCP address (host:port); otherwise Unix socket.
// For TCP connections, it retries briefly to allow the SSH reverse forward to establish.
// refreshAddr, if non-nil, is called on each retry to pick up updated socket_addr files.
func dialSocket(addr string, refreshAddr func() string) (net.Conn, error) {
if strings.Contains(addr, ":") && !strings.HasPrefix(addr, "/") {
return dialTCPRetry(addr, 15*time.Second, refreshAddr)
conn, err := dialTCPRetry(addr, 15*time.Second, refreshAddr)
if err != nil {
return nil, err
}
if auth := currentRelayAuth(addr); auth != nil {
if err := authenticateRelayConn(conn, auth); err != nil {
conn.Close()
return nil, err
}
}
return conn, nil
}
return net.Dial("unix", addr)
}
@ -429,6 +481,66 @@ func isConnectionRefused(err error) bool {
return strings.Contains(err.Error(), "connection refused")
}
func authenticateRelayConn(conn net.Conn, auth *relayAuthState) error {
reader := bufio.NewReader(conn)
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
var challenge struct {
Protocol string `json:"protocol"`
Version int `json:"version"`
RelayID string `json:"relay_id"`
Nonce string `json:"nonce"`
}
line, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read relay auth challenge: %w", err)
}
if err := json.Unmarshal([]byte(line), &challenge); err != nil {
return fmt.Errorf("invalid relay auth challenge")
}
if challenge.Protocol != "cmux-relay-auth" || challenge.Version != 1 || challenge.RelayID != auth.RelayID || challenge.Nonce == "" {
return fmt.Errorf("relay auth challenge mismatch")
}
tokenBytes, err := hex.DecodeString(auth.RelayToken)
if err != nil {
return fmt.Errorf("invalid relay auth token")
}
mac := computeRelayMAC(tokenBytes, auth.RelayID, challenge.Nonce, challenge.Version)
payload, err := json.Marshal(map[string]any{
"relay_id": auth.RelayID,
"mac": hex.EncodeToString(mac),
})
if err != nil {
return fmt.Errorf("failed to encode relay auth response: %w", err)
}
if _, err := conn.Write(append(payload, '\n')); err != nil {
return fmt.Errorf("failed to send relay auth response: %w", err)
}
line, err = reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read relay auth result: %w", err)
}
var result struct {
OK bool `json:"ok"`
}
if err := json.Unmarshal([]byte(line), &result); err != nil {
return fmt.Errorf("invalid relay auth result")
}
if !result.OK {
return fmt.Errorf("relay auth rejected")
}
_ = conn.SetDeadline(time.Time{})
return nil
}
func computeRelayMAC(token []byte, relayID, nonce string, version int) []byte {
mac := hmac.New(sha256.New, token)
_, _ = io.WriteString(mac, fmt.Sprintf("relay_id=%s\nnonce=%s\nversion=%d", relayID, nonce, version))
return mac.Sum(nil)
}
// socketRoundTrip sends a raw text line and reads a raw text response (v1).
func socketRoundTrip(socketPath, command string, refreshAddr func() string) (string, error) {
conn, err := dialSocket(socketPath, refreshAddr)

View file

@ -9,6 +9,7 @@ import (
"flag"
"fmt"
"io"
"math"
"net"
"os"
"path/filepath"
@ -1017,6 +1018,9 @@ func getIntParam(params map[string]any, key string) (int, bool) {
case uint64:
return int(value), true
case float64:
if math.Trunc(value) != value {
return 0, false
}
return int(value), true
case json.Number:
n, err := value.Int64()