diff --git a/common.py b/common.py index 2129575..20ee7fa 100644 --- a/common.py +++ b/common.py @@ -1,20 +1,22 @@ -"""Shared OAuth + user resolution for all MCP servers. +"""Shared OAuth for all MCP servers. -OAuth client_credentials flow: -1. claude.ai discovers /.well-known/oauth-authorization-server -2. claude.ai POSTs to /token with client_id + client_secret -3. Server returns access_token -4. claude.ai sends Bearer access_token with MCP requests -5. Server resolves user from token +Supports OAuth 2.0 Authorization Code flow with PKCE (for claude.ai): +1. GET /authorize → auto-approves if client_id+client_secret valid, redirects with code +2. POST /token grant_type=authorization_code → exchanges code for access_token +3. Bearer access_token on MCP requests """ import json import os import secrets +import hashlib +import base64 import contextvars import time +from urllib.parse import urlencode, parse_qs, urlparse + from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse from starlette.routing import Route from starlette.middleware.base import BaseHTTPMiddleware @@ -35,6 +37,7 @@ def load_config(): _tokens_cache = None _current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None) _access_tokens: dict[str, dict] = {} +_auth_codes: dict[str, dict] = {} def _load_tokens(): @@ -71,39 +74,100 @@ def _resolve_access_token(token): return info["user"] +def _verify_pkce(code_verifier, code_challenge, method): + if method == "S256": + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return computed == code_challenge + return code_verifier == code_challenge + + +# --- OAuth Endpoints --- + async def oauth_metadata(request: Request): - base = str(request.base_url).rstrip("/") + proto = request.headers.get("x-forwarded-proto", request.url.scheme) + host = request.headers.get("host", "") + prefix = request.headers.get("x-forwarded-prefix", "") + base = f"{proto}://{host}{prefix}" return JSONResponse({ "issuer": base, + "authorization_endpoint": base + "/authorize", "token_endpoint": base + "/token", - "response_types_supported": ["token"], - "grant_types_supported": ["client_credentials"], + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "client_credentials"], + "code_challenge_methods_supported": ["S256", "plain"], "token_endpoint_auth_methods_supported": ["client_secret_post"], }) +async def oauth_authorize(request: Request): + params = dict(request.query_params) + client_id = params.get("client_id", "") + redirect_uri = params.get("redirect_uri", "") + state = params.get("state", "") + code_challenge = params.get("code_challenge", "") + code_challenge_method = params.get("code_challenge_method", "plain") + + tokens = _load_tokens() + if client_id not in tokens: + return HTMLResponse(f"
Unbekannte Client ID: {client_id}
", status_code=400) + + code = secrets.token_urlsafe(32) + _auth_codes[code] = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "code_challenge_method": code_challenge_method, + "expires_at": time.time() + 300, + } + + sep = "&" if "?" in redirect_uri else "?" + location = f"{redirect_uri}{sep}code={code}" + if state: + location += f"&state={state}" + return RedirectResponse(location, status_code=302) + + async def oauth_token(request: Request): try: form = await request.form() - grant_type = form.get("grant_type", "") - client_id = form.get("client_id", "") - client_secret = form.get("client_secret", "") + data = dict(form) except Exception: - body = await request.body() try: - data = json.loads(body) - grant_type = data.get("grant_type", "") - client_id = data.get("client_id", "") - client_secret = data.get("client_secret", "") + data = json.loads(await request.body()) except Exception: return JSONResponse({"error": "invalid_request"}, status_code=400) - if grant_type != "client_credentials": - return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) + grant_type = data.get("grant_type", "") + client_id = data.get("client_id", "") + client_secret = data.get("client_secret", "") - user = _resolve_client(client_id, client_secret) - if not user: - return JSONResponse({"error": "invalid_client"}, status_code=401) + if grant_type == "authorization_code": + code = data.get("code", "") + code_verifier = data.get("code_verifier", "") + + code_data = _auth_codes.pop(code, None) + if not code_data: + return JSONResponse({"error": "invalid_grant", "error_description": "Invalid or expired code"}, status_code=400) + if code_data["expires_at"] < time.time(): + return JSONResponse({"error": "invalid_grant", "error_description": "Code expired"}, status_code=400) + if code_data["client_id"] != client_id: + return JSONResponse({"error": "invalid_grant", "error_description": "Client ID mismatch"}, status_code=400) + + if code_data["code_challenge"]: + if not _verify_pkce(code_verifier, code_data["code_challenge"], code_data["code_challenge_method"]): + return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400) + + user = client_id + if user not in _load_tokens(): + return JSONResponse({"error": "invalid_client"}, status_code=401) + + elif grant_type == "client_credentials": + user = _resolve_client(client_id, client_secret) + if not user: + return JSONResponse({"error": "invalid_client"}, status_code=401) + else: + return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) access_token = secrets.token_urlsafe(48) expires_in = 86400 @@ -118,6 +182,7 @@ async def oauth_token(request: Request): OAUTH_ROUTES = [ Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), + Route("/authorize", oauth_authorize, methods=["GET"]), Route("/token", oauth_token, methods=["POST"]), ] @@ -125,7 +190,7 @@ OAUTH_ROUTES = [ class BearerAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): path = request.url.path - if path.endswith("/token") or "/.well-known/" in path: + if path.endswith("/token") or path.endswith("/authorize") or "/.well-known/" in path: return await call_next(request) auth = request.headers.get("authorization", "")