"""Shared OAuth + user resolution for all MCP servers. OAuth client_credentials flow: 1. claude.ai discovers /.well-known/oauth-authorization-server 2. claude.ai POSTs to /token with client_id + client_secret 3. Server returns access_token 4. claude.ai sends Bearer access_token with MCP requests 5. Server resolves user from token """ import json import os import secrets import contextvars import time from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from starlette.middleware.base import BaseHTTPMiddleware BASE_DIR = os.path.dirname(__file__) TOKENS_FILE = os.path.join(BASE_DIR, "tokens.json") CONFIG_FILE = os.path.join(BASE_DIR, "config.json") VALID_USERS = ["stefan", "kati"] _config_cache = None def load_config(): global _config_cache if _config_cache is None: with open(CONFIG_FILE) as f: _config_cache = json.load(f) return _config_cache _tokens_cache = None _current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None) _access_tokens: dict[str, dict] = {} def _load_tokens(): global _tokens_cache if _tokens_cache is None: with open(TOKENS_FILE) as f: _tokens_cache = json.load(f) return _tokens_cache def get_current_user() -> str | None: return _current_user.get() def get_user_key(username: str) -> str: return _load_tokens().get(username, {}).get("token", "") def _resolve_client(client_id, client_secret): tokens = _load_tokens() for username, data in tokens.items(): if username == client_id and data["token"] == client_secret: return username return None def _resolve_access_token(token): info = _access_tokens.get(token) if not info: return None if info.get("expires_at", 0) < time.time(): del _access_tokens[token] return None return info["user"] async def oauth_metadata(request: Request): base = str(request.base_url).rstrip("/") return JSONResponse({ "issuer": base, "token_endpoint": base + "/token", "response_types_supported": ["token"], "grant_types_supported": ["client_credentials"], "token_endpoint_auth_methods_supported": ["client_secret_post"], }) async def oauth_token(request: Request): try: form = await request.form() grant_type = form.get("grant_type", "") client_id = form.get("client_id", "") client_secret = form.get("client_secret", "") except Exception: body = await request.body() try: data = json.loads(body) grant_type = data.get("grant_type", "") client_id = data.get("client_id", "") client_secret = data.get("client_secret", "") except Exception: return JSONResponse({"error": "invalid_request"}, status_code=400) if grant_type != "client_credentials": return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) user = _resolve_client(client_id, client_secret) if not user: return JSONResponse({"error": "invalid_client"}, status_code=401) access_token = secrets.token_urlsafe(48) expires_in = 86400 _access_tokens[access_token] = {"user": user, "expires_at": time.time() + expires_in} return JSONResponse({ "access_token": access_token, "token_type": "bearer", "expires_in": expires_in, }) OAUTH_ROUTES = [ Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), Route("/token", oauth_token, methods=["POST"]), ] class BearerAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): path = request.url.path if path.endswith("/token") or "/.well-known/" in path: return await call_next(request) auth = request.headers.get("authorization", "") if auth.startswith("Bearer "): token = auth[7:] user = _resolve_access_token(token) if user: tok = _current_user.set(user) try: return await call_next(request) finally: _current_user.reset(tok) return JSONResponse( {"error": "unauthorized"}, status_code=401, headers={"WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-authorization-server"'}, )