Add OAuth Authorization Code flow with PKCE for claude.ai
claude.ai uses the full OAuth Authorization Code flow, not client_credentials. Flow: GET /authorize → auto-approve → redirect with code → POST /token with code + code_verifier (PKCE S256). Also fixes OAuth metadata URLs to use correct external scheme/host/prefix via X-Forwarded-Proto, Host, and X-Forwarded-Prefix headers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,20 +1,22 @@
|
||||
"""Shared OAuth + user resolution for all MCP servers.
|
||||
"""Shared OAuth for all MCP servers.
|
||||
|
||||
OAuth client_credentials flow:
|
||||
1. claude.ai discovers /.well-known/oauth-authorization-server
|
||||
2. claude.ai POSTs to /token with client_id + client_secret
|
||||
3. Server returns access_token
|
||||
4. claude.ai sends Bearer access_token with MCP requests
|
||||
5. Server resolves user from token
|
||||
Supports OAuth 2.0 Authorization Code flow with PKCE (for claude.ai):
|
||||
1. GET /authorize → auto-approves if client_id+client_secret valid, redirects with code
|
||||
2. POST /token grant_type=authorization_code → exchanges code for access_token
|
||||
3. Bearer access_token on MCP requests
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
import base64
|
||||
import contextvars
|
||||
import time
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
@@ -35,6 +37,7 @@ def load_config():
|
||||
_tokens_cache = None
|
||||
_current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None)
|
||||
_access_tokens: dict[str, dict] = {}
|
||||
_auth_codes: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _load_tokens():
|
||||
@@ -71,39 +74,100 @@ def _resolve_access_token(token):
|
||||
return info["user"]
|
||||
|
||||
|
||||
def _verify_pkce(code_verifier, code_challenge, method):
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
return computed == code_challenge
|
||||
return code_verifier == code_challenge
|
||||
|
||||
|
||||
# --- OAuth Endpoints ---
|
||||
|
||||
async def oauth_metadata(request: Request):
|
||||
base = str(request.base_url).rstrip("/")
|
||||
proto = request.headers.get("x-forwarded-proto", request.url.scheme)
|
||||
host = request.headers.get("host", "")
|
||||
prefix = request.headers.get("x-forwarded-prefix", "")
|
||||
base = f"{proto}://{host}{prefix}"
|
||||
return JSONResponse({
|
||||
"issuer": base,
|
||||
"authorization_endpoint": base + "/authorize",
|
||||
"token_endpoint": base + "/token",
|
||||
"response_types_supported": ["token"],
|
||||
"grant_types_supported": ["client_credentials"],
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "client_credentials"],
|
||||
"code_challenge_methods_supported": ["S256", "plain"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
||||
})
|
||||
|
||||
|
||||
async def oauth_authorize(request: Request):
|
||||
params = dict(request.query_params)
|
||||
client_id = params.get("client_id", "")
|
||||
redirect_uri = params.get("redirect_uri", "")
|
||||
state = params.get("state", "")
|
||||
code_challenge = params.get("code_challenge", "")
|
||||
code_challenge_method = params.get("code_challenge_method", "plain")
|
||||
|
||||
tokens = _load_tokens()
|
||||
if client_id not in tokens:
|
||||
return HTMLResponse(f"<h1>Fehler</h1><p>Unbekannte Client ID: {client_id}</p>", status_code=400)
|
||||
|
||||
code = secrets.token_urlsafe(32)
|
||||
_auth_codes[code] = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"expires_at": time.time() + 300,
|
||||
}
|
||||
|
||||
sep = "&" if "?" in redirect_uri else "?"
|
||||
location = f"{redirect_uri}{sep}code={code}"
|
||||
if state:
|
||||
location += f"&state={state}"
|
||||
return RedirectResponse(location, status_code=302)
|
||||
|
||||
|
||||
async def oauth_token(request: Request):
|
||||
try:
|
||||
form = await request.form()
|
||||
grant_type = form.get("grant_type", "")
|
||||
client_id = form.get("client_id", "")
|
||||
client_secret = form.get("client_secret", "")
|
||||
data = dict(form)
|
||||
except Exception:
|
||||
body = await request.body()
|
||||
try:
|
||||
data = json.loads(body)
|
||||
grant_type = data.get("grant_type", "")
|
||||
client_id = data.get("client_id", "")
|
||||
client_secret = data.get("client_secret", "")
|
||||
data = json.loads(await request.body())
|
||||
except Exception:
|
||||
return JSONResponse({"error": "invalid_request"}, status_code=400)
|
||||
|
||||
if grant_type != "client_credentials":
|
||||
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
|
||||
grant_type = data.get("grant_type", "")
|
||||
client_id = data.get("client_id", "")
|
||||
client_secret = data.get("client_secret", "")
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
code = data.get("code", "")
|
||||
code_verifier = data.get("code_verifier", "")
|
||||
|
||||
code_data = _auth_codes.pop(code, None)
|
||||
if not code_data:
|
||||
return JSONResponse({"error": "invalid_grant", "error_description": "Invalid or expired code"}, status_code=400)
|
||||
if code_data["expires_at"] < time.time():
|
||||
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired"}, status_code=400)
|
||||
if code_data["client_id"] != client_id:
|
||||
return JSONResponse({"error": "invalid_grant", "error_description": "Client ID mismatch"}, status_code=400)
|
||||
|
||||
if code_data["code_challenge"]:
|
||||
if not _verify_pkce(code_verifier, code_data["code_challenge"], code_data["code_challenge_method"]):
|
||||
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400)
|
||||
|
||||
user = client_id
|
||||
if user not in _load_tokens():
|
||||
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
||||
|
||||
elif grant_type == "client_credentials":
|
||||
user = _resolve_client(client_id, client_secret)
|
||||
if not user:
|
||||
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
||||
else:
|
||||
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
|
||||
|
||||
access_token = secrets.token_urlsafe(48)
|
||||
expires_in = 86400
|
||||
@@ -118,6 +182,7 @@ async def oauth_token(request: Request):
|
||||
|
||||
OAUTH_ROUTES = [
|
||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
||||
Route("/authorize", oauth_authorize, methods=["GET"]),
|
||||
Route("/token", oauth_token, methods=["POST"]),
|
||||
]
|
||||
|
||||
@@ -125,7 +190,7 @@ OAUTH_ROUTES = [
|
||||
class BearerAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
path = request.url.path
|
||||
if path.endswith("/token") or "/.well-known/" in path:
|
||||
if path.endswith("/token") or path.endswith("/authorize") or "/.well-known/" in path:
|
||||
return await call_next(request)
|
||||
|
||||
auth = request.headers.get("authorization", "")
|
||||
|
||||
Reference in New Issue
Block a user