"""Shared OAuth for all MCP servers. Supports OAuth 2.0 Authorization Code flow with PKCE (for claude.ai): 1. GET /authorize → auto-approves if client_id+client_secret valid, redirects with code 2. POST /token grant_type=authorization_code → exchanges code for access_token 3. Bearer access_token on MCP requests """ import json import os import secrets import hashlib import base64 import contextvars import time from urllib.parse import urlencode, parse_qs, urlparse from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse from starlette.routing import Route from starlette.middleware.base import BaseHTTPMiddleware BASE_DIR = os.path.dirname(__file__) TOKENS_FILE = os.path.join(BASE_DIR, "tokens.json") CONFIG_FILE = os.path.join(BASE_DIR, "config.json") VALID_USERS = ["stefan", "kati", "test"] _config_cache = None def load_config(): global _config_cache if _config_cache is None: with open(CONFIG_FILE) as f: _config_cache = json.load(f) return _config_cache _tokens_cache = None _current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None) _TOKEN_STORE = os.path.join(BASE_DIR, ".active_tokens.json") _auth_codes: dict[str, dict] = {} def _load_access_tokens(): try: with open(_TOKEN_STORE) as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): return {} def _save_access_tokens(tokens): with open(_TOKEN_STORE, "w") as f: json.dump(tokens, f) def _load_tokens(): global _tokens_cache if _tokens_cache is None: with open(TOKENS_FILE) as f: _tokens_cache = json.load(f) return _tokens_cache USER_ALIASES = {"test": "stefan"} def get_current_user() -> str | None: user = _current_user.get() return USER_ALIASES.get(user, user) def get_user_key(username: str) -> str: return _load_tokens().get(username, {}).get("token", "") def _resolve_client(client_id, client_secret): tokens = _load_tokens() for username, data in tokens.items(): if username == client_id and data["token"] == client_secret: return username return None def _resolve_access_token(token): tokens = _load_access_tokens() info = tokens.get(token) if not info: return None if info.get("expires_at", 0) < time.time(): del tokens[token] _save_access_tokens(tokens) return None return info["user"] def _verify_pkce(code_verifier, code_challenge, method): if method == "S256": digest = hashlib.sha256(code_verifier.encode("ascii")).digest() computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") return computed == code_challenge return code_verifier == code_challenge # --- OAuth Endpoints --- async def oauth_metadata(request: Request): proto = request.headers.get("x-forwarded-proto", request.url.scheme) host = request.headers.get("host", "") base = f"{proto}://{host}" return JSONResponse({ "issuer": base, "authorization_endpoint": base + "/authorize", "token_endpoint": base + "/token", "response_types_supported": ["code"], "grant_types_supported": ["authorization_code", "client_credentials"], "code_challenge_methods_supported": ["S256", "plain"], "token_endpoint_auth_methods_supported": ["client_secret_post"], }) async def oauth_authorize(request: Request): params = dict(request.query_params) client_id = params.get("client_id", "") redirect_uri = params.get("redirect_uri", "") state = params.get("state", "") code_challenge = params.get("code_challenge", "") code_challenge_method = params.get("code_challenge_method", "plain") tokens = _load_tokens() if client_id not in tokens: return HTMLResponse(f"

Fehler

Unbekannte Client ID: {client_id}

", status_code=400) code = secrets.token_urlsafe(32) _auth_codes[code] = { "client_id": client_id, "redirect_uri": redirect_uri, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method, "expires_at": time.time() + 300, } sep = "&" if "?" in redirect_uri else "?" location = f"{redirect_uri}{sep}code={code}" if state: location += f"&state={state}" return RedirectResponse(location, status_code=302) async def oauth_token(request: Request): try: form = await request.form() data = dict(form) except Exception: try: data = json.loads(await request.body()) except Exception: return JSONResponse({"error": "invalid_request"}, status_code=400) grant_type = data.get("grant_type", "") client_id = data.get("client_id", "") client_secret = data.get("client_secret", "") if grant_type == "authorization_code": code = data.get("code", "") code_verifier = data.get("code_verifier", "") code_data = _auth_codes.pop(code, None) if not code_data: return JSONResponse({"error": "invalid_grant", "error_description": "Invalid or expired code"}, status_code=400) if code_data["expires_at"] < time.time(): return JSONResponse({"error": "invalid_grant", "error_description": "Code expired"}, status_code=400) if code_data["client_id"] != client_id: return JSONResponse({"error": "invalid_grant", "error_description": "Client ID mismatch"}, status_code=400) if code_data["code_challenge"]: if not _verify_pkce(code_verifier, code_data["code_challenge"], code_data["code_challenge_method"]): return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400) # Verify client_secret if provided, otherwise rely on PKCE alone if client_secret: user = _resolve_client(client_id, client_secret) if not user: return JSONResponse({"error": "invalid_client", "error_description": "Invalid client credentials"}, status_code=401) else: user = client_id if user not in _load_tokens(): return JSONResponse({"error": "invalid_client"}, status_code=401) elif grant_type == "client_credentials": user = _resolve_client(client_id, client_secret) if not user: return JSONResponse({"error": "invalid_client"}, status_code=401) else: return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) access_token = secrets.token_urlsafe(48) expires_in = 2592000 # 30 days tokens = _load_access_tokens() # Cleanup expired now = time.time() tokens = {k: v for k, v in tokens.items() if v.get("expires_at", 0) > now} tokens[access_token] = {"user": user, "expires_at": now + expires_in} _save_access_tokens(tokens) return JSONResponse({ "access_token": access_token, "token_type": "bearer", "expires_in": expires_in, }) OAUTH_ROUTES = [ Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), Route("/authorize", oauth_authorize, methods=["GET"]), Route("/token", oauth_token, methods=["POST"]), ] class BearerAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): path = request.url.path if path.endswith("/token") or path.endswith("/authorize") or "/.well-known/" in path: return await call_next(request) auth = request.headers.get("authorization", "") if auth.startswith("Bearer "): token = auth[7:] user = _resolve_access_token(token) if user: tok = _current_user.set(user) try: return await call_next(request) finally: _current_user.reset(tok) return JSONResponse( {"error": "unauthorized"}, status_code=401, headers={"WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-authorization-server"'}, )