Tighten ssh remote regression cleanup

This commit is contained in:
Lawrence Chen 2026-03-11 23:22:00 -07:00
parent de47345538
commit b6f0e3a3f6
2 changed files with 58 additions and 45 deletions

View file

@ -102,6 +102,28 @@ def _read_any_terminal_text(client: cmux, workspace_id: str, timeout: float = 8.
return None return None
def _resolve_workspace_id_from_payload(client: cmux, payload: dict) -> str:
workspace_id = str(payload.get("workspace_id") or "")
if workspace_id:
return workspace_id
workspace_ref = str(payload.get("workspace_ref") or "")
if not workspace_ref.startswith("workspace:"):
return ""
listed = client._call("workspace.list", {}) or {}
for row in listed.get("workspaces") or []:
if str(row.get("ref") or "") == workspace_ref:
return str(row.get("id") or "")
return ""
def _append_workspace_to_cleanup(workspaces_to_close: list[str], workspace_id: str) -> str:
if workspace_id:
workspaces_to_close.append(workspace_id)
return workspace_id
def main() -> int: def main() -> int:
cli = _find_cli_binary() cli = _find_cli_binary()
help_text = _run_cli(cli, ["ssh", "--help"], json_output=False) help_text = _run_cli(cli, ["ssh", "--help"], json_output=False)
@ -120,16 +142,10 @@ def main() -> int:
cli, cli,
["ssh", "127.0.0.1", "--port", "1", "--name", "ssh-meta-test"], ["ssh", "127.0.0.1", "--port", "1", "--name", "ssh-meta-test"],
) )
workspace_id = str(payload.get("workspace_id") or "") workspace_id = _append_workspace_to_cleanup(
if workspace_id: workspaces_to_close,
workspaces_to_close.append(workspace_id) _resolve_workspace_id_from_payload(client, payload),
workspace_ref = str(payload.get("workspace_ref") or "") )
if not workspace_id and workspace_ref.startswith("workspace:"):
listed = client._call("workspace.list", {}) or {}
for row in listed.get("workspaces") or []:
if str(row.get("ref") or "") == workspace_ref:
workspace_id = str(row.get("id") or "")
break
_must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}") _must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}")
selected_workspace_id = "" selected_workspace_id = ""
deadline_select = time.time() + 5.0 deadline_select = time.time() + 5.0
@ -274,17 +290,11 @@ def main() -> int:
cli, cli,
["ssh", "127.0.0.1", "--port", "1"], ["ssh", "127.0.0.1", "--port", "1"],
) )
workspace_id_without_name = str(payload2.get("workspace_id") or "") workspace_id_without_name = _append_workspace_to_cleanup(
if workspace_id_without_name: workspaces_to_close,
workspaces_to_close.append(workspace_id_without_name) _resolve_workspace_id_from_payload(client, payload2),
)
ssh_command_without_name = str(payload2.get("ssh_command") or "") ssh_command_without_name = str(payload2.get("ssh_command") or "")
workspace_ref_without_name = str(payload2.get("workspace_ref") or "")
if not workspace_id_without_name and workspace_ref_without_name.startswith("workspace:"):
listed2 = client._call("workspace.list", {}) or {}
for row in listed2.get("workspaces") or []:
if str(row.get("ref") or "") == workspace_ref_without_name:
workspace_id_without_name = str(row.get("id") or "")
break
_must(bool(workspace_id_without_name), f"cmux ssh without --name should still create workspace: {payload2}") _must(bool(workspace_id_without_name), f"cmux ssh without --name should still create workspace: {payload2}")
_must( _must(
@ -324,16 +334,10 @@ def main() -> int:
"StrictHostKeyChecking=no", "StrictHostKeyChecking=no",
], ],
) )
workspace_id_strict_override = str(payload_strict_override.get("workspace_id") or "") workspace_id_strict_override = _append_workspace_to_cleanup(
if workspace_id_strict_override: workspaces_to_close,
workspaces_to_close.append(workspace_id_strict_override) _resolve_workspace_id_from_payload(client, payload_strict_override),
workspace_ref_strict_override = str(payload_strict_override.get("workspace_ref") or "") )
if not workspace_id_strict_override and workspace_ref_strict_override.startswith("workspace:"):
listed_override = client._call("workspace.list", {}) or {}
for row in listed_override.get("workspaces") or []:
if str(row.get("ref") or "") == workspace_ref_strict_override:
workspace_id_strict_override = str(row.get("id") or "")
break
_must( _must(
bool(workspace_id_strict_override), bool(workspace_id_strict_override),
f"cmux ssh with StrictHostKeyChecking override should create workspace: {payload_strict_override}", f"cmux ssh with StrictHostKeyChecking override should create workspace: {payload_strict_override}",
@ -373,16 +377,10 @@ def main() -> int:
"controlpath=/tmp/cmux-ssh-%C-custom", "controlpath=/tmp/cmux-ssh-%C-custom",
], ],
) )
workspace_id_case_override = str(payload_case_override.get("workspace_id") or "") workspace_id_case_override = _append_workspace_to_cleanup(
if workspace_id_case_override: workspaces_to_close,
workspaces_to_close.append(workspace_id_case_override) _resolve_workspace_id_from_payload(client, payload_case_override),
workspace_ref_case_override = str(payload_case_override.get("workspace_ref") or "") )
if not workspace_id_case_override and workspace_ref_case_override.startswith("workspace:"):
listed_case_override = client._call("workspace.list", {}) or {}
for row in listed_case_override.get("workspaces") or []:
if str(row.get("ref") or "") == workspace_ref_case_override:
workspace_id_case_override = str(row.get("id") or "")
break
_must( _must(
bool(workspace_id_case_override), bool(workspace_id_case_override),
f"cmux ssh with lowercase SSH option overrides should create workspace: {payload_case_override}", f"cmux ssh with lowercase SSH option overrides should create workspace: {payload_case_override}",
@ -467,9 +465,10 @@ def main() -> int:
merged_features == "cursor,title,ssh-env,ssh-terminfo", merged_features == "cursor,title,ssh-env,ssh-terminfo",
f"cmux ssh should merge existing shell features when present: {payload3!r}", f"cmux ssh should merge existing shell features when present: {payload3!r}",
) )
workspace_id3 = str(payload3.get("workspace_id") or "") workspace_id3 = _append_workspace_to_cleanup(
if workspace_id3: workspaces_to_close,
workspaces_to_close.append(workspace_id3) _resolve_workspace_id_from_payload(client, payload3),
)
if workspace_id3: if workspace_id3:
try: try:
client.close_workspace(workspace_id3) client.close_workspace(workspace_id3)

View file

@ -473,6 +473,20 @@ def _local_file_sha256(path: Path) -> str:
return digest.hexdigest() return digest.hexdigest()
def _local_binary_contains_version_marker(path: Path, version: str) -> bool:
marker = version.encode("utf-8")
tail = b""
with path.open("rb") as handle:
while True:
chunk = handle.read(1024 * 1024)
if not chunk:
return False
haystack = tail + chunk
if marker in haystack:
return True
tail = haystack[-max(len(marker) - 1, 0) :]
def _remote_binary_sha256(host: str, host_port: int, key_path: Path, remote_path: str) -> str: def _remote_binary_sha256(host: str, host_port: int, key_path: Path, remote_path: str) -> str:
script = f""" script = f"""
set -eu set -eu
@ -614,8 +628,8 @@ def main() -> int:
f"local daemon cache artifact must be executable: {local_cached_binary}", f"local daemon cache artifact must be executable: {local_cached_binary}",
) )
_must( _must(
daemon_version in local_cached_binary.parts, _local_binary_contains_version_marker(local_cached_binary, daemon_version),
f"local cached daemon binary path should encode daemon version {daemon_version!r}: {local_cached_binary}", f"local cached daemon binary should embed daemon version marker {daemon_version!r}: {local_cached_binary}",
) )
local_sha256 = _local_file_sha256(local_cached_binary) local_sha256 = _local_file_sha256(local_cached_binary)
remote_sha256 = _remote_binary_sha256(host, host_ssh_port, key_path, remote_path) remote_sha256 = _remote_binary_sha256(host, host_ssh_port, key_path, remote_path)