Fix ssh stack review regressions
This commit is contained in:
parent
19b59cae37
commit
2e6856ff2f
27 changed files with 1270 additions and 506 deletions
28
tests/fixtures/ssh-remote/ws_echo.py
vendored
28
tests/fixtures/ssh-remote/ws_echo.py
vendored
|
|
@ -14,8 +14,13 @@ import threading
|
|||
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
def _recv_exact(conn: socket.socket, n: int) -> bytes:
|
||||
def _recv_exact(conn: socket.socket, n: int, pending: bytearray | None = None) -> bytes:
|
||||
data = bytearray()
|
||||
if pending:
|
||||
take = min(len(pending), n)
|
||||
if take:
|
||||
data.extend(pending[:take])
|
||||
del pending[:take]
|
||||
while len(data) < n:
|
||||
chunk = conn.recv(n - len(data))
|
||||
if not chunk:
|
||||
|
|
@ -24,7 +29,7 @@ def _recv_exact(conn: socket.socket, n: int) -> bytes:
|
|||
return bytes(data)
|
||||
|
||||
|
||||
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes:
|
||||
def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> tuple[bytes, bytearray]:
|
||||
data = bytearray()
|
||||
while marker not in data:
|
||||
chunk = conn.recv(1024)
|
||||
|
|
@ -33,21 +38,22 @@ def _recv_until(conn: socket.socket, marker: bytes, limit: int = 8192) -> bytes:
|
|||
data.extend(chunk)
|
||||
if len(data) > limit:
|
||||
raise ValueError("header too large")
|
||||
return bytes(data)
|
||||
marker_end = data.index(marker) + len(marker)
|
||||
return bytes(data[:marker_end]), bytearray(data[marker_end:])
|
||||
|
||||
|
||||
def _read_frame(conn: socket.socket) -> tuple[int, bytes]:
|
||||
first, second = _recv_exact(conn, 2)
|
||||
def _read_frame(conn: socket.socket, pending: bytearray | None = None) -> tuple[int, bytes]:
|
||||
first, second = _recv_exact(conn, 2, pending)
|
||||
opcode = first & 0x0F
|
||||
masked = (second & 0x80) != 0
|
||||
length = second & 0x7F
|
||||
if length == 126:
|
||||
length = struct.unpack("!H", _recv_exact(conn, 2))[0]
|
||||
length = struct.unpack("!H", _recv_exact(conn, 2, pending))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack("!Q", _recv_exact(conn, 8))[0]
|
||||
length = struct.unpack("!Q", _recv_exact(conn, 8, pending))[0]
|
||||
|
||||
mask_key = _recv_exact(conn, 4) if masked else b""
|
||||
payload = _recv_exact(conn, length) if length else b""
|
||||
mask_key = _recv_exact(conn, 4, pending) if masked else b""
|
||||
payload = _recv_exact(conn, length, pending) if length else b""
|
||||
if masked and payload:
|
||||
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||
return opcode, payload
|
||||
|
|
@ -67,7 +73,7 @@ def _send_frame(conn: socket.socket, opcode: int, payload: bytes) -> None:
|
|||
|
||||
def handle_client(conn: socket.socket) -> None:
|
||||
try:
|
||||
request = _recv_until(conn, b"\r\n\r\n")
|
||||
request, pending = _recv_until(conn, b"\r\n\r\n")
|
||||
headers_raw = request.decode("utf-8", errors="replace").split("\r\n")
|
||||
header_map: dict[str, str] = {}
|
||||
for line in headers_raw[1:]:
|
||||
|
|
@ -94,7 +100,7 @@ def handle_client(conn: socket.socket) -> None:
|
|||
conn.sendall(response.encode("utf-8"))
|
||||
|
||||
while True:
|
||||
opcode, payload = _read_frame(conn)
|
||||
opcode, payload = _read_frame(conn, pending)
|
||||
if opcode == 0x8: # close
|
||||
_send_frame(conn, 0x8, b"")
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue