Socket access control: process ancestry check (#58)
* Socket access control: process ancestry check + file permissions Redesign socket control modes from (off, notifications, full) to (off, cmuxOnly, allowAll): - cmuxOnly (default): uses LOCAL_PEERPID + sysctl process tree walk to verify the connecting process is a descendant of cmux. External processes (SSH, other terminals) are rejected. - allowAll: hidden mode accessible only via CMUX_SOCKET_MODE=allowAll env var, skips ancestry check. Legacy "full"/"notifications" env values map here for backward compat. - off: disables socket entirely. Security hardening: - Server: chmod 0600 on socket after bind (owner-only access) - CLI: stat() ownership check before connect (reject fake sockets) Removes per-command allow-list (isCommandAllowed) — once a process passes the ancestry check, all commands are available. Includes migration for persisted UserDefaults values and env var aliases (cmux_only, cmux-only, allow_all, allow-all). * Add /sync-branch skill for submodule + main sync
This commit is contained in:
parent
60978d4d8b
commit
51a67e31fd
8 changed files with 577 additions and 85 deletions
38
.claude/commands/sync-branch.md
Normal file
38
.claude/commands/sync-branch.md
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Sync Branch
|
||||
|
||||
Get the current branch ready: update all submodules to their latest remote main, merge from main, and push.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Update submodules to latest**
|
||||
- For each submodule (ghostty, homebrew-cmux, vendor/bonsplit):
|
||||
- `cd <submodule>`
|
||||
- `git fetch origin`
|
||||
- Check if behind: `git rev-list HEAD..origin/main --count`
|
||||
- If behind, merge: `git merge origin/main --no-edit`
|
||||
- For ghostty specifically, push the merge to the fork: `git push origin HEAD:main`
|
||||
- Verify with: `git merge-base --is-ancestor HEAD origin/main`
|
||||
- Go back to repo root
|
||||
|
||||
2. **Commit submodule updates on main**
|
||||
- `git checkout main && git pull origin main`
|
||||
- Check if any submodules changed: `git diff --name-only` (look for submodule paths)
|
||||
- If changed, stage and commit: `git add ghostty homebrew-cmux vendor/bonsplit && git commit -m "Update submodules: <brief description>"`
|
||||
- Push main: `git push origin main`
|
||||
|
||||
3. **Rebase current branch on main**
|
||||
- `git checkout <original-branch>`
|
||||
- `git rebase main`
|
||||
- If conflicts, resolve them and continue
|
||||
- Force push if branch was already pushed: `git push --force-with-lease origin <branch>`
|
||||
|
||||
4. **Report status**
|
||||
- Show what submodules were updated and by how many commits
|
||||
- Show if rebase was clean or had conflicts
|
||||
- Show current branch and commit
|
||||
|
||||
## Notes
|
||||
|
||||
- Never commit a submodule pointer in the parent repo unless the submodule commit is reachable from the submodule's remote main (per CLAUDE.md pitfall about orphaned commits)
|
||||
- If no submodules need updating and main has no new commits, just say "Already up to date"
|
||||
- If on main already, skip step 3
|
||||
|
|
@ -289,6 +289,16 @@ final class SocketClient {
|
|||
|
||||
func connect() throws {
|
||||
if socketFD >= 0 { return }
|
||||
|
||||
// Verify socket is owned by the current user to prevent fake-socket attacks
|
||||
var st = stat()
|
||||
guard stat(path, &st) == 0 else {
|
||||
throw CLIError(message: "Socket not found at \(path)")
|
||||
}
|
||||
guard st.st_uid == getuid() else {
|
||||
throw CLIError(message: "Socket at \(path) is not owned by the current user — refusing to connect")
|
||||
}
|
||||
|
||||
socketFD = socket(AF_UNIX, SOCK_STREAM, 0)
|
||||
if socketFD < 0 {
|
||||
throw CLIError(message: "Failed to create socket")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ final class AutomationSocketUITests: XCTestCase {
|
|||
|
||||
func testSocketToggleDisablesAndEnables() {
|
||||
let app = XCUIApplication()
|
||||
app.launchArguments += ["-\(modeKey)", "notifications"]
|
||||
app.launchArguments += ["-\(modeKey)", "cmuxOnly"]
|
||||
app.launchEnvironment["CMUX_SOCKET_PATH"] = socketPath
|
||||
app.launch()
|
||||
app.activate()
|
||||
|
|
|
|||
|
|
@ -2,19 +2,24 @@ import Foundation
|
|||
|
||||
enum SocketControlMode: String, CaseIterable, Identifiable {
|
||||
case off
|
||||
case notifications
|
||||
case full
|
||||
case cmuxOnly
|
||||
/// Allow any local process to connect (no ancestry check).
|
||||
/// Only accessible via CMUX_SOCKET_MODE=allowAll env var — not shown in the UI.
|
||||
case allowAll
|
||||
|
||||
var id: String { rawValue }
|
||||
|
||||
/// Cases shown in the Settings UI. `allowAll` is intentionally excluded.
|
||||
static var uiCases: [SocketControlMode] { [.off, .cmuxOnly] }
|
||||
|
||||
var displayName: String {
|
||||
switch self {
|
||||
case .off:
|
||||
return "Off"
|
||||
case .notifications:
|
||||
return "Notifications only"
|
||||
case .full:
|
||||
return "Full control"
|
||||
case .cmuxOnly:
|
||||
return "cmux processes only"
|
||||
case .allowAll:
|
||||
return "Allow all processes"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -22,10 +27,10 @@ enum SocketControlMode: String, CaseIterable, Identifiable {
|
|||
switch self {
|
||||
case .off:
|
||||
return "Disable the local control socket."
|
||||
case .notifications:
|
||||
return "Allow only notification commands over the local socket."
|
||||
case .full:
|
||||
return "Allow all socket commands, including tab and input control."
|
||||
case .cmuxOnly:
|
||||
return "Only processes started inside cmux terminals can send commands."
|
||||
case .allowAll:
|
||||
return "Allow any local process to connect (no ancestry check)."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -34,12 +39,19 @@ struct SocketControlSettings {
|
|||
static let appStorageKey = "socketControlMode"
|
||||
static let legacyEnabledKey = "socketControlEnabled"
|
||||
|
||||
/// Map old persisted rawValues to the new enum.
|
||||
static func migrateMode(_ raw: String) -> SocketControlMode {
|
||||
switch raw {
|
||||
case "off": return .off
|
||||
case "cmuxOnly": return .cmuxOnly
|
||||
// Legacy values:
|
||||
case "notifications", "full": return .cmuxOnly
|
||||
default: return defaultMode
|
||||
}
|
||||
}
|
||||
|
||||
static var defaultMode: SocketControlMode {
|
||||
#if DEBUG
|
||||
return .full
|
||||
#else
|
||||
return .notifications
|
||||
#endif
|
||||
return .cmuxOnly
|
||||
}
|
||||
|
||||
static func socketPath() -> String {
|
||||
|
|
@ -72,7 +84,15 @@ struct SocketControlSettings {
|
|||
guard let raw = ProcessInfo.processInfo.environment["CMUX_SOCKET_MODE"], !raw.isEmpty else {
|
||||
return nil
|
||||
}
|
||||
return SocketControlMode(rawValue: raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased())
|
||||
let cleaned = raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
|
||||
switch cleaned {
|
||||
case "off": return .off
|
||||
case "cmuxonly", "cmux_only", "cmux-only": return .cmuxOnly
|
||||
case "allowall", "allow_all", "allow-all": return .allowAll
|
||||
// Legacy env var values — map to allowAll so existing test scripts keep working
|
||||
case "notifications", "full": return .allowAll
|
||||
default: return SocketControlMode(rawValue: cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
static func effectiveMode(userMode: SocketControlMode) -> SocketControlMode {
|
||||
|
|
@ -83,7 +103,7 @@ struct SocketControlSettings {
|
|||
if let overrideMode = envOverrideMode() {
|
||||
return overrideMode
|
||||
}
|
||||
return userMode == .off ? .notifications : userMode
|
||||
return userMode == .off ? .cmuxOnly : userMode
|
||||
}
|
||||
|
||||
if let overrideMode = envOverrideMode() {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ class TerminalController {
|
|||
private nonisolated(unsafe) var acceptLoopAlive = false
|
||||
private var clientHandlers: [Int32: Thread] = [:]
|
||||
private var tabManager: TabManager?
|
||||
private var accessMode: SocketControlMode = .full
|
||||
private var accessMode: SocketControlMode = .cmuxOnly
|
||||
private let myPid = getpid()
|
||||
|
||||
private enum V2HandleKind: String, CaseIterable {
|
||||
case window
|
||||
|
|
@ -73,6 +74,48 @@ class TerminalController {
|
|||
self.tabManager = tabManager
|
||||
}
|
||||
|
||||
// MARK: - Process Ancestry Check
|
||||
|
||||
/// Get the peer PID of a connected Unix domain socket using LOCAL_PEERPID.
|
||||
private func getPeerPid(_ socket: Int32) -> pid_t? {
|
||||
var pid: pid_t = 0
|
||||
var pidSize = socklen_t(MemoryLayout<pid_t>.size)
|
||||
let result = getsockopt(socket, SOL_LOCAL, LOCAL_PEERPID, &pid, &pidSize)
|
||||
guard result == 0, pid > 0 else { return nil }
|
||||
return pid
|
||||
}
|
||||
|
||||
/// Check if `pid` is a descendant of this process by walking the process tree.
|
||||
func isDescendant(_ pid: pid_t) -> Bool {
|
||||
var current = pid
|
||||
// Walk up to 128 levels to avoid infinite loops from kernel bugs
|
||||
for _ in 0..<128 {
|
||||
if current == myPid {
|
||||
return true
|
||||
}
|
||||
if current <= 1 {
|
||||
return false
|
||||
}
|
||||
let parent = parentPid(of: current)
|
||||
if parent == current || parent < 0 {
|
||||
return false
|
||||
}
|
||||
current = parent
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/// Get the parent PID of a process using sysctl.
|
||||
private func parentPid(of pid: pid_t) -> pid_t {
|
||||
var info = kinfo_proc()
|
||||
var size = MemoryLayout<kinfo_proc>.size
|
||||
var mib: [Int32] = [CTL_KERN, KERN_PROC, KERN_PROC_PID, pid]
|
||||
guard sysctl(&mib, 4, &info, &size, nil, 0) == 0 else {
|
||||
return -1
|
||||
}
|
||||
return info.kp_eproc.e_ppid
|
||||
}
|
||||
|
||||
func start(tabManager: TabManager, socketPath: String, accessMode: SocketControlMode) {
|
||||
self.tabManager = tabManager
|
||||
self.accessMode = accessMode
|
||||
|
|
@ -119,6 +162,9 @@ class TerminalController {
|
|||
return
|
||||
}
|
||||
|
||||
// Restrict socket to owner only (0600)
|
||||
chmod(socketPath, 0o600)
|
||||
|
||||
// Listen
|
||||
guard listen(serverSocket, 5) >= 0 else {
|
||||
print("TerminalController: Failed to listen on socket")
|
||||
|
|
@ -187,6 +233,21 @@ class TerminalController {
|
|||
private func handleClient(_ socket: Int32) {
|
||||
defer { close(socket) }
|
||||
|
||||
// In cmuxOnly mode, verify the connecting process is a descendant of cmux.
|
||||
// In allowAll mode (env-var only), skip the ancestry check.
|
||||
if accessMode == .cmuxOnly {
|
||||
guard let peerPid = getPeerPid(socket) else {
|
||||
let msg = "ERROR: Unable to verify client process\n"
|
||||
msg.withCString { ptr in _ = write(socket, ptr, strlen(ptr)) }
|
||||
return
|
||||
}
|
||||
guard isDescendant(peerPid) else {
|
||||
let msg = "ERROR: Access denied — only processes started inside cmux can connect\n"
|
||||
msg.withCString { ptr in _ = write(socket, ptr, strlen(ptr)) }
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var buffer = [UInt8](repeating: 0, count: 4096)
|
||||
var pending = ""
|
||||
|
||||
|
|
@ -226,9 +287,6 @@ class TerminalController {
|
|||
|
||||
let cmd = parts[0].lowercased()
|
||||
let args = parts.count > 1 ? parts[1] : ""
|
||||
if !isCommandAllowed(cmd) {
|
||||
return "ERROR: Command disabled by socket access mode"
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "ping":
|
||||
|
|
@ -512,10 +570,6 @@ class TerminalController {
|
|||
return v2Error(id: id, code: "invalid_request", message: "Missing method")
|
||||
}
|
||||
|
||||
// Apply access-mode restrictions.
|
||||
if !isV2MethodAllowed(method) {
|
||||
return v2Error(id: id, code: "forbidden", message: "Command disabled by socket access mode")
|
||||
}
|
||||
v2MainSync { self.v2RefreshKnownRefs() }
|
||||
|
||||
|
||||
|
|
@ -831,29 +885,6 @@ class TerminalController {
|
|||
}
|
||||
}
|
||||
|
||||
private func isV2MethodAllowed(_ method: String) -> Bool {
|
||||
switch accessMode {
|
||||
case .full:
|
||||
return true
|
||||
case .notifications:
|
||||
let allowed: Set<String> = [
|
||||
"system.ping",
|
||||
"system.capabilities",
|
||||
"system.identify",
|
||||
"notification.create",
|
||||
"notification.create_for_surface",
|
||||
"notification.create_for_target",
|
||||
"notification.list",
|
||||
"notification.clear",
|
||||
"app.focus_override.set",
|
||||
"app.simulate_active"
|
||||
]
|
||||
return allowed.contains(method)
|
||||
case .off:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func v2Capabilities() -> [String: Any] {
|
||||
var methods: [String] = [
|
||||
"system.ping",
|
||||
|
|
@ -6847,29 +6878,6 @@ class TerminalController {
|
|||
}
|
||||
#endif
|
||||
|
||||
private func isCommandAllowed(_ command: String) -> Bool {
|
||||
switch accessMode {
|
||||
case .full:
|
||||
return true
|
||||
case .notifications:
|
||||
let allowed: Set<String> = [
|
||||
"ping",
|
||||
"help",
|
||||
"notify",
|
||||
"notify_surface",
|
||||
"notify_target",
|
||||
"list_notifications",
|
||||
"clear_notifications",
|
||||
"set_status",
|
||||
"clear_status",
|
||||
"list_status"
|
||||
]
|
||||
return allowed.contains(command)
|
||||
case .off:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func listWindows() -> String {
|
||||
let summaries = v2MainSync { AppDelegate.shared?.listMainWindowSummaries() } ?? []
|
||||
guard !summaries.isEmpty else { return "No windows" }
|
||||
|
|
|
|||
|
|
@ -19,12 +19,15 @@ struct cmuxApp: App {
|
|||
|
||||
init() {
|
||||
configureGhosttyEnvironment()
|
||||
// Start the terminal controller for programmatic control
|
||||
// This runs after TabManager is created via @StateObject
|
||||
// Migrate legacy and old-format socket mode values to the new enum.
|
||||
let defaults = UserDefaults.standard
|
||||
if defaults.object(forKey: SocketControlSettings.appStorageKey) == nil,
|
||||
let legacy = defaults.object(forKey: SocketControlSettings.legacyEnabledKey) as? Bool {
|
||||
defaults.set(legacy ? SocketControlMode.full.rawValue : SocketControlMode.off.rawValue,
|
||||
if let stored = defaults.string(forKey: SocketControlSettings.appStorageKey) {
|
||||
let migrated = SocketControlSettings.migrateMode(stored)
|
||||
if migrated.rawValue != stored {
|
||||
defaults.set(migrated.rawValue, forKey: SocketControlSettings.appStorageKey)
|
||||
}
|
||||
} else if let legacy = defaults.object(forKey: SocketControlSettings.legacyEnabledKey) as? Bool {
|
||||
defaults.set(legacy ? SocketControlMode.cmuxOnly.rawValue : SocketControlMode.off.rawValue,
|
||||
forKey: SocketControlSettings.appStorageKey)
|
||||
}
|
||||
migrateSidebarAppearanceDefaultsIfNeeded(defaults: defaults)
|
||||
|
|
@ -522,7 +525,7 @@ struct cmuxApp: App {
|
|||
}
|
||||
|
||||
private var currentSocketMode: SocketControlMode {
|
||||
SocketControlMode(rawValue: socketControlMode) ?? SocketControlSettings.defaultMode
|
||||
SocketControlSettings.migrateMode(socketControlMode)
|
||||
}
|
||||
|
||||
private var splitRightMenuShortcut: StoredShortcut {
|
||||
|
|
@ -2250,7 +2253,7 @@ struct SettingsView: View {
|
|||
}
|
||||
|
||||
private var selectedSocketControlMode: SocketControlMode {
|
||||
SocketControlMode(rawValue: socketControlMode) ?? SocketControlSettings.defaultMode
|
||||
SocketControlSettings.migrateMode(socketControlMode)
|
||||
}
|
||||
|
||||
private var browserHistorySubtitle: String {
|
||||
|
|
@ -2341,7 +2344,7 @@ struct SettingsView: View {
|
|||
controlWidth: pickerColumnWidth
|
||||
) {
|
||||
Picker("", selection: $socketControlMode) {
|
||||
ForEach(SocketControlMode.allCases) { mode in
|
||||
ForEach(SocketControlMode.uiCases) { mode in
|
||||
Text(mode.displayName).tag(mode.rawValue)
|
||||
}
|
||||
}
|
||||
|
|
@ -2352,7 +2355,7 @@ struct SettingsView: View {
|
|||
|
||||
SettingsCardDivider()
|
||||
|
||||
SettingsCardNote("Expose a local Unix socket for programmatic control. This can be a security risk on shared machines.")
|
||||
SettingsCardNote("Controls access to the local Unix socket for programmatic control. In \"cmux processes only\" mode, only processes spawned inside cmux terminals can connect.")
|
||||
SettingsCardNote("Overrides: CMUX_SOCKET_ENABLE, CMUX_SOCKET_MODE, and CMUX_SOCKET_PATH.")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ Values: `1`, `0`, `true`, `false`, `yes`, `no`
|
|||
Override the socket access mode.
|
||||
|
||||
```bash
|
||||
export CMUX_SOCKET_MODE=notifications # Notifications only
|
||||
export CMUX_SOCKET_MODE=full # Full control
|
||||
export CMUX_SOCKET_MODE=off # Disabled
|
||||
export CMUX_SOCKET_MODE=cmuxOnly # cmux processes only (default)
|
||||
export CMUX_SOCKET_MODE=allowAll # Allow any local process (no ancestry check)
|
||||
export CMUX_SOCKET_MODE=off # Disabled
|
||||
```
|
||||
|
||||
## CLI Context
|
||||
|
|
@ -176,4 +176,4 @@ Environment variables override app settings:
|
|||
2. App settings (Settings window)
|
||||
3. Default value
|
||||
|
||||
For example, if `CMUX_SOCKET_MODE=full` is set, it overrides the app's Automation Mode setting.
|
||||
For example, if `CMUX_SOCKET_MODE=cmuxOnly` is set, it overrides the app's Automation Mode setting.
|
||||
|
|
|
|||
413
tests/test_socket_access.py
Normal file
413
tests/test_socket_access.py
Normal file
|
|
@ -0,0 +1,413 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for socket access control (process ancestry check).
|
||||
|
||||
In cmuxOnly mode (default), only processes descended from the cmux
|
||||
app process can connect. External processes (e.g., SSH) are rejected.
|
||||
|
||||
Test strategy:
|
||||
Phase 1: cmuxOnly — external processes get rejected
|
||||
Phase 2: cmuxOnly — internal process CAN connect (inject via shell rc)
|
||||
Phase 3: allowAll env override — existing test commands still work
|
||||
|
||||
Usage:
|
||||
python3 test_socket_access.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from cmux import cmux, cmuxError
|
||||
|
||||
|
||||
class TestResult:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.passed = False
|
||||
self.message = ""
|
||||
|
||||
def success(self, msg: str = ""):
|
||||
self.passed = True
|
||||
self.message = msg
|
||||
|
||||
def failure(self, msg: str):
|
||||
self.passed = False
|
||||
self.message = msg
|
||||
|
||||
|
||||
def _find_socket_path():
|
||||
return cmux().socket_path
|
||||
|
||||
|
||||
def _raw_connect(socket_path: str, timeout: float = 3.0):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
return sock
|
||||
|
||||
|
||||
def _raw_send(sock, command: str, timeout: float = 3.0) -> str:
|
||||
sock.sendall((command + "\n").encode())
|
||||
data = b""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
if b"\n" in data:
|
||||
break
|
||||
except socket.timeout:
|
||||
break
|
||||
return data.decode().strip()
|
||||
|
||||
|
||||
def _find_app():
|
||||
r = subprocess.run(
|
||||
["find", "/Users/cmux/Library/Developer/Xcode/DerivedData",
|
||||
"-path", "*/Build/Products/Debug/cmux DEV.app", "-print", "-quit"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
return r.stdout.strip()
|
||||
|
||||
|
||||
def _wait_for_socket(socket_path: str, timeout: float = 10.0) -> bool:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if os.path.exists(socket_path):
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def _kill_cmux():
|
||||
subprocess.run(["pkill", "-x", "cmux DEV"], capture_output=True)
|
||||
time.sleep(1.5)
|
||||
|
||||
|
||||
def _launch_cmux(app_path: str, socket_path: str, mode: str = None):
|
||||
env_args = []
|
||||
if mode:
|
||||
env_args = ["--env", f"CMUX_SOCKET_MODE={mode}"]
|
||||
subprocess.Popen(["open", "-a", app_path] + env_args)
|
||||
if not _wait_for_socket(socket_path):
|
||||
raise RuntimeError(f"Socket {socket_path} not created after launch")
|
||||
time.sleep(8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External rejection tests (Phase 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_external_rejected(socket_path: str) -> TestResult:
|
||||
result = TestResult("External process rejected")
|
||||
try:
|
||||
sock = _raw_connect(socket_path)
|
||||
try:
|
||||
response = _raw_send(sock, "ping")
|
||||
if "Access denied" in response:
|
||||
result.success(f"Correctly rejected")
|
||||
elif response == "PONG":
|
||||
result.failure("External allowed — ancestry check not working")
|
||||
else:
|
||||
result.failure(f"Unexpected: {response!r}")
|
||||
finally:
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
return result
|
||||
|
||||
|
||||
def test_connection_closed_after_reject(socket_path: str) -> TestResult:
|
||||
result = TestResult("Connection closed after rejection")
|
||||
try:
|
||||
sock = _raw_connect(socket_path)
|
||||
try:
|
||||
_raw_send(sock, "ping")
|
||||
try:
|
||||
sock.sendall(b"list_tabs\n")
|
||||
time.sleep(0.3)
|
||||
data = sock.recv(4096)
|
||||
if data:
|
||||
result.failure(f"Got response after rejection: {data.decode().strip()!r}")
|
||||
else:
|
||||
result.success("Connection properly closed")
|
||||
except (BrokenPipeError, ConnectionResetError, OSError):
|
||||
result.success("Connection properly closed")
|
||||
finally:
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
return result
|
||||
|
||||
|
||||
def test_rapid_reconnect(socket_path: str) -> TestResult:
|
||||
result = TestResult("Rapid reconnect all rejected")
|
||||
try:
|
||||
for i in range(20):
|
||||
try:
|
||||
sock = _raw_connect(socket_path, timeout=2.0)
|
||||
response = _raw_send(sock, "ping", timeout=1.0)
|
||||
sock.close()
|
||||
except (BrokenPipeError, ConnectionResetError, OSError):
|
||||
# Server closed connection before we could read — counts as rejection
|
||||
continue
|
||||
if "Access denied" not in response and "ERROR" not in response:
|
||||
result.failure(f"Iteration {i}: not rejected: {response!r}")
|
||||
return result
|
||||
result.success("All 20 rejected")
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
return result
|
||||
|
||||
|
||||
def test_subprocess_rejected(socket_path: str) -> TestResult:
|
||||
result = TestResult("Subprocess of external rejected")
|
||||
try:
|
||||
script = f"""
|
||||
import socket, sys, time
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(3)
|
||||
sock.connect("{socket_path}")
|
||||
sock.sendall(b"ping\\n")
|
||||
data = b""
|
||||
deadline = time.time() + 3
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk: break
|
||||
data += chunk
|
||||
if b"\\n" in data: break
|
||||
except socket.timeout: break
|
||||
sock.close()
|
||||
resp = data.decode().strip()
|
||||
if "Access denied" in resp or "ERROR" in resp:
|
||||
print("REJECTED"); sys.exit(0)
|
||||
else:
|
||||
print("ALLOWED:" + resp); sys.exit(1)
|
||||
"""
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-c", script],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if proc.returncode == 0 and "REJECTED" in proc.stdout:
|
||||
result.success("Child process rejected")
|
||||
else:
|
||||
result.failure(f"exit={proc.returncode} out={proc.stdout!r}")
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal process test (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_internal_process_allowed(socket_path: str, app_path: str) -> TestResult:
|
||||
"""
|
||||
Verify a cmux-spawned terminal process CAN connect in cmuxOnly mode.
|
||||
Inject a test via the shell rc file, then launch cmux in cmuxOnly mode.
|
||||
The shell (a descendant of cmux) runs the test on startup.
|
||||
"""
|
||||
result = TestResult("Internal process can connect (cmuxOnly)")
|
||||
marker = os.path.join(tempfile.gettempdir(), f"cmux_internal_{os.getpid()}")
|
||||
hook_file = os.path.join(tempfile.gettempdir(), f"cmux_rc_hook_{os.getpid()}.sh")
|
||||
zprofile_path = os.path.expanduser("~/.zprofile")
|
||||
|
||||
try:
|
||||
for f in [marker, hook_file]:
|
||||
if os.path.exists(f):
|
||||
os.unlink(f)
|
||||
|
||||
# Write test script: connects to socket, sends ping, writes result
|
||||
with open(hook_file, "w") as f:
|
||||
f.write(f"""#!/bin/bash
|
||||
# One-shot test hook — self-removes after running
|
||||
RESULT=$(echo "ping" | nc -U "{socket_path}" 2>/dev/null | head -1)
|
||||
if [ "$RESULT" = "PONG" ]; then
|
||||
echo "OK" > "{marker}"
|
||||
else
|
||||
echo "FAIL:$RESULT" > "{marker}"
|
||||
fi
|
||||
""")
|
||||
os.chmod(hook_file, 0o755)
|
||||
|
||||
# Append hook to .zprofile (runs on terminal startup)
|
||||
zprofile_backup = None
|
||||
if os.path.exists(zprofile_path):
|
||||
with open(zprofile_path) as f:
|
||||
zprofile_backup = f.read()
|
||||
|
||||
hook_line = f'\n[ -f "{hook_file}" ] && bash "{hook_file}" && rm -f "{hook_file}"\n'
|
||||
with open(zprofile_path, "a") as f:
|
||||
f.write(hook_line)
|
||||
|
||||
# Kill existing cmux, launch in cmuxOnly mode (default)
|
||||
_kill_cmux()
|
||||
_launch_cmux(app_path, socket_path)
|
||||
|
||||
# Wait for marker (the shell sources .zprofile on startup)
|
||||
for _ in range(40):
|
||||
if os.path.exists(marker):
|
||||
break
|
||||
time.sleep(0.5)
|
||||
|
||||
if not os.path.exists(marker):
|
||||
result.failure("Marker not created — hook didn't run in terminal")
|
||||
return result
|
||||
|
||||
with open(marker) as f:
|
||||
content = f.read().strip()
|
||||
|
||||
if content == "OK":
|
||||
result.success("Internal process pinged socket successfully in cmuxOnly mode")
|
||||
else:
|
||||
result.failure(f"Internal process got: {content!r}")
|
||||
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
finally:
|
||||
# Restore .zprofile
|
||||
if zprofile_backup is not None:
|
||||
with open(zprofile_path, "w") as f:
|
||||
f.write(zprofile_backup)
|
||||
elif os.path.exists(zprofile_path):
|
||||
# Remove the hook line we added
|
||||
with open(zprofile_path) as f:
|
||||
content = f.read()
|
||||
content = content.replace(hook_line, "")
|
||||
if content.strip():
|
||||
with open(zprofile_path, "w") as f:
|
||||
f.write(content)
|
||||
else:
|
||||
os.unlink(zprofile_path)
|
||||
|
||||
for f in [marker, hook_file]:
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# allowAll mode test (Phase 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_allowall_mode_works(socket_path: str, app_path: str) -> TestResult:
|
||||
"""Verify CMUX_SOCKET_MODE=allowAll bypasses ancestry check."""
|
||||
result = TestResult("allowAll mode allows external")
|
||||
try:
|
||||
_kill_cmux()
|
||||
_launch_cmux(app_path, socket_path, mode="allowAll")
|
||||
|
||||
sock = _raw_connect(socket_path)
|
||||
response = _raw_send(sock, "ping")
|
||||
sock.close()
|
||||
|
||||
if response == "PONG":
|
||||
result.success("External process allowed in allowAll mode")
|
||||
else:
|
||||
result.failure(f"Unexpected response: {response!r}")
|
||||
except Exception as e:
|
||||
result.failure(f"{type(e).__name__}: {e}")
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_tests():
|
||||
print("=" * 60)
|
||||
print("cmux Socket Access Control Tests")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
app_path = _find_app()
|
||||
if not app_path:
|
||||
print("Error: Could not find cmux DEV.app in DerivedData")
|
||||
return 1
|
||||
print(f"App: {app_path}")
|
||||
|
||||
socket_path = _find_socket_path()
|
||||
print(f"Socket: {socket_path}")
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
def run_test(test_fn, *args):
|
||||
name = test_fn.__name__.replace("test_", "").replace("_", " ").title()
|
||||
print(f" Testing {name}...")
|
||||
r = test_fn(*args)
|
||||
results.append(r)
|
||||
status = "\u2705" if r.passed else "\u274c"
|
||||
print(f" {status} {r.message}")
|
||||
|
||||
# ── Phase 1: cmuxOnly — external rejection ──
|
||||
print("Phase 1: cmuxOnly mode — external rejection")
|
||||
print("-" * 50)
|
||||
|
||||
# Ensure cmux is running in cmuxOnly mode
|
||||
_kill_cmux()
|
||||
print(" Launching cmux in cmuxOnly mode...")
|
||||
_launch_cmux(app_path, socket_path)
|
||||
|
||||
run_test(test_external_rejected, socket_path)
|
||||
run_test(test_connection_closed_after_reject, socket_path)
|
||||
run_test(test_rapid_reconnect, socket_path)
|
||||
run_test(test_subprocess_rejected, socket_path)
|
||||
print()
|
||||
|
||||
# ── Phase 2: cmuxOnly — internal process CAN connect ──
|
||||
print("Phase 2: cmuxOnly mode — internal process allowed")
|
||||
print("-" * 50)
|
||||
|
||||
run_test(test_internal_process_allowed, socket_path, app_path)
|
||||
print()
|
||||
|
||||
# ── Phase 3: allowAll env override ──
|
||||
print("Phase 3: allowAll mode — env override bypasses check")
|
||||
print("-" * 50)
|
||||
|
||||
run_test(test_allowall_mode_works, socket_path, app_path)
|
||||
print()
|
||||
|
||||
# ── Cleanup: leave cmux in cmuxOnly mode ──
|
||||
_kill_cmux()
|
||||
_launch_cmux(app_path, socket_path)
|
||||
|
||||
# ── Summary ──
|
||||
print("=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
total = len(results)
|
||||
|
||||
for r in results:
|
||||
status = "\u2705 PASS" if r.passed else "\u274c FAIL"
|
||||
print(f" {r.name}: {status}")
|
||||
if not r.passed and r.message:
|
||||
print(f" {r.message}")
|
||||
|
||||
print()
|
||||
print(f"Passed: {passed}/{total}")
|
||||
|
||||
if passed == total:
|
||||
print("\n\U0001f389 All tests passed!")
|
||||
return 0
|
||||
else:
|
||||
print(f"\n\u26a0\ufe0f {total - passed} test(s) failed")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(run_tests())
|
||||
Loading…
Add table
Add a link
Reference in a new issue