cmux/tests/fixtures/ssh-remote/ws_echo.py
2026-03-13 04:14:52 -07:00

138 lines
4.6 KiB
Python

#!/usr/bin/env python3
"""Tiny WebSocket echo server for SSH proxy integration tests."""
from __future__ import annotations
import argparse
import base64
import hashlib
import socket
import struct
import threading
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def _recv_exact(conn: socket.socket, n: int, pending: bytearray | None = None) -> bytes:
data = bytearray()
if pending:
take = min(len(pending), n)
if take:
data.extend(pending[:take])
del pending[:take]
while len(data) < n:
chunk = conn.recv(n - len(data))
if not chunk:
raise ConnectionError("unexpected EOF")
data.extend(chunk)
return bytes(data)
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> tuple[bytes, bytearray]:
data = bytearray()
while marker not in data:
chunk = conn.recv(1024)
if not chunk:
raise ConnectionError("unexpected EOF while reading headers")
data.extend(chunk)
if len(data) > limit:
raise ValueError("header too large")
marker_end = data.index(marker) + len(marker)
return bytes(data[:marker_end]), bytearray(data[marker_end:])
def _read_frame(conn: socket.socket, pending: bytearray | None = None) -> tuple[int, bytes]:
first, second = _recv_exact(conn, 2, pending)
opcode = first & 0x0F
masked = (second & 0x80) != 0
length = second & 0x7F
if length == 126:
length = struct.unpack("!H", _recv_exact(conn, 2, pending))[0]
elif length == 127:
length = struct.unpack("!Q", _recv_exact(conn, 8, pending))[0]
mask_key = _recv_exact(conn, 4, pending) if masked else b""
payload = _recv_exact(conn, length, pending) if length else b""
if masked and payload:
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
return opcode, payload
def _send_frame(conn: socket.socket, opcode: int, payload: bytes) -> None:
first = 0x80 | (opcode & 0x0F)
length = len(payload)
if length < 126:
header = bytes([first, length])
elif length <= 0xFFFF:
header = bytes([first, 126]) + struct.pack("!H", length)
else:
header = bytes([first, 127]) + struct.pack("!Q", length)
conn.sendall(header + payload)
def handle_client(conn: socket.socket) -> None:
try:
request, pending = _recv_until(conn, b"\r\n\r\n")
headers_raw = request.decode("utf-8", errors="replace").split("\r\n")
header_map: dict[str, str] = {}
for line in headers_raw[1:]:
if not line or ":" not in line:
continue
k, v = line.split(":", 1)
header_map[k.strip().lower()] = v.strip()
key = header_map.get("sec-websocket-key", "")
upgrade = header_map.get("upgrade", "").lower()
connection_hdr = header_map.get("connection", "").lower()
if not key or upgrade != "websocket" or "upgrade" not in connection_hdr:
conn.sendall(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
return
accept = base64.b64encode(hashlib.sha1((key + GUID).encode("utf-8")).digest()).decode("ascii")
response = (
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
f"Sec-WebSocket-Accept: {accept}\r\n"
"\r\n"
)
conn.sendall(response.encode("utf-8"))
while True:
opcode, payload = _read_frame(conn, pending)
if opcode == 0x8: # close
_send_frame(conn, 0x8, b"")
return
if opcode == 0x9: # ping
_send_frame(conn, 0xA, payload)
continue
if opcode == 0x1: # text
_send_frame(conn, 0x1, payload)
continue
# ignore all other opcodes
finally:
try:
conn.close()
except Exception:
pass
def main() -> int:
parser = argparse.ArgumentParser(description="WebSocket echo server")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=43174)
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind((args.host, args.port))
server.listen(16)
while True:
conn, _ = server.accept()
thread = threading.Thread(target=handle_client, args=(conn,), daemon=True)
thread.start()
if __name__ == "__main__":
raise SystemExit(main())