fix: split security hardening and setup mcporter config checks
This commit is contained in:
parent
a5682716ec
commit
4b2e6f2ffb
7 changed files with 348 additions and 114 deletions
|
|
@ -29,7 +29,7 @@ class DouyinChannel(Channel):
|
|||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "douyin" not in r.stdout:
|
||||
return "off", (
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class ExaSearchChannel(Channel):
|
|||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "exa" in r.stdout.lower():
|
||||
return "ok", "全网语义搜索可用(免费,无需 API Key)"
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ class GitHubChannel(Channel):
|
|||
if not shutil.which("gh"):
|
||||
return "warn", "gh CLI 未安装。安装:https://cli.github.com"
|
||||
try:
|
||||
subprocess.run(
|
||||
r = subprocess.run(
|
||||
["gh", "auth", "status"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
return "ok", "完整可用(读取、搜索、Fork、Issue、PR 等)"
|
||||
if r.returncode == 0:
|
||||
return "ok", "完整可用(读取、搜索、Fork、Issue、PR 等)"
|
||||
return "warn", "gh CLI 已安装但未认证。运行 gh auth login 可解锁完整功能"
|
||||
except Exception:
|
||||
return "ok", "gh CLI 已装但未认证。运行 gh auth login 可解锁完整功能"
|
||||
return "warn", "gh CLI 状态检查失败,运行 gh auth status 查看详情"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class LinkedInChannel(Channel):
|
|||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "linkedin" in r.stdout.lower():
|
||||
return "ok", "完整可用(Profile、公司、职位搜索)"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class XiaoHongShuChannel(Channel):
|
|||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "xiaohongshu" not in r.stdout:
|
||||
return "off", (
|
||||
|
|
|
|||
|
|
@ -13,18 +13,31 @@ import sys
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
# Fix Windows console encoding — emoji/CJK characters crash on cp936/cp1252
|
||||
if sys.platform == 'win32':
|
||||
import io
|
||||
if hasattr(sys.stdout, 'buffer'):
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||||
if hasattr(sys.stderr, 'buffer'):
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
|
||||
import time
|
||||
|
||||
from agent_reach import __version__
|
||||
|
||||
|
||||
def _ensure_utf8_console():
|
||||
"""Best-effort Windows console UTF-8 setup for CLI runtime only."""
|
||||
if sys.platform != "win32":
|
||||
return
|
||||
# Avoid interfering with pytest/captured streams.
|
||||
if os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
return
|
||||
if not getattr(sys.stdout, "isatty", lambda: False)():
|
||||
return
|
||||
try:
|
||||
import io
|
||||
if hasattr(sys.stdout, "buffer"):
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||||
if hasattr(sys.stderr, "buffer"):
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
# Do not crash CLI just because encoding patch failed.
|
||||
pass
|
||||
|
||||
|
||||
def _configure_logging(verbose: bool = False):
|
||||
"""Suppress loguru output unless --verbose is set."""
|
||||
from loguru import logger
|
||||
|
|
@ -34,6 +47,8 @@ def _configure_logging(verbose: bool = False):
|
|||
|
||||
|
||||
def main():
|
||||
_ensure_utf8_console()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="agent-reach",
|
||||
description="👁️ Give your AI Agent eyes to see the entire internet",
|
||||
|
|
@ -279,6 +294,7 @@ def _install_system_deps():
|
|||
import shutil
|
||||
import subprocess
|
||||
import platform
|
||||
import tempfile
|
||||
|
||||
print("🔧 Checking system dependencies...")
|
||||
|
||||
|
|
@ -290,15 +306,25 @@ def _install_system_deps():
|
|||
os_type = platform.system().lower()
|
||||
if os_type == "linux":
|
||||
try:
|
||||
# Official GitHub method for Linux
|
||||
cmds = [
|
||||
"curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg 2>/dev/null",
|
||||
'echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null',
|
||||
"apt-get update -qq 2>/dev/null",
|
||||
"apt-get install -y -qq gh 2>/dev/null",
|
||||
]
|
||||
for cmd in cmds:
|
||||
subprocess.run(cmd, shell=True, capture_output=True, timeout=60)
|
||||
# Official GitHub apt source setup without invoking a shell.
|
||||
keyring_path = "/usr/share/keyrings/githubcli-archive-keyring.gpg"
|
||||
list_path = "/etc/apt/sources.list.d/github-cli.list"
|
||||
arch = subprocess.run(
|
||||
["dpkg", "--print-architecture"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
).stdout.strip() or "amd64"
|
||||
subprocess.run(
|
||||
["curl", "-fsSL", "https://cli.github.com/packages/githubcli-archive-keyring.gpg", "-o", keyring_path],
|
||||
capture_output=True, timeout=60,
|
||||
)
|
||||
repo_line = (
|
||||
f"deb [arch={arch} signed-by={keyring_path}] "
|
||||
"https://cli.github.com/packages stable main\n"
|
||||
)
|
||||
with open(list_path, "w", encoding="utf-8") as f:
|
||||
f.write(repo_line)
|
||||
subprocess.run(["apt-get", "update", "-qq"], capture_output=True, timeout=60)
|
||||
subprocess.run(["apt-get", "install", "-y", "-qq", "gh"], capture_output=True, timeout=60)
|
||||
if shutil.which("gh"):
|
||||
print(" ✅ gh CLI installed")
|
||||
else:
|
||||
|
|
@ -326,10 +352,24 @@ def _install_system_deps():
|
|||
else:
|
||||
print(" 📥 Installing Node.js...")
|
||||
try:
|
||||
# Use NodeSource for quick install
|
||||
# Use NodeSource setup script without invoking a shell pipeline.
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".sh") as tf:
|
||||
script_path = tf.name
|
||||
subprocess.run(
|
||||
"curl -fsSL https://deb.nodesource.com/setup_22.x | bash - 2>/dev/null && apt-get install -y -qq nodejs 2>/dev/null",
|
||||
shell=True, capture_output=True, timeout=120,
|
||||
["curl", "-fsSL", "https://deb.nodesource.com/setup_22.x", "-o", script_path],
|
||||
capture_output=True, timeout=60,
|
||||
)
|
||||
subprocess.run(
|
||||
["bash", script_path],
|
||||
capture_output=True, timeout=120,
|
||||
)
|
||||
try:
|
||||
os.unlink(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
subprocess.run(
|
||||
["apt-get", "install", "-y", "-qq", "nodejs"],
|
||||
capture_output=True, timeout=120,
|
||||
)
|
||||
if shutil.which("node"):
|
||||
print(" ✅ Node.js installed")
|
||||
|
|
@ -453,7 +493,7 @@ def _install_mcporter():
|
|||
# Configure Exa MCP (free, no key needed)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "exa" not in r.stdout:
|
||||
subprocess.run(
|
||||
|
|
@ -469,7 +509,7 @@ def _install_mcporter():
|
|||
# Check XiaoHongShu MCP (only if server is running)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "list"], capture_output=True, text=True, timeout=10
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "xiaohongshu" in r.stdout:
|
||||
print(" ✅ XiaoHongShu MCP already configured")
|
||||
|
|
@ -710,29 +750,41 @@ def _cmd_setup():
|
|||
print("=" * 40)
|
||||
print()
|
||||
|
||||
# Step 1: Exa
|
||||
print("【推荐】全网搜索 — Exa Search API")
|
||||
print(" 免费 1000 次/月,注册地址: https://exa.ai")
|
||||
current = config.get("exa_api_key")
|
||||
if current:
|
||||
print(f" 当前状态: ✅ 已配置 ({current[:8]}...)")
|
||||
change = input(" 要更换吗?[y/N]: ").strip().lower()
|
||||
if change != "y":
|
||||
print()
|
||||
else:
|
||||
key = input(" EXA_API_KEY: ").strip()
|
||||
if key:
|
||||
config.set("exa_api_key", key)
|
||||
print(" ✅ 已更新!")
|
||||
print()
|
||||
# Step 1: Exa (via mcporter, no API key required)
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
print("【推荐】全网搜索 — Exa(通过 mcporter)")
|
||||
print(" 免费,无需 API Key")
|
||||
|
||||
if not shutil.which("mcporter"):
|
||||
print(" 当前状态: ⬜ mcporter 未安装")
|
||||
print(" 安装:npm install -g mcporter")
|
||||
print(" 然后:mcporter config add exa https://mcp.exa.ai/mcp")
|
||||
print()
|
||||
else:
|
||||
print(" 当前状态: ⬜ 未配置")
|
||||
key = input(" EXA_API_KEY (回车跳过): ").strip()
|
||||
if key:
|
||||
config.set("exa_api_key", key)
|
||||
print(" ✅ 全网搜索 + Reddit搜索 + Twitter搜索 已开启!")
|
||||
else:
|
||||
print(" ℹ️ 跳过。稍后可运行 agent-reach setup 配置")
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["mcporter", "config", "list"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if "exa" in r.stdout.lower():
|
||||
print(" 当前状态: ✅ 已配置")
|
||||
else:
|
||||
print(" 当前状态: ⬜ 未配置")
|
||||
setup_now = input(" 现在自动配置 Exa 吗?[Y/n]: ").strip().lower()
|
||||
if setup_now in ("", "y", "yes"):
|
||||
add_r = subprocess.run(
|
||||
["mcporter", "config", "add", "exa", "https://mcp.exa.ai/mcp"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if add_r.returncode == 0:
|
||||
print(" ✅ Exa 已配置")
|
||||
else:
|
||||
print(" ⚠️ 自动配置失败,请手动执行:")
|
||||
print(" mcporter config add exa https://mcp.exa.ai/mcp")
|
||||
except Exception:
|
||||
print(" ⚠️ 无法检查 Exa 配置,请手动执行:")
|
||||
print(" mcporter config add exa https://mcp.exa.ai/mcp")
|
||||
print()
|
||||
|
||||
# Step 2: GitHub token
|
||||
|
|
@ -789,62 +841,160 @@ def _cmd_setup():
|
|||
print()
|
||||
|
||||
|
||||
def _classify_update_error(exc):
|
||||
"""Classify update-check errors for user-friendly diagnostics."""
|
||||
import requests
|
||||
|
||||
if isinstance(exc, requests.exceptions.Timeout):
|
||||
return "timeout"
|
||||
if isinstance(exc, requests.exceptions.ConnectionError):
|
||||
msg = str(exc).lower()
|
||||
dns_markers = [
|
||||
"name or service not known",
|
||||
"temporary failure in name resolution",
|
||||
"nodename nor servname",
|
||||
"getaddrinfo failed",
|
||||
"name resolution",
|
||||
"dns",
|
||||
]
|
||||
if any(marker in msg for marker in dns_markers):
|
||||
return "dns"
|
||||
return "connection"
|
||||
if isinstance(exc, requests.exceptions.HTTPError):
|
||||
return "http"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _update_error_text(kind):
|
||||
"""Map internal error kinds to user-facing text."""
|
||||
mapping = {
|
||||
"timeout": "网络超时",
|
||||
"dns": "DNS 解析失败",
|
||||
"rate_limit": "GitHub API 速率限制",
|
||||
"connection": "网络连接失败",
|
||||
"server_error": "GitHub 服务暂时不可用",
|
||||
"http": "HTTP 请求失败",
|
||||
"unknown": "未知网络错误",
|
||||
}
|
||||
return mapping.get(kind, "请求失败")
|
||||
|
||||
|
||||
def _classify_github_response_error(resp):
|
||||
"""Classify non-200 GitHub responses that merit special handling."""
|
||||
if resp is None:
|
||||
return "unknown"
|
||||
if resp.status_code == 429:
|
||||
return "rate_limit"
|
||||
if resp.status_code == 403:
|
||||
remaining = resp.headers.get("X-RateLimit-Remaining", "")
|
||||
if remaining == "0":
|
||||
return "rate_limit"
|
||||
try:
|
||||
message = resp.json().get("message", "").lower()
|
||||
if "rate limit" in message:
|
||||
return "rate_limit"
|
||||
except Exception:
|
||||
pass
|
||||
if 500 <= resp.status_code < 600:
|
||||
return "server_error"
|
||||
return None
|
||||
|
||||
|
||||
def _github_get_with_retry(url, timeout=10, retries=3, sleeper=time.sleep):
|
||||
"""GET GitHub API with retry/backoff and basic error classification."""
|
||||
import requests
|
||||
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
resp = requests.get(url, timeout=timeout)
|
||||
except requests.exceptions.RequestException as exc:
|
||||
if attempt >= retries:
|
||||
return None, _classify_update_error(exc), attempt
|
||||
sleeper(2 ** (attempt - 1))
|
||||
continue
|
||||
|
||||
err_kind = _classify_github_response_error(resp)
|
||||
if err_kind in ("rate_limit", "server_error"):
|
||||
if attempt >= retries:
|
||||
return None, err_kind, attempt
|
||||
delay = 2 ** (attempt - 1)
|
||||
retry_after = resp.headers.get("Retry-After")
|
||||
if err_kind == "rate_limit" and retry_after:
|
||||
try:
|
||||
delay = max(delay, float(retry_after))
|
||||
except Exception:
|
||||
pass
|
||||
sleeper(delay)
|
||||
continue
|
||||
|
||||
return resp, None, attempt
|
||||
|
||||
return None, "unknown", retries
|
||||
|
||||
|
||||
def _cmd_check_update():
|
||||
"""Check for newer versions on GitHub."""
|
||||
import requests
|
||||
from agent_reach import __version__
|
||||
|
||||
print(f"📦 当前版本: v{__version__}")
|
||||
release_url = "https://api.github.com/repos/Panniantong/Agent-Reach/releases/latest"
|
||||
commit_url = "https://api.github.com/repos/Panniantong/Agent-Reach/commits/main"
|
||||
|
||||
try:
|
||||
# Fetch latest version from GitHub
|
||||
resp = requests.get(
|
||||
"https://api.github.com/repos/Panniantong/Agent-Reach/releases/latest",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
latest = data.get("tag_name", "").lstrip("v")
|
||||
body = data.get("body", "")
|
||||
|
||||
if latest and latest != __version__:
|
||||
print(f"🆕 最新版本: v{latest} ← 有更新!")
|
||||
if body:
|
||||
print()
|
||||
print("更新内容:")
|
||||
# Show first 20 lines of release notes
|
||||
for line in body.strip().split("\n")[:20]:
|
||||
print(f" {line}")
|
||||
print()
|
||||
print("更新命令:")
|
||||
print(" pip install --upgrade https://github.com/Panniantong/agent-reach/archive/main.zip")
|
||||
return "update_available"
|
||||
else:
|
||||
print(f"✅ 已是最新版本")
|
||||
return "up_to_date"
|
||||
else:
|
||||
# No releases yet, fall back to comparing commit
|
||||
resp2 = requests.get(
|
||||
"https://api.github.com/repos/Panniantong/Agent-Reach/commits/main",
|
||||
timeout=10,
|
||||
)
|
||||
if resp2.status_code == 200:
|
||||
commit = resp2.json()
|
||||
sha = commit.get("sha", "")[:7]
|
||||
msg = commit.get("commit", {}).get("message", "").split("\n")[0]
|
||||
date = commit.get("commit", {}).get("committer", {}).get("date", "")[:10]
|
||||
print(f"🔍 最新提交: {sha} ({date}) {msg}")
|
||||
print()
|
||||
print("更新命令:")
|
||||
print(" pip install --upgrade https://github.com/Panniantong/agent-reach/archive/main.zip")
|
||||
return "unknown"
|
||||
else:
|
||||
print("⚠️ 无法检查更新(网络问题)")
|
||||
return "error"
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法检查更新: {e}")
|
||||
# Fetch latest release with retry/backoff.
|
||||
resp, err, attempts = _github_get_with_retry(release_url, timeout=10, retries=3)
|
||||
if err:
|
||||
print(f"⚠️ 无法检查更新({_update_error_text(err)},已重试 {attempts} 次)")
|
||||
return "error"
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
latest = data.get("tag_name", "").lstrip("v")
|
||||
body = data.get("body", "")
|
||||
|
||||
if latest and latest != __version__:
|
||||
print(f"🆕 最新版本: v{latest} ← 有更新!")
|
||||
if body:
|
||||
print()
|
||||
print("更新内容:")
|
||||
# Show first 20 lines of release notes
|
||||
for line in body.strip().split("\n")[:20]:
|
||||
print(f" {line}")
|
||||
print()
|
||||
print("更新命令:")
|
||||
print(" pip install --upgrade https://github.com/Panniantong/agent-reach/archive/main.zip")
|
||||
return "update_available"
|
||||
print(f"✅ 已是最新版本")
|
||||
return "up_to_date"
|
||||
|
||||
release_err = _classify_github_response_error(resp)
|
||||
if release_err == "rate_limit":
|
||||
print("⚠️ 无法检查更新(GitHub API 速率限制,请稍后重试)")
|
||||
return "error"
|
||||
|
||||
# No releases yet, fall back to latest main commit.
|
||||
resp2, err2, attempts2 = _github_get_with_retry(commit_url, timeout=10, retries=2)
|
||||
if err2:
|
||||
print(f"⚠️ 无法检查更新({_update_error_text(err2)},已重试 {attempts + attempts2} 次)")
|
||||
return "error"
|
||||
if resp2.status_code == 200:
|
||||
commit = resp2.json()
|
||||
sha = commit.get("sha", "")[:7]
|
||||
msg = commit.get("commit", {}).get("message", "").split("\n")[0]
|
||||
date = commit.get("commit", {}).get("committer", {}).get("date", "")[:10]
|
||||
print(f"🔍 最新提交: {sha} ({date}) {msg}")
|
||||
print()
|
||||
print("更新命令:")
|
||||
print(" pip install --upgrade https://github.com/Panniantong/agent-reach/archive/main.zip")
|
||||
return "unknown"
|
||||
|
||||
commit_err = _classify_github_response_error(resp2)
|
||||
if commit_err == "rate_limit":
|
||||
print("⚠️ 无法检查更新(GitHub API 速率限制,请稍后重试)")
|
||||
return "error"
|
||||
|
||||
print(f"⚠️ 无法检查更新(GitHub 返回 {resp2.status_code})")
|
||||
return "error"
|
||||
|
||||
|
||||
def _cmd_watch():
|
||||
"""Quick health check + update check, designed for scheduled tasks.
|
||||
|
|
@ -853,7 +1003,6 @@ def _cmd_watch():
|
|||
"""
|
||||
from agent_reach.config import Config
|
||||
from agent_reach.doctor import check_all
|
||||
import requests
|
||||
from agent_reach import __version__
|
||||
|
||||
config = Config()
|
||||
|
|
@ -875,20 +1024,18 @@ def _cmd_watch():
|
|||
update_available = False
|
||||
new_version = ""
|
||||
release_body = ""
|
||||
try:
|
||||
resp = requests.get(
|
||||
"https://api.github.com/repos/Panniantong/Agent-Reach/releases/latest",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
latest = data.get("tag_name", "").lstrip("v")
|
||||
if latest and latest != __version__:
|
||||
update_available = True
|
||||
new_version = latest
|
||||
release_body = data.get("body", "")
|
||||
except Exception:
|
||||
pass
|
||||
resp, err, _attempts = _github_get_with_retry(
|
||||
"https://api.github.com/repos/Panniantong/Agent-Reach/releases/latest",
|
||||
timeout=10,
|
||||
retries=2,
|
||||
)
|
||||
if not err and resp and resp.status_code == 200:
|
||||
data = resp.json()
|
||||
latest = data.get("tag_name", "").lstrip("v")
|
||||
if latest and latest != __version__:
|
||||
update_available = True
|
||||
new_version = latest
|
||||
release_body = data.get("body", "")
|
||||
|
||||
# Output
|
||||
if not issues and not update_available:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@
|
|||
"""Tests for Agent Reach CLI."""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from unittest.mock import patch
|
||||
import agent_reach.cli as cli
|
||||
from agent_reach.cli import main
|
||||
|
||||
|
||||
|
|
@ -27,3 +29,86 @@ class TestCLI:
|
|||
captured = capsys.readouterr()
|
||||
assert "Agent Reach" in captured.out
|
||||
assert "✅" in captured.out
|
||||
|
||||
|
||||
class TestCheckUpdateRetry:
|
||||
def test_retry_timeout_classification(self):
|
||||
sleeps = []
|
||||
|
||||
def fake_sleep(seconds):
|
||||
sleeps.append(seconds)
|
||||
|
||||
with patch("requests.get", side_effect=requests.exceptions.Timeout("timed out")):
|
||||
resp, err, attempts = cli._github_get_with_retry(
|
||||
"https://api.github.com/test",
|
||||
timeout=1,
|
||||
retries=3,
|
||||
sleeper=fake_sleep,
|
||||
)
|
||||
|
||||
assert resp is None
|
||||
assert err == "timeout"
|
||||
assert attempts == 3
|
||||
assert sleeps == [1, 2]
|
||||
|
||||
def test_retry_dns_classification(self):
|
||||
error = requests.exceptions.ConnectionError("getaddrinfo failed for api.github.com")
|
||||
with patch("requests.get", side_effect=error):
|
||||
resp, err, attempts = cli._github_get_with_retry(
|
||||
"https://api.github.com/test",
|
||||
retries=1,
|
||||
sleeper=lambda _x: None,
|
||||
)
|
||||
assert resp is None
|
||||
assert err == "dns"
|
||||
assert attempts == 1
|
||||
|
||||
def test_retry_rate_limit_then_success(self):
|
||||
sleeps = []
|
||||
|
||||
class R:
|
||||
def __init__(self, code, payload=None, headers=None):
|
||||
self.status_code = code
|
||||
self._payload = payload or {}
|
||||
self.headers = headers or {}
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
sequence = [
|
||||
R(429, headers={"Retry-After": "3"}),
|
||||
R(200, payload={"tag_name": "v1.2.0"}),
|
||||
]
|
||||
|
||||
with patch("requests.get", side_effect=sequence):
|
||||
resp, err, attempts = cli._github_get_with_retry(
|
||||
"https://api.github.com/test",
|
||||
retries=3,
|
||||
sleeper=lambda s: sleeps.append(s),
|
||||
)
|
||||
|
||||
assert err is None
|
||||
assert resp is not None
|
||||
assert resp.status_code == 200
|
||||
assert attempts == 2
|
||||
assert sleeps == [3.0]
|
||||
|
||||
def test_classify_rate_limit_from_403(self):
|
||||
class R:
|
||||
status_code = 403
|
||||
headers = {"X-RateLimit-Remaining": "0"}
|
||||
|
||||
@staticmethod
|
||||
def json():
|
||||
return {"message": "API rate limit exceeded"}
|
||||
|
||||
assert cli._classify_github_response_error(R()) == "rate_limit"
|
||||
|
||||
def test_check_update_reports_classified_error(self, capsys):
|
||||
with patch("agent_reach.cli._github_get_with_retry", return_value=(None, "timeout", 3)):
|
||||
result = cli._cmd_check_update()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert result == "error"
|
||||
assert "网络超时" in captured.out
|
||||
assert "已重试 3 次" in captured.out
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue