reset history
This commit is contained in:
commit
f4e5a2d03b
6 changed files with 682 additions and 0 deletions
442
server.py
Normal file
442
server.py
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
import copy
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socketserver
|
||||
import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from oauthlib.oauth1 import Client as OAuth1Client
|
||||
from requests_oauthlib import OAuth1Session
|
||||
from fastmcp import FastMCP
|
||||
|
||||
HTTP_METHODS = {
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"options",
|
||||
"head",
|
||||
"trace",
|
||||
}
|
||||
|
||||
LOGGER = logging.getLogger("xmcp.x_api")
|
||||
OAUTH_LOGGER = logging.getLogger("xmcp.oauth1")
|
||||
|
||||
REQUEST_TOKEN_URL = "https://api.x.com/oauth/request_token"
|
||||
AUTHORIZE_URL = "https://api.x.com/oauth/authorize"
|
||||
ACCESS_TOKEN_URL = "https://api.x.com/oauth/access_token"
|
||||
|
||||
|
||||
def is_truthy(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def parse_csv_env(key: str) -> set[str]:
|
||||
raw = os.getenv(key, "")
|
||||
if not raw.strip():
|
||||
return set()
|
||||
return {item.strip() for item in raw.split(",") if item.strip()}
|
||||
|
||||
|
||||
def should_join_query_param(param: dict) -> bool:
|
||||
if param.get("in") != "query":
|
||||
return False
|
||||
schema = param.get("schema", {})
|
||||
if schema.get("type") != "array":
|
||||
return False
|
||||
return param.get("explode") is False
|
||||
|
||||
|
||||
def collect_comma_params(spec: dict) -> set[str]:
|
||||
comma_params: set[str] = set()
|
||||
components = spec.get("components", {}).get("parameters", {})
|
||||
for param in components.values():
|
||||
if isinstance(param, dict) and should_join_query_param(param):
|
||||
name = param.get("name")
|
||||
if isinstance(name, str):
|
||||
comma_params.add(name)
|
||||
|
||||
for item in spec.get("paths", {}).values():
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for method, operation in item.items():
|
||||
if method.lower() not in HTTP_METHODS or not isinstance(operation, dict):
|
||||
continue
|
||||
for param in operation.get("parameters", []):
|
||||
if not isinstance(param, dict) or "$ref" in param:
|
||||
continue
|
||||
if should_join_query_param(param):
|
||||
name = param.get("name")
|
||||
if isinstance(name, str):
|
||||
comma_params.add(name)
|
||||
|
||||
return comma_params
|
||||
|
||||
|
||||
def load_openapi_spec() -> dict:
|
||||
url = "https://api.twitter.com/2/openapi.json"
|
||||
LOGGER.info("Fetching OpenAPI spec from %s", url)
|
||||
response = requests.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_env_int(key: str, default: int) -> int:
|
||||
raw = os.getenv(key, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
raise RuntimeError(f"{key} must be an integer value.")
|
||||
|
||||
|
||||
def _callback_url(host: str, port: int, path: str) -> str:
|
||||
return f"http://{host}:{port}{path}"
|
||||
|
||||
|
||||
def _wait_for_callback(
|
||||
host: str, port: int, path: str, timeout_seconds: int
|
||||
) -> tuple[str, str]:
|
||||
params: dict[str, str | None] = {"oauth_token": None, "oauth_verifier": None}
|
||||
event = threading.Event()
|
||||
|
||||
class _Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802 - required by BaseHTTPRequestHandler
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
if parsed.path != path:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not found.")
|
||||
return
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
params["oauth_token"] = (query.get("oauth_token") or [None])[0]
|
||||
params["oauth_verifier"] = (query.get("oauth_verifier") or [None])[0]
|
||||
event.set()
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"OAuth complete. You may close this tab.")
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None: # noqa: A003
|
||||
OAUTH_LOGGER.debug("OAuth1 callback: " + format, *args)
|
||||
|
||||
class _Server(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
server = _Server((host, port), _Handler)
|
||||
server.timeout = 1
|
||||
|
||||
deadline = time.time() + timeout_seconds
|
||||
try:
|
||||
while time.time() < deadline:
|
||||
server.handle_request()
|
||||
if event.is_set():
|
||||
break
|
||||
finally:
|
||||
server.server_close()
|
||||
|
||||
oauth_token = params.get("oauth_token")
|
||||
oauth_verifier = params.get("oauth_verifier")
|
||||
if not oauth_token or not oauth_verifier:
|
||||
raise TimeoutError("OAuth callback not received before timeout.")
|
||||
return oauth_token, oauth_verifier
|
||||
|
||||
|
||||
def run_oauth1_flow() -> tuple[str, str]:
|
||||
consumer_key = os.getenv("TWITTER_CONSUMER_KEY")
|
||||
consumer_secret = os.getenv("TWITTER_CONSUMER_SECRET")
|
||||
if not consumer_key or not consumer_secret:
|
||||
raise RuntimeError(
|
||||
"Missing TWITTER_CONSUMER_KEY or TWITTER_CONSUMER_SECRET for OAuth1 flow."
|
||||
)
|
||||
|
||||
callback_host = os.getenv("X_OAUTH_CALLBACK_HOST", "127.0.0.1")
|
||||
callback_port = _get_env_int("X_OAUTH_CALLBACK_PORT", 8976)
|
||||
callback_path = os.getenv("X_OAUTH_CALLBACK_PATH", "/oauth/callback")
|
||||
callback_timeout = _get_env_int("X_OAUTH_CALLBACK_TIMEOUT", 300)
|
||||
|
||||
callback_url = _callback_url(callback_host, callback_port, callback_path)
|
||||
|
||||
oauth = OAuth1Session(
|
||||
client_key=consumer_key,
|
||||
client_secret=consumer_secret,
|
||||
callback_uri=callback_url,
|
||||
)
|
||||
request_token = oauth.fetch_request_token(REQUEST_TOKEN_URL)
|
||||
resource_owner_key = request_token.get("oauth_token")
|
||||
resource_owner_secret = request_token.get("oauth_token_secret")
|
||||
if not resource_owner_key or not resource_owner_secret:
|
||||
raise RuntimeError("Failed to obtain OAuth request token.")
|
||||
|
||||
authorization_url = oauth.authorization_url(AUTHORIZE_URL)
|
||||
OAUTH_LOGGER.info("Opening browser for OAuth1 consent.")
|
||||
webbrowser.open(authorization_url)
|
||||
|
||||
oauth_token, oauth_verifier = _wait_for_callback(
|
||||
callback_host, callback_port, callback_path, callback_timeout
|
||||
)
|
||||
if oauth_token != resource_owner_key:
|
||||
raise RuntimeError("OAuth callback token does not match request token.")
|
||||
|
||||
oauth = OAuth1Session(
|
||||
client_key=consumer_key,
|
||||
client_secret=consumer_secret,
|
||||
resource_owner_key=resource_owner_key,
|
||||
resource_owner_secret=resource_owner_secret,
|
||||
verifier=oauth_verifier,
|
||||
)
|
||||
access_token = oauth.fetch_access_token(ACCESS_TOKEN_URL)
|
||||
access_key = access_token.get("oauth_token")
|
||||
access_secret = access_token.get("oauth_token_secret")
|
||||
if not access_key or not access_secret:
|
||||
raise RuntimeError("Failed to obtain OAuth access token.")
|
||||
return access_key, access_secret
|
||||
|
||||
|
||||
def load_env() -> None:
|
||||
env_path = Path(__file__).resolve().parent / ".env"
|
||||
if not env_path.exists():
|
||||
return
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
return
|
||||
load_dotenv(env_path, override=True)
|
||||
|
||||
|
||||
def setup_logging() -> bool:
|
||||
debug_enabled = is_truthy(os.getenv("X_API_DEBUG", "1"))
|
||||
if debug_enabled:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
LOGGER.setLevel(logging.INFO)
|
||||
return debug_enabled
|
||||
|
||||
|
||||
def should_exclude_operation(path: str, operation: dict) -> bool:
|
||||
if "/webhooks" in path or "/stream" in path:
|
||||
return True
|
||||
|
||||
tags = [tag.lower() for tag in operation.get("tags", []) if isinstance(tag, str)]
|
||||
if "stream" in tags or "webhooks" in tags:
|
||||
return True
|
||||
|
||||
if operation.get("x-twitter-streaming") is True:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def filter_openapi_spec(spec: dict) -> dict:
|
||||
filtered = copy.deepcopy(spec)
|
||||
paths = filtered.get("paths", {})
|
||||
new_paths = {}
|
||||
allow_tags = {tag.lower() for tag in parse_csv_env("X_API_TOOL_TAGS")}
|
||||
allow_ops = parse_csv_env("X_API_TOOL_ALLOWLIST")
|
||||
deny_ops = parse_csv_env("X_API_TOOL_DENYLIST")
|
||||
|
||||
for path, item in paths.items():
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
new_item = {}
|
||||
for key, value in item.items():
|
||||
if key.lower() in HTTP_METHODS:
|
||||
if should_exclude_operation(path, value):
|
||||
continue
|
||||
operation_id = value.get("operationId")
|
||||
operation_tags = [
|
||||
tag.lower()
|
||||
for tag in value.get("tags", [])
|
||||
if isinstance(tag, str)
|
||||
]
|
||||
if allow_tags and not (set(operation_tags) & allow_tags):
|
||||
continue
|
||||
if allow_ops and operation_id not in allow_ops:
|
||||
continue
|
||||
if deny_ops and operation_id in deny_ops:
|
||||
continue
|
||||
new_item[key] = value
|
||||
else:
|
||||
new_item[key] = value
|
||||
|
||||
if any(method.lower() in HTTP_METHODS for method in new_item.keys()):
|
||||
new_paths[path] = new_item
|
||||
|
||||
filtered["paths"] = new_paths
|
||||
return filtered
|
||||
|
||||
|
||||
def print_tool_list(spec: dict) -> None:
|
||||
tools: list[str] = []
|
||||
for path, item in spec.get("paths", {}).items():
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
for method, operation in item.items():
|
||||
if method.lower() not in HTTP_METHODS or not isinstance(operation, dict):
|
||||
continue
|
||||
op_id = operation.get("operationId")
|
||||
if op_id:
|
||||
tools.append(op_id)
|
||||
else:
|
||||
tools.append(f"{method.upper()} {path}")
|
||||
|
||||
tools.sort()
|
||||
print(f"Loaded {len(tools)} tools from OpenAPI:")
|
||||
for tool in tools:
|
||||
print(f"- {tool}")
|
||||
|
||||
|
||||
def get_auth_headers(oauth_token: str | None = None) -> dict:
|
||||
env_oauth_token = os.getenv("X_OAUTH_ACCESS_TOKEN", "").strip()
|
||||
bearer_token = os.getenv("X_BEARER_TOKEN", "").strip()
|
||||
token = oauth_token or env_oauth_token or bearer_token
|
||||
if not token:
|
||||
raise RuntimeError(
|
||||
"Set X_BEARER_TOKEN or provide OAuth1 access token on startup."
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def build_oauth1_client() -> OAuth1Client:
|
||||
consumer_key = os.getenv("TWITTER_CONSUMER_KEY")
|
||||
consumer_secret = os.getenv("TWITTER_CONSUMER_SECRET")
|
||||
if not consumer_key or not consumer_secret:
|
||||
raise RuntimeError(
|
||||
"Missing TWITTER_CONSUMER_KEY or TWITTER_CONSUMER_SECRET for OAuth1 signing."
|
||||
)
|
||||
access_token, access_secret = run_oauth1_flow()
|
||||
LOGGER.info("OAuth1 access token: %s", access_token)
|
||||
return OAuth1Client(
|
||||
client_key=consumer_key,
|
||||
client_secret=consumer_secret,
|
||||
resource_owner_key=access_token,
|
||||
resource_owner_secret=access_secret,
|
||||
signature_type="AUTH_HEADER",
|
||||
)
|
||||
|
||||
|
||||
def create_mcp() -> FastMCP:
|
||||
load_env()
|
||||
debug_enabled = setup_logging()
|
||||
parser_flag = os.getenv("FASTMCP_EXPERIMENTAL_ENABLE_NEW_OPENAPI_PARSER")
|
||||
if parser_flag is not None:
|
||||
os.environ["FASTMCP_EXPERIMENTAL_ENABLE_NEW_OPENAPI_PARSER"] = parser_flag
|
||||
|
||||
base_url = os.getenv("X_API_BASE_URL", "https://api.x.com")
|
||||
timeout = float(os.getenv("X_API_TIMEOUT", "30"))
|
||||
|
||||
oauth1_client = build_oauth1_client()
|
||||
|
||||
spec = load_openapi_spec()
|
||||
filtered_spec = filter_openapi_spec(spec)
|
||||
comma_params = collect_comma_params(filtered_spec)
|
||||
print_tool_list(filtered_spec)
|
||||
async def normalize_query_params(request: httpx.Request) -> None:
|
||||
if not comma_params:
|
||||
return
|
||||
params = list(request.url.params.multi_items())
|
||||
grouped: dict[str, list[str]] = {}
|
||||
ordered: list[str] = []
|
||||
normalized: list[tuple[str, str]] = []
|
||||
|
||||
for key, value in params:
|
||||
if key in comma_params:
|
||||
if key not in grouped:
|
||||
ordered.append(key)
|
||||
grouped.setdefault(key, []).append(value)
|
||||
else:
|
||||
normalized.append((key, value))
|
||||
|
||||
if not grouped:
|
||||
return
|
||||
|
||||
for key in ordered:
|
||||
values: list[str] = []
|
||||
for raw in grouped[key]:
|
||||
for part in raw.split(","):
|
||||
part = part.strip()
|
||||
if part and part not in values:
|
||||
values.append(part)
|
||||
if values:
|
||||
normalized.append((key, ",".join(values)))
|
||||
|
||||
request.url = request.url.copy_with(params=normalized)
|
||||
|
||||
b3_flags = os.getenv("X_B3_FLAGS", "1")
|
||||
|
||||
async def sign_oauth1_request(request: httpx.Request) -> None:
|
||||
request.headers["X-B3-Flags"] = b3_flags
|
||||
headers = dict(request.headers)
|
||||
content_type = headers.get("Content-Type", "")
|
||||
body: str | None = None
|
||||
if content_type.startswith("application/x-www-form-urlencoded"):
|
||||
body_bytes = request.content or b""
|
||||
body = body_bytes.decode("utf-8")
|
||||
signed_url, signed_headers, _ = oauth1_client.sign(
|
||||
str(request.url),
|
||||
http_method=request.method,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
request.url = httpx.URL(signed_url)
|
||||
request.headers.update(signed_headers)
|
||||
|
||||
async def log_request(request: httpx.Request) -> None:
|
||||
if not debug_enabled:
|
||||
return
|
||||
LOGGER.info("X API request %s %s", request.method, request.url)
|
||||
|
||||
async def log_response(response: httpx.Response) -> None:
|
||||
if not debug_enabled:
|
||||
return
|
||||
LOGGER.info(
|
||||
"X API response %s %s -> %s",
|
||||
response.request.method,
|
||||
response.request.url,
|
||||
response.status_code,
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
transaction_id = response.headers.get("x-transaction-id")
|
||||
if transaction_id:
|
||||
LOGGER.warning("X API x-transaction-id: %s", transaction_id)
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
if len(text) > 1000:
|
||||
text = text[:1000] + "...<truncated>"
|
||||
LOGGER.warning("X API error body: %s", text)
|
||||
|
||||
client = httpx.AsyncClient(
|
||||
base_url=base_url,
|
||||
headers={},
|
||||
timeout=timeout,
|
||||
event_hooks={
|
||||
"request": [normalize_query_params, sign_oauth1_request, log_request],
|
||||
"response": [log_response],
|
||||
},
|
||||
)
|
||||
return FastMCP.from_openapi(
|
||||
openapi_spec=filtered_spec,
|
||||
client=client,
|
||||
name="X API MCP",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
host = os.getenv("MCP_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MCP_PORT", "8000"))
|
||||
mcp = create_mcp()
|
||||
mcp.run(transport="http", host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue