Add OAuth Authorization Code flow with PKCE for claude.ai

claude.ai uses the full OAuth Authorization Code flow, not client_credentials.
Flow: GET /authorize → auto-approve → redirect with code → POST /token
with code + code_verifier (PKCE S256).

Also fixes OAuth metadata URLs to use correct external scheme/host/prefix
via X-Forwarded-Proto, Host, and X-Forwarded-Prefix headers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Stefan Lohmaier
2026-06-12 08:48:34 +02:00
parent 1f98695821
commit 7f0b03606a
+90 -25
View File
@@ -1,20 +1,22 @@
"""Shared OAuth + user resolution for all MCP servers. """Shared OAuth for all MCP servers.
OAuth client_credentials flow: Supports OAuth 2.0 Authorization Code flow with PKCE (for claude.ai):
1. claude.ai discovers /.well-known/oauth-authorization-server 1. GET /authorize → auto-approves if client_id+client_secret valid, redirects with code
2. claude.ai POSTs to /token with client_id + client_secret 2. POST /token grant_type=authorization_code → exchanges code for access_token
3. Server returns access_token 3. Bearer access_token on MCP requests
4. claude.ai sends Bearer access_token with MCP requests
5. Server resolves user from token
""" """
import json import json
import os import os
import secrets import secrets
import hashlib
import base64
import contextvars import contextvars
import time import time
from urllib.parse import urlencode, parse_qs, urlparse
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
from starlette.routing import Route from starlette.routing import Route
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@@ -35,6 +37,7 @@ def load_config():
_tokens_cache = None _tokens_cache = None
_current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None) _current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None)
_access_tokens: dict[str, dict] = {} _access_tokens: dict[str, dict] = {}
_auth_codes: dict[str, dict] = {}
def _load_tokens(): def _load_tokens():
@@ -71,39 +74,100 @@ def _resolve_access_token(token):
return info["user"] return info["user"]
def _verify_pkce(code_verifier, code_challenge, method):
if method == "S256":
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return computed == code_challenge
return code_verifier == code_challenge
# --- OAuth Endpoints ---
async def oauth_metadata(request: Request): async def oauth_metadata(request: Request):
base = str(request.base_url).rstrip("/") proto = request.headers.get("x-forwarded-proto", request.url.scheme)
host = request.headers.get("host", "")
prefix = request.headers.get("x-forwarded-prefix", "")
base = f"{proto}://{host}{prefix}"
return JSONResponse({ return JSONResponse({
"issuer": base, "issuer": base,
"authorization_endpoint": base + "/authorize",
"token_endpoint": base + "/token", "token_endpoint": base + "/token",
"response_types_supported": ["token"], "response_types_supported": ["code"],
"grant_types_supported": ["client_credentials"], "grant_types_supported": ["authorization_code", "client_credentials"],
"code_challenge_methods_supported": ["S256", "plain"],
"token_endpoint_auth_methods_supported": ["client_secret_post"], "token_endpoint_auth_methods_supported": ["client_secret_post"],
}) })
async def oauth_authorize(request: Request):
params = dict(request.query_params)
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
state = params.get("state", "")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "plain")
tokens = _load_tokens()
if client_id not in tokens:
return HTMLResponse(f"<h1>Fehler</h1><p>Unbekannte Client ID: {client_id}</p>", status_code=400)
code = secrets.token_urlsafe(32)
_auth_codes[code] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"expires_at": time.time() + 300,
}
sep = "&" if "?" in redirect_uri else "?"
location = f"{redirect_uri}{sep}code={code}"
if state:
location += f"&state={state}"
return RedirectResponse(location, status_code=302)
async def oauth_token(request: Request): async def oauth_token(request: Request):
try: try:
form = await request.form() form = await request.form()
grant_type = form.get("grant_type", "") data = dict(form)
client_id = form.get("client_id", "")
client_secret = form.get("client_secret", "")
except Exception: except Exception:
body = await request.body()
try: try:
data = json.loads(body) data = json.loads(await request.body())
grant_type = data.get("grant_type", "")
client_id = data.get("client_id", "")
client_secret = data.get("client_secret", "")
except Exception: except Exception:
return JSONResponse({"error": "invalid_request"}, status_code=400) return JSONResponse({"error": "invalid_request"}, status_code=400)
if grant_type != "client_credentials": grant_type = data.get("grant_type", "")
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400) client_id = data.get("client_id", "")
client_secret = data.get("client_secret", "")
user = _resolve_client(client_id, client_secret) if grant_type == "authorization_code":
if not user: code = data.get("code", "")
return JSONResponse({"error": "invalid_client"}, status_code=401) code_verifier = data.get("code_verifier", "")
code_data = _auth_codes.pop(code, None)
if not code_data:
return JSONResponse({"error": "invalid_grant", "error_description": "Invalid or expired code"}, status_code=400)
if code_data["expires_at"] < time.time():
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired"}, status_code=400)
if code_data["client_id"] != client_id:
return JSONResponse({"error": "invalid_grant", "error_description": "Client ID mismatch"}, status_code=400)
if code_data["code_challenge"]:
if not _verify_pkce(code_verifier, code_data["code_challenge"], code_data["code_challenge_method"]):
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400)
user = client_id
if user not in _load_tokens():
return JSONResponse({"error": "invalid_client"}, status_code=401)
elif grant_type == "client_credentials":
user = _resolve_client(client_id, client_secret)
if not user:
return JSONResponse({"error": "invalid_client"}, status_code=401)
else:
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
access_token = secrets.token_urlsafe(48) access_token = secrets.token_urlsafe(48)
expires_in = 86400 expires_in = 86400
@@ -118,6 +182,7 @@ async def oauth_token(request: Request):
OAUTH_ROUTES = [ OAUTH_ROUTES = [
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
Route("/authorize", oauth_authorize, methods=["GET"]),
Route("/token", oauth_token, methods=["POST"]), Route("/token", oauth_token, methods=["POST"]),
] ]
@@ -125,7 +190,7 @@ OAUTH_ROUTES = [
class BearerAuthMiddleware(BaseHTTPMiddleware): class BearerAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next): async def dispatch(self, request, call_next):
path = request.url.path path = request.url.path
if path.endswith("/token") or "/.well-known/" in path: if path.endswith("/token") or path.endswith("/authorize") or "/.well-known/" in path:
return await call_next(request) return await call_next(request)
auth = request.headers.get("authorization", "") auth = request.headers.get("authorization", "")