diff --git a/tests_v2/test_ssh_remote_cli_metadata.py b/tests_v2/test_ssh_remote_cli_metadata.py index 59eee991..0b3aabfc 100644 --- a/tests_v2/test_ssh_remote_cli_metadata.py +++ b/tests_v2/test_ssh_remote_cli_metadata.py @@ -102,6 +102,28 @@ def _read_any_terminal_text(client: cmux, workspace_id: str, timeout: float = 8. return None +def _resolve_workspace_id_from_payload(client: cmux, payload: dict) -> str: + workspace_id = str(payload.get("workspace_id") or "") + if workspace_id: + return workspace_id + + workspace_ref = str(payload.get("workspace_ref") or "") + if not workspace_ref.startswith("workspace:"): + return "" + + listed = client._call("workspace.list", {}) or {} + for row in listed.get("workspaces") or []: + if str(row.get("ref") or "") == workspace_ref: + return str(row.get("id") or "") + return "" + + +def _append_workspace_to_cleanup(workspaces_to_close: list[str], workspace_id: str) -> str: + if workspace_id: + workspaces_to_close.append(workspace_id) + return workspace_id + + def main() -> int: cli = _find_cli_binary() help_text = _run_cli(cli, ["ssh", "--help"], json_output=False) @@ -120,16 +142,10 @@ def main() -> int: cli, ["ssh", "127.0.0.1", "--port", "1", "--name", "ssh-meta-test"], ) - workspace_id = str(payload.get("workspace_id") or "") - if workspace_id: - workspaces_to_close.append(workspace_id) - workspace_ref = str(payload.get("workspace_ref") or "") - if not workspace_id and workspace_ref.startswith("workspace:"): - listed = client._call("workspace.list", {}) or {} - for row in listed.get("workspaces") or []: - if str(row.get("ref") or "") == workspace_ref: - workspace_id = str(row.get("id") or "") - break + workspace_id = _append_workspace_to_cleanup( + workspaces_to_close, + _resolve_workspace_id_from_payload(client, payload), + ) _must(bool(workspace_id), f"cmux ssh output missing workspace_id: {payload}") selected_workspace_id = "" deadline_select = time.time() + 5.0 @@ -274,17 +290,11 @@ def main() -> int: cli, ["ssh", "127.0.0.1", "--port", "1"], ) - workspace_id_without_name = str(payload2.get("workspace_id") or "") - if workspace_id_without_name: - workspaces_to_close.append(workspace_id_without_name) + workspace_id_without_name = _append_workspace_to_cleanup( + workspaces_to_close, + _resolve_workspace_id_from_payload(client, payload2), + ) ssh_command_without_name = str(payload2.get("ssh_command") or "") - workspace_ref_without_name = str(payload2.get("workspace_ref") or "") - if not workspace_id_without_name and workspace_ref_without_name.startswith("workspace:"): - listed2 = client._call("workspace.list", {}) or {} - for row in listed2.get("workspaces") or []: - if str(row.get("ref") or "") == workspace_ref_without_name: - workspace_id_without_name = str(row.get("id") or "") - break _must(bool(workspace_id_without_name), f"cmux ssh without --name should still create workspace: {payload2}") _must( @@ -324,16 +334,10 @@ def main() -> int: "StrictHostKeyChecking=no", ], ) - workspace_id_strict_override = str(payload_strict_override.get("workspace_id") or "") - if workspace_id_strict_override: - workspaces_to_close.append(workspace_id_strict_override) - workspace_ref_strict_override = str(payload_strict_override.get("workspace_ref") or "") - if not workspace_id_strict_override and workspace_ref_strict_override.startswith("workspace:"): - listed_override = client._call("workspace.list", {}) or {} - for row in listed_override.get("workspaces") or []: - if str(row.get("ref") or "") == workspace_ref_strict_override: - workspace_id_strict_override = str(row.get("id") or "") - break + workspace_id_strict_override = _append_workspace_to_cleanup( + workspaces_to_close, + _resolve_workspace_id_from_payload(client, payload_strict_override), + ) _must( bool(workspace_id_strict_override), f"cmux ssh with StrictHostKeyChecking override should create workspace: {payload_strict_override}", @@ -373,16 +377,10 @@ def main() -> int: "controlpath=/tmp/cmux-ssh-%C-custom", ], ) - workspace_id_case_override = str(payload_case_override.get("workspace_id") or "") - if workspace_id_case_override: - workspaces_to_close.append(workspace_id_case_override) - workspace_ref_case_override = str(payload_case_override.get("workspace_ref") or "") - if not workspace_id_case_override and workspace_ref_case_override.startswith("workspace:"): - listed_case_override = client._call("workspace.list", {}) or {} - for row in listed_case_override.get("workspaces") or []: - if str(row.get("ref") or "") == workspace_ref_case_override: - workspace_id_case_override = str(row.get("id") or "") - break + workspace_id_case_override = _append_workspace_to_cleanup( + workspaces_to_close, + _resolve_workspace_id_from_payload(client, payload_case_override), + ) _must( bool(workspace_id_case_override), f"cmux ssh with lowercase SSH option overrides should create workspace: {payload_case_override}", @@ -467,9 +465,10 @@ def main() -> int: merged_features == "cursor,title,ssh-env,ssh-terminfo", f"cmux ssh should merge existing shell features when present: {payload3!r}", ) - workspace_id3 = str(payload3.get("workspace_id") or "") - if workspace_id3: - workspaces_to_close.append(workspace_id3) + workspace_id3 = _append_workspace_to_cleanup( + workspaces_to_close, + _resolve_workspace_id_from_payload(client, payload3), + ) if workspace_id3: try: client.close_workspace(workspace_id3) diff --git a/tests_v2/test_ssh_remote_docker_forwarding.py b/tests_v2/test_ssh_remote_docker_forwarding.py index cc7ec4b0..6661aa5c 100644 --- a/tests_v2/test_ssh_remote_docker_forwarding.py +++ b/tests_v2/test_ssh_remote_docker_forwarding.py @@ -473,6 +473,20 @@ def _local_file_sha256(path: Path) -> str: return digest.hexdigest() +def _local_binary_contains_version_marker(path: Path, version: str) -> bool: + marker = version.encode("utf-8") + tail = b"" + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + return False + haystack = tail + chunk + if marker in haystack: + return True + tail = haystack[-max(len(marker) - 1, 0) :] + + def _remote_binary_sha256(host: str, host_port: int, key_path: Path, remote_path: str) -> str: script = f""" set -eu @@ -614,8 +628,8 @@ def main() -> int: f"local daemon cache artifact must be executable: {local_cached_binary}", ) _must( - daemon_version in local_cached_binary.parts, - f"local cached daemon binary path should encode daemon version {daemon_version!r}: {local_cached_binary}", + _local_binary_contains_version_marker(local_cached_binary, daemon_version), + f"local cached daemon binary should embed daemon version marker {daemon_version!r}: {local_cached_binary}", ) local_sha256 = _local_file_sha256(local_cached_binary) remote_sha256 = _remote_binary_sha256(host, host_ssh_port, key_path, remote_path)