Fix ssh stack review regressions

This commit is contained in:
Lawrence Chen 2026-03-13 04:14:52 -07:00
parent 19b59cae37
commit 2e6856ff2f
27 changed files with 1270 additions and 506 deletions

View file

@ -14,8 +14,13 @@ import threading
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def _recv_exact(conn: socket.socket, n: int) -> bytes:
def _recv_exact(conn: socket.socket, n: int, pending: bytearray | None = None) -> bytes:
data = bytearray()
if pending:
take = min(len(pending), n)
if take:
data.extend(pending[:take])
del pending[:take]
while len(data) < n:
chunk = conn.recv(n - len(data))
if not chunk:
@ -24,7 +29,7 @@ def _recv_exact(conn: socket.socket, n: int) -> bytes:
return bytes(data)
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes:
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> tuple[bytes, bytearray]:
data = bytearray()
while marker not in data:
chunk = conn.recv(1024)
@ -33,21 +38,22 @@ def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes:
data.extend(chunk)
if len(data) > limit:
raise ValueError("header too large")
return bytes(data)
marker_end = data.index(marker) + len(marker)
return bytes(data[:marker_end]), bytearray(data[marker_end:])
def _read_frame(conn: socket.socket) -> tuple[int, bytes]:
first, second = _recv_exact(conn, 2)
def _read_frame(conn: socket.socket, pending: bytearray | None = None) -> tuple[int, bytes]:
first, second = _recv_exact(conn, 2, pending)
opcode = first & 0x0F
masked = (second & 0x80) != 0
length = second & 0x7F
if length == 126:
length = struct.unpack("!H", _recv_exact(conn, 2))[0]
length = struct.unpack("!H", _recv_exact(conn, 2, pending))[0]
elif length == 127:
length = struct.unpack("!Q", _recv_exact(conn, 8))[0]
length = struct.unpack("!Q", _recv_exact(conn, 8, pending))[0]
mask_key = _recv_exact(conn, 4) if masked else b""
payload = _recv_exact(conn, length) if length else b""
mask_key = _recv_exact(conn, 4, pending) if masked else b""
payload = _recv_exact(conn, length, pending) if length else b""
if masked and payload:
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
return opcode, payload
@ -67,7 +73,7 @@ def _send_frame(conn: socket.socket, opcode: int, payload: bytes) -> None:
def handle_client(conn: socket.socket) -> None:
try:
request = _recv_until(conn, b"\r\n\r\n")
request, pending = _recv_until(conn, b"\r\n\r\n")
headers_raw = request.decode("utf-8", errors="replace").split("\r\n")
header_map: dict[str, str] = {}
for line in headers_raw[1:]:
@ -94,7 +100,7 @@ def handle_client(conn: socket.socket) -> None:
conn.sendall(response.encode("utf-8"))
while True:
opcode, payload = _read_frame(conn)
opcode, payload = _read_frame(conn, pending)
if opcode == 0x8: # close
_send_frame(conn, 0x8, b"")
return