2618ecfc86
- 'test' OAuth client maps to stefan's data via USER_ALIASES - 38 tests covering OAuth (metadata, client_credentials, PKCE, invalid secret, no auth), Mail (accounts, folders, search), Calendar (calendars, tasks, events, search), Contacts (search, empty), Files (list, info), Notes (notebooks) - Daily systemd timer (05:00) with NTFY notification on failure - Shared token store (.active_tokens.json) for cross-process auth Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
240 lines
8.0 KiB
Python
240 lines
8.0 KiB
Python
"""Shared OAuth for all MCP servers.
|
|
|
|
Supports OAuth 2.0 Authorization Code flow with PKCE (for claude.ai):
|
|
1. GET /authorize → auto-approves if client_id+client_secret valid, redirects with code
|
|
2. POST /token grant_type=authorization_code → exchanges code for access_token
|
|
3. Bearer access_token on MCP requests
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import secrets
|
|
import hashlib
|
|
import base64
|
|
import contextvars
|
|
import time
|
|
from urllib.parse import urlencode, parse_qs, urlparse
|
|
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, RedirectResponse, HTMLResponse
|
|
from starlette.routing import Route
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
BASE_DIR = os.path.dirname(__file__)
|
|
TOKENS_FILE = os.path.join(BASE_DIR, "tokens.json")
|
|
CONFIG_FILE = os.path.join(BASE_DIR, "config.json")
|
|
VALID_USERS = ["stefan", "kati", "test"]
|
|
|
|
_config_cache = None
|
|
|
|
def load_config():
|
|
global _config_cache
|
|
if _config_cache is None:
|
|
with open(CONFIG_FILE) as f:
|
|
_config_cache = json.load(f)
|
|
return _config_cache
|
|
|
|
_tokens_cache = None
|
|
_current_user: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_user", default=None)
|
|
_TOKEN_STORE = os.path.join(BASE_DIR, ".active_tokens.json")
|
|
_auth_codes: dict[str, dict] = {}
|
|
|
|
|
|
def _load_access_tokens():
|
|
try:
|
|
with open(_TOKEN_STORE) as f:
|
|
return json.load(f)
|
|
except (FileNotFoundError, json.JSONDecodeError):
|
|
return {}
|
|
|
|
|
|
def _save_access_tokens(tokens):
|
|
with open(_TOKEN_STORE, "w") as f:
|
|
json.dump(tokens, f)
|
|
|
|
|
|
def _load_tokens():
|
|
global _tokens_cache
|
|
if _tokens_cache is None:
|
|
with open(TOKENS_FILE) as f:
|
|
_tokens_cache = json.load(f)
|
|
return _tokens_cache
|
|
|
|
|
|
USER_ALIASES = {"test": "stefan"}
|
|
|
|
def get_current_user() -> str | None:
|
|
user = _current_user.get()
|
|
return USER_ALIASES.get(user, user)
|
|
|
|
|
|
def get_user_key(username: str) -> str:
|
|
return _load_tokens().get(username, {}).get("token", "")
|
|
|
|
|
|
def _resolve_client(client_id, client_secret):
|
|
tokens = _load_tokens()
|
|
for username, data in tokens.items():
|
|
if username == client_id and data["token"] == client_secret:
|
|
return username
|
|
return None
|
|
|
|
|
|
def _resolve_access_token(token):
|
|
tokens = _load_access_tokens()
|
|
info = tokens.get(token)
|
|
if not info:
|
|
return None
|
|
if info.get("expires_at", 0) < time.time():
|
|
del tokens[token]
|
|
_save_access_tokens(tokens)
|
|
return None
|
|
return info["user"]
|
|
|
|
|
|
def _verify_pkce(code_verifier, code_challenge, method):
|
|
if method == "S256":
|
|
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
return computed == code_challenge
|
|
return code_verifier == code_challenge
|
|
|
|
|
|
# --- OAuth Endpoints ---
|
|
|
|
async def oauth_metadata(request: Request):
|
|
proto = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
host = request.headers.get("host", "")
|
|
base = f"{proto}://{host}"
|
|
return JSONResponse({
|
|
"issuer": base,
|
|
"authorization_endpoint": base + "/authorize",
|
|
"token_endpoint": base + "/token",
|
|
"response_types_supported": ["code"],
|
|
"grant_types_supported": ["authorization_code", "client_credentials"],
|
|
"code_challenge_methods_supported": ["S256", "plain"],
|
|
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
|
})
|
|
|
|
|
|
async def oauth_authorize(request: Request):
|
|
params = dict(request.query_params)
|
|
client_id = params.get("client_id", "")
|
|
redirect_uri = params.get("redirect_uri", "")
|
|
state = params.get("state", "")
|
|
code_challenge = params.get("code_challenge", "")
|
|
code_challenge_method = params.get("code_challenge_method", "plain")
|
|
|
|
tokens = _load_tokens()
|
|
if client_id not in tokens:
|
|
return HTMLResponse(f"<h1>Fehler</h1><p>Unbekannte Client ID: {client_id}</p>", status_code=400)
|
|
|
|
code = secrets.token_urlsafe(32)
|
|
_auth_codes[code] = {
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": code_challenge_method,
|
|
"expires_at": time.time() + 300,
|
|
}
|
|
|
|
sep = "&" if "?" in redirect_uri else "?"
|
|
location = f"{redirect_uri}{sep}code={code}"
|
|
if state:
|
|
location += f"&state={state}"
|
|
return RedirectResponse(location, status_code=302)
|
|
|
|
|
|
async def oauth_token(request: Request):
|
|
try:
|
|
form = await request.form()
|
|
data = dict(form)
|
|
except Exception:
|
|
try:
|
|
data = json.loads(await request.body())
|
|
except Exception:
|
|
return JSONResponse({"error": "invalid_request"}, status_code=400)
|
|
|
|
grant_type = data.get("grant_type", "")
|
|
client_id = data.get("client_id", "")
|
|
client_secret = data.get("client_secret", "")
|
|
|
|
if grant_type == "authorization_code":
|
|
code = data.get("code", "")
|
|
code_verifier = data.get("code_verifier", "")
|
|
|
|
code_data = _auth_codes.pop(code, None)
|
|
if not code_data:
|
|
return JSONResponse({"error": "invalid_grant", "error_description": "Invalid or expired code"}, status_code=400)
|
|
if code_data["expires_at"] < time.time():
|
|
return JSONResponse({"error": "invalid_grant", "error_description": "Code expired"}, status_code=400)
|
|
if code_data["client_id"] != client_id:
|
|
return JSONResponse({"error": "invalid_grant", "error_description": "Client ID mismatch"}, status_code=400)
|
|
|
|
if code_data["code_challenge"]:
|
|
if not _verify_pkce(code_verifier, code_data["code_challenge"], code_data["code_challenge_method"]):
|
|
return JSONResponse({"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400)
|
|
|
|
# Verify client_secret if provided, otherwise rely on PKCE alone
|
|
if client_secret:
|
|
user = _resolve_client(client_id, client_secret)
|
|
if not user:
|
|
return JSONResponse({"error": "invalid_client", "error_description": "Invalid client credentials"}, status_code=401)
|
|
else:
|
|
user = client_id
|
|
if user not in _load_tokens():
|
|
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
|
|
|
elif grant_type == "client_credentials":
|
|
user = _resolve_client(client_id, client_secret)
|
|
if not user:
|
|
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
|
else:
|
|
return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
|
|
|
|
access_token = secrets.token_urlsafe(48)
|
|
expires_in = 86400
|
|
tokens = _load_access_tokens()
|
|
# Cleanup expired
|
|
now = time.time()
|
|
tokens = {k: v for k, v in tokens.items() if v.get("expires_at", 0) > now}
|
|
tokens[access_token] = {"user": user, "expires_at": now + expires_in}
|
|
_save_access_tokens(tokens)
|
|
|
|
return JSONResponse({
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"expires_in": expires_in,
|
|
})
|
|
|
|
|
|
OAUTH_ROUTES = [
|
|
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
|
Route("/authorize", oauth_authorize, methods=["GET"]),
|
|
Route("/token", oauth_token, methods=["POST"]),
|
|
]
|
|
|
|
|
|
class BearerAuthMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request, call_next):
|
|
path = request.url.path
|
|
if path.endswith("/token") or path.endswith("/authorize") or "/.well-known/" in path:
|
|
return await call_next(request)
|
|
|
|
auth = request.headers.get("authorization", "")
|
|
if auth.startswith("Bearer "):
|
|
token = auth[7:]
|
|
user = _resolve_access_token(token)
|
|
if user:
|
|
tok = _current_user.set(user)
|
|
try:
|
|
return await call_next(request)
|
|
finally:
|
|
_current_user.reset(tok)
|
|
|
|
return JSONResponse(
|
|
{"error": "unauthorized"},
|
|
status_code=401,
|
|
headers={"WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-authorization-server"'},
|
|
)
|