Skip to content

Commit 37cdb92

Browse files
committed
Support falling back to OIDC metadata for auth
1 parent 6f43d1f commit 37cdb92

File tree

2 files changed

+252
-60
lines changed

2 files changed

+252
-60
lines changed

src/mcp/client/auth.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
175175
return protocol_version >= "2025-06-18"
176176

177177

178+
OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]]
179+
180+
178181
class OAuthClientProvider(httpx.Auth):
179182
"""
180183
OAuth2 authentication for httpx.
@@ -221,32 +224,60 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
221224
except ValidationError:
222225
pass
223226

224-
def _build_well_known_path(self, pathname: str) -> str:
227+
def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str:
225228
"""Construct well-known path for OAuth metadata discovery."""
226-
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
229+
well_known_path = f"/.well-known/{well_known_endpoint}{pathname}"
227230
if pathname.endswith("/"):
228231
# Strip trailing slash from pathname to avoid double slashes
229232
well_known_path = well_known_path[:-1]
230233
return well_known_path
231234

232-
def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
233-
"""Determine if fallback to root discovery should be attempted."""
234-
return response_status == 404 and pathname != "/"
235+
def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str:
236+
"""Construct fallback well-known URL for OAuth metadata discovery in legacy servers."""
237+
base_url = getattr(self.context, "discovery_base_url", "")
238+
if not base_url:
239+
raise OAuthFlowError("No base URL available for fallback discovery")
240+
241+
# Fallback to root discovery for legacy servers
242+
return urljoin(base_url, f"/.well-known/{well_known_endpoint}")
243+
244+
def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str:
245+
"""Construct fallback well-known path for OIDC metadata discovery in legacy servers."""
246+
# Strip trailing slash from pathname to avoid double slashes
247+
clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname
248+
# OIDC 1.0 appends the well-known path to the full AS URL
249+
return f"{clean_pathname}/.well-known/{well_known_endpoint}"
250+
251+
def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str:
252+
"""Construct fallback well-known URL for OIDC metadata discovery in legacy servers."""
253+
if self.context.auth_server_url:
254+
auth_server_url = self.context.auth_server_url
255+
else:
256+
auth_server_url = self.context.server_url
257+
258+
parsed = urlparse(auth_server_url)
259+
well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint)
260+
base_url = f"{parsed.scheme}://{parsed.netloc}"
261+
return urljoin(base_url, well_known_path)
262+
263+
def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool:
264+
"""Determine if further fallback should be attempted."""
265+
return response_status == 404 and len(discovery_stack) > 0
235266

236267
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
237268
"""Build metadata discovery request for a specific URL."""
238269
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
239270

240-
async def _discover_oauth_metadata(self) -> httpx.Request:
241-
"""Build OAuth metadata discovery request with fallback support."""
271+
async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request:
272+
"""Build .well-known metadata discovery request with fallback support."""
242273
if self.context.auth_server_url:
243274
auth_server_url = self.context.auth_server_url
244275
else:
245276
auth_server_url = self.context.server_url
246277

247278
# Per RFC 8414, try path-aware discovery first
248279
parsed = urlparse(auth_server_url)
249-
well_known_path = self._build_well_known_path(parsed.path)
280+
well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint)
250281
base_url = f"{parsed.scheme}://{parsed.netloc}"
251282
url = urljoin(base_url, well_known_path)
252283

@@ -256,17 +287,37 @@ async def _discover_oauth_metadata(self) -> httpx.Request:
256287

257288
return await self._try_metadata_discovery(url)
258289

290+
async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request:
291+
"""Build fallback OAuth metadata discovery request for legacy servers."""
292+
url = self._build_well_known_fallback_url(well_known_endpoint)
293+
return await self._try_metadata_discovery(url)
294+
295+
async def _discover_oauth_metadata(self) -> httpx.Request:
296+
"""Build OAuth metadata discovery request with fallback support."""
297+
return await self._discover_well_known_metadata("oauth-authorization-server")
298+
259299
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
260300
"""Build fallback OAuth metadata discovery request for legacy servers."""
261-
base_url = getattr(self.context, "discovery_base_url", "")
262-
if not base_url:
263-
raise OAuthFlowError("No base URL available for fallback discovery")
301+
return await self._discover_well_known_metadata_fallback("oauth-authorization-server")
264302

265-
# Fallback to root discovery for legacy servers
266-
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
303+
async def _discover_oidc_metadata(self) -> httpx.Request:
304+
"""
305+
Build fallback OIDC metadata discovery request.
306+
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
307+
"""
308+
return await self._discover_well_known_metadata("openid-configuration")
309+
310+
async def _discover_oidc_metadata_fallback(self) -> httpx.Request:
311+
"""
312+
Build fallback OIDC metadata discovery request for legacy servers.
313+
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
314+
"""
315+
url = self._build_oidc_fallback_url("openid-configuration")
267316
return await self._try_metadata_discovery(url)
268317

269-
async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
318+
async def _handle_oauth_metadata_response(
319+
self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack
320+
) -> bool:
270321
"""Handle OAuth metadata response. Returns True if handled successfully."""
271322
if response.status_code == 200:
272323
try:
@@ -280,13 +331,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal
280331
except ValidationError:
281332
pass
282333

283-
# Check if we should attempt fallback (404 on path-aware discovery)
284-
if not is_fallback and self._should_attempt_fallback(
285-
response.status_code, getattr(self.context, "discovery_pathname", "/")
286-
):
287-
return False # Signal that fallback should be attempted
288-
289-
return True # Signal no fallback needed (either success or non-404 error)
334+
# Check if we should attempt fallback
335+
# True: No fallback needed (either success or non-404 error)
336+
# False: Signal that fallback should be attempted
337+
return not self._should_attempt_fallback(response.status_code, discovery_stack)
290338

291339
async def _register_client(self) -> httpx.Request | None:
292340
"""Build registration request or skip if already registered."""
@@ -480,6 +528,26 @@ def _add_auth_header(self, request: httpx.Request) -> None:
480528
if self.context.current_tokens and self.context.current_tokens.access_token:
481529
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
482530

531+
def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack:
532+
"""Create a stack of attempts to discover OAuth metadata."""
533+
discovery_attempts: OAuthDiscoveryStack = [
534+
# Start with path-aware OAuth discovery
535+
self._discover_oauth_metadata,
536+
# If path-aware discovery fails with 404, try fallback to root
537+
self._discover_oauth_metadata_fallback,
538+
# If root discovery fails with 404, fall back to OIDC 1.0 following
539+
# RFC 8414 path-aware semantics (see RFC 8414 section 5)
540+
self._discover_oidc_metadata,
541+
# If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0
542+
# following OIDC 1.0 semantics (see RFC 8414 section 5)
543+
self._discover_oidc_metadata_fallback,
544+
]
545+
546+
# Reverse the list so we can call pop() without remembering we declared
547+
# this stack backwards for readability
548+
discovery_attempts.reverse()
549+
return discovery_attempts
550+
483551
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
484552
"""HTTPX auth flow integration."""
485553
async with self.context.lock:
@@ -499,15 +567,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
499567
await self._handle_protected_resource_response(discovery_response)
500568

501569
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
502-
oauth_request = await self._discover_oauth_metadata()
503-
oauth_response = yield oauth_request
504-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
505-
506-
# If path-aware discovery failed with 404, try fallback to root
507-
if not handled:
508-
fallback_request = await self._discover_oauth_metadata_fallback()
509-
fallback_response = yield fallback_request
510-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
570+
oauth_discovery_stack = self._create_oauth_discovery_stack()
571+
while len(oauth_discovery_stack) > 0:
572+
oauth_discovery = oauth_discovery_stack.pop()
573+
oauth_request = await oauth_discovery()
574+
oauth_response = yield oauth_request
575+
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)
511576

512577
# Step 3: Register client if needed
513578
registration_request = await self._register_client()
@@ -551,15 +616,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
551616
await self._handle_protected_resource_response(discovery_response)
552617

553618
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
554-
oauth_request = await self._discover_oauth_metadata()
555-
oauth_response = yield oauth_request
556-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
557-
558-
# If path-aware discovery failed with 404, try fallback to root
559-
if not handled:
560-
fallback_request = await self._discover_oauth_metadata_fallback()
561-
fallback_response = yield fallback_request
562-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
619+
oauth_discovery_stack = self._create_oauth_discovery_stack()
620+
while len(oauth_discovery_stack) > 0:
621+
oauth_discovery = oauth_discovery_stack.pop()
622+
oauth_request = await oauth_discovery()
623+
oauth_response = yield oauth_request
624+
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)
563625

564626
# Step 3: Register client if needed
565627
registration_request = await self._register_client()

0 commit comments

Comments
 (0)